base.h 4.3 KB
Newer Older
tqchen committed
1 2
/*!
 *  Copyright (c) 2016 by Contributors
tqchen committed
3
 * \file tvm/base.h
tqchen committed
4 5 6 7 8 9 10
 * \brief Defines the base data structure
 */
#ifndef TVM_BASE_H_
#define TVM_BASE_H_

#include <dmlc/logging.h>
#include <dmlc/registry.h>
tqchen committed
11
#include <tvm/node.h>
tqchen committed
12 13 14
#include <string>
#include <memory>
#include <functional>
15
#include "runtime/registry.h"
tqchen committed
16 17 18

namespace tvm {

19 20 21
using ::tvm::Node;
using ::tvm::NodeRef;
using ::tvm::AttrVisitor;
tqchen committed
22

23 24
/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName)                  \
25
  class TypeName : public ::tvm::NodeRef {                       \
26 27
   public:                                                       \
    TypeName() {}                                                 \
28
    explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {}   \
29 30 31 32 33 34 35
    const NodeName* operator->() const {                          \
      return static_cast<const NodeName*>(node_.get());           \
    }                                                             \
    using ContainerType = NodeName;                               \
  };                                                              \


36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
/*!
 * \brief save the node as well as all the node it depends on as json.
 *  This can be used to serialize any TVM object
 *
 * \return the string representation of the node.
 */
std::string SaveJSON(const NodeRef& node);

/*!
 * \brief Internal implementation of LoadJSON
 * Load tvm Node object from json and return a shared_ptr of Node.
 * \param json_str The json string to load from.
 *
 * \return The shared_ptr of the Node.
 */
std::shared_ptr<Node> LoadJSON_(std::string json_str);

/*!
 * \brief Load the node from json string.
 *  This can be used to deserialize any TVM object.
 *
 * \param json_str The json string to load from.
 *
 * \tparam NodeType the nodetype
 *
 * \code
 *  Expr e = LoadJSON<Expr>(json_str);
 * \endcode
 */
template<typename NodeType,
         typename = typename std::enable_if<std::is_base_of<NodeRef, NodeType>::value>::type >
inline NodeType LoadJSON(const std::string& json_str) {
  return NodeType(LoadJSON_(json_str));
}

tqchen committed
71
/*!
72 73 74 75 76 77 78 79
 * \brief Registry entry for NodeFactory.
 *
 *  There are two types of Nodes that can be serialized.
 *  The normal node requires a registration a creator function that
 *  constructs an empty Node of the corresponding type.
 *
 *  The global singleton(e.g. global operator) where only global_key need to be serialized,
 *  in this case, FGlobalKey need to be defined.
tqchen committed
80
 */
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
struct NodeFactoryReg {
  /*!
   * \brief creator function.
   * \param global_key Key that identifies a global single object.
   *        If this is not empty then FGlobalKey
   * \return The created function.
   */
  using FCreate = std::function<std::shared_ptr<Node>(const std::string& global_key)>;
  /*!
   * \brief Global key function, only needed by global objects.
   * \param node The node pointer.
   * \return node The global key to the node.
   */
  using FGlobalKey = std::function<std::string(const Node* node)>;
  /*! \brief registered name */
  std::string name;
  /*!
   * \brief The creator function
   */
  FCreate fcreator = nullptr;
  /*!
   * \brief The global key function.
   */
  FGlobalKey fglobal_key = nullptr;
  // setter of creator
  NodeFactoryReg& set_creator(FCreate f) {  // NOLINT(*)
    this->fcreator = f;
    return *this;
  }
  // setter of creator
  NodeFactoryReg& set_global_key(FGlobalKey f) {  // NOLINT(*)
    this->fglobal_key = f;
    return *this;
  }
  // global registry singleton
  TVM_DLL static ::dmlc::Registry<::tvm::NodeFactoryReg> *Registry();
tqchen committed
117 118
};

119 120 121 122
/*!
 * \brief Register a Node type
 * \note This is necessary to enable serialization of the Node.
 */
tqchen committed
123
#define TVM_REGISTER_NODE_TYPE(TypeName)                                \
124
  static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
125 126 127 128 129 130 131 132 133 134 135
      ::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \
      .set_creator([](const std::string&) { return std::make_shared<TypeName>(); })


#define TVM_STRINGIZE_DETAIL(x) #x
#define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x)
#define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__))
/*!
 * \brief Macro to include current line as string
 */
#define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__)
tqchen committed
136

137

tqchen committed
138 139
}  // namespace tvm
#endif  // TVM_BASE_H_