compile_engine.h 3.15 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*!
 *  Copyright (c) 2017 by Contributors
 * \file compile_engine.h
 * \brief Internal engine to compile a subgraph fragment and cache compilation.
 */
#ifndef NNVM_COMPILER_COMPILE_ENGINE_H_
#define NNVM_COMPILER_COMPILE_ENGINE_H_

#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/tuple.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/operation.h>
#include <tvm/lowered_func.h>
#include <string>
20
#include <utility>
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
#include "./graph_hash.h"

namespace nnvm {
namespace compiler {

/*! \brief A TVM Node to represent compiled graph function */
struct GraphFuncNode : public tvm::Node {
  /* \brief compiled target */
  std::string target;
  /*! \brief Function name */
  std::string func_name;
  /* \brief The inputs to the function */
  tvm::Array<Tensor> inputs;
  /* \brief The outputs to the function */
  tvm::Array<Tensor> outputs;
  /*! \brief The lowered functions */
  tvm::Array<tvm::LoweredFunc> funcs;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("target", &target);
    v->Visit("func_name", &func_name);
    v->Visit("inputs", &inputs);
    v->Visit("outputs", &outputs);
    v->Visit("funcs", &funcs);
  }

  static constexpr const char* _type_key = "GraphFunc";
  TVM_DECLARE_NODE_TYPE_INFO(GraphFuncNode, tvm::Node);
};

TVM_DEFINE_NODE_REF(GraphFunc, GraphFuncNode);

/*! \brief Cache Entry in the graph */
struct GraphCacheEntryNode : public tvm::Node {
  /*! \brief The graph function */
  GraphFunc graph_func;
  /*! \brief Usage statistics */
  int use_count{0};
59 60
  /*! \brief Index of the master node for calling schedule*/
  int master_idx;
61 62 63 64

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("graph_func", &graph_func);
    v->Visit("use_count", &use_count);
65
    v->Visit("master_idx", &master_idx);
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
  }
  static constexpr const char* _type_key = "GraphCacheEntry";
  TVM_DECLARE_NODE_TYPE_INFO(GraphCacheEntryNode, tvm::Node);
};

class GraphCacheEntry : public ::tvm::NodeRef {
 public:
  GraphCacheEntry() {}
  explicit GraphCacheEntry(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {}
  GraphCacheEntryNode* operator->() {
    return static_cast<GraphCacheEntryNode*>(node_.get());
  }
  using ContainerType = GraphCacheEntryNode;
};

/*!
 * \brief Call compile engine to lower a graph with given inputs.
 *
 * \param graph The graph to be compiled
 * \param inputs The input specification.
86 87
 * \param target The build target
 * \param master_idx The index of master node for calling schedule
88 89 90 91 92 93
 *
 * \return func A lowered tvm function.
 */
GraphFunc GraphLower(Graph graph,
                     const Array<tvm::Tensor>& inputs,
                     const std::string& target,
94
                     int master_idx);
95

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
/*!
 * \brief Get type flag from TVM Type
 *
 * \param type the tvm type
 * \return corresponding DLDataType
 */
int GetTypeFlag(tvm::Type type);

/*!
 * \brief Get TVM Type from type flag
 *
 * \param type_flag the type flag
 * \return corresponding TVM type
 */
tvm::Type GetTVMType(int type_flag);

112 113 114 115
}  // namespace compiler
}  // namespace nnvm

#endif  // NNVM_COMPILER_COMPILE_ENGINE_H_