Unverified Commit 78ca6fc8 by Tianqi Chen Committed by GitHub

[NODE][REFACTOR] Refactor reflection system in node. (#4189)

* [NODE][REFACTOR] Refactor reflection system in node.

- Removed the old Node, Node is now just an alias of runtime::Object
- Introduce ReflectionVTable, a new columnar dispatcher to support reflection
  - This allows us to remove vtable from most node objects
  - The VisitAttrs are registered via TVM_RESGITER_NODE_TYPE,
    they are no longer virtual.
- Consolidated serialization and reflection features into node.

* Explicit type qualification when calling destructor.

* Fix SPIRV, more comments
parent 324a9607
...@@ -58,7 +58,7 @@ class EnvFuncNode : public Node { ...@@ -58,7 +58,7 @@ class EnvFuncNode : public Node {
/*! \brief constructor */ /*! \brief constructor */
EnvFuncNode() {} EnvFuncNode() {}
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
} }
......
...@@ -60,7 +60,7 @@ class ConstIntBoundNode : public Node { ...@@ -60,7 +60,7 @@ class ConstIntBoundNode : public Node {
int64_t min_value; int64_t min_value;
int64_t max_value; int64_t max_value;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("min_value", &min_value); v->Visit("min_value", &min_value);
v->Visit("max_value", &max_value); v->Visit("max_value", &max_value);
} }
...@@ -162,7 +162,7 @@ class ModularSetNode : public Node { ...@@ -162,7 +162,7 @@ class ModularSetNode : public Node {
/*! \brief The base */ /*! \brief The base */
int64_t base; int64_t base;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("coeff", &coeff); v->Visit("coeff", &coeff);
v->Visit("base", &base); v->Visit("base", &base);
} }
...@@ -351,7 +351,7 @@ enum SignType { ...@@ -351,7 +351,7 @@ enum SignType {
*/ */
struct IntSetNode : public Node { struct IntSetNode : public Node {
static constexpr const char* _type_key = "IntSet"; static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Object);
}; };
/*! /*!
......
...@@ -115,7 +115,7 @@ class AttrFieldInfoNode : public Node { ...@@ -115,7 +115,7 @@ class AttrFieldInfoNode : public Node {
/*! \brief detailed description of the type */ /*! \brief detailed description of the type */
std::string description; std::string description;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("type_info", &type_info); v->Visit("type_info", &type_info);
v->Visit("description", &description); v->Visit("description", &description);
...@@ -197,7 +197,7 @@ class AttrsHash { ...@@ -197,7 +197,7 @@ class AttrsHash {
size_t operator()(const std::string& value) const { size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value); return std::hash<std::string>()(value);
} }
size_t operator()(const Type& value) const { size_t operator()(const DataType& value) const {
return std::hash<int>()( return std::hash<int>()(
static_cast<int>(value.code()) | static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) | (static_cast<int>(value.bits()) << 8) |
...@@ -221,6 +221,8 @@ class BaseAttrsNode : public Node { ...@@ -221,6 +221,8 @@ class BaseAttrsNode : public Node {
public: public:
using TVMArgs = runtime::TVMArgs; using TVMArgs = runtime::TVMArgs;
using TVMRetValue = runtime::TVMRetValue; using TVMRetValue = runtime::TVMRetValue;
// visit function
virtual void VisitAttrs(AttrVisitor* v) {}
/*! /*!
* \brief Initialize the attributes by sequence of arguments * \brief Initialize the attributes by sequence of arguments
* \param args The postional arguments in the form * \param args The postional arguments in the form
...@@ -753,12 +755,12 @@ class AttrNonDefaultVisitor { ...@@ -753,12 +755,12 @@ class AttrNonDefaultVisitor {
template<typename DerivedType> template<typename DerivedType>
class AttrsNode : public BaseAttrsNode { class AttrsNode : public BaseAttrsNode {
public: public:
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
::tvm::detail::AttrNormalVisitor vis(v); ::tvm::detail::AttrNormalVisitor vis(v);
self()->__VisitAttrs__(vis); self()->__VisitAttrs__(vis);
} }
void VisitNonDefaultAttrs(AttrVisitor* v) final { void VisitNonDefaultAttrs(AttrVisitor* v) {
::tvm::detail::AttrNonDefaultVisitor vis(v); ::tvm::detail::AttrNonDefaultVisitor vis(v);
self()->__VisitAttrs__(vis); self()->__VisitAttrs__(vis);
} }
......
...@@ -19,89 +19,16 @@ ...@@ -19,89 +19,16 @@
/*! /*!
* \file tvm/base.h * \file tvm/base.h
* \brief Defines the base data structure * \brief Base utilities
*/ */
#ifndef TVM_BASE_H_ #ifndef TVM_BASE_H_
#define TVM_BASE_H_ #define TVM_BASE_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <tvm/node/node.h>
#include <string>
#include <memory>
#include <functional>
#include <utility> #include <utility>
#include "runtime/registry.h"
namespace tvm { namespace tvm {
using ::tvm::Node;
using ::tvm::NodeRef;
using ::tvm::AttrVisitor;
/*!
* \brief Macro to define common node ref methods.
* \param TypeName The name of the NodeRef.
* \param BaseTypeName The Base type.
* \param NodeName The node container type.
*/
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
TypeName() {} \
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
: BaseTypeName(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(data_.get()); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = NodeName;
/*!
* \brief Macro to define CopyOnWrite function in a NodeRef.
* \param NodeName The Type of the Node.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWNodeRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
NodeName* CopyOnWrite() { \
CHECK(data_ != nullptr); \
if (!data_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<NodeName*>(data_.get()); \
}
/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public ::tvm::NodeRef { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
}; \
/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
TVM_DEFINE_NODE_REF_COW(NodeName); \
};
/*! /*!
* \brief RAII wrapper function to enter and exit a context object * \brief RAII wrapper function to enter and exit a context object
* similar to python's with syntax. * similar to python's with syntax.
...@@ -146,100 +73,6 @@ class With { ...@@ -146,100 +73,6 @@ class With {
ContextType ctx_; ContextType ctx_;
}; };
/*!
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
*
* \return the string representation of the node.
*/
std::string SaveJSON(const NodeRef& node);
/*!
* \brief Internal implementation of LoadJSON
* Load tvm Node object from json and return a shared_ptr of Node.
* \param json_str The json string to load from.
*
* \return The shared_ptr of the Node.
*/
ObjectPtr<Object> LoadJSON_(std::string json_str);
/*!
* \brief Load the node from json string.
* This can be used to deserialize any TVM object.
*
* \param json_str The json string to load from.
*
* \tparam NodeType the nodetype
*
* \code
* Expr e = LoadJSON<Expr>(json_str);
* \endcode
*/
template<typename NodeType,
typename = typename std::enable_if<std::is_base_of<NodeRef, NodeType>::value>::type >
inline NodeType LoadJSON(const std::string& json_str) {
return NodeType(LoadJSON_(json_str));
}
/*!
* \brief Registry entry for NodeFactory.
*
* There are two types of Nodes that can be serialized.
* The normal node requires a registration a creator function that
* constructs an empty Node of the corresponding type.
*
* The global singleton(e.g. global operator) where only global_key need to be serialized,
* in this case, FGlobalKey need to be defined.
*/
struct NodeFactoryReg {
/*!
* \brief creator function.
* \param global_key Key that identifies a global single object.
* If this is not empty then FGlobalKey
* \return The created function.
*/
using FCreate = std::function<NodePtr<Node>(const std::string& global_key)>;
/*!
* \brief Global key function, only needed by global objects.
* \param node The node pointer.
* \return node The global key to the node.
*/
using FGlobalKey = std::function<std::string(const Node* node)>;
/*! \brief registered name */
std::string name;
/*!
* \brief The creator function
*/
FCreate fcreator = nullptr;
/*!
* \brief The global key function.
*/
FGlobalKey fglobal_key = nullptr;
// setter of creator
NodeFactoryReg& set_creator(FCreate f) { // NOLINT(*)
this->fcreator = f;
return *this;
}
// setter of creator
NodeFactoryReg& set_global_key(FGlobalKey f) { // NOLINT(*)
this->fglobal_key = f;
return *this;
}
// global registry singleton
TVM_DLL static ::dmlc::Registry<::tvm::NodeFactoryReg> *Registry();
};
/*!
* \brief Register a Node type
* \note This is necessary to enable serialization of the Node.
*/
#define TVM_REGISTER_NODE_TYPE(TypeName) \
TVM_REGISTER_OBJECT_TYPE(TypeName); \
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \
.set_creator([](const std::string&) { return ::tvm::make_node<TypeName>(); })
#define TVM_STRINGIZE_DETAIL(x) #x #define TVM_STRINGIZE_DETAIL(x) #x
#define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x) #define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x)
#define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__)) #define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__))
......
...@@ -135,7 +135,7 @@ class BufferNode : public Node { ...@@ -135,7 +135,7 @@ class BufferNode : public Node {
/*! \brief constructor */ /*! \brief constructor */
BufferNode() {} BufferNode() {}
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("data", &data); v->Visit("data", &data);
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
v->Visit("shape", &shape); v->Visit("shape", &shape);
......
...@@ -61,7 +61,7 @@ class TargetNode : public Node { ...@@ -61,7 +61,7 @@ class TargetNode : public Node {
/*! \return the full device string to pass to codegen::Build */ /*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const; TVM_DLL const std::string& str() const;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("target_name", &target_name); v->Visit("target_name", &target_name);
v->Visit("device_name", &device_name); v->Visit("device_name", &device_name);
v->Visit("device_type", &device_type); v->Visit("device_type", &device_type);
...@@ -229,7 +229,7 @@ class BuildConfigNode : public Node { ...@@ -229,7 +229,7 @@ class BuildConfigNode : public Node {
/*! \brief Whether to disable loop vectorization. */ /*! \brief Whether to disable loop vectorization. */
bool disable_vectorize = false; bool disable_vectorize = false;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("data_alignment", &data_alignment); v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor); v->Visit("offset_factor", &offset_factor);
v->Visit("double_buffer_split_loop", &double_buffer_split_loop); v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
...@@ -473,6 +473,8 @@ class GenericFuncNode : public Node { ...@@ -473,6 +473,8 @@ class GenericFuncNode : public Node {
/* \brief map from keys to registered functions */ /* \brief map from keys to registered functions */
std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_; std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
void VisitAttrs(AttrVisitor* v) {}
static constexpr const char* _type_key = "GenericFunc"; static constexpr const char* _type_key = "GenericFunc";
TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node); TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node);
}; };
......
...@@ -54,7 +54,7 @@ struct ChannelNode : public Node { ...@@ -54,7 +54,7 @@ struct ChannelNode : public Node {
/*! \brief default data type in read/write */ /*! \brief default data type in read/write */
Type dtype; Type dtype;
// visit all attributes // visit all attributes
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("handle_var", &handle_var); v->Visit("handle_var", &handle_var);
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
} }
......
...@@ -104,7 +104,7 @@ class LayoutNode : public Node { ...@@ -104,7 +104,7 @@ class LayoutNode : public Node {
*/ */
Array<IterVar> axes; Array<IterVar> axes;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("axes", &axes); v->Visit("axes", &axes);
} }
...@@ -325,7 +325,7 @@ class BijectiveLayoutNode : public Node { ...@@ -325,7 +325,7 @@ class BijectiveLayoutNode : public Node {
/*! \brief The destination layout */ /*! \brief The destination layout */
Layout dst_layout; Layout dst_layout;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("src_layout", &src_layout); v->Visit("src_layout", &src_layout);
v->Visit("dst_layout", &dst_layout); v->Visit("dst_layout", &dst_layout);
v->Visit("forward_rule", &forward_rule); v->Visit("forward_rule", &forward_rule);
......
...@@ -27,8 +27,10 @@ ...@@ -27,8 +27,10 @@
#include <string> #include <string>
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <unordered_map>
#include <iostream>
#include "base.h" #include "base.h"
#include "dtype.h" #include "dtype.h"
#include "node/node.h"
#include "node/container.h" #include "node/container.h"
#include "node/ir_functor.h" #include "node/ir_functor.h"
#include "runtime/c_runtime_api.h" #include "runtime/c_runtime_api.h"
...@@ -110,7 +112,7 @@ class Variable : public ExprNode { ...@@ -110,7 +112,7 @@ class Variable : public ExprNode {
static Var make(DataType dtype, std::string name_hint); static Var make(DataType dtype, std::string name_hint);
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("name", &name_hint); v->Visit("name", &name_hint);
} }
...@@ -164,7 +166,7 @@ class IntImm : public ExprNode { ...@@ -164,7 +166,7 @@ class IntImm : public ExprNode {
/*! \brief the Internal value. */ /*! \brief the Internal value. */
int64_t value; int64_t value;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("value", &value); v->Visit("value", &value);
} }
...@@ -230,7 +232,7 @@ class RangeNode : public Node { ...@@ -230,7 +232,7 @@ class RangeNode : public Node {
RangeNode() {} RangeNode() {}
RangeNode(Expr min, Expr extent) : min(min), extent(extent) {} RangeNode(Expr min, Expr extent) : min(min), extent(extent) {}
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("min", &min); v->Visit("min", &min);
v->Visit("extent", &extent); v->Visit("extent", &extent);
} }
...@@ -406,7 +408,7 @@ class IterVarNode : public Node { ...@@ -406,7 +408,7 @@ class IterVarNode : public Node {
*/ */
std::string thread_tag; std::string thread_tag;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dom", &dom); v->Visit("dom", &dom);
v->Visit("var", &var); v->Visit("var", &var);
v->Visit("iter_type", &iter_type); v->Visit("iter_type", &iter_type);
...@@ -490,7 +492,7 @@ class IRPrinter { ...@@ -490,7 +492,7 @@ class IRPrinter {
}; };
// default print function for all nodes // default print function for all nodes
inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*) inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
IRPrinter(os).Print(n); IRPrinter(os).Print(n);
return os; return os;
} }
......
...@@ -45,7 +45,7 @@ class UIntImm : public ExprNode { ...@@ -45,7 +45,7 @@ class UIntImm : public ExprNode {
/*! \brief The constant value content. */ /*! \brief The constant value content. */
uint64_t value; uint64_t value;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("value", &value); v->Visit("value", &value);
} }
...@@ -62,7 +62,7 @@ class FloatImm : public ExprNode { ...@@ -62,7 +62,7 @@ class FloatImm : public ExprNode {
/*! \brief The constant value content. */ /*! \brief The constant value content. */
double value; double value;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("value", &value); v->Visit("value", &value);
} }
...@@ -79,7 +79,7 @@ class StringImm : public ExprNode { ...@@ -79,7 +79,7 @@ class StringImm : public ExprNode {
/*! \brief The constant value content. */ /*! \brief The constant value content. */
std::string value; std::string value;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("value", &value); v->Visit("value", &value);
} }
...@@ -99,7 +99,7 @@ class Cast : public ExprNode { ...@@ -99,7 +99,7 @@ class Cast : public ExprNode {
/*! \brief Original data type. */ /*! \brief Original data type. */
Expr value; Expr value;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("value", &value); v->Visit("value", &value);
} }
...@@ -122,7 +122,7 @@ class BinaryOpNode : public ExprNode { ...@@ -122,7 +122,7 @@ class BinaryOpNode : public ExprNode {
/*! \brief The right operand. */ /*! \brief The right operand. */
Expr b; Expr b;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->type)); v->Visit("dtype", &(this->type));
v->Visit("a", &a); v->Visit("a", &a);
v->Visit("b", &b); v->Visit("b", &b);
...@@ -214,7 +214,7 @@ class CmpOpNode : public ExprNode { ...@@ -214,7 +214,7 @@ class CmpOpNode : public ExprNode {
/*! \brief The right operand. */ /*! \brief The right operand. */
Expr b; Expr b;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->type)); v->Visit("dtype", &(this->type));
v->Visit("a", &a); v->Visit("a", &a);
v->Visit("b", &b); v->Visit("b", &b);
...@@ -278,7 +278,7 @@ class And : public ExprNode { ...@@ -278,7 +278,7 @@ class And : public ExprNode {
/*! \brief The right operand. */ /*! \brief The right operand. */
Expr b; Expr b;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->type)); v->Visit("dtype", &(this->type));
v->Visit("a", &a); v->Visit("a", &a);
v->Visit("b", &b); v->Visit("b", &b);
...@@ -298,7 +298,7 @@ class Or : public ExprNode { ...@@ -298,7 +298,7 @@ class Or : public ExprNode {
/*! \brief The right operand. */ /*! \brief The right operand. */
Expr b; Expr b;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("a", &a); v->Visit("a", &a);
v->Visit("b", &b); v->Visit("b", &b);
...@@ -316,7 +316,7 @@ class Not : public ExprNode { ...@@ -316,7 +316,7 @@ class Not : public ExprNode {
/*! \brief The input operand. */ /*! \brief The input operand. */
Expr a; Expr a;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("a", &a); v->Visit("a", &a);
} }
...@@ -343,7 +343,7 @@ class Select : public ExprNode { ...@@ -343,7 +343,7 @@ class Select : public ExprNode {
/*! \brief value to be returned when condition is false. */ /*! \brief value to be returned when condition is false. */
Expr false_value; Expr false_value;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("condition", &condition); v->Visit("condition", &condition);
v->Visit("true_value", &true_value); v->Visit("true_value", &true_value);
...@@ -380,7 +380,7 @@ class Load : public ExprNode { ...@@ -380,7 +380,7 @@ class Load : public ExprNode {
/*! \brief The predicate to mask which lanes would be loaded. */ /*! \brief The predicate to mask which lanes would be loaded. */
Expr predicate; Expr predicate;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("buffer_var", &buffer_var); v->Visit("buffer_var", &buffer_var);
v->Visit("index", &index); v->Visit("index", &index);
...@@ -411,7 +411,7 @@ class Ramp : public ExprNode { ...@@ -411,7 +411,7 @@ class Ramp : public ExprNode {
/*! \brief Total number of lanes. */ /*! \brief Total number of lanes. */
int lanes; int lanes;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("base", &base); v->Visit("base", &base);
v->Visit("stride", &stride); v->Visit("stride", &stride);
...@@ -432,7 +432,7 @@ class Broadcast : public ExprNode { ...@@ -432,7 +432,7 @@ class Broadcast : public ExprNode {
/*! \brief The numerb of lanes. */ /*! \brief The numerb of lanes. */
int lanes; int lanes;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("value", &value); v->Visit("value", &value);
v->Visit("lanes", &lanes); v->Visit("lanes", &lanes);
...@@ -456,7 +456,7 @@ class Let : public ExprNode { ...@@ -456,7 +456,7 @@ class Let : public ExprNode {
/*! \brief The result expression. */ /*! \brief The result expression. */
Expr body; Expr body;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("var", &var); v->Visit("var", &var);
v->Visit("value", &value); v->Visit("value", &value);
...@@ -522,7 +522,7 @@ class Call : public ExprNode { ...@@ -522,7 +522,7 @@ class Call : public ExprNode {
/*! \brief The output value index if func's value is a tuple. */ /*! \brief The output value index if func's value is a tuple. */
int value_index{0}; int value_index{0};
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("args", &args); v->Visit("args", &args);
...@@ -592,7 +592,7 @@ class Shuffle : public ExprNode { ...@@ -592,7 +592,7 @@ class Shuffle : public ExprNode {
/*! \brief The indices of each element. */ /*! \brief The indices of each element. */
Array<Expr> indices; Array<Expr> indices;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("vectors", &vectors); v->Visit("vectors", &vectors);
v->Visit("indices", &indices); v->Visit("indices", &indices);
} }
...@@ -652,7 +652,7 @@ class CommReducerNode : public Node { ...@@ -652,7 +652,7 @@ class CommReducerNode : public Node {
Array<Expr> result, Array<Expr> result,
Array<Expr> identity_element); Array<Expr> identity_element);
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("lhs", &lhs); v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs); v->Visit("rhs", &rhs);
v->Visit("result", &result); v->Visit("result", &result);
...@@ -694,7 +694,7 @@ class Reduce : public ExprNode { ...@@ -694,7 +694,7 @@ class Reduce : public ExprNode {
Expr condition, Expr condition,
int value_index); int value_index);
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("combiner", &combiner); v->Visit("combiner", &combiner);
v->Visit("source", &source); v->Visit("source", &source);
...@@ -710,7 +710,7 @@ class Reduce : public ExprNode { ...@@ -710,7 +710,7 @@ class Reduce : public ExprNode {
/*! \brief Any shape. */ /*! \brief Any shape. */
class Any : public ExprNode { class Any : public ExprNode {
public: public:
void VisitAttrs(AttrVisitor* v) final {} void VisitAttrs(AttrVisitor* v) {}
/*! \brief Convert to var. */ /*! \brief Convert to var. */
Var ToVar() const { Var ToVar() const {
return Variable::make(Int(32), "any_dim"); return Variable::make(Int(32), "any_dim");
...@@ -735,7 +735,7 @@ class LetStmt : public StmtNode { ...@@ -735,7 +735,7 @@ class LetStmt : public StmtNode {
/*! \brief The body block. */ /*! \brief The body block. */
Stmt body; Stmt body;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("var", &var); v->Visit("var", &var);
v->Visit("value", &value); v->Visit("value", &value);
v->Visit("body", &body); v->Visit("body", &body);
...@@ -768,7 +768,7 @@ class AttrStmt : public StmtNode { ...@@ -768,7 +768,7 @@ class AttrStmt : public StmtNode {
/*! \brief The body statement to be executed */ /*! \brief The body statement to be executed */
Stmt body; Stmt body;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("node", &node); v->Visit("node", &node);
v->Visit("attr_key", &attr_key); v->Visit("attr_key", &attr_key);
v->Visit("value", &value); v->Visit("value", &value);
...@@ -799,7 +799,7 @@ class AssertStmt : public StmtNode { ...@@ -799,7 +799,7 @@ class AssertStmt : public StmtNode {
*/ */
Stmt body; Stmt body;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("condition", &condition); v->Visit("condition", &condition);
v->Visit("message", &message); v->Visit("message", &message);
v->Visit("body", &body); v->Visit("body", &body);
...@@ -822,7 +822,7 @@ class ProducerConsumer : public StmtNode { ...@@ -822,7 +822,7 @@ class ProducerConsumer : public StmtNode {
/*! \brief Body to be executed. */ /*! \brief Body to be executed. */
Stmt body; Stmt body;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func); v->Visit("func", &func);
v->Visit("is_producer", &is_producer); v->Visit("is_producer", &is_producer);
v->Visit("body", &body); v->Visit("body", &body);
...@@ -863,7 +863,7 @@ class Store : public StmtNode { ...@@ -863,7 +863,7 @@ class Store : public StmtNode {
/*! \brief The predicate to mask which lanes would be stored. */ /*! \brief The predicate to mask which lanes would be stored. */
Expr predicate; Expr predicate;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var); v->Visit("buffer_var", &buffer_var);
v->Visit("value", &value); v->Visit("value", &value);
v->Visit("index", &index); v->Visit("index", &index);
...@@ -893,7 +893,7 @@ class Provide : public StmtNode { ...@@ -893,7 +893,7 @@ class Provide : public StmtNode {
/*! \brief The index arguments of the function. */ /*! \brief The index arguments of the function. */
Array<Expr> args; Array<Expr> args;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func); v->Visit("func", &func);
v->Visit("value_index", &value_index); v->Visit("value_index", &value_index);
v->Visit("value", &value); v->Visit("value", &value);
...@@ -929,7 +929,7 @@ class Allocate : public StmtNode { ...@@ -929,7 +929,7 @@ class Allocate : public StmtNode {
Expr new_expr; Expr new_expr;
std::string free_function; std::string free_function;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var); v->Visit("buffer_var", &buffer_var);
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("extents", &extents); v->Visit("extents", &extents);
...@@ -972,7 +972,7 @@ class Free : public StmtNode { ...@@ -972,7 +972,7 @@ class Free : public StmtNode {
/*! \brief The buffer variable. */ /*! \brief The buffer variable. */
Var buffer_var; Var buffer_var;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var); v->Visit("buffer_var", &buffer_var);
} }
...@@ -1001,7 +1001,7 @@ class Realize : public StmtNode { ...@@ -1001,7 +1001,7 @@ class Realize : public StmtNode {
/*! \brief The body of realization. */ /*! \brief The body of realization. */
Stmt body; Stmt body;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func); v->Visit("func", &func);
v->Visit("value_index", &value_index); v->Visit("value_index", &value_index);
v->Visit("dtype", &type); v->Visit("dtype", &type);
...@@ -1031,7 +1031,7 @@ class Block : public StmtNode { ...@@ -1031,7 +1031,7 @@ class Block : public StmtNode {
/*! \brief The restof statments. */ /*! \brief The restof statments. */
Stmt rest; Stmt rest;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("first", &first); v->Visit("first", &first);
v->Visit("rest", &rest); v->Visit("rest", &rest);
} }
...@@ -1055,7 +1055,7 @@ class IfThenElse : public StmtNode { ...@@ -1055,7 +1055,7 @@ class IfThenElse : public StmtNode {
/*! \brief The branch to be executed when condition is false, can be null. */ /*! \brief The branch to be executed when condition is false, can be null. */
Stmt else_case; Stmt else_case;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("condition", &condition); v->Visit("condition", &condition);
v->Visit("then_case", &then_case); v->Visit("then_case", &then_case);
v->Visit("else_case", &else_case); v->Visit("else_case", &else_case);
...@@ -1078,7 +1078,7 @@ class Evaluate : public StmtNode { ...@@ -1078,7 +1078,7 @@ class Evaluate : public StmtNode {
/*! \brief The expression to be evaluated. */ /*! \brief The expression to be evaluated. */
Expr value; Expr value;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("value", &value); v->Visit("value", &value);
} }
...@@ -1142,7 +1142,7 @@ class For : public StmtNode { ...@@ -1142,7 +1142,7 @@ class For : public StmtNode {
DeviceAPI device_api, DeviceAPI device_api,
Stmt body); Stmt body);
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_var", &loop_var); v->Visit("loop_var", &loop_var);
v->Visit("min", &min); v->Visit("min", &min);
v->Visit("extent", &extent); v->Visit("extent", &extent);
...@@ -1169,7 +1169,7 @@ class Prefetch : public StmtNode { ...@@ -1169,7 +1169,7 @@ class Prefetch : public StmtNode {
/*! \brief Bounds to be prefetched. */ /*! \brief Bounds to be prefetched. */
Region bounds; Region bounds;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func); v->Visit("func", &func);
v->Visit("value_index", &value_index); v->Visit("value_index", &value_index);
v->Visit("type", &type); v->Visit("type", &type);
......
...@@ -119,7 +119,7 @@ class LoweredFuncNode : public ir::FunctionBaseNode { ...@@ -119,7 +119,7 @@ class LoweredFuncNode : public ir::FunctionBaseNode {
int num_outputs() const final { int num_outputs() const final {
return 1; return 1;
} }
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("args", &args); v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis); v->Visit("thread_axis", &thread_axis);
......
...@@ -40,8 +40,7 @@ class ArrayNode : public Node { ...@@ -40,8 +40,7 @@ class ArrayNode : public Node {
/*! \brief the data content */ /*! \brief the data content */
std::vector<ObjectRef> data; std::vector<ObjectRef> data;
void VisitAttrs(AttrVisitor* visitor) final { void VisitAttrs(AttrVisitor* visitor) {
// Visitor to array have no effect.
} }
static constexpr const char* _type_key = "Array"; static constexpr const char* _type_key = "Array";
...@@ -51,9 +50,9 @@ class ArrayNode : public Node { ...@@ -51,9 +50,9 @@ class ArrayNode : public Node {
/*! \brief map node content */ /*! \brief map node content */
class MapNode : public Node { class MapNode : public Node {
public: public:
void VisitAttrs(AttrVisitor* visitor) final { void VisitAttrs(AttrVisitor* visitor) {
// Visitor to map have no effect.
} }
/*! \brief The corresponding conatiner type */ /*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map< using ContainerType = std::unordered_map<
ObjectRef, ObjectRef,
...@@ -71,12 +70,12 @@ class MapNode : public Node { ...@@ -71,12 +70,12 @@ class MapNode : public Node {
/*! \brief specialized map node with string as key */ /*! \brief specialized map node with string as key */
class StrMapNode : public Node { class StrMapNode : public Node {
public: public:
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect.
}
/*! \brief The corresponding conatiner type */ /*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<std::string, ObjectRef>; using ContainerType = std::unordered_map<std::string, ObjectRef>;
void VisitAttrs(AttrVisitor* visitor) {
}
/*! \brief the data content */ /*! \brief the data content */
ContainerType data; ContainerType data;
......
...@@ -18,113 +18,68 @@ ...@@ -18,113 +18,68 @@
*/ */
/*! /*!
* \file tvm/node/node.h * \file tvm/node/node.h
* \brief Node system data structure. * \brief Definitions and helper macros for IR/AST nodes.
*
* The node folder contains base utilities for IR/AST nodes,
* invariant of which specific language dialect.
*
* We implement AST/IR nodes as sub-classes of runtime::Object.
* The base class Node is just an alias of runtime::Object.
*
* Besides the runtime type checking provided by Object,
* node folder contains additional functionalities such as
* reflection and serialization, which are important features
* for building a compiler infra.
*/ */
#ifndef TVM_NODE_NODE_H_ #ifndef TVM_NODE_NODE_H_
#define TVM_NODE_NODE_H_ #define TVM_NODE_NODE_H_
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h> #include <tvm/runtime/memory.h>
#include <tvm/runtime/ndarray.h> #include <tvm/node/reflection.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <type_traits> #include <type_traits>
namespace tvm { namespace tvm {
// forward declaration
class DataType;
class Node;
class NodeRef;
/*! using runtime::TypeIndex;
* \brief Visitor class to each node content. using runtime::Object;
* The content is going to be called for each field. using runtime::ObjectPtr;
*/ using runtime::ObjectRef;
class TVM_DLL AttrVisitor { using runtime::GetRef;
public: using runtime::Downcast;
//! \cond Doxygen_Suppress using runtime::ObjectHash;
virtual ~AttrVisitor() = default; using runtime::ObjectEqual;
virtual void Visit(const char* key, double* value) = 0; using runtime::make_object;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0;
virtual void Visit(const char* key, int* value) = 0;
virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, void** value) = 0;
virtual void Visit(const char* key, DataType* value) = 0;
virtual void Visit(const char* key, NodeRef* value) = 0;
virtual void Visit(const char* key, runtime::NDArray* value) = 0;
virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr));
}
//! \endcond
};
/*! \brief Reuse the type index in he runtime. */ using NodeHash = ObjectHash;
using TypeIndex = runtime::TypeIndex; using NodeEqual = ObjectEqual;
using Node = Object;
/*! /*!
* \brief base class of node container in DSL AST. * \brief Base class of all references to AST/IR nodes.
*/ */
class Node : public runtime::Object { class NodeRef : public ObjectRef {
public: public:
/*! \brief virtual destructor */ NodeRef() {}
virtual ~Node() {} explicit NodeRef(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Apply visitor to each field of the Node
* Visitor could mutate the content of the node.
* override if Node contains attribute fields.
* \param visitor The visitor
*/
virtual void VisitAttrs(AttrVisitor* visitor) {}
static constexpr const char* _type_key = "Node";
static constexpr uint32_t _type_index = TypeIndex::kDynamic;
TVM_DECLARE_BASE_OBJECT_INFO(Node, runtime::Object);
}; };
/*! /*!
* \brief Base class of all node reference object * \brief Allocate a node object.
* NodeRef is just a alias of ObjectRef. * \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
* \note This function is an alias of make_object.
*/ */
class NodeRef : public runtime::ObjectRef { template<typename T, typename... Args>
public: inline NodePtr<T> make_node(Args&&... args) {
/*! \brief type indicate the container type */ return runtime::make_object<T>(std::forward<Args>(args)...);
using ContainerType = Node; }
/*! \return the internal node pointer */
const Node* get() const {
return static_cast<const Node*>(ObjectRef::get());
}
/*! \return the internal node pointer */
const Node* operator->() const {
return get();
}
/*!
* \brief A more powerful version of as that also works with
* intermediate base types.
* \tparam T the target type, must be subtype of IRNode
*/
template<typename T>
const T *as_derived() const {
return as<T>();
}
/*! \brief default constructor */
NodeRef() = default;
explicit NodeRef(runtime::ObjectPtr<runtime::Object> ptr) : ObjectRef(ptr) {}
};
/*! /*!
* \brief helper macro to declare type information in a base node. * \brief helper macro to declare type information in a base node.
...@@ -139,27 +94,67 @@ class NodeRef : public runtime::ObjectRef { ...@@ -139,27 +94,67 @@ class NodeRef : public runtime::ObjectRef {
TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, Parent); TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, Parent);
using runtime::Object; /*!
using runtime::ObjectPtr; * \brief Macro to define common node ref methods.
using runtime::ObjectRef; * \param TypeName The name of the NodeRef.
using runtime::GetRef; * \param BaseTypeName The Base type.
using runtime::Downcast; * \param NodeName The node container type.
using runtime::make_object; */
using runtime::ObjectHash; #define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
using runtime::ObjectEqual; TypeName() {} \
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
: BaseTypeName(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(data_.get()); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = NodeName;
using NodeHash = ObjectHash; /*!
using NodeEqual = ObjectEqual; * \brief Macro to define CopyOnWrite function in a NodeRef.
* \param NodeName The Type of the Node.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWNodeRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
NodeName* CopyOnWrite() { \
CHECK(data_ != nullptr); \
if (!data_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<NodeName*>(data_.get()); \
}
/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public ::tvm::NodeRef { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
}; \
/*! /*!
* \brief Allocate a node object. * \brief Macro to make it easy to define node ref type that
* \param args arguments to the constructor. * has a CopyOnWrite member function.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
*/ */
template<typename T, typename... Args> #define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
inline NodePtr<T> make_node(Args&&... args) { class TypeName : public BaseType { \
return runtime::make_object<T>(std::forward<Args>(args)...); public: \
} TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
TVM_DEFINE_NODE_REF_COW(NodeName); \
};
} // namespace tvm } // namespace tvm
#endif // TVM_NODE_NODE_H_ #endif // TVM_NODE_NODE_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/reflection.h
* \brief Reflection and serialization of compiler IR/AST nodes.
*/
#ifndef TVM_NODE_REFLECTION_H_
#define TVM_NODE_REFLECTION_H_
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/ndarray.h>
#include <vector>
#include <string>
namespace tvm {
// forward declaration
class DataType;
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
/*!
* \brief Visitor class for to get the attributesof a AST/IR node.
* The content is going to be called for each field.
*
* Each objects that wants reflection will need to implement
* a VisitAttrs function and call visitor->Visit on each of its field.
*/
class TVM_DLL AttrVisitor {
public:
//! \cond Doxygen_Suppress
virtual ~AttrVisitor() = default;
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0;
virtual void Visit(const char* key, int* value) = 0;
virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, void** value) = 0;
virtual void Visit(const char* key, DataType* value) = 0;
virtual void Visit(const char* key, runtime::NDArray* value) = 0;
virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr));
}
//! \endcond
};
/*!
* \brief Virtual function table to support IR/AST node reflection.
*
* Functions are stored in columar manner.
* Each column is a vector indexed by Object's type_index.
*/
class ReflectionVTable {
public:
/*!
* \brief Visitor function.
* \note We use function pointer, instead of std::function
* to reduce the dispatch overhead as field visit
* does not need as much customization.
*/
typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor);
/*!
* \brief creator function.
* \param global_key Key that identifies a global single object.
* If this is not empty then FGlobalKey must be defined for the object.
* \return The created function.
*/
using FCreate = std::function<ObjectPtr<Object>(const std::string& global_key)>;
/*!
* \brief Global key function, only needed by global objects.
* \param node The node pointer.
* \return node The global key to the node.
*/
using FGlobalKey = std::function<std::string(const Object* self)>;
/*!
* \brief Dispatch the VisitAttrs function.
* \param self The pointer to the object.
* \param visitor The attribute visitor.
*/
inline void VisitAttrs(Object* self, AttrVisitor* visitor) const;
/*!
* \brief Get global key of the object, if any.
* \param self The pointer to the object.
* \return the global key if object has one, otherwise return empty string.
*/
inline std::string GetGlobalKey(Object* self) const;
/*!
* \brief Create an initial object using default constructor
* by type_key and global key.
*
* \param type_key The type key of the object.
* \param global_key A global key that can be used to uniquely identify the object if any.
*/
TVM_DLL ObjectPtr<Object> CreateInitObject(const std::string& type_key,
const std::string& global_key = "") const;
/*!
* \brief Get an field object by the attr name.
* \param self The pointer to the object.
* \param attr_name The name of the field.
* \return The corresponding attribute value.
* \note This function will throw an exception if the object does not contain the field.
*/
TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const std::string& attr_name) const;
/*!
* \brief List all the fields in the object.
* \return All the fields.
*/
TVM_DLL std::vector<std::string> ListAttrNames(Object* self) const;
/*! \return The global singleton. */
TVM_DLL static ReflectionVTable* Global();
class Registry;
template<typename T>
inline Registry Register();
private:
/*! \brief Attribute visitor. */
std::vector<FVisitAttrs> fvisit_attrs_;
/*! \brief Creation function. */
std::vector<FCreate> fcreate_;
/*! \brief Global key function. */
std::vector<FGlobalKey> fglobal_key_;
};
/*! \brief Registry of a reflection table. */
class ReflectionVTable::Registry {
public:
Registry(ReflectionVTable* parent, uint32_t type_index)
: parent_(parent), type_index_(type_index) { }
/*!
* \brief Set fcreate function.
* \param f The creator function.
* \return rference to self.
*/
Registry& set_creator(FCreate f) { // NOLINT(*)
CHECK_LT(type_index_, parent_->fcreate_.size());
parent_->fcreate_[type_index_] = f;
return *this;
}
/*!
* \brief Set global_key function.
* \param f The creator function.
* \return rference to self.
*/
Registry& set_global_key(FGlobalKey f) { // NOLINT(*)
CHECK_LT(type_index_, parent_->fglobal_key_.size());
parent_->fglobal_key_[type_index_] = f;
return *this;
}
private:
ReflectionVTable* parent_;
uint32_t type_index_;
};
/*!
* \brief Register a node type to object registry and reflection registry.
* \param TypeName The name of the type.
* \note This macro will call TVM_REGISTER_OBJECT_TYPE for the type as well.
*/
#define TVM_REGISTER_NODE_TYPE(TypeName) \
TVM_REGISTER_OBJECT_TYPE(TypeName); \
static DMLC_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry & \
__make_Node ## _ ## TypeName ## __ = \
::tvm::ReflectionVTable::Global()->Register<TypeName>() \
.set_creator([](const std::string&) { \
return ::tvm::runtime::make_object<TypeName>(); \
})
// Implementation details
template<typename T>
inline ReflectionVTable::Registry
ReflectionVTable::Register() {
uint32_t tindex = T::RuntimeTypeIndex();
if (tindex >= fvisit_attrs_.size()) {
fvisit_attrs_.resize(tindex + 1, nullptr);
fcreate_.resize(tindex + 1, nullptr);
fglobal_key_.resize(tindex + 1, nullptr);
}
// functor that implemnts the redirection.
struct Functor {
static void VisitAttrs(Object* self, AttrVisitor* v) {
static_cast<T*>(self)->VisitAttrs(v);
}
};
fvisit_attrs_[tindex] = Functor::VisitAttrs;
return Registry(this, tindex);
}
inline void ReflectionVTable::
VisitAttrs(Object* self, AttrVisitor* visitor) const {
uint32_t tindex = self->type_index();
if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: " << self->GetTypeKey()
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
fvisit_attrs_[tindex](self, visitor);
}
inline std::string ReflectionVTable::GetGlobalKey(Object* self) const {
uint32_t tindex = self->type_index();
if (tindex < fglobal_key_.size() && fglobal_key_[tindex] != nullptr) {
return fglobal_key_[tindex](self);
} else {
return std::string();
}
}
} // namespace tvm
#endif // TVM_NODE_REFLECTION_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Utility functions for serialization.
* \file tvm/node/serialization.h
*/
#ifndef TVM_NODE_SERIALIZATION_H_
#define TVM_NODE_SERIALIZATION_H_
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/object.h>
#include <string>
namespace tvm {
/*!
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
*
* \return the string representation of the node.
*/
TVM_DLL std::string SaveJSON(const runtime::ObjectRef& node);
/*!
* \brief Internal implementation of LoadJSON
* Load tvm Node object from json and return a shared_ptr of Node.
* \param json_str The json string to load from.
*
* \return The shared_ptr of the Node.
*/
TVM_DLL runtime::ObjectRef LoadJSON(std::string json_str);
} // namespace tvm
#endif // TVM_NODE_SERIALIZATION_H_
...@@ -188,7 +188,7 @@ class PlaceholderOpNode : public OperationNode { ...@@ -188,7 +188,7 @@ class PlaceholderOpNode : public OperationNode {
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final; bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("tag", &tag); v->Visit("tag", &tag);
v->Visit("attrs", &attrs); v->Visit("attrs", &attrs);
...@@ -259,7 +259,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { ...@@ -259,7 +259,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
bool debug_keep_trivial_loop) const final; bool debug_keep_trivial_loop) const final;
size_t num_schedulable_dims() const final; size_t num_schedulable_dims() const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("tag", &tag); v->Visit("tag", &tag);
v->Visit("attrs", &attrs); v->Visit("attrs", &attrs);
...@@ -312,7 +312,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { ...@@ -312,7 +312,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
bool debug_keep_trivial_loop) const final; bool debug_keep_trivial_loop) const final;
size_t num_schedulable_dims() const final; size_t num_schedulable_dims() const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("tag", &tag); v->Visit("tag", &tag);
v->Visit("axis", &axis); v->Visit("axis", &axis);
...@@ -394,7 +394,7 @@ class ScanOpNode : public OperationNode { ...@@ -394,7 +394,7 @@ class ScanOpNode : public OperationNode {
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final; bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("tag", &tag); v->Visit("tag", &tag);
v->Visit("attrs", &attrs); v->Visit("attrs", &attrs);
...@@ -461,7 +461,7 @@ class ExternOpNode : public OperationNode { ...@@ -461,7 +461,7 @@ class ExternOpNode : public OperationNode {
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final; bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("tag", &tag); v->Visit("tag", &tag);
v->Visit("attrs", &attrs); v->Visit("attrs", &attrs);
...@@ -529,7 +529,7 @@ class HybridOpNode : public OperationNode { ...@@ -529,7 +529,7 @@ class HybridOpNode : public OperationNode {
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final; bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("tag", &tag); v->Visit("tag", &tag);
v->Visit("attrs", &attrs); v->Visit("attrs", &attrs);
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
/*! /*!
* \file tvm/packed_func_ext.h * \file tvm/packed_func_ext.h
* \brief Extension package to PackedFunc * \brief Extension package to PackedFunc
* This enales pass NodeRef types into/from PackedFunc. * This enales pass ObjectRef types into/from PackedFunc.
*/ */
#ifndef TVM_PACKED_FUNC_EXT_H_ #ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_ #define TVM_PACKED_FUNC_EXT_H_
...@@ -129,18 +129,18 @@ inline std::string ObjectTypeName() { ...@@ -129,18 +129,18 @@ inline std::string ObjectTypeName() {
// extensions for tvm arg value // extensions for tvm arg value
template<typename TNodeRef> template<typename TObjectRef>
inline TNodeRef TVMArgValue::AsNodeRef() const { inline TObjectRef TVMArgValue::AsObjectRef() const {
static_assert( static_assert(
std::is_base_of<NodeRef, TNodeRef>::value, std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for NodeRef"); "Conversion only works for ObjectRef");
if (type_code_ == kNull) return TNodeRef(NodePtr<Node>(nullptr)); if (type_code_ == kNull) return TObjectRef(NodePtr<Node>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle); Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TNodeRef>::Check(ptr)) CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TNodeRef>() << "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey(); << " but get " << ptr->GetTypeKey();
return TNodeRef(ObjectPtr<Node>(ptr)); return TObjectRef(ObjectPtr<Node>(ptr));
} }
inline TVMArgValue::operator tvm::Expr() const { inline TVMArgValue::operator tvm::Expr() const {
...@@ -184,28 +184,28 @@ inline TVMArgValue::operator tvm::Integer() const { ...@@ -184,28 +184,28 @@ inline TVMArgValue::operator tvm::Integer() const {
return Integer(ObjectPtr<Node>(ptr)); return Integer(ObjectPtr<Node>(ptr));
} }
template<typename TNodeRef, typename> template<typename TObjectRef, typename>
inline bool TVMPODValue_::IsObjectRef() const { inline bool TVMPODValue_::IsObjectRef() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle); Object* ptr = static_cast<Object*>(value_.v_handle);
return ObjectTypeChecker<TNodeRef>::Check(ptr); return ObjectTypeChecker<TObjectRef>::Check(ptr);
} }
// extensions for TVMRetValue // extensions for TVMRetValue
template<typename TNodeRef> template<typename TObjectRef>
inline TNodeRef TVMRetValue::AsNodeRef() const { inline TObjectRef TVMRetValue::AsObjectRef() const {
static_assert( static_assert(
std::is_base_of<NodeRef, TNodeRef>::value, std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for NodeRef"); "Conversion only works for ObjectRef");
if (type_code_ == kNull) return TNodeRef(); if (type_code_ == kNull) return TObjectRef();
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle); Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TNodeRef>::Check(ptr)) CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TNodeRef>() << "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey(); << " but get " << ptr->GetTypeKey();
return TNodeRef(ObjectPtr<Object>(ptr)); return TObjectRef(ObjectPtr<Object>(ptr));
} }
// type related stuffs // type related stuffs
......
...@@ -66,7 +66,7 @@ class PatternWildcardNode : public PatternNode { ...@@ -66,7 +66,7 @@ class PatternWildcardNode : public PatternNode {
TVM_DLL static PatternWildcard make(); TVM_DLL static PatternWildcard make();
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("span", &span); v->Visit("span", &span);
} }
...@@ -88,7 +88,7 @@ class PatternVarNode : public PatternNode { ...@@ -88,7 +88,7 @@ class PatternVarNode : public PatternNode {
TVM_DLL static PatternVar make(tvm::relay::Var var); TVM_DLL static PatternVar make(tvm::relay::Var var);
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var); v->Visit("var", &var);
v->Visit("span", &span); v->Visit("span", &span);
} }
...@@ -122,7 +122,7 @@ class ConstructorNode : public ExprNode { ...@@ -122,7 +122,7 @@ class ConstructorNode : public ExprNode {
tvm::Array<Type> inputs, tvm::Array<Type> inputs,
GlobalTypeVar belong_to); GlobalTypeVar belong_to);
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint); v->Visit("name_hint", &name_hint);
v->Visit("inputs", &inputs); v->Visit("inputs", &inputs);
v->Visit("belong_to", &belong_to); v->Visit("belong_to", &belong_to);
...@@ -151,7 +151,7 @@ class PatternConstructorNode : public PatternNode { ...@@ -151,7 +151,7 @@ class PatternConstructorNode : public PatternNode {
TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var); TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var);
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("constructor", &constructor); v->Visit("constructor", &constructor);
v->Visit("patterns", &patterns); v->Visit("patterns", &patterns);
v->Visit("span", &span); v->Visit("span", &span);
...@@ -175,7 +175,7 @@ class PatternTupleNode : public PatternNode { ...@@ -175,7 +175,7 @@ class PatternTupleNode : public PatternNode {
TVM_DLL static PatternTuple make(tvm::Array<Pattern> var); TVM_DLL static PatternTuple make(tvm::Array<Pattern> var);
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("patterns", &patterns); v->Visit("patterns", &patterns);
v->Visit("span", &span); v->Visit("span", &span);
} }
...@@ -213,7 +213,7 @@ class TypeDataNode : public TypeNode { ...@@ -213,7 +213,7 @@ class TypeDataNode : public TypeNode {
/*! \brief The constructors. */ /*! \brief The constructors. */
tvm::Array<Constructor> constructors; tvm::Array<Constructor> constructors;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("header", &header); v->Visit("header", &header);
v->Visit("type_vars", &type_vars); v->Visit("type_vars", &type_vars);
v->Visit("constructors", &constructors); v->Visit("constructors", &constructors);
...@@ -240,7 +240,7 @@ class ClauseNode : public Node { ...@@ -240,7 +240,7 @@ class ClauseNode : public Node {
/*! \brief The resulting value. */ /*! \brief The resulting value. */
Expr rhs; Expr rhs;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("lhs", &lhs); v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs); v->Visit("rhs", &rhs);
} }
...@@ -269,7 +269,7 @@ class MatchNode : public ExprNode { ...@@ -269,7 +269,7 @@ class MatchNode : public ExprNode {
*/ */
bool complete; bool complete;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data); v->Visit("data", &data);
v->Visit("clauses", &clauses); v->Visit("clauses", &clauses);
v->Visit("complete", &complete); v->Visit("complete", &complete);
......
...@@ -107,7 +107,7 @@ class SourceNameNode : public Node { ...@@ -107,7 +107,7 @@ class SourceNameNode : public Node {
/*! \brief The source name. */ /*! \brief The source name. */
std::string name; std::string name;
// override attr visitor // override attr visitor
void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); } void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
static constexpr const char* _type_key = "relay.SourceName"; static constexpr const char* _type_key = "relay.SourceName";
TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node); TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node);
...@@ -160,7 +160,7 @@ class SpanNode : public Node { ...@@ -160,7 +160,7 @@ class SpanNode : public Node {
/*! \brief column offset */ /*! \brief column offset */
int col_offset; int col_offset;
// override attr visitor // override attr visitor
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("source", &source); v->Visit("source", &source);
v->Visit("lineno", &lineno); v->Visit("lineno", &lineno);
v->Visit("col_offset", &col_offset); v->Visit("col_offset", &col_offset);
...@@ -204,7 +204,7 @@ class IdNode : public Node { ...@@ -204,7 +204,7 @@ class IdNode : public Node {
*/ */
std::string name_hint; std::string name_hint;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint); v->Visit("name_hint", &name_hint);
} }
......
...@@ -95,7 +95,7 @@ class ConstantNode : public ExprNode { ...@@ -95,7 +95,7 @@ class ConstantNode : public ExprNode {
return data->ndim == 0; return data->ndim == 0;
} }
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data); v->Visit("data", &data);
v->Visit("span", &span); v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
...@@ -117,7 +117,7 @@ class TupleNode : public ExprNode { ...@@ -117,7 +117,7 @@ class TupleNode : public ExprNode {
/*! \brief the fields of the tuple */ /*! \brief the fields of the tuple */
tvm::Array<relay::Expr> fields; tvm::Array<relay::Expr> fields;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("fields", &fields); v->Visit("fields", &fields);
v->Visit("span", &span); v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
...@@ -165,7 +165,7 @@ class VarNode : public ExprNode { ...@@ -165,7 +165,7 @@ class VarNode : public ExprNode {
return vid->name_hint; return vid->name_hint;
} }
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("vid", &vid); v->Visit("vid", &vid);
v->Visit("type_annotation", &type_annotation); v->Visit("type_annotation", &type_annotation);
v->Visit("span", &span); v->Visit("span", &span);
...@@ -197,7 +197,7 @@ class GlobalVarNode : public ExprNode { ...@@ -197,7 +197,7 @@ class GlobalVarNode : public ExprNode {
/*! \brief The name of the variable, this only acts as a hint. */ /*! \brief The name of the variable, this only acts as a hint. */
std::string name_hint; std::string name_hint;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint); v->Visit("name_hint", &name_hint);
v->Visit("span", &span); v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
...@@ -243,7 +243,7 @@ class FunctionNode : public ExprNode { ...@@ -243,7 +243,7 @@ class FunctionNode : public ExprNode {
*/ */
tvm::Attrs attrs; tvm::Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params); v->Visit("params", &params);
v->Visit("body", &body); v->Visit("body", &body);
v->Visit("ret_type", &ret_type); v->Visit("ret_type", &ret_type);
...@@ -327,7 +327,7 @@ class CallNode : public ExprNode { ...@@ -327,7 +327,7 @@ class CallNode : public ExprNode {
*/ */
tvm::Array<Type> type_args; tvm::Array<Type> type_args;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("op", &op); v->Visit("op", &op);
v->Visit("args", &args); v->Visit("args", &args);
v->Visit("attrs", &attrs); v->Visit("attrs", &attrs);
...@@ -369,7 +369,7 @@ class LetNode : public ExprNode { ...@@ -369,7 +369,7 @@ class LetNode : public ExprNode {
/*! \brief The body of the let binding */ /*! \brief The body of the let binding */
Expr body; Expr body;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var); v->Visit("var", &var);
v->Visit("value", &value); v->Visit("value", &value);
v->Visit("body", &body); v->Visit("body", &body);
...@@ -407,7 +407,7 @@ class IfNode : public ExprNode { ...@@ -407,7 +407,7 @@ class IfNode : public ExprNode {
/*! \brief The expression evaluated when condition is false */ /*! \brief The expression evaluated when condition is false */
Expr false_branch; Expr false_branch;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("cond", &cond); v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch); v->Visit("true_branch", &true_branch);
v->Visit("false_branch", &false_branch); v->Visit("false_branch", &false_branch);
...@@ -432,7 +432,7 @@ class TupleGetItemNode : public ExprNode { ...@@ -432,7 +432,7 @@ class TupleGetItemNode : public ExprNode {
/*! \brief which value to get */ /*! \brief which value to get */
int index; int index;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("tuple_value", &tuple); v->Visit("tuple_value", &tuple);
v->Visit("index", &index); v->Visit("index", &index);
v->Visit("span", &span); v->Visit("span", &span);
...@@ -454,7 +454,7 @@ class RefCreateNode : public ExprNode { ...@@ -454,7 +454,7 @@ class RefCreateNode : public ExprNode {
/*! \brief The initial value of the Reference. */ /*! \brief The initial value of the Reference. */
Expr value; Expr value;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value); v->Visit("value", &value);
v->Visit("span", &span); v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
...@@ -475,7 +475,7 @@ class RefReadNode : public ExprNode { ...@@ -475,7 +475,7 @@ class RefReadNode : public ExprNode {
/*! \brief The Reference Expression. */ /*! \brief The Reference Expression. */
Expr ref; Expr ref;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ref", &ref); v->Visit("ref", &ref);
v->Visit("span", &span); v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
...@@ -498,7 +498,7 @@ class RefWriteNode : public ExprNode { ...@@ -498,7 +498,7 @@ class RefWriteNode : public ExprNode {
/*! \brief The value to write into. */ /*! \brief The value to write into. */
Expr value; Expr value;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ref", &ref); v->Visit("ref", &ref);
v->Visit("value", &value); v->Visit("value", &value);
v->Visit("span", &span); v->Visit("span", &span);
......
...@@ -106,7 +106,7 @@ class ClosureNode : public ValueNode { ...@@ -106,7 +106,7 @@ class ClosureNode : public ValueNode {
ClosureNode() {} ClosureNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("env", &env); v->Visit("env", &env);
v->Visit("func", &func); v->Visit("func", &func);
} }
...@@ -154,7 +154,7 @@ struct TupleValueNode : ValueNode { ...@@ -154,7 +154,7 @@ struct TupleValueNode : ValueNode {
TupleValueNode() {} TupleValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); } void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }
TVM_DLL static TupleValue make(tvm::Array<Value> value); TVM_DLL static TupleValue make(tvm::Array<Value> value);
...@@ -173,7 +173,7 @@ struct TensorValueNode : ValueNode { ...@@ -173,7 +173,7 @@ struct TensorValueNode : ValueNode {
TensorValueNode() {} TensorValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); } void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); }
/*! \brief Build a value from an NDArray. */ /*! \brief Build a value from an NDArray. */
TVM_DLL static TensorValue make(runtime::NDArray data); TVM_DLL static TensorValue make(runtime::NDArray data);
...@@ -192,7 +192,7 @@ struct RefValueNode : ValueNode { ...@@ -192,7 +192,7 @@ struct RefValueNode : ValueNode {
RefValueNode() {} RefValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value); v->Visit("value", &value);
} }
...@@ -215,7 +215,7 @@ struct ConstructorValueNode : ValueNode { ...@@ -215,7 +215,7 @@ struct ConstructorValueNode : ValueNode {
/*! \brief Optional field tracking ADT constructor. */ /*! \brief Optional field tracking ADT constructor. */
Constructor constructor; Constructor constructor;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("tag", &tag); v->Visit("tag", &tag);
v->Visit("fields", &fields); v->Visit("fields", &fields);
v->Visit("constructor", &constructor); v->Visit("constructor", &constructor);
......
...@@ -68,7 +68,7 @@ class ModuleNode : public RelayNode { ...@@ -68,7 +68,7 @@ class ModuleNode : public RelayNode {
ModuleNode() {} ModuleNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("functions", &functions); v->Visit("functions", &functions);
v->Visit("type_definitions", &type_definitions); v->Visit("type_definitions", &type_definitions);
v->Visit("global_var_map_", &global_var_map_); v->Visit("global_var_map_", &global_var_map_);
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
#ifndef TVM_RELAY_OP_H_ #ifndef TVM_RELAY_OP_H_
#define TVM_RELAY_OP_H_ #define TVM_RELAY_OP_H_
#include <dmlc/registry.h>
#include <functional> #include <functional>
#include <limits> #include <limits>
#include <string> #include <string>
...@@ -82,7 +84,7 @@ class OpNode : public relay::ExprNode { ...@@ -82,7 +84,7 @@ class OpNode : public relay::ExprNode {
*/ */
int32_t support_level = 10; int32_t support_level = 10;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("op_type", &op_type); v->Visit("op_type", &op_type);
v->Visit("description", &description); v->Visit("description", &description);
......
...@@ -101,7 +101,7 @@ class PassContextNode : public RelayNode { ...@@ -101,7 +101,7 @@ class PassContextNode : public RelayNode {
PassContextNode() = default; PassContextNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("opt_level", &opt_level); v->Visit("opt_level", &opt_level);
v->Visit("fallback_device", &fallback_device); v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass); v->Visit("required_pass", &required_pass);
...@@ -196,7 +196,7 @@ class PassInfoNode : public RelayNode { ...@@ -196,7 +196,7 @@ class PassInfoNode : public RelayNode {
PassInfoNode() = default; PassInfoNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("opt_level", &opt_level); v->Visit("opt_level", &opt_level);
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("required", &required); v->Visit("required", &required);
...@@ -221,6 +221,7 @@ class Pass; ...@@ -221,6 +221,7 @@ class Pass;
*/ */
class PassNode : public RelayNode { class PassNode : public RelayNode {
public: public:
virtual ~PassNode() {}
/*! /*!
* \brief Get the pass information/meta data. */ * \brief Get the pass information/meta data. */
virtual PassInfo Info() const = 0; virtual PassInfo Info() const = 0;
...@@ -247,7 +248,7 @@ class PassNode : public RelayNode { ...@@ -247,7 +248,7 @@ class PassNode : public RelayNode {
virtual Module operator()(const Module& mod, virtual Module operator()(const Module& mod,
const PassContext& pass_ctx) const = 0; const PassContext& pass_ctx) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) override {} void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.Pass"; static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode); TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
......
...@@ -96,7 +96,7 @@ class TensorTypeNode : public BaseTensorTypeNode { ...@@ -96,7 +96,7 @@ class TensorTypeNode : public BaseTensorTypeNode {
/*! \brief The content data type */ /*! \brief The content data type */
DataType dtype; DataType dtype;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("shape", &shape); v->Visit("shape", &shape);
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
v->Visit("span", &span); v->Visit("span", &span);
...@@ -159,7 +159,7 @@ class TypeVarNode : public TypeNode { ...@@ -159,7 +159,7 @@ class TypeVarNode : public TypeNode {
/*! \brief The kind of type parameter */ /*! \brief The kind of type parameter */
Kind kind; Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var); v->Visit("var", &var);
v->Visit("kind", &kind); v->Visit("kind", &kind);
v->Visit("span", &span); v->Visit("span", &span);
...@@ -188,7 +188,7 @@ class GlobalTypeVarNode : public TypeNode { ...@@ -188,7 +188,7 @@ class GlobalTypeVarNode : public TypeNode {
/*! \brief The kind of type parameter */ /*! \brief The kind of type parameter */
Kind kind; Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var); v->Visit("var", &var);
v->Visit("kind", &kind); v->Visit("kind", &kind);
v->Visit("span", &span); v->Visit("span", &span);
...@@ -216,7 +216,7 @@ class TypeCallNode : public TypeNode { ...@@ -216,7 +216,7 @@ class TypeCallNode : public TypeNode {
/*! \brief The arguments. */ /*! \brief The arguments. */
tvm::Array<Type> args; tvm::Array<Type> args;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("func", &func); v->Visit("func", &func);
v->Visit("args", &args); v->Visit("args", &args);
v->Visit("span", &span); v->Visit("span", &span);
...@@ -245,7 +245,7 @@ class IncompleteTypeNode : public TypeNode { ...@@ -245,7 +245,7 @@ class IncompleteTypeNode : public TypeNode {
public: public:
Kind kind; Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("kind", &kind); v->Visit("kind", &kind);
v->Visit("span", &span); v->Visit("span", &span);
} }
...@@ -297,7 +297,7 @@ class FuncTypeNode : public TypeNode { ...@@ -297,7 +297,7 @@ class FuncTypeNode : public TypeNode {
*/ */
tvm::Array<TypeConstraint> type_constraints; tvm::Array<TypeConstraint> type_constraints;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("arg_types", &arg_types); v->Visit("arg_types", &arg_types);
v->Visit("ret_type", &ret_type); v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params); v->Visit("type_params", &type_params);
...@@ -330,7 +330,7 @@ class TupleTypeNode : public TypeNode { ...@@ -330,7 +330,7 @@ class TupleTypeNode : public TypeNode {
TupleTypeNode() {} TupleTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("fields", &fields); v->Visit("fields", &fields);
v->Visit("span", &span); v->Visit("span", &span);
} }
...@@ -357,7 +357,7 @@ class RefTypeNode : public TypeNode { ...@@ -357,7 +357,7 @@ class RefTypeNode : public TypeNode {
RefTypeNode() {} RefTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value); v->Visit("value", &value);
v->Visit("span", &span); v->Visit("span", &span);
} }
...@@ -417,7 +417,7 @@ class TypeReporterNode : public Node { ...@@ -417,7 +417,7 @@ class TypeReporterNode : public Node {
TVM_DLL virtual Module GetModule() = 0; TVM_DLL virtual Module GetModule() = 0;
// solver is not serializable. // solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {} void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.TypeReporter"; static constexpr const char* _type_key = "relay.TypeReporter";
TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node); TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node);
...@@ -488,7 +488,7 @@ class TypeRelationNode : public TypeConstraintNode { ...@@ -488,7 +488,7 @@ class TypeRelationNode : public TypeConstraintNode {
/*! \brief Attributes to the relation function */ /*! \brief Attributes to the relation function */
Attrs attrs; Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("func", &func); v->Visit("func", &func);
v->Visit("args", &args); v->Visit("args", &args);
v->Visit("num_inputs", &num_inputs); v->Visit("num_inputs", &num_inputs);
......
...@@ -230,6 +230,7 @@ inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*) ...@@ -230,6 +230,7 @@ inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*)
os << runtime::DeviceName(device_type) << "(" << ctx.device_id << ")"; os << runtime::DeviceName(device_type) << "(" << ctx.device_id << ")";
return os; return os;
} }
#endif #endif
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -82,6 +82,8 @@ class SimpleObjAllocator : ...@@ -82,6 +82,8 @@ class SimpleObjAllocator :
template<typename T> template<typename T>
class Handler { class Handler {
public: public:
using StorageType = typename std::aligned_storage<sizeof(T), alignof(T)>::type;
template<typename... Args> template<typename... Args>
static T* New(SimpleObjAllocator*, Args&&... args) { static T* New(SimpleObjAllocator*, Args&&... args) {
// NOTE: the first argument is not needed for SimpleObjAllocator // NOTE: the first argument is not needed for SimpleObjAllocator
...@@ -91,7 +93,15 @@ class SimpleObjAllocator : ...@@ -91,7 +93,15 @@ class SimpleObjAllocator :
// In the case of an object pool, an allocator needs to create // In the case of an object pool, an allocator needs to create
// a special chunk memory that hides reference to the allocator // a special chunk memory that hides reference to the allocator
// and call allocator's release function in the deleter. // and call allocator's release function in the deleter.
return new T(std::forward<Args>(args)...);
// NOTE2: Use inplace new to allocate
// This is used to get rid of warning when deleting a virtual
// class with non-virtual destructor.
// We are fine here as we captured the right deleter during construction.
// This is also the right way to get storage type for an object pool.
StorageType* data = new StorageType();
new (data) T(std::forward<Args>(args)...);
return reinterpret_cast<T*>(data);
} }
static Object::FDeleter Deleter() { static Object::FDeleter Deleter() {
...@@ -99,8 +109,17 @@ class SimpleObjAllocator : ...@@ -99,8 +109,17 @@ class SimpleObjAllocator :
} }
private: private:
static void Deleter_(Object* ptr) { static void Deleter_(Object* objptr) {
delete static_cast<T*>(ptr); // NOTE: this is important to cast back to T*
// because objptr and tptr may not be the same
// depending on how sub-class allocates the space.
T* tptr = static_cast<T*>(objptr);
// It is important to do tptr->T::~T(),
// so that we explicitly call the specific destructor
// instead of tptr->~T(), which could mean the intention
// call a virtual destructor(which may not be available and is not required).
tptr->T::~T();
delete reinterpret_cast<StorageType*>(tptr);
} }
}; };
}; };
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#ifndef TVM_RUNTIME_OBJECT_H_ #ifndef TVM_RUNTIME_OBJECT_H_
#define TVM_RUNTIME_OBJECT_H_ #define TVM_RUNTIME_OBJECT_H_
#include <dmlc/logging.h>
#include <type_traits> #include <type_traits>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -189,7 +190,7 @@ class Object { ...@@ -189,7 +190,7 @@ class Object {
* \param key The type key. * \param key The type key.
* \return the result. * \return the result.
*/ */
TVM_DLL static uint32_t TypeKey2Index(const char* key); TVM_DLL static uint32_t TypeKey2Index(const std::string& key);
#if TVM_OBJECT_ATOMIC_REF_COUNTER #if TVM_OBJECT_ATOMIC_REF_COUNTER
using RefCounterType = std::atomic<int32_t>; using RefCounterType = std::atomic<int32_t>;
...@@ -197,18 +198,24 @@ class Object { ...@@ -197,18 +198,24 @@ class Object {
using RefCounterType = int32_t; using RefCounterType = int32_t;
#endif #endif
// Object type properties
static constexpr const char* _type_key = "Object"; static constexpr const char* _type_key = "Object";
static constexpr bool _type_final = false;
static constexpr uint32_t _type_child_slots = 0;
static constexpr bool _type_child_slots_can_overflow = true;
static uint32_t _GetOrAllocRuntimeTypeIndex() { static uint32_t _GetOrAllocRuntimeTypeIndex() {
return 0; return TypeIndex::kRoot;
} }
static uint32_t RuntimeTypeIndex() { static uint32_t RuntimeTypeIndex() {
return 0; return TypeIndex::kRoot;
} }
// Default object type properties for sub-classes
static constexpr bool _type_final = false;
static constexpr uint32_t _type_child_slots = 0;
static constexpr bool _type_child_slots_can_overflow = true;
// NOTE: the following field is not type index of Object
// but was intended to be used by sub-classes as default value.
// The type index of Object is TypeIndex::kRoot
static constexpr uint32_t _type_index = TypeIndex::kDynamic;
// Default constructor and copy constructor // Default constructor and copy constructor
Object() {} Object() {}
// Override the copy and assign constructors to do nothing. // Override the copy and assign constructors to do nothing.
...@@ -262,13 +269,12 @@ class Object { ...@@ -262,13 +269,12 @@ class Object {
* \return The allocated type index. * \return The allocated type index.
*/ */
TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex( TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex(
const char* key, const std::string& key,
uint32_t static_tindex, uint32_t static_tindex,
uint32_t parent_tindex, uint32_t parent_tindex,
uint32_t type_child_slots, uint32_t type_child_slots,
bool type_child_slots_can_overflow); bool type_child_slots_can_overflow);
private:
// reference counter related operations // reference counter related operations
/*! \brief developer function, increases reference counter. */ /*! \brief developer function, increases reference counter. */
inline void IncRef(); inline void IncRef();
...@@ -621,8 +627,8 @@ struct ObjectEqual { ...@@ -621,8 +627,8 @@ struct ObjectEqual {
*/ */
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ #define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static const uint32_t RuntimeTypeIndex() { \ static const uint32_t RuntimeTypeIndex() { \
if (_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
return _type_index; \ return TypeName::_type_index; \
} \ } \
return _GetOrAllocRuntimeTypeIndex(); \ return _GetOrAllocRuntimeTypeIndex(); \
} \ } \
......
...@@ -51,8 +51,6 @@ namespace tvm { ...@@ -51,8 +51,6 @@ namespace tvm {
class Integer; class Integer;
class DataType; class DataType;
class Expr; class Expr;
class Node;
class NodeRef;
namespace runtime { namespace runtime {
...@@ -516,9 +514,9 @@ class TVMPODValue_ { ...@@ -516,9 +514,9 @@ class TVMPODValue_ {
CHECK_LT(type_code_, kExtEnd); CHECK_LT(type_code_, kExtEnd);
return static_cast<TExtension*>(value_.v_handle)[0]; return static_cast<TExtension*>(value_.v_handle)[0];
} }
template<typename TNodeRef, template<typename TObjectRef,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type> std::is_class<TObjectRef>::value>::type>
inline bool IsObjectRef() const; inline bool IsObjectRef() const;
int type_code() const { int type_code() const {
return type_code_; return type_code_;
...@@ -620,8 +618,8 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -620,8 +618,8 @@ class TVMArgValue : public TVMPODValue_ {
return value_; return value_;
} }
// Deferred extension handler. // Deferred extension handler.
template<typename TNodeRef> template<typename TObjectRef>
inline TNodeRef AsNodeRef() const; inline TObjectRef AsObjectRef() const;
template<typename T, template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<T>::value>::type> std::is_class<T>::value>::type>
...@@ -834,13 +832,13 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -834,13 +832,13 @@ class TVMRetValue : public TVMPODValue_ {
type_code_ != kStr) << "TVMRetValue.value can only be used for POD data"; type_code_ != kStr) << "TVMRetValue.value can only be used for POD data";
return value_; return value_;
} }
// NodeRef related extenstions: in tvm/packed_func_ext.h // ObjectRef related extenstions: in tvm/packed_func_ext.h
template<typename T, template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<T>::value>::type> std::is_class<T>::value>::type>
inline operator T() const; inline operator T() const;
template<typename TNodeRef> template<typename TObjectRef>
inline TNodeRef AsNodeRef() const; inline TObjectRef AsObjectRef() const;
// type related // type related
inline operator tvm::DataType() const; inline operator tvm::DataType() const;
inline TVMRetValue& operator=(const tvm::DataType& other); inline TVMRetValue& operator=(const tvm::DataType& other);
...@@ -1306,7 +1304,7 @@ template<typename T, typename TSrc, bool is_ext, bool is_nd> ...@@ -1306,7 +1304,7 @@ template<typename T, typename TSrc, bool is_ext, bool is_nd>
struct TVMValueCast { struct TVMValueCast {
static T Apply(const TSrc* self) { static T Apply(const TSrc* self) {
static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions"); static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions");
return self->template AsNodeRef<T>(); return self->template AsObjectRef<T>();
} }
}; };
......
...@@ -91,7 +91,7 @@ class Registry { ...@@ -91,7 +91,7 @@ class Registry {
* Note that this will ignore default arg values and always require all arguments to be provided. * Note that this will ignore default arg values and always require all arguments to be provided.
* *
* \code * \code
* *
* int multiply(int x, int y) { * int multiply(int x, int y) {
* return x * y; * return x * y;
* } * }
...@@ -115,7 +115,7 @@ class Registry { ...@@ -115,7 +115,7 @@ class Registry {
* Note that this will ignore default arg values and always require all arguments to be provided. * Note that this will ignore default arg values and always require all arguments to be provided.
* *
* \code * \code
* *
* // node subclass: * // node subclass:
* struct Example { * struct Example {
* int doThing(int x); * int doThing(int x);
...@@ -143,7 +143,7 @@ class Registry { ...@@ -143,7 +143,7 @@ class Registry {
* Note that this will ignore default arg values and always require all arguments to be provided. * Note that this will ignore default arg values and always require all arguments to be provided.
* *
* \code * \code
* *
* // node subclass: * // node subclass:
* struct Example { * struct Example {
* int doThing(int x); * int doThing(int x);
...@@ -168,22 +168,22 @@ class Registry { ...@@ -168,22 +168,22 @@ class Registry {
/*! /*!
* \brief set the body of the function to be the passed method pointer. * \brief set the body of the function to be the passed method pointer.
* Used when calling a method on a Node subclass through a NodeRef subclass. * Used when calling a method on a Node subclass through a ObjectRef subclass.
* Note that this will ignore default arg values and always require all arguments to be provided. * Note that this will ignore default arg values and always require all arguments to be provided.
* *
* \code * \code
* *
* // node subclass: * // node subclass:
* struct ExampleNode: BaseNode { * struct ExampleNode: BaseNode {
* int doThing(int x); * int doThing(int x);
* } * }
* *
* // noderef subclass * // noderef subclass
* struct Example; * struct Example;
* *
* TVM_REGISTER_API("Example_doThing") * TVM_REGISTER_API("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int) * .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
* *
* // note that just doing: * // note that just doing:
* // .set_body_method(&ExampleNode::doThing); * // .set_body_method(&ExampleNode::doThing);
* // wouldn't work, because ExampleNode can't be taken from a TVMArgValue. * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
...@@ -191,15 +191,15 @@ class Registry { ...@@ -191,15 +191,15 @@ class Registry {
* \endcode * \endcode
* *
* \param f the method pointer to forward to. * \param f the method pointer to forward to.
* \tparam TNodeRef the node reference type to call the method on * \tparam TObjectRef the node reference type to call the method on
* \tparam TNode the node type containing the method (inferred). * \tparam TNode the node type containing the method (inferred).
* \tparam R the return type of the function (inferred). * \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred). * \tparam Args the argument types of the function (inferred).
*/ */
template<typename TNodeRef, typename TNode, typename R, typename ...Args, template<typename TObjectRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type> typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...)) { Registry& set_body_method(R (TNode::*f)(Args...)) {
return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) { return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
TNode* target = ref.operator->(); TNode* target = ref.operator->();
// call method pointer // call method pointer
return (target->*f)(params...); return (target->*f)(params...);
...@@ -208,22 +208,22 @@ class Registry { ...@@ -208,22 +208,22 @@ class Registry {
/*! /*!
* \brief set the body of the function to be the passed method pointer. * \brief set the body of the function to be the passed method pointer.
* Used when calling a method on a Node subclass through a NodeRef subclass. * Used when calling a method on a Node subclass through a ObjectRef subclass.
* Note that this will ignore default arg values and always require all arguments to be provided. * Note that this will ignore default arg values and always require all arguments to be provided.
* *
* \code * \code
* *
* // node subclass: * // node subclass:
* struct ExampleNode: BaseNode { * struct ExampleNode: BaseNode {
* int doThing(int x); * int doThing(int x);
* } * }
* *
* // noderef subclass * // noderef subclass
* struct Example; * struct Example;
* *
* TVM_REGISTER_API("Example_doThing") * TVM_REGISTER_API("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int) * .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
* *
* // note that just doing: * // note that just doing:
* // .set_body_method(&ExampleNode::doThing); * // .set_body_method(&ExampleNode::doThing);
* // wouldn't work, because ExampleNode can't be taken from a TVMArgValue. * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
...@@ -231,15 +231,15 @@ class Registry { ...@@ -231,15 +231,15 @@ class Registry {
* \endcode * \endcode
* *
* \param f the method pointer to forward to. * \param f the method pointer to forward to.
* \tparam TNodeRef the node reference type to call the method on * \tparam TObjectRef the node reference type to call the method on
* \tparam TNode the node type containing the method (inferred). * \tparam TNode the node type containing the method (inferred).
* \tparam R the return type of the function (inferred). * \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred). * \tparam Args the argument types of the function (inferred).
*/ */
template<typename TNodeRef, typename TNode, typename R, typename ...Args, template<typename TObjectRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type> typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...) const) { Registry& set_body_method(R (TNode::*f)(Args...) const) {
return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) { return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
const TNode* target = ref.operator->(); const TNode* target = ref.operator->();
// call method pointer // call method pointer
return (target->*f)(params...); return (target->*f)(params...);
......
...@@ -495,7 +495,7 @@ class StageNode : public Node { ...@@ -495,7 +495,7 @@ class StageNode : public Node {
/*! \brief Number of direct child stages, only used for group stage.*/ /*! \brief Number of direct child stages, only used for group stage.*/
int num_child_stages{0}; int num_child_stages{0};
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("op", &op); v->Visit("op", &op);
v->Visit("origin_op", &origin_op); v->Visit("origin_op", &origin_op);
v->Visit("all_iter_vars", &all_iter_vars); v->Visit("all_iter_vars", &all_iter_vars);
...@@ -540,7 +540,7 @@ class ScheduleNode : public Node { ...@@ -540,7 +540,7 @@ class ScheduleNode : public Node {
*/ */
std::unordered_map<const Node*, Stage> op2stage_cache_; std::unordered_map<const Node*, Stage> op2stage_cache_;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("outputs", &outputs); v->Visit("outputs", &outputs);
v->Visit("stages", &stages); v->Visit("stages", &stages);
v->Visit("groups", &groups); v->Visit("groups", &groups);
...@@ -617,7 +617,7 @@ class IterVarAttrNode : public Node { ...@@ -617,7 +617,7 @@ class IterVarAttrNode : public Node {
*/ */
Array<Expr> pragma_values; Array<Expr> pragma_values;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("iter_type", &iter_type); v->Visit("iter_type", &iter_type);
v->Visit("bind_thread", &bind_thread); v->Visit("bind_thread", &bind_thread);
v->Visit("prefetch_data", &prefetch_data); v->Visit("prefetch_data", &prefetch_data);
...@@ -657,7 +657,7 @@ class SplitNode : public IterVarRelationNode { ...@@ -657,7 +657,7 @@ class SplitNode : public IterVarRelationNode {
/*! \brief Number of parts, only factor or nparts can be given */ /*! \brief Number of parts, only factor or nparts can be given */
Expr nparts; Expr nparts;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("parent", &parent); v->Visit("parent", &parent);
v->Visit("outer", &outer); v->Visit("outer", &outer);
v->Visit("inner", &inner); v->Visit("inner", &inner);
...@@ -687,7 +687,7 @@ class FuseNode : public IterVarRelationNode { ...@@ -687,7 +687,7 @@ class FuseNode : public IterVarRelationNode {
/*! \brief The target domain */ /*! \brief The target domain */
IterVar fused; IterVar fused;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("outer", &outer); v->Visit("outer", &outer);
v->Visit("inner", &inner); v->Visit("inner", &inner);
v->Visit("fused", &fused); v->Visit("fused", &fused);
...@@ -712,7 +712,7 @@ class RebaseNode : public IterVarRelationNode { ...@@ -712,7 +712,7 @@ class RebaseNode : public IterVarRelationNode {
/*! \brief The inner domain */ /*! \brief The inner domain */
IterVar rebased; IterVar rebased;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("parent", &parent); v->Visit("parent", &parent);
v->Visit("rebased", &rebased); v->Visit("rebased", &rebased);
} }
...@@ -732,7 +732,7 @@ class SingletonNode : public IterVarRelationNode { ...@@ -732,7 +732,7 @@ class SingletonNode : public IterVarRelationNode {
/*! \brief The singleton iterator */ /*! \brief The singleton iterator */
IterVar iter; IterVar iter;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("iter", &iter); v->Visit("iter", &iter);
} }
......
...@@ -47,7 +47,7 @@ struct MemoryInfoNode : public Node { ...@@ -47,7 +47,7 @@ struct MemoryInfoNode : public Node {
*/ */
Expr head_address; Expr head_address;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("unit_bits", &unit_bits); v->Visit("unit_bits", &unit_bits);
v->Visit("max_num_bits", &max_num_bits); v->Visit("max_num_bits", &max_num_bits);
v->Visit("max_simd_bits", &max_simd_bits); v->Visit("max_simd_bits", &max_simd_bits);
......
...@@ -171,7 +171,7 @@ class TensorNode : public Node { ...@@ -171,7 +171,7 @@ class TensorNode : public Node {
/*! \brief constructor */ /*! \brief constructor */
TensorNode() {} TensorNode() {}
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("shape", &shape); v->Visit("shape", &shape);
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
v->Visit("op", &op); v->Visit("op", &op);
......
...@@ -87,7 +87,7 @@ class TensorIntrinNode : public Node { ...@@ -87,7 +87,7 @@ class TensorIntrinNode : public Node {
/*! \brief constructor */ /*! \brief constructor */
TensorIntrinNode() {} TensorIntrinNode() {}
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("op", &op); v->Visit("op", &op);
v->Visit("inputs", &inputs); v->Visit("inputs", &inputs);
...@@ -152,7 +152,7 @@ class TensorIntrinCallNode : public Node { ...@@ -152,7 +152,7 @@ class TensorIntrinCallNode : public Node {
/*! \brief scalar expression inputs */ /*! \brief scalar expression inputs */
Array<Expr> scalar_inputs; Array<Expr> scalar_inputs;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("intrin", &intrin); v->Visit("intrin", &intrin);
v->Visit("tensors", &tensors); v->Visit("tensors", &tensors);
v->Visit("regions", &regions); v->Visit("regions", &regions);
......
...@@ -55,7 +55,7 @@ struct GraphFuncNode : public tvm::Node { ...@@ -55,7 +55,7 @@ struct GraphFuncNode : public tvm::Node {
/*! \brief The lowered functions */ /*! \brief The lowered functions */
tvm::Array<tvm::LoweredFunc> funcs; tvm::Array<tvm::LoweredFunc> funcs;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("target", &target); v->Visit("target", &target);
v->Visit("func_name", &func_name); v->Visit("func_name", &func_name);
v->Visit("inputs", &inputs); v->Visit("inputs", &inputs);
...@@ -78,7 +78,7 @@ struct GraphCacheEntryNode : public tvm::Node { ...@@ -78,7 +78,7 @@ struct GraphCacheEntryNode : public tvm::Node {
/*! \brief Index of the master node for calling schedule*/ /*! \brief Index of the master node for calling schedule*/
int master_idx; int master_idx;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("graph_func", &graph_func); v->Visit("graph_func", &graph_func);
v->Visit("use_count", &use_count); v->Visit("use_count", &use_count);
v->Visit("master_idx", &master_idx); v->Visit("master_idx", &master_idx);
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -48,7 +48,7 @@ struct GraphKeyNode : public tvm::Node { ...@@ -48,7 +48,7 @@ struct GraphKeyNode : public tvm::Node {
// The graph hash key is ensured always not to be 0 // The graph hash key is ensured always not to be 0
mutable size_t cache_hash_key_{0}; mutable size_t cache_hash_key_{0};
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("inputs", &inputs); v->Visit("inputs", &inputs);
v->Visit("target", &target); v->Visit("target", &target);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,11 +18,12 @@ ...@@ -18,11 +18,12 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file graph_runtime.cc * \file graph_runtime.cc
* \brief Interface code with TVM graph runtime. * \brief Interface code with TVM graph runtime.
*/ */
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h>
#include <utility> #include <utility>
#include "graph_runtime.h" #include "graph_runtime.h"
......
...@@ -61,13 +61,13 @@ struct NDArrayWrapperNode : public ::tvm::Node { ...@@ -61,13 +61,13 @@ struct NDArrayWrapperNode : public ::tvm::Node {
std::string name; std::string name;
tvm::runtime::NDArray array; tvm::runtime::NDArray array;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("array", &array); v->Visit("array", &array);
} }
static constexpr const char* _type_key = "NDArrayWrapper"; static constexpr const char* _type_key = "NDArrayWrapper";
TVM_DECLARE_NODE_TYPE_INFO(NDArrayWrapperNode, Node); TVM_DECLARE_NODE_TYPE_INFO(NDArrayWrapperNode, tvm::Node);
}; };
TVM_DEFINE_NODE_REF(NDArrayWrapper, NDArrayWrapperNode); TVM_DEFINE_NODE_REF(NDArrayWrapper, NDArrayWrapperNode);
......
...@@ -22,6 +22,8 @@ There can be internal header files within each module that sit in src. ...@@ -22,6 +22,8 @@ There can be internal header files within each module that sit in src.
## Modules ## Modules
- common: Internal common utilities. - common: Internal common utilities.
- runtime: Minimum runtime related codes.
- node: base infra for IR/AST nodes that is dialect independent.
- api: API function registration. - api: API function registration.
- lang: The definition of DSL related data structure. - lang: The definition of DSL related data structure.
- arithmetic: Arithmetic expression and set simplification. - arithmetic: Arithmetic expression and set simplification.
...@@ -29,7 +31,6 @@ There can be internal header files within each module that sit in src. ...@@ -29,7 +31,6 @@ There can be internal header files within each module that sit in src.
- schedule: The operations on the schedule graph before converting to IR. - schedule: The operations on the schedule graph before converting to IR.
- pass: The optimization pass on the IR structure. - pass: The optimization pass on the IR structure.
- codegen: The code generator. - codegen: The code generator.
- runtime: Minimum runtime related codes.
- autotvm: The auto-tuning module. - autotvm: The auto-tuning module.
- relay: Implementation of Relay. The second generation of NNVM, a new IR for deep learning frameworks. - relay: Implementation of Relay. The second generation of NNVM, a new IR for deep learning frameworks.
- contrib: Contrib extension libraries. - contrib: Contrib extension libraries.
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <tvm/node/serialization.h>
namespace tvm { namespace tvm {
TVM_REGISTER_API("_format_str") TVM_REGISTER_API("_format_str")
...@@ -43,10 +44,10 @@ TVM_REGISTER_API("_raw_ptr") ...@@ -43,10 +44,10 @@ TVM_REGISTER_API("_raw_ptr")
}); });
TVM_REGISTER_API("_save_json") TVM_REGISTER_API("_save_json")
.set_body_typed<std::string(NodeRef)>(SaveJSON); .set_body_typed<std::string(ObjectRef)>(SaveJSON);
TVM_REGISTER_API("_load_json") TVM_REGISTER_API("_load_json")
.set_body_typed<NodeRef(std::string)>(LoadJSON<NodeRef>); .set_body_typed<ObjectRef(std::string)>(LoadJSON);
TVM_REGISTER_API("_TVMSetStream") TVM_REGISTER_API("_TVMSetStream")
.set_body_typed(TVMSetStream); .set_body_typed(TVMSetStream);
......
...@@ -53,17 +53,17 @@ class VariablePathFinder: public IRVisitor { ...@@ -53,17 +53,17 @@ class VariablePathFinder: public IRVisitor {
if (!found_) path_.pop_back(); if (!found_) path_.pop_back();
} }
std::vector<const Node*> path_; std::vector<const Object*> path_;
private: private:
bool found_{false}; bool found_{false};
Expr target_; Expr target_;
std::unordered_set<const Node*> visited_; std::unordered_set<const Object*> visited_;
}; };
// get the path to the variable, // get the path to the variable,
// return empty vector to represent failure // return empty vector to represent failure
std::vector<const Node*> GetPath(Expr target, Expr expr) { std::vector<const Object*> GetPath(Expr target, Expr expr) {
VariablePathFinder v(target); VariablePathFinder v(target);
v.Visit(expr); v.Visit(expr);
return v.path_; return v.path_;
...@@ -189,7 +189,7 @@ class BoundDeducer: public IRVisitor { ...@@ -189,7 +189,7 @@ class BoundDeducer: public IRVisitor {
const std::unordered_map<const Variable*, IntSet>& hint_map_; const std::unordered_map<const Variable*, IntSet>& hint_map_;
const std::unordered_map<const Variable*, IntSet>& relax_map_; const std::unordered_map<const Variable*, IntSet>& relax_map_;
ExprIntSetMap expr_map_; ExprIntSetMap expr_map_;
std::vector<const Node*> path_; std::vector<const Object*> path_;
size_t iter_{0}; size_t iter_{0};
// internal analzyer // internal analzyer
Analyzer analyzer_; Analyzer analyzer_;
......
...@@ -43,6 +43,7 @@ class SplitExpr; ...@@ -43,6 +43,7 @@ class SplitExpr;
*/ */
class CanonicalExprNode : public BaseExprNode { class CanonicalExprNode : public BaseExprNode {
public: public:
virtual ~CanonicalExprNode() {}
/*! /*!
* \brief Return the normal Expr that is equivalent to self. * \brief Return the normal Expr that is equivalent to self.
* \note Can mutate the internal data structure. * \note Can mutate the internal data structure.
...@@ -51,7 +52,7 @@ class CanonicalExprNode : public BaseExprNode { ...@@ -51,7 +52,7 @@ class CanonicalExprNode : public BaseExprNode {
virtual Expr Normalize() const = 0; virtual Expr Normalize() const = 0;
// overrides // overrides
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
} }
static constexpr const char* _type_key = "arith.CanonicalExpr"; static constexpr const char* _type_key = "arith.CanonicalExpr";
...@@ -485,7 +486,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { ...@@ -485,7 +486,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
* \return Normalized expr. * \return Normalized expr.
*/ */
Expr Normalize(Expr expr) { Expr Normalize(Expr expr) {
if (const auto* op = expr.as_derived<CanonicalExprNode>()) { if (const auto* op = expr.as<CanonicalExprNode>()) {
return op->Normalize(); return op->Normalize();
} else { } else {
return expr; return expr;
...@@ -503,7 +504,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { ...@@ -503,7 +504,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
if (const auto* op = expr.as<SumExprNode>()) { if (const auto* op = expr.as<SumExprNode>()) {
if (op->base == 0 && op->args.size() == 1) return op->args[0]; if (op->base == 0 && op->args.size() == 1) return op->args[0];
} }
if (const auto* op = expr.as_derived<CanonicalExprNode>()) { if (const auto* op = expr.as<CanonicalExprNode>()) {
expr = op->Normalize(); expr = op->Normalize();
} }
NodePtr<SplitExprNode> n = make_node<SplitExprNode>(); NodePtr<SplitExprNode> n = make_node<SplitExprNode>();
......
...@@ -807,6 +807,8 @@ IntSet EvalSet(Range r, ...@@ -807,6 +807,8 @@ IntSet EvalSet(Range r,
return EvalSet(r, ConvertDomMap(dom_map)); return EvalSet(r, ConvertDomMap(dom_map));
} }
TVM_REGISTER_NODE_TYPE(IntervalSetNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntervalSetNode>([](const IntervalSetNode *op, IRPrinter *p) { .set_dispatch<IntervalSetNode>([](const IntervalSetNode *op, IRPrinter *p) {
p->stream << "IntervalSet" p->stream << "IntervalSet"
......
...@@ -47,7 +47,7 @@ class IntervalSetNode : public IntSetNode { ...@@ -47,7 +47,7 @@ class IntervalSetNode : public IntSetNode {
Expr max_value; Expr max_value;
// visitor overload. // visitor overload.
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("min_value", &min_value); v->Visit("min_value", &min_value);
v->Visit("max_value", &max_value); v->Visit("max_value", &max_value);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule_spirv.cc * \file intrin_rule_spirv.cc
*/ */
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <GLSL.std.450.h> #include <GLSL.std.450.h>
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -62,7 +62,7 @@ TVM_REGISTER_API("_EnvFuncGetPackedFunc") ...@@ -62,7 +62,7 @@ TVM_REGISTER_API("_EnvFuncGetPackedFunc")
TVM_REGISTER_NODE_TYPE(EnvFuncNode) TVM_REGISTER_NODE_TYPE(EnvFuncNode)
.set_creator(CreateEnvNode) .set_creator(CreateEnvNode)
.set_global_key([](const Node* n) { .set_global_key([](const Object* n) {
return static_cast<const EnvFuncNode*>(n)->name; return static_cast<const EnvFuncNode*>(n)->name;
}); });
......
...@@ -1150,6 +1150,8 @@ TVM_REGISTER_NODE_TYPE(Select); ...@@ -1150,6 +1150,8 @@ TVM_REGISTER_NODE_TYPE(Select);
TVM_REGISTER_NODE_TYPE(Load); TVM_REGISTER_NODE_TYPE(Load);
TVM_REGISTER_NODE_TYPE(Ramp); TVM_REGISTER_NODE_TYPE(Ramp);
TVM_REGISTER_NODE_TYPE(Broadcast); TVM_REGISTER_NODE_TYPE(Broadcast);
TVM_REGISTER_NODE_TYPE(Shuffle);
TVM_REGISTER_NODE_TYPE(Prefetch);
TVM_REGISTER_NODE_TYPE(Call); TVM_REGISTER_NODE_TYPE(Call);
TVM_REGISTER_NODE_TYPE(Let); TVM_REGISTER_NODE_TYPE(Let);
TVM_REGISTER_NODE_TYPE(LetStmt); TVM_REGISTER_NODE_TYPE(LetStmt);
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file target_info.cc * \file target_info.cc
*/ */
#include <tvm/runtime/registry.h>
#include <tvm/target_info.h> #include <tvm/target_info.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
......
...@@ -18,22 +18,27 @@ ...@@ -18,22 +18,27 @@
*/ */
/*! /*!
* Implementation of DSL API * Reflection utilities.
* \file dsl_api.cc * \file node/reflection.cc
*/ */
#include <dmlc/logging.h> #include <tvm/runtime/registry.h>
#include <tvm/api_registry.h> #include <tvm/node/node.h>
#include <tvm/node/container.h>
#include <tvm/node/reflection.h>
#include <tvm/attrs.h> #include <tvm/attrs.h>
#include <tvm/expr.h>
#include <vector>
#include <string>
namespace tvm { namespace tvm {
namespace runtime {
struct APIAttrGetter : public AttrVisitor { // Attr getter.
std::string skey; class AttrGetter : public AttrVisitor {
public:
const std::string& skey;
TVMRetValue* ret; TVMRetValue* ret;
AttrGetter(const std::string &skey,
TVMRetValue* ret)
: skey(skey), ret(ret) {}
bool found_ref_object{false}; bool found_ref_object{false};
void Visit(const char* key, double* value) final { void Visit(const char* key, double* value) final {
...@@ -62,12 +67,7 @@ struct APIAttrGetter : public AttrVisitor { ...@@ -62,12 +67,7 @@ struct APIAttrGetter : public AttrVisitor {
void Visit(const char* key, std::string* value) final { void Visit(const char* key, std::string* value) final {
if (skey == key) *ret = value[0]; if (skey == key) *ret = value[0];
} }
void Visit(const char* key, NodeRef* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
}
}
void Visit(const char* key, runtime::NDArray* value) final { void Visit(const char* key, runtime::NDArray* value) final {
if (skey == key) { if (skey == key) {
*ret = value[0]; *ret = value[0];
...@@ -82,7 +82,39 @@ struct APIAttrGetter : public AttrVisitor { ...@@ -82,7 +82,39 @@ struct APIAttrGetter : public AttrVisitor {
} }
}; };
struct APIAttrDir : public AttrVisitor { runtime::TVMRetValue ReflectionVTable::GetAttr(
Object* self, const std::string& field_name) const {
runtime::TVMRetValue ret;
AttrGetter getter(field_name, &ret);
bool success;
if (getter.skey == "type_key") {
ret = self->GetTypeKey();
success = true;
} else if (!self->IsInstance<DictAttrsNode>()) {
VisitAttrs(self, &getter);
success = getter.found_ref_object || ret.type_code() != kNull;
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
auto it = dnode->dict.find(getter.skey);
if (it != dnode->dict.end()) {
success = true;
ret = (*it).second;
} else {
success = false;
}
}
if (!success) {
LOG(FATAL) << "AttributeError: " << self->GetTypeKey()
<< " object has no attributed " << getter.skey;
}
return ret;
}
// List names;
class AttrDir : public AttrVisitor {
public:
std::vector<std::string>* names; std::vector<std::string>* names;
void Visit(const char* key, double* value) final { void Visit(const char* key, double* value) final {
...@@ -109,9 +141,6 @@ struct APIAttrDir : public AttrVisitor { ...@@ -109,9 +141,6 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, std::string* value) final { void Visit(const char* key, std::string* value) final {
names->push_back(key); names->push_back(key);
} }
void Visit(const char* key, NodeRef* value) final {
names->push_back(key);
}
void Visit(const char* key, runtime::NDArray* value) final { void Visit(const char* key, runtime::NDArray* value) final {
names->push_back(key); names->push_back(key);
} }
...@@ -120,71 +149,158 @@ struct APIAttrDir : public AttrVisitor { ...@@ -120,71 +149,158 @@ struct APIAttrDir : public AttrVisitor {
} }
}; };
struct NodeAPI { std::vector<std::string>
static void GetAttr(TVMArgs args, TVMRetValue* ret) { ReflectionVTable::ListAttrNames(Object* self) const {
NodeRef ref = args[0]; std::vector<std::string> names;
Node* tnode = const_cast<Node*>(ref.get()); AttrDir dir;
APIAttrGetter getter; dir.names = &names;
getter.skey = args[1].operator std::string();
getter.ret = ret; if (!self->IsInstance<DictAttrsNode>()) {
VisitAttrs(self, &dir);
bool success; } else {
if (getter.skey == "type_key") { // specially handle dict attr
*ret = tnode->GetTypeKey(); DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
success = true; for (const auto& kv : dnode->dict) {
} else if (!tnode->IsInstance<DictAttrsNode>()) { names.push_back(kv.first);
tnode->VisitAttrs(&getter);
success = getter.found_ref_object || ret->type_code() != kNull;
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode);
auto it = dnode->dict.find(getter.skey);
if (it != dnode->dict.end()) {
success = true;
*ret = (*it).second;
} else {
success = false;
}
} }
if (!success) { }
LOG(FATAL) << "AttributeError: " << tnode->GetTypeKey() return names;
<< " object has no attributed " << getter.skey; }
ReflectionVTable* ReflectionVTable::Global() {
static ReflectionVTable inst;
return &inst;
}
ObjectPtr<Object>
ReflectionVTable::CreateInitObject(const std::string& type_key,
const std::string& global_key) const {
uint32_t tindex = Object::TypeKey2Index(type_key);
if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: " << type_key
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
return fcreate_[tindex](global_key);
}
class NodeAttrSetter : public AttrVisitor {
public:
std::string type_key;
std::unordered_map<std::string, runtime::TVMArgValue> attrs;
void Visit(const char* key, double* value) final {
*value = GetAttr(key).operator double();
}
void Visit(const char* key, int64_t* value) final {
*value = GetAttr(key).operator int64_t();
}
void Visit(const char* key, uint64_t* value) final {
*value = GetAttr(key).operator uint64_t();
}
void Visit(const char* key, int* value) final {
*value = GetAttr(key).operator int();
}
void Visit(const char* key, bool* value) final {
*value = GetAttr(key).operator bool();
}
void Visit(const char* key, std::string* value) final {
*value = GetAttr(key).operator std::string();
}
void Visit(const char* key, void** value) final {
*value = GetAttr(key).operator void*();
}
void Visit(const char* key, DataType* value) final {
*value = GetAttr(key).operator DataType();
}
void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray();
}
void Visit(const char* key, ObjectRef* value) final {
*value = GetAttr(key).operator ObjectRef();
}
private:
runtime::TVMArgValue GetAttr(const char* key) {
auto it = attrs.find(key);
if (it == attrs.end()) {
LOG(FATAL) << type_key << ": require field " << key;
} }
runtime::TVMArgValue v = it->second;
attrs.erase(it);
return v;
} }
};
static void ListAttrNames(TVMArgs args, TVMRetValue* ret) { void InitNodeByPackedArgs(Object* n, const TVMArgs& args) {
NodeRef ref = args[0]; NodeAttrSetter setter;
Node* tnode = const_cast<Node*>(ref.get()); setter.type_key = n->GetTypeKey();
auto names = std::make_shared<std::vector<std::string> >(); CHECK_EQ(args.size() % 2, 0);
APIAttrDir dir; for (int i = 0; i < args.size(); i += 2) {
dir.names = names.get(); setter.attrs.emplace(args[i].operator std::string(),
args[i + 1]);
}
auto* reflection = ReflectionVTable::Global();
reflection->VisitAttrs(n, &setter);
if (!tnode->IsInstance<DictAttrsNode>()) { if (setter.attrs.size() != 0) {
tnode->VisitAttrs(&dir); std::ostringstream os;
} else { os << setter.type_key << " does not contain field ";
// specially handle dict attr for (const auto &kv : setter.attrs) {
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode); os << " " << kv.first;
for (const auto& kv : dnode->dict) {
names->push_back(kv.first);
}
} }
LOG(FATAL) << os.str();
}
}
// Expose to FFI APIs.
void NodeGetAttr(TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kObjectHandle);
Object* self = static_cast<Object*>(args[0].value().v_handle);
*ret = ReflectionVTable::Global()->GetAttr(self, args[1]);
}
void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kObjectHandle);
Object* self = static_cast<Object*>(args[0].value().v_handle);
*ret = PackedFunc([names](TVMArgs args, TVMRetValue *rv) { auto names = std::make_shared<std::vector<std::string> >(
int64_t i = args[0]; ReflectionVTable::Global()->ListAttrNames(self));
if (i == -1) {
*rv = static_cast<int64_t>(names->size()); *ret = PackedFunc([names](TVMArgs args, TVMRetValue *rv) {
} else { int64_t i = args[0];
*rv = (*names)[i]; if (i == -1) {
} *rv = static_cast<int64_t>(names->size());
}); } else {
*rv = (*names)[i];
}
});
}
// API function to make node.
// args format:
// key1, value1, ..., key_n, value_n
void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
std::string type_key = args[0];
std::string empty_str;
TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1);
auto* reflection = ReflectionVTable::Global();
ObjectPtr<Object> n = reflection->CreateInitObject(type_key);
if (n->IsInstance<BaseAttrsNode>()) {
static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
} else {
InitNodeByPackedArgs(n.get(), kwargs);
} }
}; *rv = ObjectRef(n);
}
TVM_REGISTER_GLOBAL("_NodeGetAttr") TVM_REGISTER_GLOBAL("_NodeGetAttr")
.set_body(NodeAPI::GetAttr); .set_body(NodeGetAttr);
TVM_REGISTER_GLOBAL("_NodeListAttrNames") TVM_REGISTER_GLOBAL("_NodeListAttrNames")
.set_body(NodeAPI::ListAttrNames); .set_body(NodeListAttrNames);
TVM_REGISTER_GLOBAL("make._Node")
.set_body(MakeNode);
} // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -59,7 +59,7 @@ struct CachedFuncNode : public Node { ...@@ -59,7 +59,7 @@ struct CachedFuncNode : public Node {
/*! \brief Parameter usage states in the shape function. */ /*! \brief Parameter usage states in the shape function. */
tvm::Array<Integer> shape_func_param_states; tvm::Array<Integer> shape_func_param_states;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("target", &target); v->Visit("target", &target);
v->Visit("func_name", &func_name); v->Visit("func_name", &func_name);
v->Visit("inputs", &inputs); v->Visit("inputs", &inputs);
...@@ -84,7 +84,7 @@ class CCacheKeyNode : public Node { ...@@ -84,7 +84,7 @@ class CCacheKeyNode : public Node {
/*! \brief The hardware target.*/ /*! \brief The hardware target.*/
Target target; Target target;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("source_func", &source_func); v->Visit("source_func", &source_func);
v->Visit("target", &target); v->Visit("target", &target);
} }
...@@ -141,7 +141,7 @@ class CCacheValueNode : public Node { ...@@ -141,7 +141,7 @@ class CCacheValueNode : public Node {
/*! \brief usage statistics */ /*! \brief usage statistics */
int use_count{0}; int use_count{0};
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("cached_func", &cached_func); v->Visit("cached_func", &cached_func);
v->Visit("use_count", &use_count); v->Visit("use_count", &use_count);
} }
...@@ -191,7 +191,7 @@ class CompileEngineNode : public Node { ...@@ -191,7 +191,7 @@ class CompileEngineNode : public Node {
virtual void Clear() = 0; virtual void Clear() = 0;
// VisitAttrs // VisitAttrs
void VisitAttrs(AttrVisitor*) final {} void VisitAttrs(AttrVisitor*) {}
static constexpr const char* _type_key = "relay.CompileEngine"; static constexpr const char* _type_key = "relay.CompileEngine";
TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node); TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node);
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file src/tvm/relay/interpreter.cc * \file src/tvm/relay/interpreter.cc
* \brief An interpreter for the Relay IR. * \brief An interpreter for the Relay IR.
*/ */
...@@ -116,6 +115,8 @@ RefValue RefValueNode::make(Value value) { ...@@ -116,6 +115,8 @@ RefValue RefValueNode::make(Value value) {
TVM_REGISTER_API("relay._make.RefValue") TVM_REGISTER_API("relay._make.RefValue")
.set_body_typed(RefValueNode::make); .set_body_typed(RefValueNode::make);
TVM_REGISTER_NODE_TYPE(RefValueNode);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefValueNode>([](const RefValueNode* node, .set_dispatch<RefValueNode>([](const RefValueNode* node,
tvm::IRPrinter* p) { tvm::IRPrinter* p) {
...@@ -135,6 +136,8 @@ ConstructorValue ConstructorValueNode::make(int32_t tag, ...@@ -135,6 +136,8 @@ ConstructorValue ConstructorValueNode::make(int32_t tag,
TVM_REGISTER_API("relay._make.ConstructorValue") TVM_REGISTER_API("relay._make.ConstructorValue")
.set_body_typed(ConstructorValueNode::make); .set_body_typed(ConstructorValueNode::make);
TVM_REGISTER_NODE_TYPE(ConstructorValueNode);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node, .set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
tvm::IRPrinter* p) { tvm::IRPrinter* p) {
...@@ -207,7 +210,7 @@ class InterpreterStateNode : public Node { ...@@ -207,7 +210,7 @@ class InterpreterStateNode : public Node {
/*! \brief The call stack of the interpreter. */ /*! \brief The call stack of the interpreter. */
Stack stack; Stack stack;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("current_expr", &current_expr); v->Visit("current_expr", &current_expr);
v->Visit("stack", &stack); v->Visit("stack", &stack);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,19 +18,21 @@ ...@@ -18,19 +18,21 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file param_dict.cc * \file param_dict.cc
* \brief Implementation and registration of parameter dictionary * \brief Implementation and registration of parameter dictionary
* serializing/deserializing functions. * serializing/deserializing functions.
*/ */
#include "param_dict.h" #include <tvm/runtime/registry.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include "param_dict.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -45,7 +45,7 @@ struct NamedNDArrayNode : public ::tvm::Node { ...@@ -45,7 +45,7 @@ struct NamedNDArrayNode : public ::tvm::Node {
std::string name; std::string name;
tvm::runtime::NDArray array; tvm::runtime::NDArray array;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("array", &array); v->Visit("array", &array);
} }
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file src/tvm/ir/adt.cc * \file src/tvm/ir/adt.cc
* \brief AST nodes for Relay algebraic data types (ADTs). * \brief AST nodes for Relay algebraic data types (ADTs).
*/ */
......
...@@ -61,7 +61,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -61,7 +61,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(SourceNameNode) TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(GetSourceNameNode) .set_creator(GetSourceNameNode)
.set_global_key([](const Node* n) { .set_global_key([](const Object* n) {
return static_cast<const SourceNameNode*>(n)->name; return static_cast<const SourceNameNode*>(n)->name;
}); });
...@@ -88,7 +88,7 @@ TVM_REGISTER_NODE_TYPE(IdNode); ...@@ -88,7 +88,7 @@ TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_API("relay._base.set_span") TVM_REGISTER_API("relay._base.set_span")
.set_body_typed<void(NodeRef, Span)>([](NodeRef node_ref, Span sp) { .set_body_typed<void(NodeRef, Span)>([](NodeRef node_ref, Span sp) {
auto rn = node_ref.as_derived<RelayNode>(); auto rn = node_ref.as<RelayNode>();
CHECK(rn); CHECK(rn);
rn->span = sp; rn->span = sp;
}); });
......
...@@ -195,7 +195,7 @@ NodePtr<Node> CreateOp(const std::string& name) { ...@@ -195,7 +195,7 @@ NodePtr<Node> CreateOp(const std::string& name) {
TVM_REGISTER_NODE_TYPE(OpNode) TVM_REGISTER_NODE_TYPE(OpNode)
.set_creator(CreateOp) .set_creator(CreateOp)
.set_global_key([](const Node* n) { .set_global_key([](const Object* n) {
return static_cast<const OpNode*>(n)->name; return static_cast<const OpNode*>(n)->name;
}); });
......
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
* - Otherwise, inline if the node is at the end of a scope and is used at most once. * - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/ */
#include <dmlc/json.h> #include <tvm/node/serialization.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h> #include <tvm/relay/module.h>
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
...@@ -214,7 +214,7 @@ class PrettyPrinter : ...@@ -214,7 +214,7 @@ class PrettyPrinter :
} }
Doc PrintFinal(const NodeRef& node) { Doc PrintFinal(const NodeRef& node) {
if (node.as_derived<ExprNode>()) { if (node.as<ExprNode>()) {
Expr expr = Downcast<Expr>(node); Expr expr = Downcast<Expr>(node);
dg_ = DependencyGraph::Create(&arena_, expr); dg_ = DependencyGraph::Create(&arena_, expr);
} }
...@@ -237,13 +237,13 @@ class PrettyPrinter : ...@@ -237,13 +237,13 @@ class PrettyPrinter :
std::vector<Doc> PrintFuncAttrs(const Attrs& attrs); std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) { Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) {
if (node.as_derived<ExprNode>()) { if (node.as<ExprNode>()) {
return PrintExpr(Downcast<Expr>(node), meta, try_inline); return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as_derived<TypeNode>()) { } else if (node.as<TypeNode>()) {
return PrintType(Downcast<Type>(node), meta); return PrintType(Downcast<Type>(node), meta);
} else if (node.as_derived<PatternNode>()) { } else if (node.as<PatternNode>()) {
return PrintPattern(Downcast<Pattern>(node), meta); return PrintPattern(Downcast<Pattern>(node), meta);
} else if (node.as_derived<ModuleNode>()) { } else if (node.as<ModuleNode>()) {
return PrintMod(Downcast<Module>(node)); return PrintMod(Downcast<Module>(node));
} else { } else {
Doc doc; Doc doc;
...@@ -924,14 +924,11 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { ...@@ -924,14 +924,11 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
void Visit(const char* key, DataType* value) final { void Visit(const char* key, DataType* value) final {
PrintKV(key, PrintString(runtime::TVMType2String(Type2TVMType(*value)))); PrintKV(key, PrintString(runtime::TVMType2String(Type2TVMType(*value))));
} }
void Visit(const char* key, NodeRef* value) final {
PrintKV(key, parent_->PrintAttr(*value));
}
void Visit(const char* key, runtime::NDArray* value) final { void Visit(const char* key, runtime::NDArray* value) final {
LOG(FATAL) << "do not allow NDarray as argument"; LOG(FATAL) << "do not allow NDarray as argument";
} }
void Visit(const char* key, runtime::ObjectRef* obj) final { void Visit(const char* key, runtime::ObjectRef* obj) final {
LOG(FATAL) << "do not allow Object as argument"; PrintKV(key, parent_->PrintAttr(*obj));
} }
private: private:
......
...@@ -132,7 +132,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { ...@@ -132,7 +132,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) {
if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) { if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeVar>(tin)); type_params.push_back(GetRef<TypeVar>(tin));
} else { } else {
LOG(FATAL) << new_type_param << std::endl; LOG(FATAL) << new_type_param;
} }
} }
...@@ -141,10 +141,10 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { ...@@ -141,10 +141,10 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) {
auto new_type_cs = VisitType(type_cs); auto new_type_cs = VisitType(type_cs);
changed = changed || !new_type_cs.same_as(type_cs); changed = changed || !new_type_cs.same_as(type_cs);
if (const TypeConstraintNode* tin = if (const TypeConstraintNode* tin =
new_type_cs.as_derived<TypeConstraintNode>()) { new_type_cs.as<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin)); type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else { } else {
LOG(FATAL) << new_type_cs << std::endl; LOG(FATAL) << new_type_cs;
} }
} }
......
...@@ -140,7 +140,7 @@ class LayoutAlternatedExprNode : public TempExprNode { ...@@ -140,7 +140,7 @@ class LayoutAlternatedExprNode : public TempExprNode {
return tmp_memorizer.Transform(value, new_layout, old_layout); return tmp_memorizer.Transform(value, new_layout, old_layout);
} }
void VisitAttrs(AttrVisitor *v) final { void VisitAttrs(AttrVisitor *v) {
v->Visit("value", &value); v->Visit("value", &value);
v->Visit("old_layout", &old_layout); v->Visit("old_layout", &old_layout);
v->Visit("new_layout", &new_layout); v->Visit("new_layout", &new_layout);
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
*
* \file deivce_annotation.cc * \file deivce_annotation.cc
* \brief Passes to rewrite annotated program and retrieve the device allocation * \brief Passes to rewrite annotated program and retrieve the device allocation
* of expression. * of expression.
...@@ -46,13 +44,15 @@ namespace relay { ...@@ -46,13 +44,15 @@ namespace relay {
namespace { namespace {
bool IsOnDeviceNode(const ExprNode* node) { bool IsOnDeviceNode(const ExprNode* node) {
const auto* call_node = dynamic_cast<const CallNode*>(node); if (!node->IsInstance<CallNode>()) return false;
return call_node != nullptr && call_node->attrs.as<OnDeviceAttrs>(); const auto* call_node = static_cast<const CallNode*>(node);
return call_node->attrs.as<OnDeviceAttrs>();
} }
bool IsDeviceCopyNode(const ExprNode* node) { bool IsDeviceCopyNode(const ExprNode* node) {
const auto* call_node = dynamic_cast<const CallNode*>(node); if (!node->IsInstance<CallNode>()) return false;
return call_node != nullptr && call_node->attrs.as<DeviceCopyAttrs>(); const auto* call_node = static_cast<const CallNode*>(node);
return call_node->attrs.as<DeviceCopyAttrs>();
} }
} // namespace } // namespace
...@@ -447,7 +447,8 @@ class DeviceInfo { ...@@ -447,7 +447,8 @@ class DeviceInfo {
static const ExprNode* GetDeviceCopyNode(const ExprNode* node) { static const ExprNode* GetDeviceCopyNode(const ExprNode* node) {
if (IsDeviceCopyNode(node)) { if (IsDeviceCopyNode(node)) {
return node; return node;
} else if (const auto* call_node = dynamic_cast<const CallNode*>(node)) { } else if (node->IsInstance<CallNode>()) {
const auto* call_node = static_cast<const CallNode*>(node);
if (const auto* fn = call_node->op.as<FunctionNode>()) { if (const auto* fn = call_node->op.as<FunctionNode>()) {
const ExprNode* body = fn->body.operator->(); const ExprNode* body = fn->body.operator->();
if (IsDeviceCopyNode(body)) { if (IsDeviceCopyNode(body)) {
...@@ -472,7 +473,8 @@ class DeviceInfo { ...@@ -472,7 +473,8 @@ class DeviceInfo {
for (auto it = post_visitor_.post_dfs_order_.crbegin(); for (auto it = post_visitor_.post_dfs_order_.crbegin();
it != post_visitor_.post_dfs_order_.crend(); ++it) { it != post_visitor_.post_dfs_order_.crend(); ++it) {
if (const auto* node = GetDeviceCopyNode(it->first)) { if (const auto* node = GetDeviceCopyNode(it->first)) {
last_copy_node = dynamic_cast<const CallNode*>(node); CHECK(node->IsInstance<CallNode>());
last_copy_node = static_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>(); const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->src_dev_type; cur_dev_type = attrs->src_dev_type;
if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type; if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type;
......
...@@ -37,14 +37,14 @@ Expr EtaExpand(const Expr& e, const Module& mod) { ...@@ -37,14 +37,14 @@ Expr EtaExpand(const Expr& e, const Module& mod) {
Type ret_type; Type ret_type;
if (e->IsInstance<GlobalVarNode>()) { if (e->IsInstance<GlobalVarNode>()) {
auto gvar_node = e.as_derived<GlobalVarNode>(); auto gvar_node = e.as<GlobalVarNode>();
auto func = mod->Lookup(GetRef<GlobalVar>(gvar_node)); auto func = mod->Lookup(GetRef<GlobalVar>(gvar_node));
original_params = func->params; original_params = func->params;
original_type_params = func->type_params; original_type_params = func->type_params;
ret_type = func->ret_type; ret_type = func->ret_type;
} else { } else {
CHECK(e->IsInstance<FunctionNode>()); CHECK(e->IsInstance<FunctionNode>());
auto func = GetRef<Function>(e.as_derived<FunctionNode>()); auto func = GetRef<Function>(e.as<FunctionNode>());
original_params = func->params; original_params = func->params;
original_type_params = func->type_params; original_type_params = func->type_params;
ret_type = func->ret_type; ret_type = func->ret_type;
......
...@@ -176,7 +176,7 @@ class ScaledExprNode : public TempExprNode { ...@@ -176,7 +176,7 @@ class ScaledExprNode : public TempExprNode {
return value; return value;
} }
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("value", &value); v->Visit("value", &value);
v->Visit("axes", &axes); v->Visit("axes", &axes);
v->Visit("scale", &scale); v->Visit("scale", &scale);
...@@ -664,7 +664,7 @@ class BackwardTransformerNode : ...@@ -664,7 +664,7 @@ class BackwardTransformerNode :
} }
// solver is not serializable. // solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {} void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer"; static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer";
TVM_DECLARE_NODE_TYPE_INFO(BackwardTransformerNode, Node); TVM_DECLARE_NODE_TYPE_INFO(BackwardTransformerNode, Node);
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -47,7 +47,7 @@ class TempRealizer : private ExprMutator { ...@@ -47,7 +47,7 @@ class TempRealizer : private ExprMutator {
return it->second; return it->second;
} else { } else {
Expr res; Expr res;
if (const auto* temp = expr.as_derived<TempExprNode>()) { if (const auto* temp = expr.as<TempExprNode>()) {
res = temp->Realize(); res = temp->Realize();
} else { } else {
......
...@@ -102,7 +102,7 @@ class ModulePassNode : public PassNode { ...@@ -102,7 +102,7 @@ class ModulePassNode : public PassNode {
ModulePassNode() = default; ModulePassNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pass_info", &pass_info); v->Visit("pass_info", &pass_info);
} }
...@@ -156,7 +156,7 @@ class FunctionPassNode : public PassNode { ...@@ -156,7 +156,7 @@ class FunctionPassNode : public PassNode {
FunctionPassNode() = default; FunctionPassNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pass_info", &pass_info); v->Visit("pass_info", &pass_info);
} }
...@@ -211,7 +211,7 @@ class SequentialNode : public PassNode { ...@@ -211,7 +211,7 @@ class SequentialNode : public PassNode {
/*! \brief A list of passes that used to compose a sequential pass. */ /*! \brief A list of passes that used to compose a sequential pass. */
tvm::Array<Pass> passes; tvm::Array<Pass> passes;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pass_info", &pass_info); v->Visit("pass_info", &pass_info);
v->Visit("passes", &passes); v->Visit("passes", &passes);
} }
......
...@@ -41,7 +41,7 @@ class QAnnotateExprNode : public TempExprNode { ...@@ -41,7 +41,7 @@ class QAnnotateExprNode : public TempExprNode {
Expr expr; Expr expr;
QAnnotateKind kind; QAnnotateKind kind;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("expr", &expr); v->Visit("expr", &expr);
v->Visit("kind", &kind); v->Visit("kind", &kind);
} }
......
...@@ -42,7 +42,7 @@ class QPartitionExprNode : public TempExprNode { ...@@ -42,7 +42,7 @@ class QPartitionExprNode : public TempExprNode {
/*! \brief The original expression */ /*! \brief The original expression */
Expr expr; Expr expr;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("expr", &expr); v->Visit("expr", &expr);
} }
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
*
* \file quantize.cc * \file quantize.cc
* *
* \brief transform a graph to a low-bit graph * \brief transform a graph to a low-bit graph
......
...@@ -76,7 +76,7 @@ class QConfigNode : public Node { ...@@ -76,7 +76,7 @@ class QConfigNode : public Node {
bool round_for_shift = true; bool round_for_shift = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr)); Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) {
v->Visit("nbit_input", &nbit_input); v->Visit("nbit_input", &nbit_input);
v->Visit("nbit_weight", &nbit_weight); v->Visit("nbit_weight", &nbit_weight);
v->Visit("nbit_activation", &nbit_activation); v->Visit("nbit_activation", &nbit_activation);
......
...@@ -56,7 +56,7 @@ class QRealizeIntExprNode : public QRealizeExprNode { ...@@ -56,7 +56,7 @@ class QRealizeIntExprNode : public QRealizeExprNode {
Expr dom_scale; Expr dom_scale;
DataType dtype; DataType dtype;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data); v->Visit("data", &data);
v->Visit("dom_scale", &dom_scale); v->Visit("dom_scale", &dom_scale);
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
......
...@@ -153,7 +153,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -153,7 +153,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
// default: unify only if alpha-equal // default: unify only if alpha-equal
Type VisitTypeDefault_(const Node* op, const Type& tn) final { Type VisitTypeDefault_(const Node* op, const Type& tn) final {
NodeRef nr = GetRef<NodeRef>(op); NodeRef nr = GetRef<NodeRef>(op);
Type t1 = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>()); Type t1 = GetRef<Type>(nr.as<tvm::relay::TypeNode>());
if (!AlphaEqual(t1, tn)) { if (!AlphaEqual(t1, tn)) {
return Type(nullptr); return Type(nullptr);
} }
...@@ -411,7 +411,7 @@ class TypeSolver::Propagator : public TypeFunctor<void(const Type&)> { ...@@ -411,7 +411,7 @@ class TypeSolver::Propagator : public TypeFunctor<void(const Type&)> {
void VisitTypeDefault_(const Node* op) override { void VisitTypeDefault_(const Node* op) override {
NodeRef nr = GetRef<NodeRef>(op); NodeRef nr = GetRef<NodeRef>(op);
Type t = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>()); Type t = GetRef<Type>(nr.as<tvm::relay::TypeNode>());
UpdateRelSet(t); UpdateRelSet(t);
} }
...@@ -495,7 +495,7 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> { ...@@ -495,7 +495,7 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
void VisitTypeDefault_(const Node* op) override { void VisitTypeDefault_(const Node* op) override {
NodeRef nr = GetRef<NodeRef>(op); NodeRef nr = GetRef<NodeRef>(op);
Type t = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>()); Type t = GetRef<Type>(nr.as<tvm::relay::TypeNode>());
TransferLinks(t); TransferLinks(t);
} }
......
...@@ -280,7 +280,7 @@ TVM_REGISTER_API("relay._analysis.free_vars") ...@@ -280,7 +280,7 @@ TVM_REGISTER_API("relay._analysis.free_vars")
TVM_REGISTER_API("relay._analysis.bound_vars") TVM_REGISTER_API("relay._analysis.bound_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0]; NodeRef x = args[0];
if (x.as_derived<ExprNode>()) { if (x.as<ExprNode>()) {
*ret = BoundVars(Downcast<Expr>(x)); *ret = BoundVars(Downcast<Expr>(x));
} else { } else {
*ret = BoundVars(Downcast<Pattern>(x)); *ret = BoundVars(Downcast<Pattern>(x));
...@@ -294,7 +294,7 @@ TVM_REGISTER_API("relay._analysis.free_type_vars") ...@@ -294,7 +294,7 @@ TVM_REGISTER_API("relay._analysis.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0]; NodeRef x = args[0];
Module mod = args[1]; Module mod = args[1];
if (x.as_derived<TypeNode>()) { if (x.as<TypeNode>()) {
*ret = FreeTypeVars(Downcast<Type>(x), mod); *ret = FreeTypeVars(Downcast<Type>(x), mod);
} else { } else {
*ret = FreeTypeVars(Downcast<Expr>(x), mod); *ret = FreeTypeVars(Downcast<Expr>(x), mod);
...@@ -305,7 +305,7 @@ TVM_REGISTER_API("relay._analysis.bound_type_vars") ...@@ -305,7 +305,7 @@ TVM_REGISTER_API("relay._analysis.bound_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0]; NodeRef x = args[0];
Module mod = args[1]; Module mod = args[1];
if (x.as_derived<TypeNode>()) { if (x.as<TypeNode>()) {
*ret = BoundTypeVars(Downcast<Type>(x), mod); *ret = BoundTypeVars(Downcast<Type>(x), mod);
} else { } else {
*ret = BoundTypeVars(Downcast<Expr>(x), mod); *ret = BoundTypeVars(Downcast<Expr>(x), mod);
...@@ -316,7 +316,7 @@ TVM_REGISTER_API("relay._analysis.all_type_vars") ...@@ -316,7 +316,7 @@ TVM_REGISTER_API("relay._analysis.all_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0]; NodeRef x = args[0];
Module mod = args[1]; Module mod = args[1];
if (x.as_derived<TypeNode>()) { if (x.as<TypeNode>()) {
*ret = AllTypeVars(Downcast<Type>(x), mod); *ret = AllTypeVars(Downcast<Type>(x), mod);
} else { } else {
*ret = AllTypeVars(Downcast<Expr>(x), mod); *ret = AllTypeVars(Downcast<Expr>(x), mod);
......
...@@ -73,13 +73,12 @@ class TypeContext { ...@@ -73,13 +73,12 @@ class TypeContext {
return child_tindex == parent_tindex; return child_tindex == parent_tindex;
} }
uint32_t GetOrAllocRuntimeTypeIndex(const char* key, uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey,
uint32_t static_tindex, uint32_t static_tindex,
uint32_t parent_tindex, uint32_t parent_tindex,
uint32_t num_child_slots, uint32_t num_child_slots,
bool child_slots_can_overflow) { bool child_slots_can_overflow) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
std::string skey = key;
auto it = type_key2index_.find(skey); auto it = type_key2index_.find(skey);
if (it != type_key2index_.end()) { if (it != type_key2index_.end()) {
return it->second; return it->second;
...@@ -106,7 +105,7 @@ class TypeContext { ...@@ -106,7 +105,7 @@ class TypeContext {
<< "Conflicting static index " << static_tindex << "Conflicting static index " << static_tindex
<< " between " << type_table_[allocated_tindex].name << " between " << type_table_[allocated_tindex].name
<< " and " << " and "
<< key; << skey;
} else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) { } else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) {
// allocate the slot from parent's reserved pool // allocate the slot from parent's reserved pool
allocated_tindex = parent_tindex + pinfo.allocated_slots; allocated_tindex = parent_tindex + pinfo.allocated_slots;
...@@ -152,11 +151,10 @@ class TypeContext { ...@@ -152,11 +151,10 @@ class TypeContext {
return type_table_[tindex].name_hash; return type_table_[tindex].name_hash;
} }
uint32_t TypeKey2Index(const char* key) { uint32_t TypeKey2Index(const std::string& skey) {
std::string skey = key;
auto it = type_key2index_.find(skey); auto it = type_key2index_.find(skey);
CHECK(it != type_key2index_.end()) CHECK(it != type_key2index_.end())
<< "Cannot find type " << key; << "Cannot find type " << skey;
return it->second; return it->second;
} }
...@@ -176,7 +174,7 @@ class TypeContext { ...@@ -176,7 +174,7 @@ class TypeContext {
std::unordered_map<std::string, uint32_t> type_key2index_; std::unordered_map<std::string, uint32_t> type_key2index_;
}; };
uint32_t Object::GetOrAllocRuntimeTypeIndex(const char* key, uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key,
uint32_t static_tindex, uint32_t static_tindex,
uint32_t parent_tindex, uint32_t parent_tindex,
uint32_t num_child_slots, uint32_t num_child_slots,
...@@ -198,7 +196,7 @@ size_t Object::TypeIndex2KeyHash(uint32_t tindex) { ...@@ -198,7 +196,7 @@ size_t Object::TypeIndex2KeyHash(uint32_t tindex) {
return TypeContext::Global()->TypeIndex2KeyHash(tindex); return TypeContext::Global()->TypeIndex2KeyHash(tindex);
} }
uint32_t Object::TypeKey2Index(const char* key) { uint32_t Object::TypeKey2Index(const std::string& key) {
return TypeContext::Global()->TypeKey2Index(key); return TypeContext::Global()->TypeKey2Index(key);
} }
...@@ -210,7 +208,7 @@ class TVMObjectCAPI { ...@@ -210,7 +208,7 @@ class TVMObjectCAPI {
} }
} }
static uint32_t TypeKey2Index(const char* type_key) { static uint32_t TypeKey2Index(const std::string& type_key) {
return Object::TypeKey2Index(type_key); return Object::TypeKey2Index(type_key);
} }
}; };
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <topi/cuda/injective.h> #include <topi/cuda/injective.h>
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/ir.h> #include <tvm/ir.h>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment