base.h 2.73 KB
Newer Older
tqchen committed
1 2 3 4 5 6 7 8 9 10
/*!
 *  Copyright (c) 2016 by Contributors
 * \file base.h
 * \brief Defines the base data structure
 */
#ifndef TVM_BASE_H_
#define TVM_BASE_H_

#include <dmlc/logging.h>
#include <dmlc/registry.h>
tqchen committed
11
#include <tvm/node.h>
tqchen committed
12 13 14
#include <string>
#include <memory>
#include <functional>
15
#include "./runtime/registry.h"
tqchen committed
16 17 18

namespace tvm {

19 20 21
using ::tvm::Node;
using ::tvm::NodeRef;
using ::tvm::AttrVisitor;
tqchen committed
22

23 24
/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName)                  \
25
  class TypeName : public ::tvm::NodeRef {                       \
26 27
   public:                                                       \
    TypeName() {}                                                 \
28
    explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {}   \
29 30 31 32 33 34 35
    const NodeName* operator->() const {                          \
      return static_cast<const NodeName*>(node_.get());           \
    }                                                             \
    using ContainerType = NodeName;                               \
  };                                                              \


36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
/*!
 * \brief save the node as well as all the node it depends on as json.
 *  This can be used to serialize any TVM object
 *
 * \return the string representation of the node.
 */
std::string SaveJSON(const NodeRef& node);

/*!
 * \brief Internal implementation of LoadJSON
 * Load tvm Node object from json and return a shared_ptr of Node.
 * \param json_str The json string to load from.
 *
 * \return The shared_ptr of the Node.
 */
std::shared_ptr<Node> LoadJSON_(std::string json_str);

/*!
 * \brief Load the node from json string.
 *  This can be used to deserialize any TVM object.
 *
 * \param json_str The json string to load from.
 *
 * \tparam NodeType the nodetype
 *
 * \code
 *  Expr e = LoadJSON<Expr>(json_str);
 * \endcode
 */
template<typename NodeType,
         typename = typename std::enable_if<std::is_base_of<NodeRef, NodeType>::value>::type >
inline NodeType LoadJSON(const std::string& json_str) {
  return NodeType(LoadJSON_(json_str));
}

tqchen committed
71 72 73
/*! \brief typedef the factory function of data iterator */
using NodeFactory = std::function<std::shared_ptr<Node> ()>;
/*!
tqchen committed
74
 * \brief Registry entry for NodeFactory
tqchen committed
75 76 77 78 79 80 81
 */
struct NodeFactoryReg
    : public dmlc::FunctionRegEntryBase<NodeFactoryReg,
                                        NodeFactory> {
};

#define TVM_REGISTER_NODE_TYPE(TypeName)                                \
82 83 84
  static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
      ::dmlc::Registry<::tvm::NodeFactoryReg>::Get()->__REGISTER__(TypeName::_type_key) \
      .set_body([]() { return std::make_shared<TypeName>(); })
tqchen committed
85 86 87

}  // namespace tvm
#endif  // TVM_BASE_H_