graph_runtime.h 1.45 KB
Newer Older
1 2 3 4 5 6 7 8 9
/*!
 * Copyright (c) 2017 by Contributors
 * \file graph_runtime.h
 * \brief Interface code with TVM graph runtime.
*/
#ifndef NNVM_COMPILER_GRAPH_RUNTIME_H_
#define NNVM_COMPILER_GRAPH_RUNTIME_H_

#include <nnvm/graph.h>
10 11
#include <tvm/base.h>
#include <tvm/expr.h>
12
#include <tvm/node/memory.h>
13 14
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
15
#include <vector>
16
#include <string>
17 18 19 20

namespace nnvm {
namespace compiler {

21 22 23
/*! \brief Magic number for NDArray list file  */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;

24 25 26 27 28 29 30 31 32 33 34 35 36 37
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
  std::string func_name;
  uint32_t num_inputs;
  uint32_t num_outputs;
  uint32_t flatten_data;

  DMLC_DECLARE_PARAMETER(TVMOpParam) {
    DMLC_DECLARE_FIELD(func_name);
    DMLC_DECLARE_FIELD(num_inputs).set_default(1);
    DMLC_DECLARE_FIELD(num_outputs).set_default(1);
    DMLC_DECLARE_FIELD(flatten_data).set_default(0);
  }
};

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

/*!
 * \brief wrapper node container for exchange.
 */
struct NDArrayWrapperNode : public ::tvm::Node {
  std::string name;
  tvm::runtime::NDArray array;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("name", &name);
    v->Visit("array", &array);
  }

  static constexpr const char* _type_key = "NDArrayWrapper";
  TVM_DECLARE_NODE_TYPE_INFO(NDArrayWrapperNode, Node);
};

TVM_DEFINE_NODE_REF(NDArrayWrapper, NDArrayWrapperNode);

57 58
}  // namespace compiler
}  // namespace nnvm
59

60
#endif   // NNVM_COMPILER_GRAPH_RUNTIME_H_