Unverified Commit a8c36921 by Tianqi Chen Committed by GitHub

[REFACTOR][OBJECT] Consoldiate NodePtr/Ref/Hash/Equal to Object (#4603)

* [REFACTOR][OBJECT] Consoldiate NodePtr/Ref/Hash/Equal and macros to Object.

Historically, we have classes like NodePtr/Ref/HashEqual.
After unified object protocol, these names are just alias of the object counterpart.
Moreover, there are helper macros defined over the places for defining these object.

This PR consoldiate the terminologies into the corresponding ones
in the Object system so we have a clean and consistent API moving forward.

* Update include/tvm/attrs.h

Co-Authored-By: Wei Chen <ipondering.weic@gmail.com>

* fix compilation

Co-authored-by: Wei Chen <ipondering.weic@gmail.com>
parent 475158f6
...@@ -49,7 +49,7 @@ namespace tvm { ...@@ -49,7 +49,7 @@ namespace tvm {
* \brief Node container of EnvFunc * \brief Node container of EnvFunc
* \sa EnvFunc * \sa EnvFunc
*/ */
class EnvFuncNode : public Node { class EnvFuncNode : public Object {
public: public:
/*! \brief Unique name of the global function */ /*! \brief Unique name of the global function */
std::string name; std::string name;
...@@ -63,7 +63,7 @@ class EnvFuncNode : public Node { ...@@ -63,7 +63,7 @@ class EnvFuncNode : public Node {
} }
static constexpr const char* _type_key = "EnvFunc"; static constexpr const char* _type_key = "EnvFunc";
TVM_DECLARE_NODE_TYPE_INFO(EnvFuncNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
}; };
/*! /*!
...@@ -73,10 +73,10 @@ class EnvFuncNode : public Node { ...@@ -73,10 +73,10 @@ class EnvFuncNode : public Node {
* An EnvFunc is saved by its name in the global registry * An EnvFunc is saved by its name in the global registry
* under the assumption that the same function is registered during load. * under the assumption that the same function is registered during load.
*/ */
class EnvFunc : public NodeRef { class EnvFunc : public ObjectRef {
public: public:
EnvFunc() {} EnvFunc() {}
explicit EnvFunc(NodePtr<Node> n) : NodeRef(n) {} explicit EnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return The internal global function pointer */ /*! \return The internal global function pointer */
const EnvFuncNode* operator->() const { const EnvFuncNode* operator->() const {
return static_cast<const EnvFuncNode*>(get()); return static_cast<const EnvFuncNode*>(get());
...@@ -119,12 +119,12 @@ class TypedEnvFunc; ...@@ -119,12 +119,12 @@ class TypedEnvFunc;
* \sa EnvFunc * \sa EnvFunc
*/ */
template<typename R, typename... Args> template<typename R, typename... Args>
class TypedEnvFunc<R(Args...)> : public NodeRef { class TypedEnvFunc<R(Args...)> : public ObjectRef {
public: public:
/*! \brief short hand for this function type */ /*! \brief short hand for this function type */
using TSelf = TypedEnvFunc<R(Args...)>; using TSelf = TypedEnvFunc<R(Args...)>;
TypedEnvFunc() {} TypedEnvFunc() {}
explicit TypedEnvFunc(ObjectPtr<Object> n) : NodeRef(n) {} explicit TypedEnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief Assign global function to a TypedEnvFunc * \brief Assign global function to a TypedEnvFunc
* \param other Another global function. * \param other Another global function.
......
...@@ -55,7 +55,7 @@ class Analyzer; ...@@ -55,7 +55,7 @@ class Analyzer;
* *
* set = [min_value, max_value] * set = [min_value, max_value]
*/ */
class ConstIntBoundNode : public Node { class ConstIntBoundNode : public Object {
public: public:
int64_t min_value; int64_t min_value;
int64_t max_value; int64_t max_value;
...@@ -74,14 +74,14 @@ class ConstIntBoundNode : public Node { ...@@ -74,14 +74,14 @@ class ConstIntBoundNode : public Node {
static const constexpr int64_t kNegInf = -kPosInf; static const constexpr int64_t kNegInf = -kPosInf;
static constexpr const char* _type_key = "arith.ConstIntBound"; static constexpr const char* _type_key = "arith.ConstIntBound";
TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object);
}; };
/*! /*!
* \brief reference class to ConstIntBoundNode * \brief reference class to ConstIntBoundNode
* \sa ConstIntBoundNode * \sa ConstIntBoundNode
*/ */
class ConstIntBound : public NodeRef { class ConstIntBound : public ObjectRef {
public: public:
/*! /*!
* \brief constructor by fields. * \brief constructor by fields.
...@@ -92,7 +92,7 @@ class ConstIntBound : public NodeRef { ...@@ -92,7 +92,7 @@ class ConstIntBound : public NodeRef {
static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf; static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
TVM_DEFINE_NODE_REF_METHODS(ConstIntBound, NodeRef, ConstIntBoundNode); TVM_DEFINE_OBJECT_REF_METHODS(ConstIntBound, ObjectRef, ConstIntBoundNode);
}; };
/*! /*!
...@@ -155,7 +155,7 @@ class ConstIntBoundAnalyzer { ...@@ -155,7 +155,7 @@ class ConstIntBoundAnalyzer {
* This is useful to decide if the index is dividable by certain value. * This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4. * For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/ */
class ModularSetNode : public Node { class ModularSetNode : public Object {
public: public:
/*! \brief linear co-efficient */ /*! \brief linear co-efficient */
int64_t coeff; int64_t coeff;
...@@ -168,18 +168,18 @@ class ModularSetNode : public Node { ...@@ -168,18 +168,18 @@ class ModularSetNode : public Node {
} }
static constexpr const char* _type_key = "arith.ModularSet"; static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object);
}; };
/*! /*!
* \brief reference of ModularSetNode * \brief reference of ModularSetNode
* \sa ModularSetNode * \sa ModularSetNode
*/ */
class ModularSet : public NodeRef { class ModularSet : public ObjectRef {
public: public:
TVM_DLL ModularSet(int64_t coeff, int64_t base); TVM_DLL ModularSet(int64_t coeff, int64_t base);
TVM_DEFINE_NODE_REF_METHODS(ModularSet, NodeRef, ModularSetNode); TVM_DEFINE_OBJECT_REF_METHODS(ModularSet, ObjectRef, ModularSetNode);
}; };
/*! /*!
...@@ -349,20 +349,20 @@ enum SignType { ...@@ -349,20 +349,20 @@ enum SignType {
/*! /*!
* \brief Base class of all IntSet containers. * \brief Base class of all IntSet containers.
*/ */
struct IntSetNode : public Node { struct IntSetNode : public Object {
static constexpr const char* _type_key = "IntSet"; static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object);
}; };
/*! /*!
* \brief Integer set class, represent a set of integers in one dimension. * \brief Integer set class, represent a set of integers in one dimension.
*/ */
class IntSet : public NodeRef { class IntSet : public ObjectRef {
public: public:
/*! \brief constructor */ /*! \brief constructor */
IntSet() {} IntSet() {}
// constructor from not container. // constructor from not container.
explicit IntSet(ObjectPtr<Object> n) : NodeRef(n) {} explicit IntSet(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -598,7 +598,7 @@ IntSet EvalSet(Range r, ...@@ -598,7 +598,7 @@ IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map); const std::unordered_map<const Variable*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */ /*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, NodeHash, NodeEqual>; using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
/*! /*!
* \brief Find the integer set of every sub-expression, given the * \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables. * domain of each iteration variables.
......
...@@ -65,7 +65,7 @@ namespace tvm { ...@@ -65,7 +65,7 @@ namespace tvm {
*/ */
#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
static constexpr const char* _type_key = TypeKey; \ static constexpr const char* _type_key = TypeKey; \
TVM_DECLARE_NODE_TYPE_INFO(ClassName, ::tvm::BaseAttrsNode) \ TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
template<typename FVisit> \ template<typename FVisit> \
void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*) void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*)
...@@ -83,9 +83,9 @@ namespace tvm { ...@@ -83,9 +83,9 @@ namespace tvm {
* \tparam TNodeRef the type to be created. * \tparam TNodeRef the type to be created.
* \return A instance that will represent None. * \return A instance that will represent None.
*/ */
template<typename TNodeRef> template<typename TObjectRef>
inline TNodeRef NullValue() { inline TObjectRef NullValue() {
return TNodeRef(NodePtr<Node>(nullptr)); return TObjectRef(ObjectPtr<Object>(nullptr));
} }
template<> template<>
...@@ -106,7 +106,7 @@ struct AttrError : public dmlc::Error { ...@@ -106,7 +106,7 @@ struct AttrError : public dmlc::Error {
/*! /*!
* \brief Information about attribute fields in string representations. * \brief Information about attribute fields in string representations.
*/ */
class AttrFieldInfoNode : public Node { class AttrFieldInfoNode : public Object {
public: public:
/*! \brief name of the field */ /*! \brief name of the field */
std::string name; std::string name;
...@@ -121,11 +121,14 @@ class AttrFieldInfoNode : public Node { ...@@ -121,11 +121,14 @@ class AttrFieldInfoNode : public Node {
v->Visit("description", &description); v->Visit("description", &description);
} }
static constexpr const char* _type_key = "AttrFieldInfo"; static constexpr const char* _type_key = "AttrFieldInfo";
TVM_DECLARE_NODE_TYPE_INFO(AttrFieldInfoNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
}; };
/*! \brief AttrFieldInfo */ /*! \brief AttrFieldInfo */
TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode); class AttrFieldInfo : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode);
};
class AttrsHashHandler; class AttrsHashHandler;
class AttrsEqualHandler; class AttrsEqualHandler;
...@@ -217,7 +220,7 @@ class AttrsHash { ...@@ -217,7 +220,7 @@ class AttrsHash {
* subclass AttrsNode instead. * subclass AttrsNode instead.
* \sa AttrsNode * \sa AttrsNode
*/ */
class BaseAttrsNode : public Node { class BaseAttrsNode : public Object {
public: public:
using TVMArgs = runtime::TVMArgs; using TVMArgs = runtime::TVMArgs;
using TVMRetValue = runtime::TVMRetValue; using TVMRetValue = runtime::TVMRetValue;
...@@ -271,16 +274,16 @@ class BaseAttrsNode : public Node { ...@@ -271,16 +274,16 @@ class BaseAttrsNode : public Node {
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0; TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
static constexpr const char* _type_key = "Attrs"; static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node); TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
}; };
/*! \brief Base attribute container for all attributes */ /*! \brief Base attribute container for all attributes */
class Attrs : public NodeRef { class Attrs : public ObjectRef {
public: public:
// normal constructor // normal constructor
Attrs() {} Attrs() {}
// construct from shared ptr. // construct from shared ptr.
explicit Attrs(NodePtr<Node> n) : NodeRef(n) {} explicit Attrs(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return The attribute node */ /*! \return The attribute node */
const BaseAttrsNode* operator->() const { const BaseAttrsNode* operator->() const {
...@@ -305,13 +308,13 @@ class Attrs : public NodeRef { ...@@ -305,13 +308,13 @@ class Attrs : public NodeRef {
class DictAttrsNode : public BaseAttrsNode { class DictAttrsNode : public BaseAttrsNode {
public: public:
/*! \brief internal attrs map */ /*! \brief internal attrs map */
Map<std::string, NodeRef> dict; Map<std::string, ObjectRef> dict;
/*! /*!
* \brief Consruct a Attrs backed by DictAttrsNode. * \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes. * \param dict The attributes.
* \return The dict attributes. * \return The dict attributes.
*/ */
TVM_DLL static Attrs make(Map<std::string, NodeRef> dict); TVM_DLL static Attrs make(Map<std::string, ObjectRef> dict);
// implementations // implementations
void VisitAttrs(AttrVisitor* v) final; void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final; void VisitNonDefaultAttrs(AttrVisitor* v) final;
...@@ -321,7 +324,7 @@ class DictAttrsNode : public BaseAttrsNode { ...@@ -321,7 +324,7 @@ class DictAttrsNode : public BaseAttrsNode {
size_t ContentHash(AttrsHash hasher) const final; size_t ContentHash(AttrsHash hasher) const final;
// type info // type info
static constexpr const char* _type_key = "DictAttrs"; static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode); TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
}; };
...@@ -639,7 +642,7 @@ class AttrDocEntry { ...@@ -639,7 +642,7 @@ class AttrDocEntry {
public: public:
using TSelf = AttrDocEntry; using TSelf = AttrDocEntry;
explicit AttrDocEntry(NodePtr<AttrFieldInfoNode> info) explicit AttrDocEntry(ObjectPtr<AttrFieldInfoNode> info)
: info_(info) { : info_(info) {
} }
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
...@@ -663,15 +666,15 @@ class AttrDocEntry { ...@@ -663,15 +666,15 @@ class AttrDocEntry {
} }
private: private:
NodePtr<AttrFieldInfoNode> info_; ObjectPtr<AttrFieldInfoNode> info_;
}; };
class AttrDocVisitor { class AttrDocVisitor {
public: public:
template<typename T> template<typename T>
AttrDocEntry operator()(const char* key, T* v) { AttrDocEntry operator()(const char* key, T* v) {
NodePtr<AttrFieldInfoNode> info ObjectPtr<AttrFieldInfoNode> info
= make_node<AttrFieldInfoNode>(); = make_object<AttrFieldInfoNode>();
info->name = key; info->name = key;
info->type_info = TypeName<T>::value; info->type_info = TypeName<T>::value;
fields_.push_back(AttrFieldInfo(info)); fields_.push_back(AttrFieldInfo(info));
......
...@@ -48,10 +48,10 @@ enum BufferType : int { ...@@ -48,10 +48,10 @@ enum BufferType : int {
* It is a composition of primitive symbolic types, * It is a composition of primitive symbolic types,
* used to specify the memory layout of the Tensor used in program input. * used to specify the memory layout of the Tensor used in program input.
*/ */
class Buffer : public NodeRef { class Buffer : public ObjectRef {
public: public:
Buffer() {} Buffer() {}
explicit Buffer(ObjectPtr<Object> n) : NodeRef(n) {} explicit Buffer(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief Return a new buffer that is equivalent with current one * \brief Return a new buffer that is equivalent with current one
* but always add stride field. * but always add stride field.
...@@ -101,7 +101,7 @@ class Buffer : public NodeRef { ...@@ -101,7 +101,7 @@ class Buffer : public NodeRef {
}; };
/*! \brief Node to represent a buffer */ /*! \brief Node to represent a buffer */
class BufferNode : public Node { class BufferNode : public Object {
public: public:
// Data fields. // Data fields.
/*! /*!
...@@ -169,7 +169,7 @@ class BufferNode : public Node { ...@@ -169,7 +169,7 @@ class BufferNode : public Node {
BufferType buffer_type); BufferType buffer_type);
static constexpr const char* _type_key = "Buffer"; static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
}; };
inline const BufferNode* Buffer::operator->() const { inline const BufferNode* Buffer::operator->() const {
......
...@@ -39,7 +39,7 @@ namespace tvm { ...@@ -39,7 +39,7 @@ namespace tvm {
* \brief Container for target device information. * \brief Container for target device information.
* Use target::llvm, target::cuda etc functions instead of constructing directly. * Use target::llvm, target::cuda etc functions instead of constructing directly.
*/ */
class TargetNode : public Node { class TargetNode : public Object {
public: public:
/*! \brief The name of the target device */ /*! \brief The name of the target device */
std::string target_name; std::string target_name;
...@@ -82,7 +82,7 @@ class TargetNode : public Node { ...@@ -82,7 +82,7 @@ class TargetNode : public Node {
TVM_DLL std::unordered_set<std::string> libs() const; TVM_DLL std::unordered_set<std::string> libs() const;
static constexpr const char* _type_key = "Target"; static constexpr const char* _type_key = "Target";
TVM_DECLARE_NODE_TYPE_INFO(TargetNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object);
private: private:
/*! \brief Internal string repr. */ /*! \brief Internal string repr. */
...@@ -90,10 +90,10 @@ class TargetNode : public Node { ...@@ -90,10 +90,10 @@ class TargetNode : public Node {
}; };
/*! \brief reference cpass to the target. */ /*! \brief reference cpass to the target. */
class Target : public NodeRef { class Target : public ObjectRef {
public: public:
Target() {} Target() {}
explicit Target(ObjectPtr<Object> n) : NodeRef(n) {} explicit Target(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief Create a Target given a string * \brief Create a Target given a string
* \param target_str the string to parse * \param target_str the string to parse
...@@ -178,7 +178,7 @@ TVM_DLL Target ext_dev(const std::vector<std::string>& options = ...@@ -178,7 +178,7 @@ TVM_DLL Target ext_dev(const std::vector<std::string>& options =
/*! /*!
* \brief Container for build configuration options * \brief Container for build configuration options
*/ */
class BuildConfigNode : public Node { class BuildConfigNode : public Object {
public: public:
/*! /*!
* \brief The data alignment to use when constructing buffers. If this is set to * \brief The data alignment to use when constructing buffers. If this is set to
...@@ -254,16 +254,16 @@ class BuildConfigNode : public Node { ...@@ -254,16 +254,16 @@ class BuildConfigNode : public Node {
} }
static constexpr const char* _type_key = "BuildConfig"; static constexpr const char* _type_key = "BuildConfig";
TVM_DECLARE_NODE_TYPE_INFO(BuildConfigNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(BuildConfigNode, Object);
}; };
/*! /*!
* \brief Build configuration for compilations. * \brief Build configuration for compilations.
*/ */
class BuildConfig : public ::tvm::NodeRef { class BuildConfig : public ::tvm::ObjectRef {
public: public:
BuildConfig() {} BuildConfig() {}
explicit BuildConfig(ObjectPtr<Object> n) : NodeRef(n) {} explicit BuildConfig(ObjectPtr<Object> n) : ObjectRef(n) {}
const BuildConfigNode* operator->() const { const BuildConfigNode* operator->() const {
return static_cast<const BuildConfigNode*>(get()); return static_cast<const BuildConfigNode*>(get());
} }
...@@ -375,10 +375,10 @@ class GenericFuncNode; ...@@ -375,10 +375,10 @@ class GenericFuncNode;
/*! /*!
* \brief Generic function that can be specialized on a per-target basis. * \brief Generic function that can be specialized on a per-target basis.
*/ */
class GenericFunc : public NodeRef { class GenericFunc : public ObjectRef {
public: public:
GenericFunc() {} GenericFunc() {}
explicit GenericFunc(ObjectPtr<Object> n) : NodeRef(n) {} explicit GenericFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief Set the default function implementaiton. * \brief Set the default function implementaiton.
...@@ -471,7 +471,7 @@ inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const { ...@@ -471,7 +471,7 @@ inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const {
/*! /*!
* \brief Represents a generic function that can be specialized on a per-target basis. * \brief Represents a generic function that can be specialized on a per-target basis.
*/ */
class GenericFuncNode : public Node { class GenericFuncNode : public Object {
public: public:
/*! \brief name of the function */ /*! \brief name of the function */
std::string name_; std::string name_;
...@@ -483,7 +483,7 @@ class GenericFuncNode : public Node { ...@@ -483,7 +483,7 @@ class GenericFuncNode : public Node {
void VisitAttrs(AttrVisitor* v) {} void VisitAttrs(AttrVisitor* v) {}
static constexpr const char* _type_key = "GenericFunc"; static constexpr const char* _type_key = "GenericFunc";
TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(GenericFuncNode, Object);
}; };
inline GenericFuncNode* GenericFunc::operator->() { inline GenericFuncNode* GenericFunc::operator->() {
......
...@@ -92,7 +92,7 @@ class LayoutAxis { ...@@ -92,7 +92,7 @@ class LayoutAxis {
class Layout; class Layout;
// Internal node container Buffer // Internal node container Buffer
class LayoutNode : public Node { class LayoutNode : public Object {
public: public:
/*! \brief string representation of layout, "" for scalar. */ /*! \brief string representation of layout, "" for scalar. */
std::string name; std::string name;
...@@ -112,7 +112,7 @@ class LayoutNode : public Node { ...@@ -112,7 +112,7 @@ class LayoutNode : public Node {
TVM_DLL static Layout make(const std::string& layout); TVM_DLL static Layout make(const std::string& layout);
static constexpr const char* _type_key = "Layout"; static constexpr const char* _type_key = "Layout";
TVM_DECLARE_NODE_TYPE_INFO(LayoutNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object);
}; };
/*! /*!
...@@ -125,9 +125,9 @@ class LayoutNode : public Node { ...@@ -125,9 +125,9 @@ class LayoutNode : public Node {
* Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). * Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
* Layout for scalar is defined, while both its name and axes have size 0. * Layout for scalar is defined, while both its name and axes have size 0.
*/ */
class Layout : public NodeRef { class Layout : public ObjectRef {
public: public:
explicit Layout(ObjectPtr<Object> n) : NodeRef(n) {} explicit Layout(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \brief default constructor */ /*! \brief default constructor */
Layout() = default; Layout() = default;
...@@ -311,7 +311,7 @@ class Layout : public NodeRef { ...@@ -311,7 +311,7 @@ class Layout : public NodeRef {
class BijectiveLayout; class BijectiveLayout;
// Internal node container BijectiveLayout // Internal node container BijectiveLayout
class BijectiveLayoutNode : public Node { class BijectiveLayoutNode : public Object {
public: public:
/*! \brief Describes how source axes can be mapped to the destination axes, /*! \brief Describes how source axes can be mapped to the destination axes,
* e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n * e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n
...@@ -333,7 +333,7 @@ class BijectiveLayoutNode : public Node { ...@@ -333,7 +333,7 @@ class BijectiveLayoutNode : public Node {
} }
static constexpr const char* _type_key = "BijectiveLayout"; static constexpr const char* _type_key = "BijectiveLayout";
TVM_DECLARE_NODE_TYPE_INFO(BijectiveLayoutNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object);
TVM_DLL static BijectiveLayout make(const Layout& src_layout, TVM_DLL static BijectiveLayout make(const Layout& src_layout,
const Layout& dst_layout); const Layout& dst_layout);
...@@ -344,10 +344,10 @@ class BijectiveLayoutNode : public Node { ...@@ -344,10 +344,10 @@ class BijectiveLayoutNode : public Node {
* provides API to transform N-dimention tensor from the source indices (i0, i1, …, im) * provides API to transform N-dimention tensor from the source indices (i0, i1, …, im)
* to the destination indices (j0, j1, … jm). * to the destination indices (j0, j1, … jm).
*/ */
class BijectiveLayout : public NodeRef { class BijectiveLayout : public ObjectRef {
public: public:
BijectiveLayout() = default; BijectiveLayout() = default;
explicit BijectiveLayout(NodePtr<Node> n) : NodeRef(n) {} explicit BijectiveLayout(ObjectPtr<Object> n) : ObjectRef(n) {}
// Given the source shape, infer the destination shape. // Given the source shape, infer the destination shape.
TVM_DLL Array<Expr> ForwardShape(const Array<Expr>& shape) const; TVM_DLL Array<Expr> ForwardShape(const Array<Expr>& shape) const;
......
...@@ -38,20 +38,20 @@ ...@@ -38,20 +38,20 @@
namespace tvm { namespace tvm {
/*! \brief Base node of all expressions. */ /*! \brief Base node of all expressions. */
class ExprNode : public Node { class ExprNode : public Object {
public: public:
/*! \brief The data type of the expression. */ /*! \brief The data type of the expression. */
DataType dtype; DataType dtype;
static constexpr const char* _type_key = "Expr"; static constexpr const char* _type_key = "Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node); TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, Object);
}; };
/*! \brief Container of all expressions. */ /*! \brief Container of all expressions. */
class Expr : public NodeRef { class Expr : public ObjectRef {
public: public:
Expr() {} Expr() {}
explicit Expr(ObjectPtr<Object> ptr) : NodeRef(ptr) {} explicit Expr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! /*!
* \brief construct from integer. * \brief construct from integer.
* \param value The value to be constructed. * \param value The value to be constructed.
...@@ -78,16 +78,16 @@ class Expr : public NodeRef { ...@@ -78,16 +78,16 @@ class Expr : public NodeRef {
}; };
/*! \brief Base node of all statements. */ /*! \brief Base node of all statements. */
class StmtNode : public Node { class StmtNode : public Object {
public: public:
static constexpr const char* _type_key = "Stmt"; static constexpr const char* _type_key = "Stmt";
TVM_DECLARE_BASE_NODE_INFO(StmtNode, Node); TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
}; };
/*! \brief Container of all statements */ /*! \brief Container of all statements */
class Stmt : public NodeRef { class Stmt : public ObjectRef {
public: public:
TVM_DEFINE_NODE_REF_METHODS(Stmt, NodeRef, StmtNode); TVM_DEFINE_OBJECT_REF_METHODS(Stmt, ObjectRef, StmtNode);
}; };
class Var; class Var;
...@@ -118,7 +118,7 @@ class Variable : public ExprNode { ...@@ -118,7 +118,7 @@ class Variable : public ExprNode {
} }
static constexpr const char* _type_key = "Variable"; static constexpr const char* _type_key = "Variable";
TVM_DECLARE_NODE_TYPE_INFO(Variable, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(Variable, ExprNode);
}; };
/*! \brief a named variable in TVM */ /*! \brief a named variable in TVM */
...@@ -156,8 +156,8 @@ class Var : public Expr { ...@@ -156,8 +156,8 @@ class Var : public Expr {
// Backward compatibility, will be removed later. // Backward compatibility, will be removed later.
using VarExpr = Var; using VarExpr = Var;
using BaseExprNode = ExprNode; using BaseExprNode = ExprNode;
using ExprHash = NodeHash; using ExprHash = ObjectHash;
using ExprEqual = NodeEqual; using ExprEqual = ObjectEqual;
class Integer; class Integer;
/*! \brief ExprNode: constant integer. */ /*! \brief ExprNode: constant integer. */
...@@ -174,7 +174,7 @@ class IntImm : public ExprNode { ...@@ -174,7 +174,7 @@ class IntImm : public ExprNode {
TVM_DLL static Integer make(DataType t, int64_t value); TVM_DLL static Integer make(DataType t, int64_t value);
static constexpr const char* _type_key = "IntImm"; static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_NODE_TYPE_INFO(IntImm, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(IntImm, ExprNode);
}; };
/*! /*!
...@@ -222,7 +222,7 @@ class Integer : public Expr { ...@@ -222,7 +222,7 @@ class Integer : public Expr {
}; };
/*! \brief range over one dimension */ /*! \brief range over one dimension */
class RangeNode : public Node { class RangeNode : public Object {
public: public:
/*! \brief beginning of the node */ /*! \brief beginning of the node */
Expr min; Expr min;
...@@ -238,11 +238,11 @@ class RangeNode : public Node { ...@@ -238,11 +238,11 @@ class RangeNode : public Node {
} }
static constexpr const char* _type_key = "Range"; static constexpr const char* _type_key = "Range";
TVM_DECLARE_NODE_TYPE_INFO(RangeNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
}; };
/*! \brief Range constainer */ /*! \brief Range constainer */
class Range : public NodeRef { class Range : public ObjectRef {
public: public:
/*! /*!
* \brief constructor by begin and end * \brief constructor by begin and end
...@@ -261,7 +261,7 @@ class Range : public NodeRef { ...@@ -261,7 +261,7 @@ class Range : public NodeRef {
*/ */
static Range make_by_min_extent(Expr min, Expr extent); static Range make_by_min_extent(Expr min, Expr extent);
// declare range. // declare range.
TVM_DEFINE_NODE_REF_METHODS(Range, NodeRef, RangeNode); TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
}; };
/*! \brief container class of iteration variable. */ /*! \brief container class of iteration variable. */
...@@ -343,12 +343,12 @@ enum IterVarType : int { ...@@ -343,12 +343,12 @@ enum IterVarType : int {
* \brief Iteration Variable, * \brief Iteration Variable,
* represents an iteration over an integer interval. * represents an iteration over an integer interval.
*/ */
class IterVar : public NodeRef { class IterVar : public ObjectRef {
public: public:
// construct a new iter var without a domain // construct a new iter var without a domain
IterVar() {} IterVar() {}
// construct from shared ptr. // construct from shared ptr.
explicit IterVar(ObjectPtr<Object> n) : NodeRef(n) {} explicit IterVar(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -384,14 +384,14 @@ using Domain = Array<Range>; ...@@ -384,14 +384,14 @@ using Domain = Array<Range>;
* \brief Dump the node to stderr, used for debug purposes. * \brief Dump the node to stderr, used for debug purposes.
* \param node The input node * \param node The input node
*/ */
TVM_DLL void Dump(const NodeRef& node); TVM_DLL void Dump(const ObjectRef& node);
// definition of Node. // definition of Node.
/*! /*!
* \brief An iteration variable representing an iteration * \brief An iteration variable representing an iteration
* over a one dimensional interval. * over a one dimensional interval.
*/ */
class IterVarNode : public Node { class IterVarNode : public Object {
public: public:
/*! /*!
* \brief the domain of iteration, if known, can be None * \brief the domain of iteration, if known, can be None
...@@ -420,7 +420,7 @@ class IterVarNode : public Node { ...@@ -420,7 +420,7 @@ class IterVarNode : public Node {
std::string thread_tag = ""); std::string thread_tag = "");
static constexpr const char* _type_key = "IterVar"; static constexpr const char* _type_key = "IterVar";
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object);
}; };
// inline implementations // inline implementations
...@@ -490,17 +490,22 @@ class IRPrinter { ...@@ -490,17 +490,22 @@ class IRPrinter {
using FType = NodeFunctor<void(const ObjectRef&, IRPrinter *)>; using FType = NodeFunctor<void(const ObjectRef&, IRPrinter *)>;
TVM_DLL static FType& vtable(); TVM_DLL static FType& vtable();
}; };
} // namespace tvm
// default print function for all nodes namespace tvm {
namespace runtime {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*) inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
IRPrinter(os).Print(n); IRPrinter(os).Print(n);
return os; return os;
} }
} // namespace runtime
} // namespace tvm } // namespace tvm
namespace std { namespace std {
template <> template <>
struct hash<::tvm::IterVar> : public ::tvm::NodeHash { struct hash<::tvm::IterVar> : public ::tvm::ObjectHash {
}; };
} }
#endif // TVM_EXPR_H_ #endif // TVM_EXPR_H_
...@@ -164,7 +164,7 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -164,7 +164,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args ...) { virtual R VisitExprDefault_(const Object* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
return R(); return R();
} }
...@@ -255,7 +255,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> { ...@@ -255,7 +255,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Node* op, Args ...) { virtual R VisitStmtDefault_(const Object* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
return R(); return R();
} }
......
...@@ -418,7 +418,7 @@ Stmt HoistIfThenElse(Stmt stmt); ...@@ -418,7 +418,7 @@ Stmt HoistIfThenElse(Stmt stmt);
*/ */
LoweredFunc MakeAPI(Stmt body, LoweredFunc MakeAPI(Stmt body,
std::string name, std::string name,
Array<NodeRef> api_args, Array<ObjectRef> api_args,
int num_unpacked_args, int num_unpacked_args,
bool is_restricted); bool is_restricted);
......
...@@ -87,7 +87,7 @@ class TVM_DLL IRVisitor { ...@@ -87,7 +87,7 @@ class TVM_DLL IRVisitor {
/*! /*!
* \brief recursively visit an IR node * \brief recursively visit an IR node
*/ */
virtual void Visit(const NodeRef& node) { virtual void Visit(const ObjectRef& node) {
static const FVisit& f = vtable(); static const FVisit& f = vtable();
if (node.defined()) f(node, this); if (node.defined()) f(node, this);
} }
...@@ -152,7 +152,7 @@ class TVM_DLL IRVisitor { ...@@ -152,7 +152,7 @@ class TVM_DLL IRVisitor {
* \param node The ir to be visited. * \param node The ir to be visited.
* \param fvisit The visitor function to be applied. * \param fvisit The visitor function to be applied.
*/ */
TVM_DLL void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit); TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
...@@ -131,7 +131,7 @@ class LoweredFuncNode : public ir::FunctionBaseNode { ...@@ -131,7 +131,7 @@ class LoweredFuncNode : public ir::FunctionBaseNode {
} }
static constexpr const char* _type_key = "LoweredFunc"; static constexpr const char* _type_key = "LoweredFunc";
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(LoweredFuncNode, Object);
}; };
// Implementations of inline functions // Implementations of inline functions
...@@ -143,7 +143,7 @@ inline const LoweredFuncNode* LoweredFunc::operator->() const { ...@@ -143,7 +143,7 @@ inline const LoweredFuncNode* LoweredFunc::operator->() const {
namespace std { namespace std {
template <> template <>
struct hash<::tvm::LoweredFunc> : public tvm::NodeHash { struct hash<::tvm::LoweredFunc> : public tvm::ObjectHash {
}; };
} }
......
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
namespace tvm { namespace tvm {
/*! \brief array node content in array */ /*! \brief array node content in array */
class ArrayNode : public Node { class ArrayNode : public Object {
public: public:
/*! \brief the data content */ /*! \brief the data content */
std::vector<ObjectRef> data; std::vector<ObjectRef> data;
...@@ -44,11 +44,11 @@ class ArrayNode : public Node { ...@@ -44,11 +44,11 @@ class ArrayNode : public Node {
} }
static constexpr const char* _type_key = "Array"; static constexpr const char* _type_key = "Array";
TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object);
}; };
/*! \brief map node content */ /*! \brief map node content */
class MapNode : public Node { class MapNode : public Object {
public: public:
void VisitAttrs(AttrVisitor* visitor) { void VisitAttrs(AttrVisitor* visitor) {
} }
...@@ -63,12 +63,12 @@ class MapNode : public Node { ...@@ -63,12 +63,12 @@ class MapNode : public Node {
ContainerType data; ContainerType data;
static constexpr const char* _type_key = "Map"; static constexpr const char* _type_key = "Map";
TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object);
}; };
/*! \brief specialized map node with string as key */ /*! \brief specialized map node with string as key */
class StrMapNode : public Node { class StrMapNode : public Object {
public: public:
/*! \brief The corresponding conatiner type */ /*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<std::string, ObjectRef>; using ContainerType = std::unordered_map<std::string, ObjectRef>;
...@@ -80,7 +80,7 @@ class StrMapNode : public Node { ...@@ -80,7 +80,7 @@ class StrMapNode : public Node {
ContainerType data; ContainerType data;
static constexpr const char* _type_key = "StrMap"; static constexpr const char* _type_key = "StrMap";
TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Object);
}; };
/*! /*!
...@@ -138,13 +138,13 @@ class IterAdapter { ...@@ -138,13 +138,13 @@ class IterAdapter {
*/ */
template<typename T, template<typename T,
typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type > typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type >
class Array : public NodeRef { class Array : public ObjectRef {
public: public:
/*! /*!
* \brief default constructor * \brief default constructor
*/ */
Array() { Array() {
data_ = make_node<ArrayNode>(); data_ = make_object<ArrayNode>();
} }
/*! /*!
* \brief move constructor * \brief move constructor
...@@ -164,7 +164,7 @@ class Array : public NodeRef { ...@@ -164,7 +164,7 @@ class Array : public NodeRef {
* \brief constructor from pointer * \brief constructor from pointer
* \param n the container pointer * \param n the container pointer
*/ */
explicit Array(ObjectPtr<Object> n) : NodeRef(n) {} explicit Array(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief constructor from iterator * \brief constructor from iterator
* \param begin begin of iterator * \param begin begin of iterator
...@@ -195,7 +195,7 @@ class Array : public NodeRef { ...@@ -195,7 +195,7 @@ class Array : public NodeRef {
* \param val The init value * \param val The init value
*/ */
explicit Array(size_t n, const T& val) { explicit Array(size_t n, const T& val) {
auto tmp_node = make_node<ArrayNode>(); auto tmp_node = make_object<ArrayNode>();
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
tmp_node->data.push_back(val); tmp_node->data.push_back(val);
} }
...@@ -227,7 +227,7 @@ class Array : public NodeRef { ...@@ -227,7 +227,7 @@ class Array : public NodeRef {
*/ */
template<typename IterType> template<typename IterType>
void assign(IterType begin, IterType end) { void assign(IterType begin, IterType end) {
auto n = make_node<ArrayNode>(); auto n = make_object<ArrayNode>();
for (IterType it = begin; it != end; ++it) { for (IterType it = begin; it != end; ++it) {
n->data.push_back(T(*it)); n->data.push_back(T(*it));
} }
...@@ -257,7 +257,7 @@ class Array : public NodeRef { ...@@ -257,7 +257,7 @@ class Array : public NodeRef {
*/ */
inline ArrayNode* CopyOnWrite() { inline ArrayNode* CopyOnWrite() {
if (data_.get() == nullptr || !data_.unique()) { if (data_.get() == nullptr || !data_.unique()) {
NodePtr<ArrayNode> n = make_node<ArrayNode>(); ObjectPtr<ArrayNode> n = make_object<ArrayNode>();
n->data = static_cast<ArrayNode*>(data_.get())->data; n->data = static_cast<ArrayNode*>(data_.get())->data;
ObjectPtr<Object>(std::move(n)).swap(data_); ObjectPtr<Object>(std::move(n)).swap(data_);
} }
...@@ -333,13 +333,13 @@ template<typename K, ...@@ -333,13 +333,13 @@ template<typename K,
std::is_base_of<ObjectRef, K>::value || std::is_base_of<ObjectRef, K>::value ||
std::is_base_of<std::string, K>::value >::type, std::is_base_of<std::string, K>::value >::type,
typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type> typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
class Map : public NodeRef { class Map : public ObjectRef {
public: public:
/*! /*!
* \brief default constructor * \brief default constructor
*/ */
Map() { Map() {
data_ = make_node<MapNode>(); data_ = make_object<MapNode>();
} }
/*! /*!
* \brief move constructor * \brief move constructor
...@@ -352,13 +352,13 @@ class Map : public NodeRef { ...@@ -352,13 +352,13 @@ class Map : public NodeRef {
* \brief copy constructor * \brief copy constructor
* \param other source * \param other source
*/ */
Map(const Map<K, V> &other) : NodeRef(other.data_) { // NOLINT(*) Map(const Map<K, V> &other) : ObjectRef(other.data_) { // NOLINT(*)
} }
/*! /*!
* \brief constructor from pointer * \brief constructor from pointer
* \param n the container pointer * \param n the container pointer
*/ */
explicit Map(ObjectPtr<Object> n) : NodeRef(n) {} explicit Map(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief constructor from iterator * \brief constructor from iterator
* \param begin begin of iterator * \param begin begin of iterator
...@@ -410,7 +410,7 @@ class Map : public NodeRef { ...@@ -410,7 +410,7 @@ class Map : public NodeRef {
*/ */
template<typename IterType> template<typename IterType>
void assign(IterType begin, IterType end) { void assign(IterType begin, IterType end) {
NodePtr<MapNode> n = make_node<MapNode>(); ObjectPtr<MapNode> n = make_object<MapNode>();
for (IterType i = begin; i != end; ++i) { for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first, i->second)); n->data.emplace(std::make_pair(i->first, i->second));
} }
...@@ -454,7 +454,7 @@ class Map : public NodeRef { ...@@ -454,7 +454,7 @@ class Map : public NodeRef {
*/ */
inline MapNode* CopyOnWrite() { inline MapNode* CopyOnWrite() {
if (data_.get() == nullptr || !data_.unique()) { if (data_.get() == nullptr || !data_.unique()) {
NodePtr<MapNode> n = make_node<MapNode>(); ObjectPtr<MapNode> n = make_object<MapNode>();
n->data = static_cast<const MapNode*>(data_.get())->data; n->data = static_cast<const MapNode*>(data_.get())->data;
ObjectPtr<Object>(std::move(n)).swap(data_); ObjectPtr<Object>(std::move(n)).swap(data_);
} }
...@@ -507,18 +507,18 @@ class Map : public NodeRef { ...@@ -507,18 +507,18 @@ class Map : public NodeRef {
// specialize of string map // specialize of string map
template<typename V, typename T1, typename T2> template<typename V, typename T1, typename T2>
class Map<std::string, V, T1, T2> : public NodeRef { class Map<std::string, V, T1, T2> : public ObjectRef {
public: public:
// for code reuse // for code reuse
Map() { Map() {
data_ = make_node<StrMapNode>(); data_ = make_object<StrMapNode>();
} }
Map(Map<std::string, V> && other) { // NOLINT(*) Map(Map<std::string, V> && other) { // NOLINT(*)
data_ = std::move(other.data_); data_ = std::move(other.data_);
} }
Map(const Map<std::string, V> &other) : NodeRef(other.data_) { // NOLINT(*) Map(const Map<std::string, V> &other) : ObjectRef(other.data_) { // NOLINT(*)
} }
explicit Map(ObjectPtr<Object> n) : NodeRef(n) {} explicit Map(ObjectPtr<Object> n) : ObjectRef(n) {}
template<typename IterType> template<typename IterType>
Map(IterType begin, IterType end) { Map(IterType begin, IterType end) {
assign(begin, end); assign(begin, end);
...@@ -541,7 +541,7 @@ class Map<std::string, V, T1, T2> : public NodeRef { ...@@ -541,7 +541,7 @@ class Map<std::string, V, T1, T2> : public NodeRef {
} }
template<typename IterType> template<typename IterType>
void assign(IterType begin, IterType end) { void assign(IterType begin, IterType end) {
auto n = make_node<StrMapNode>(); auto n = make_object<StrMapNode>();
for (IterType i = begin; i != end; ++i) { for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first, i->second)); n->data.emplace(std::make_pair(i->first, i->second));
} }
...@@ -565,7 +565,7 @@ class Map<std::string, V, T1, T2> : public NodeRef { ...@@ -565,7 +565,7 @@ class Map<std::string, V, T1, T2> : public NodeRef {
} }
inline StrMapNode* CopyOnWrite() { inline StrMapNode* CopyOnWrite() {
if (data_.get() == nullptr || !data_.unique()) { if (data_.get() == nullptr || !data_.unique()) {
NodePtr<StrMapNode> n = make_node<StrMapNode>(); ObjectPtr<StrMapNode> n = make_object<StrMapNode>();
n->data = static_cast<const StrMapNode*>(data_.get())->data; n->data = static_cast<const StrMapNode*>(data_.get())->data;
ObjectPtr<Object>(std::move(n)).swap(data_); ObjectPtr<Object>(std::move(n)).swap(data_);
} }
......
...@@ -56,105 +56,5 @@ using runtime::ObjectHash; ...@@ -56,105 +56,5 @@ using runtime::ObjectHash;
using runtime::ObjectEqual; using runtime::ObjectEqual;
using runtime::make_object; using runtime::make_object;
using NodeHash = ObjectHash;
using NodeEqual = ObjectEqual;
using Node = Object;
/*!
* \brief Base class of all references to AST/IR nodes.
*/
class NodeRef : public ObjectRef {
public:
NodeRef() {}
explicit NodeRef(ObjectPtr<Object> n) : ObjectRef(n) {}
};
/*!
* \brief Allocate a node object.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
* \note This function is an alias of make_object.
*/
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args) {
return runtime::make_object<T>(std::forward<Args>(args)...);
}
/*!
* \brief helper macro to declare type information in a base node.
*/
#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \
TVM_DECLARE_BASE_OBJECT_INFO(TypeName, Parent)
/*!
* \brief helper macro to declare type information in a terminal node
*/
#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \
TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, Parent);
/*!
* \brief Macro to define common node ref methods.
* \param TypeName The name of the NodeRef.
* \param BaseTypeName The Base type.
* \param NodeName The node container type.
*/
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
TypeName() {} \
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
: BaseTypeName(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(data_.get()); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = NodeName;
/*!
* \brief Macro to define CopyOnWrite function in a NodeRef.
* \param NodeName The Type of the Node.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWNodeRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
NodeName* CopyOnWrite() { \
CHECK(data_ != nullptr); \
if (!data_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<NodeName*>(data_.get()); \
}
/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public ::tvm::NodeRef { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
}; \
/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
TVM_DEFINE_NODE_REF_COW(NodeName); \
};
} // namespace tvm } // namespace tvm
#endif // TVM_NODE_NODE_H_ #endif // TVM_NODE_NODE_H_
...@@ -60,7 +60,7 @@ class OperationNode : public ir::FunctionBaseNode { ...@@ -60,7 +60,7 @@ class OperationNode : public ir::FunctionBaseNode {
/*! \brief optional tag of the operation */ /*! \brief optional tag of the operation */
std::string tag; std::string tag;
/*! \brief additional attributes of the operation*/ /*! \brief additional attributes of the operation*/
Map<std::string, NodeRef> attrs; Map<std::string, ObjectRef> attrs;
/*! \return name of the operation */ /*! \return name of the operation */
const std::string& func_name() const final { const std::string& func_name() const final {
return name; return name;
...@@ -149,7 +149,7 @@ class OperationNode : public ir::FunctionBaseNode { ...@@ -149,7 +149,7 @@ class OperationNode : public ir::FunctionBaseNode {
static constexpr const char* _type_key = "Operation"; static constexpr const char* _type_key = "Operation";
TVM_DECLARE_BASE_NODE_INFO(OperationNode, Node); TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object);
}; };
/*! /*!
...@@ -200,7 +200,7 @@ class PlaceholderOpNode : public OperationNode { ...@@ -200,7 +200,7 @@ class PlaceholderOpNode : public OperationNode {
DataType dtype); DataType dtype);
static constexpr const char* _type_key = "PlaceholderOp"; static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode); TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
}; };
/*! /*!
...@@ -228,7 +228,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { ...@@ -228,7 +228,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode {
virtual size_t num_schedulable_dims() const = 0; virtual size_t num_schedulable_dims() const = 0;
static constexpr const char* _type_key = "BaseComputeOp"; static constexpr const char* _type_key = "BaseComputeOp";
TVM_DECLARE_BASE_NODE_INFO(BaseComputeOpNode, OperationNode); TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode);
}; };
...@@ -269,12 +269,12 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { ...@@ -269,12 +269,12 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
} }
static Operation make(std::string name, static Operation make(std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs, Map<std::string, ObjectRef> attrs,
Array<IterVar> axis, Array<IterVar> axis,
Array<Expr> body); Array<Expr> body);
static constexpr const char* _type_key = "ComputeOp"; static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, BaseComputeOpNode); TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
}; };
/*! /*!
...@@ -334,7 +334,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { ...@@ -334,7 +334,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
Array<Expr> scalar_inputs); Array<Expr> scalar_inputs);
static constexpr const char* _type_key = "TensorComputeOp"; static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, BaseComputeOpNode); TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
}; };
/*! /*!
...@@ -407,7 +407,7 @@ class ScanOpNode : public OperationNode { ...@@ -407,7 +407,7 @@ class ScanOpNode : public OperationNode {
} }
static Operation make(std::string name, static Operation make(std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs, Map<std::string, ObjectRef> attrs,
IterVar axis, IterVar axis,
Array<Tensor> init, Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
...@@ -415,7 +415,7 @@ class ScanOpNode : public OperationNode { ...@@ -415,7 +415,7 @@ class ScanOpNode : public OperationNode {
Array<Tensor> input); Array<Tensor> input);
static constexpr const char* _type_key = "ScanOp"; static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode); TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
}; };
/*! /*!
...@@ -472,14 +472,14 @@ class ExternOpNode : public OperationNode { ...@@ -472,14 +472,14 @@ class ExternOpNode : public OperationNode {
} }
TVM_DLL static Operation make(std::string name, TVM_DLL static Operation make(std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs, Map<std::string, ObjectRef> attrs,
Array<Tensor> inputs, Array<Tensor> inputs,
Array<Buffer> input_placeholders, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Array<Buffer> output_placeholders,
Stmt body); Stmt body);
static constexpr const char* _type_key = "ExternOp"; static constexpr const char* _type_key = "ExternOp";
TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode); TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
}; };
/*! /*!
...@@ -540,13 +540,13 @@ class HybridOpNode : public OperationNode { ...@@ -540,13 +540,13 @@ class HybridOpNode : public OperationNode {
} }
TVM_DLL static Operation make(std::string name, TVM_DLL static Operation make(std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs, Map<std::string, ObjectRef> attrs,
Array<Tensor> inputs, Array<Tensor> inputs,
Array<Tensor> outputs, Array<Tensor> outputs,
Stmt body); Stmt body);
static constexpr const char* _type_key = "HybridOp"; static constexpr const char* _type_key = "HybridOp";
TVM_DECLARE_NODE_TYPE_INFO(HybridOpNode, OperationNode); TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
}; };
/*! \brief The compute function to specify the input source of a Tensor */ /*! \brief The compute function to specify the input source of a Tensor */
...@@ -578,7 +578,7 @@ TVM_DLL Tensor compute(Array<Expr> shape, ...@@ -578,7 +578,7 @@ TVM_DLL Tensor compute(Array<Expr> shape,
FCompute fcompute, FCompute fcompute,
std::string name = "tensor", std::string name = "tensor",
std::string tag = "", std::string tag = "",
Map<std::string, NodeRef> attrs = {}); Map<std::string, ObjectRef> attrs = {});
/*! /*!
* \brief Construct a new tensor by computing over shape, * \brief Construct a new tensor by computing over shape,
...@@ -593,7 +593,7 @@ TVM_DLL Array<Tensor> compute(Array<Expr> shape, ...@@ -593,7 +593,7 @@ TVM_DLL Array<Tensor> compute(Array<Expr> shape,
FBatchCompute fcompute, FBatchCompute fcompute,
std::string name = "tensor", std::string name = "tensor",
std::string tag = "", std::string tag = "",
Map<std::string, NodeRef> attrs = {}); Map<std::string, ObjectRef> attrs = {});
/*! /*!
* \brief Construct new tensors by scan. * \brief Construct new tensors by scan.
...@@ -613,14 +613,14 @@ TVM_DLL Array<Tensor> scan(Array<Tensor> init, ...@@ -613,14 +613,14 @@ TVM_DLL Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> inputs = Array<Tensor>(), Array<Tensor> inputs = Array<Tensor>(),
std::string name = "scan", std::string name = "scan",
std::string tag = "", std::string tag = "",
Map<std::string, NodeRef> attrs = {}); Map<std::string, ObjectRef> attrs = {});
// same as compute, specialized for different fcompute function // same as compute, specialized for different fcompute function
inline Tensor compute(Array<Expr> shape, inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var)> f, std::function<Expr(Var)> f,
std::string name = "tensor", std::string name = "tensor",
std::string tag = "", std::string tag = "",
Map<std::string, NodeRef> attrs = {}) { Map<std::string, ObjectRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return compute(shape, fc, name, tag, attrs); return compute(shape, fc, name, tag, attrs);
} }
...@@ -628,7 +628,7 @@ inline Tensor compute(Array<Expr> shape, ...@@ -628,7 +628,7 @@ inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f, std::function<Expr(Var, Var)> f,
std::string name = "tensor", std::string name = "tensor",
std::string tag = "", std::string tag = "",
Map<std::string, NodeRef> attrs = {}) { Map<std::string, ObjectRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return compute(shape, fc, name, tag, attrs); return compute(shape, fc, name, tag, attrs);
} }
...@@ -636,7 +636,7 @@ inline Tensor compute(Array<Expr> shape, ...@@ -636,7 +636,7 @@ inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f, std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor", std::string name = "tensor",
std::string tag = "", std::string tag = "",
Map<std::string, NodeRef> attrs = {}) { Map<std::string, ObjectRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return compute(shape, fc, name, tag, attrs); return compute(shape, fc, name, tag, attrs);
} }
...@@ -644,7 +644,7 @@ inline Tensor compute(Array<Expr> shape, ...@@ -644,7 +644,7 @@ inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f, std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor", std::string name = "tensor",
std::string tag = "", std::string tag = "",
Map<std::string, NodeRef> attrs = {}) { Map<std::string, ObjectRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return compute(shape, fc, name, tag, attrs); return compute(shape, fc, name, tag, attrs);
} }
......
...@@ -115,15 +115,15 @@ inline TVMPODValue_::operator tvm::Expr() const { ...@@ -115,15 +115,15 @@ inline TVMPODValue_::operator tvm::Expr() const {
Object* ptr = static_cast<Object*>(value_.v_handle); Object* ptr = static_cast<Object*>(value_.v_handle);
if (ptr->IsInstance<IterVarNode>()) { if (ptr->IsInstance<IterVarNode>()) {
return IterVar(ObjectPtr<Node>(ptr))->var; return IterVar(ObjectPtr<Object>(ptr))->var;
} }
if (ptr->IsInstance<TensorNode>()) { if (ptr->IsInstance<TensorNode>()) {
return Tensor(ObjectPtr<Node>(ptr))(); return Tensor(ObjectPtr<Object>(ptr))();
} }
CHECK(ObjectTypeChecker<Expr>::Check(ptr)) CHECK(ObjectTypeChecker<Expr>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<Expr>::TypeName() << "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey(); << " but get " << ptr->GetTypeKey();
return Expr(ObjectPtr<Node>(ptr)); return Expr(ObjectPtr<Object>(ptr));
} }
inline TVMPODValue_::operator tvm::Integer() const { inline TVMPODValue_::operator tvm::Integer() const {
...@@ -138,7 +138,7 @@ inline TVMPODValue_::operator tvm::Integer() const { ...@@ -138,7 +138,7 @@ inline TVMPODValue_::operator tvm::Integer() const {
CHECK(ObjectTypeChecker<Integer>::Check(ptr)) CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<Expr>::TypeName() << "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey(); << " but get " << ptr->GetTypeKey();
return Integer(ObjectPtr<Node>(ptr)); return Integer(ObjectPtr<Object>(ptr));
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -38,7 +38,7 @@ namespace relay { ...@@ -38,7 +38,7 @@ namespace relay {
class PatternNode : public RelayNode { class PatternNode : public RelayNode {
public: public:
static constexpr const char* _type_key = "relay.Pattern"; static constexpr const char* _type_key = "relay.Pattern";
TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node); TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object);
}; };
/*! /*!
...@@ -49,10 +49,10 @@ class PatternNode : public RelayNode { ...@@ -49,10 +49,10 @@ class PatternNode : public RelayNode {
* *
* ADT pattern matching thus takes a list of values and binds to the first that accepts the value. * ADT pattern matching thus takes a list of values and binds to the first that accepts the value.
*/ */
class Pattern : public NodeRef { class Pattern : public ObjectRef {
public: public:
Pattern() {} Pattern() {}
explicit Pattern(ObjectPtr<tvm::Object> p) : NodeRef(p) {} explicit Pattern(ObjectPtr<tvm::Object> p) : ObjectRef(p) {}
using ContainerType = PatternNode; using ContainerType = PatternNode;
}; };
...@@ -71,10 +71,13 @@ class PatternWildcardNode : public PatternNode { ...@@ -71,10 +71,13 @@ class PatternWildcardNode : public PatternNode {
} }
static constexpr const char* _type_key = "relay.PatternWildcard"; static constexpr const char* _type_key = "relay.PatternWildcard";
TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode); TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode);
}; };
RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern); class PatternWildcard : public Pattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PatternWildcard, Pattern, PatternWildcardNode);
};
/*! \brief A var pattern. Accept all input and bind to a var. */ /*! \brief A var pattern. Accept all input and bind to a var. */
class PatternVar; class PatternVar;
...@@ -94,10 +97,13 @@ class PatternVarNode : public PatternNode { ...@@ -94,10 +97,13 @@ class PatternVarNode : public PatternNode {
} }
static constexpr const char* _type_key = "relay.PatternVar"; static constexpr const char* _type_key = "relay.PatternVar";
TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode); TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode);
}; };
RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern); class PatternVar : public Pattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode);
};
/*! /*!
* \brief ADT constructor. * \brief ADT constructor.
...@@ -132,10 +138,13 @@ class ConstructorNode : public ExprNode { ...@@ -132,10 +138,13 @@ class ConstructorNode : public ExprNode {
} }
static constexpr const char* _type_key = "relay.Constructor"; static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr); class Constructor : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Constructor, Expr, ConstructorNode);
};
/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */ /*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
class PatternConstructor; class PatternConstructor;
...@@ -158,10 +167,13 @@ class PatternConstructorNode : public PatternNode { ...@@ -158,10 +167,13 @@ class PatternConstructorNode : public PatternNode {
} }
static constexpr const char* _type_key = "relay.PatternConstructor"; static constexpr const char* _type_key = "relay.PatternConstructor";
TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode); TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode);
}; };
RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern); class PatternConstructor : public Pattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PatternConstructor, Pattern, PatternConstructorNode);
};
/*! \brief A tuple pattern. Matches a tuple, binds recursively. */ /*! \brief A tuple pattern. Matches a tuple, binds recursively. */
class PatternTuple; class PatternTuple;
...@@ -181,10 +193,13 @@ class PatternTupleNode : public PatternNode { ...@@ -181,10 +193,13 @@ class PatternTupleNode : public PatternNode {
} }
static constexpr const char* _type_key = "relay.PatternTuple"; static constexpr const char* _type_key = "relay.PatternTuple";
TVM_DECLARE_NODE_TYPE_INFO(PatternTupleNode, PatternNode); TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode);
}; };
RELAY_DEFINE_NODE_REF(PatternTuple, PatternTupleNode, Pattern); class PatternTuple : public Pattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode);
};
/*! /*!
* \brief Stores all data for an Algebraic Data Type (ADT). * \brief Stores all data for an Algebraic Data Type (ADT).
...@@ -225,15 +240,18 @@ class TypeDataNode : public TypeNode { ...@@ -225,15 +240,18 @@ class TypeDataNode : public TypeNode {
tvm::Array<Constructor> constructors); tvm::Array<Constructor> constructors);
static constexpr const char* _type_key = "relay.TypeData"; static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
}; };
RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type); class TypeData : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode);
};
/*! \brief A clause in a match expression. */ /*! \brief A clause in a match expression. */
class Clause; class Clause;
/*! \brief Clause container node. */ /*! \brief Clause container node. */
class ClauseNode : public Node { class ClauseNode : public Object {
public: public:
/*! \brief The pattern the clause matches. */ /*! \brief The pattern the clause matches. */
Pattern lhs; Pattern lhs;
...@@ -248,10 +266,13 @@ class ClauseNode : public Node { ...@@ -248,10 +266,13 @@ class ClauseNode : public Node {
TVM_DLL static Clause make(Pattern lhs, Expr rhs); TVM_DLL static Clause make(Pattern lhs, Expr rhs);
static constexpr const char* _type_key = "relay.Clause"; static constexpr const char* _type_key = "relay.Clause";
TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object);
}; };
RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef); class Clause : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode);
};
/*! \brief ADT pattern matching exression. */ /*! \brief ADT pattern matching exression. */
class Match; class Match;
...@@ -280,10 +301,13 @@ class MatchNode : public ExprNode { ...@@ -280,10 +301,13 @@ class MatchNode : public ExprNode {
TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern, bool complete = true); TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern, bool complete = true);
static constexpr const char* _type_key = "relay.Match"; static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr); class Match : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Match, Expr, MatchNode);
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -196,7 +196,7 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> { ...@@ -196,7 +196,7 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
}; // struct SqueezeAttrs }; // struct SqueezeAttrs
struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> { struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
NodeRef indices_or_sections; ObjectRef indices_or_sections;
int axis; int axis;
TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") { TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
......
...@@ -54,53 +54,11 @@ namespace relay { ...@@ -54,53 +54,11 @@ namespace relay {
} }
/*! /*!
* \brief We always used NodeRef for referencing nodes.
*
* By default, NodeRef is a std::shared_ptr of node
*/
using NodeRef = tvm::NodeRef;
/*!
* \brief Content data type.
*/
using DataType = ::tvm::DataType;
/*!
* \brief Symbolic expression for tensor shape. * \brief Symbolic expression for tensor shape.
*/ */
using IndexExpr = ::tvm::Expr; using IndexExpr = ::tvm::Expr;
/*! /*!
* \brief Hash function for nodes.
* e.g. std::unordered_map<Expr, Value, NodeHash, NodeEqual>
*/
using NodeHash = ::tvm::NodeHash;
/*!
* \brief Equality check function for nodes.
*/
using NodeEqual = ::tvm::NodeEqual;
/*!
* \brief Macro to make it easy to define node ref type given node
* \param TypeName The name of the reference type.
* \param NodeName The internal container name.
* \param NodeRefBase The base type.
*/
#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \
class TypeName : public NodeRefBase { \
public: \
TypeName() {} \
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
: NodeRefBase(n) { \
} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(get()); \
} \
operator bool() { return this->defined(); } \
using ContainerType = NodeName; \
};
/*!
* \brief The source name in the Span * \brief The source name in the Span
* \sa SourceNameNode, Span * \sa SourceNameNode, Span
*/ */
...@@ -108,7 +66,7 @@ class SourceName; ...@@ -108,7 +66,7 @@ class SourceName;
/*! /*!
* \brief The name of a source fragment. * \brief The name of a source fragment.
*/ */
class SourceNameNode : public Node { class SourceNameNode : public Object {
public: public:
/*! \brief The source name. */ /*! \brief The source name. */
std::string name; std::string name;
...@@ -116,20 +74,20 @@ class SourceNameNode : public Node { ...@@ -116,20 +74,20 @@ class SourceNameNode : public Node {
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
static constexpr const char* _type_key = "relay.SourceName"; static constexpr const char* _type_key = "relay.SourceName";
TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
}; };
/*! /*!
* \brief The source name of a file span. * \brief The source name of a file span.
* \sa SourceNameNode, Span * \sa SourceNameNode, Span
*/ */
class SourceName : public NodeRef { class SourceName : public ObjectRef {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
SourceName() {} SourceName() {}
/*! \brief constructor from node pointer */ /*! \brief constructor from node pointer */
explicit SourceName(NodePtr<Node> n) : NodeRef(n) {} explicit SourceName(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -157,7 +115,7 @@ class Span; ...@@ -157,7 +115,7 @@ class Span;
/*! /*!
* \brief Stores locations in frontend source that generated a node. * \brief Stores locations in frontend source that generated a node.
*/ */
class SpanNode : public Node { class SpanNode : public Object {
public: public:
/*! \brief The source name */ /*! \brief The source name */
SourceName source; SourceName source;
...@@ -175,22 +133,25 @@ class SpanNode : public Node { ...@@ -175,22 +133,25 @@ class SpanNode : public Node {
TVM_DLL static Span make(SourceName source, int lineno, int col_offset); TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
static constexpr const char* _type_key = "relay.Span"; static constexpr const char* _type_key = "relay.Span";
TVM_DECLARE_NODE_TYPE_INFO(SpanNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
}; };
RELAY_DEFINE_NODE_REF(Span, SpanNode, NodeRef); class Span : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
};
/*! /*!
* \brief This is the base node container of all relay structures. * \brief This is the base node container of all relay structures.
*/ */
class RelayNode : public Node { class RelayNode : public Object {
public: public:
/*! \brief The location of the program in a SourceFragment can be null, /*! \brief The location of the program in a SourceFragment can be null,
* check with span.defined() */ * check with span.defined() */
mutable Span span; mutable Span span;
static constexpr const char* _type_key = "relay.Node"; static constexpr const char* _type_key = "relay.Node";
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node); TVM_DECLARE_BASE_OBJECT_INFO(RelayNode, Object);
}; };
/*! /*!
...@@ -201,7 +162,7 @@ class RelayNode : public Node { ...@@ -201,7 +162,7 @@ class RelayNode : public Node {
* *
* \note Do not create Id directly, they are created in Var. * \note Do not create Id directly, they are created in Var.
*/ */
class IdNode : public Node { class IdNode : public Object {
public: public:
/*! /*!
* \brief The name of the variable, * \brief The name of the variable,
...@@ -215,10 +176,13 @@ class IdNode : public Node { ...@@ -215,10 +176,13 @@ class IdNode : public Node {
} }
static constexpr const char* _type_key = "relay.Id"; static constexpr const char* _type_key = "relay.Id";
TVM_DECLARE_NODE_TYPE_INFO(IdNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object);
}; };
RELAY_DEFINE_NODE_REF(Id, IdNode, NodeRef); class Id : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode);
};
struct Module; struct Module;
......
...@@ -118,7 +118,7 @@ class ErrorReporter { ...@@ -118,7 +118,7 @@ class ErrorReporter {
* \param node The expression or type to report the error at. * \param node The expression or type to report the error at.
* \param err The error message to report. * \param err The error message to report.
*/ */
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) { inline void ReportAt(const GlobalVar& global, const ObjectRef& node, std::stringstream& err) {
std::string err_msg = err.str(); std::string err_msg = err.str();
this->ReportAt(global, node, Error(err_msg)); this->ReportAt(global, node, Error(err_msg));
} }
...@@ -134,7 +134,7 @@ class ErrorReporter { ...@@ -134,7 +134,7 @@ class ErrorReporter {
* \param node The expression or type to report the error at. * \param node The expression or type to report the error at.
* \param err The error to report. * \param err The error to report.
*/ */
void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err); void ReportAt(const GlobalVar& global, const ObjectRef& node, const Error& err);
/*! \brief Render all reported errors and exit the program. /*! \brief Render all reported errors and exit the program.
* *
...@@ -154,8 +154,8 @@ class ErrorReporter { ...@@ -154,8 +154,8 @@ class ErrorReporter {
private: private:
std::vector<Error> errors_; std::vector<Error> errors_;
std::unordered_map<NodeRef, std::vector<size_t>, NodeHash, NodeEqual> node_to_error_; std::unordered_map<ObjectRef, std::vector<size_t>, ObjectHash, ObjectEqual> node_to_error_;
std::unordered_map<NodeRef, GlobalVar, NodeHash, NodeEqual> node_to_gv_; std::unordered_map<ObjectRef, GlobalVar, ObjectHash, ObjectEqual> node_to_gv_;
}; };
} // namespace relay } // namespace relay
......
...@@ -67,10 +67,13 @@ class ExprNode : public RelayNode { ...@@ -67,10 +67,13 @@ class ExprNode : public RelayNode {
inline const TTypeNode* type_as() const; inline const TTypeNode* type_as() const;
static constexpr const char* _type_key = "relay.Expr"; static constexpr const char* _type_key = "relay.Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode); TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, RelayNode);
}; };
RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef); class Expr : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Expr, ObjectRef, ExprNode);
};
/*! /*!
* \brief Constant tensor, backed by an NDArray on the cpu(0) device. * \brief Constant tensor, backed by an NDArray on the cpu(0) device.
...@@ -104,10 +107,13 @@ class ConstantNode : public ExprNode { ...@@ -104,10 +107,13 @@ class ConstantNode : public ExprNode {
TVM_DLL static Constant make(runtime::NDArray data); TVM_DLL static Constant make(runtime::NDArray data);
static constexpr const char* _type_key = "relay.Constant"; static constexpr const char* _type_key = "relay.Constant";
TVM_DECLARE_NODE_TYPE_INFO(ConstantNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(Constant, ConstantNode, Expr); class Constant : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Constant, Expr, ConstantNode);
};
/*! \brief Tuple of multiple Exprs */ /*! \brief Tuple of multiple Exprs */
class Tuple; class Tuple;
...@@ -126,10 +132,13 @@ class TupleNode : public ExprNode { ...@@ -126,10 +132,13 @@ class TupleNode : public ExprNode {
TVM_DLL static Tuple make(tvm::Array<relay::Expr> fields); TVM_DLL static Tuple make(tvm::Array<relay::Expr> fields);
static constexpr const char* _type_key = "relay.Tuple"; static constexpr const char* _type_key = "relay.Tuple";
TVM_DECLARE_NODE_TYPE_INFO(TupleNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr); class Tuple : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode);
};
/*! /*!
* \brief Local variables used in the let expression. * \brief Local variables used in the let expression.
...@@ -179,10 +188,13 @@ class VarNode : public ExprNode { ...@@ -179,10 +188,13 @@ class VarNode : public ExprNode {
Type type_annotation); Type type_annotation);
static constexpr const char* _type_key = "relay.Var"; static constexpr const char* _type_key = "relay.Var";
TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(Var, VarNode, Expr); class Var : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode);
};
/*! /*!
* \brief Global variable that leaves in the top-level module. * \brief Global variable that leaves in the top-level module.
...@@ -206,10 +218,13 @@ class GlobalVarNode : public ExprNode { ...@@ -206,10 +218,13 @@ class GlobalVarNode : public ExprNode {
TVM_DLL static GlobalVar make(std::string name_hint); TVM_DLL static GlobalVar make(std::string name_hint);
static constexpr const char* _type_key = "relay.GlobalVar"; static constexpr const char* _type_key = "relay.GlobalVar";
TVM_DECLARE_NODE_TYPE_INFO(GlobalVarNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr); class GlobalVar : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, Expr, GlobalVarNode);
};
/*! /*!
* \brief Function (subgraph in computational graph) * \brief Function (subgraph in computational graph)
...@@ -297,14 +312,19 @@ class FunctionNode : public ExprNode { ...@@ -297,14 +312,19 @@ class FunctionNode : public ExprNode {
tvm::Map<Var, Constant> GetParams() const; tvm::Map<Var, Constant> GetParams() const;
static constexpr const char* _type_key = "relay.Function"; static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); class Function : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode);
};
TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key); TVM_DLL ObjectRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data); TVM_DLL Function FunctionSetAttr(const Function& func,
const std::string& key,
const ObjectRef& data);
/*! /*!
* \brief Call corresponds to operator invocation. * \brief Call corresponds to operator invocation.
...@@ -363,10 +383,13 @@ class CallNode : public ExprNode { ...@@ -363,10 +383,13 @@ class CallNode : public ExprNode {
Array<Type> type_args = Array<Type>()); Array<Type> type_args = Array<Type>());
static constexpr const char* _type_key = "relay.Call"; static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(Call, CallNode, Expr); class Call : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode);
};
/*! /*!
* \brief Let binding that binds a local var and optionally a type annotation. * \brief Let binding that binds a local var and optionally a type annotation.
...@@ -401,10 +424,13 @@ class LetNode : public ExprNode { ...@@ -401,10 +424,13 @@ class LetNode : public ExprNode {
TVM_DLL static Let make(Var var, Expr value, Expr body); TVM_DLL static Let make(Var var, Expr value, Expr body);
static constexpr const char* _type_key = "relay.Let"; static constexpr const char* _type_key = "relay.Let";
TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(Let, LetNode, Expr); class Let : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Let, Expr, LetNode);
};
/*! /*!
* \brief Condition expression * \brief Condition expression
...@@ -439,10 +465,13 @@ class IfNode : public ExprNode { ...@@ -439,10 +465,13 @@ class IfNode : public ExprNode {
TVM_DLL static If make(Expr cond, Expr true_branch, Expr false_branch); TVM_DLL static If make(Expr cond, Expr true_branch, Expr false_branch);
static constexpr const char* _type_key = "relay.If"; static constexpr const char* _type_key = "relay.If";
TVM_DECLARE_NODE_TYPE_INFO(IfNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(If, IfNode, Expr); class If : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode);
};
/*! \brief Get index-th field out of a tuple. */ /*! \brief Get index-th field out of a tuple. */
class TupleGetItem; class TupleGetItem;
...@@ -463,10 +492,13 @@ class TupleGetItemNode : public ExprNode { ...@@ -463,10 +492,13 @@ class TupleGetItemNode : public ExprNode {
TVM_DLL static TupleGetItem make(Expr tuple, int index); TVM_DLL static TupleGetItem make(Expr tuple, int index);
static constexpr const char* _type_key = "relay.TupleGetItem"; static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr); class TupleGetItem : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode);
};
/*! \brief Create a new Reference out of initial value. */ /*! \brief Create a new Reference out of initial value. */
class RefCreate; class RefCreate;
...@@ -484,10 +516,13 @@ class RefCreateNode : public ExprNode { ...@@ -484,10 +516,13 @@ class RefCreateNode : public ExprNode {
TVM_DLL static RefCreate make(Expr value); TVM_DLL static RefCreate make(Expr value);
static constexpr const char* _type_key = "relay.RefCreate"; static constexpr const char* _type_key = "relay.RefCreate";
TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr); class RefCreate : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, Expr, RefCreateNode);
};
/*! \brief Get value out of Reference. */ /*! \brief Get value out of Reference. */
class RefRead; class RefRead;
...@@ -505,10 +540,13 @@ class RefReadNode : public ExprNode { ...@@ -505,10 +540,13 @@ class RefReadNode : public ExprNode {
TVM_DLL static RefRead make(Expr ref); TVM_DLL static RefRead make(Expr ref);
static constexpr const char* _type_key = "relay.RefRead"; static constexpr const char* _type_key = "relay.RefRead";
TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr); class RefRead : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefRead, Expr, RefReadNode);
};
/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */ /*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
class RefWrite; class RefWrite;
class RefWriteNode : public ExprNode { class RefWriteNode : public ExprNode {
...@@ -528,10 +566,13 @@ class RefWriteNode : public ExprNode { ...@@ -528,10 +566,13 @@ class RefWriteNode : public ExprNode {
TVM_DLL static RefWrite make(Expr ref, Expr value); TVM_DLL static RefWrite make(Expr ref, Expr value);
static constexpr const char* _type_key = "relay.RefWrite"; static constexpr const char* _type_key = "relay.RefWrite";
TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(RefWriteNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr); class RefWrite : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, Expr, RefWriteNode);
};
/*! /*!
* \brief Base class of the temporary expression. * \brief Base class of the temporary expression.
...@@ -554,10 +595,13 @@ class TempExprNode : public ExprNode { ...@@ -554,10 +595,13 @@ class TempExprNode : public ExprNode {
virtual Expr Realize() const = 0; virtual Expr Realize() const = 0;
static constexpr const char* _type_key = "relay.TempExpr"; static constexpr const char* _type_key = "relay.TempExpr";
TVM_DECLARE_BASE_NODE_INFO(TempExprNode, ExprNode); TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr); class TempExpr : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, Expr, TempExprNode);
};
// implementataions // implementataions
inline const Type& ExprNode::checked_type() const { inline const Type& ExprNode::checked_type() const {
...@@ -583,7 +627,7 @@ inline const TTypeNode* ExprNode::type_as() const { ...@@ -583,7 +627,7 @@ inline const TTypeNode* ExprNode::type_as() const {
} }
/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */ /*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
std::string PrettyPrint(const NodeRef& node); std::string PrettyPrint(const ObjectRef& node);
/*! /*!
* \brief Render the node as a string in the Relay text format. * \brief Render the node as a string in the Relay text format.
...@@ -593,7 +637,7 @@ std::string PrettyPrint(const NodeRef& node); ...@@ -593,7 +637,7 @@ std::string PrettyPrint(const NodeRef& node);
* additional comment block to an expr. * additional comment block to an expr.
* \return The text representation. * \return The text representation.
*/ */
std::string AsText(const NodeRef& node, std::string AsText(const ObjectRef& node,
bool show_meta_data = true, bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr); runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
......
...@@ -116,7 +116,7 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -116,7 +116,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) { virtual R VisitExprDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
throw; throw;
} }
...@@ -177,7 +177,7 @@ class ExprVisitor ...@@ -177,7 +177,7 @@ class ExprVisitor
protected: protected:
// Internal visiting counter // Internal visiting counter
std::unordered_map<const Node*, size_t> visit_counter_; std::unordered_map<const Object*, size_t> visit_counter_;
}; };
/*! /*!
...@@ -227,7 +227,7 @@ class ExprMutator ...@@ -227,7 +227,7 @@ class ExprMutator
protected: protected:
/*! \brief Internal map used for memoization. */ /*! \brief Internal map used for memoization. */
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_; std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> memo_;
}; };
/*! /*!
......
...@@ -72,13 +72,13 @@ CreateInterpreter(Module mod, DLContext context, Target target); ...@@ -72,13 +72,13 @@ CreateInterpreter(Module mod, DLContext context, Target target);
class ValueNode : public RelayNode { class ValueNode : public RelayNode {
public: public:
static constexpr const char* _type_key = "relay.Value"; static constexpr const char* _type_key = "relay.Value";
TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode); TVM_DECLARE_BASE_OBJECT_INFO(ValueNode, RelayNode);
}; };
class Value : public NodeRef { class Value : public ObjectRef {
public: public:
Value() {} Value() {}
explicit Value(ObjectPtr<Object> n) : NodeRef(n) {} explicit Value(ObjectPtr<Object> n) : ObjectRef(n) {}
const ValueNode* operator->() const { const ValueNode* operator->() const {
return static_cast<const ValueNode*>(get()); return static_cast<const ValueNode*>(get());
} }
...@@ -114,10 +114,13 @@ class ClosureNode : public ValueNode { ...@@ -114,10 +114,13 @@ class ClosureNode : public ValueNode {
TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func); TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func);
static constexpr const char* _type_key = "relay.Closure"; static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode); TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, ValueNode);
}; };
RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value); class Closure : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, Value, ClosureNode);
};
/*! \brief A Relay Recursive Closure. A closure that has a name. */ /*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure; class RecClosure;
...@@ -140,10 +143,13 @@ class RecClosureNode : public ValueNode { ...@@ -140,10 +143,13 @@ class RecClosureNode : public ValueNode {
TVM_DLL static RecClosure make(Closure clos, Var bind); TVM_DLL static RecClosure make(Closure clos, Var bind);
static constexpr const char* _type_key = "relay.RecClosure"; static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_NODE_TYPE_INFO(RecClosureNode, ValueNode); TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, ValueNode);
}; };
RELAY_DEFINE_NODE_REF(RecClosure, RecClosureNode, Value); class RecClosure : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, Value, RecClosureNode);
};
/*! \brief A tuple value. */ /*! \brief A tuple value. */
class TupleValue; class TupleValue;
...@@ -159,10 +165,13 @@ struct TupleValueNode : ValueNode { ...@@ -159,10 +165,13 @@ struct TupleValueNode : ValueNode {
TVM_DLL static TupleValue make(tvm::Array<Value> value); TVM_DLL static TupleValue make(tvm::Array<Value> value);
static constexpr const char* _type_key = "relay.TupleValue"; static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_NODE_TYPE_INFO(TupleValueNode, ValueNode); TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, ValueNode);
}; };
RELAY_DEFINE_NODE_REF(TupleValue, TupleValueNode, Value); class TupleValue : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, Value, TupleValueNode);
};
/*! \brief A tensor value. */ /*! \brief A tensor value. */
class TensorValue; class TensorValue;
...@@ -179,10 +188,13 @@ struct TensorValueNode : ValueNode { ...@@ -179,10 +188,13 @@ struct TensorValueNode : ValueNode {
TVM_DLL static TensorValue make(runtime::NDArray data); TVM_DLL static TensorValue make(runtime::NDArray data);
static constexpr const char* _type_key = "relay.TensorValue"; static constexpr const char* _type_key = "relay.TensorValue";
TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode); TVM_DECLARE_FINAL_OBJECT_INFO(TensorValueNode, ValueNode);
}; };
RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value); class TensorValue : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TensorValue, Value, TensorValueNode);
};
/*! \brief A reference value. */ /*! \brief A reference value. */
class RefValue; class RefValue;
...@@ -199,10 +211,13 @@ struct RefValueNode : ValueNode { ...@@ -199,10 +211,13 @@ struct RefValueNode : ValueNode {
TVM_DLL static RefValue make(Value val); TVM_DLL static RefValue make(Value val);
static constexpr const char* _type_key = "relay.RefValue"; static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_NODE_TYPE_INFO(RefValueNode, ValueNode); TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, ValueNode);
}; };
RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); class RefValue : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, Value, RefValueNode);
};
/*! \brief An ADT constructor value. */ /*! \brief An ADT constructor value. */
class ConstructorValue; class ConstructorValue;
...@@ -226,10 +241,13 @@ struct ConstructorValueNode : ValueNode { ...@@ -226,10 +241,13 @@ struct ConstructorValueNode : ValueNode {
Constructor construtor = {}); Constructor construtor = {});
static constexpr const char* _type_key = "relay.ConstructorValue"; static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode); TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, ValueNode);
}; };
RELAY_DEFINE_NODE_REF(ConstructorValue, ConstructorValueNode, Value); class ConstructorValue : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, Value, ConstructorValueNode);
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -258,7 +258,7 @@ class ModuleNode : public RelayNode { ...@@ -258,7 +258,7 @@ class ModuleNode : public RelayNode {
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {}); const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
static constexpr const char* _type_key = "relay.Module"; static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object);
private: private:
/*! \brief Helper function for registering a typedef's constructors */ /*! \brief Helper function for registering a typedef's constructors */
...@@ -285,9 +285,9 @@ class ModuleNode : public RelayNode { ...@@ -285,9 +285,9 @@ class ModuleNode : public RelayNode {
std::unordered_set<std::string> import_set_; std::unordered_set<std::string> import_set_;
}; };
struct Module : public NodeRef { struct Module : public ObjectRef {
Module() {} Module() {}
explicit Module(ObjectPtr<::tvm::Object> p) : NodeRef(p) {} explicit Module(ObjectPtr<::tvm::Object> p) : ObjectRef(p) {}
ModuleNode* operator->() const { ModuleNode* operator->() const {
return static_cast<ModuleNode*>(get_mutable()); return static_cast<ModuleNode*>(get_mutable());
......
...@@ -106,7 +106,7 @@ class OpNode : public relay::ExprNode { ...@@ -106,7 +106,7 @@ class OpNode : public relay::ExprNode {
} }
static constexpr const char* _type_key = "relay.Op"; static constexpr const char* _type_key = "relay.Op";
TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, ExprNode);
private: private:
// friend class // friend class
...@@ -431,7 +431,7 @@ inline OpRegistry& OpRegistry::describe( ...@@ -431,7 +431,7 @@ inline OpRegistry& OpRegistry::describe(
inline OpRegistry& OpRegistry::add_argument(const std::string& name, inline OpRegistry& OpRegistry::add_argument(const std::string& name,
const std::string& type, const std::string& type,
const std::string& description) { const std::string& description) {
auto n = make_node<AttrFieldInfoNode>(); auto n = make_object<AttrFieldInfoNode>();
n->name = name; n->name = name;
n->type_info = type; n->type_info = type;
n->description = description; n->description = description;
......
...@@ -180,7 +180,7 @@ using FTVMLegalize = runtime::TypedPackedFunc< ...@@ -180,7 +180,7 @@ using FTVMLegalize = runtime::TypedPackedFunc<
using FForwardRewrite = runtime::TypedPackedFunc< using FForwardRewrite = runtime::TypedPackedFunc<
Expr(const Call& ref_call, Expr(const Call& ref_call,
const Array<Expr>& new_args, const Array<Expr>& new_args,
const NodeRef& ctx)>; const ObjectRef& ctx)>;
/*! /*!
* \brief Gradient for a specific op. * \brief Gradient for a specific op.
......
...@@ -102,7 +102,7 @@ class PatternFunctor<R(const Pattern& n, Args...)> { ...@@ -102,7 +102,7 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
Args... args) PATTERN_FUNCTOR_DEFAULT; Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPattern_(const PatternTupleNode* op, virtual R VisitPattern_(const PatternTupleNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT; Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPatternDefault_(const Node* op, Args...) { virtual R VisitPatternDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
throw; throw;
} }
...@@ -162,7 +162,7 @@ class PatternMutator ...@@ -162,7 +162,7 @@ class PatternMutator
/*! \brief Used to visit the vars inside of patterns. */ /*! \brief Used to visit the vars inside of patterns. */
virtual Constructor VisitConstructor(const Constructor& c); virtual Constructor VisitConstructor(const Constructor& c);
private: private:
std::unordered_map<Var, Var, NodeHash, NodeEqual> var_map_; std::unordered_map<Var, Var, ObjectHash, ObjectEqual> var_map_;
}; };
} // namespace relay } // namespace relay
......
...@@ -109,7 +109,7 @@ class PassContextNode : public RelayNode { ...@@ -109,7 +109,7 @@ class PassContextNode : public RelayNode {
} }
static constexpr const char* _type_key = "relay.PassContext"; static constexpr const char* _type_key = "relay.PassContext";
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, RelayNode);
}; };
/*! /*!
...@@ -125,10 +125,10 @@ class PassContextNode : public RelayNode { ...@@ -125,10 +125,10 @@ class PassContextNode : public RelayNode {
* *
* \endcode * \endcode
*/ */
class PassContext : public NodeRef { class PassContext : public ObjectRef {
public: public:
PassContext() {} PassContext() {}
explicit PassContext(NodePtr<::tvm::Node> n) : NodeRef(n) {} explicit PassContext(ObjectPtr<::tvm::Object> n) : ObjectRef(n) {}
/*! /*!
* \brief const accessor. * \brief const accessor.
* \return const access pointer. * \return const access pointer.
...@@ -207,10 +207,13 @@ class PassInfoNode : public RelayNode { ...@@ -207,10 +207,13 @@ class PassInfoNode : public RelayNode {
tvm::Array<tvm::Expr> required); tvm::Array<tvm::Expr> required);
static constexpr const char* _type_key = "relay.PassInfo"; static constexpr const char* _type_key = "relay.PassInfo";
TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode); TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, RelayNode);
}; };
TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode) class PassInfo : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
class Pass; class Pass;
...@@ -251,10 +254,10 @@ class PassNode : public RelayNode { ...@@ -251,10 +254,10 @@ class PassNode : public RelayNode {
void VisitAttrs(tvm::AttrVisitor* v) {} void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.Pass"; static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode); TVM_DECLARE_BASE_OBJECT_INFO(PassNode, RelayNode);
}; };
class Pass : public NodeRef { class Pass : public ObjectRef {
public: public:
/*! /*!
* \brief Transform mod using the default PassContext in the current scope. * \brief Transform mod using the default PassContext in the current scope.
...@@ -283,7 +286,7 @@ class Pass : public NodeRef { ...@@ -283,7 +286,7 @@ class Pass : public NodeRef {
return node->operator()(mod, pass_ctx); return node->operator()(mod, pass_ctx);
} }
TVM_DEFINE_NODE_REF_METHODS(Pass, NodeRef, PassNode); TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);
}; };
class SequentialNode; class SequentialNode;
...@@ -309,7 +312,7 @@ class Sequential : public Pass { ...@@ -309,7 +312,7 @@ class Sequential : public Pass {
TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential"); TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential");
Sequential() = default; Sequential() = default;
explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {} explicit Sequential(tvm::ObjectPtr<::tvm::Object> n) : Pass(n) {}
const SequentialNode* operator->() const; const SequentialNode* operator->() const;
using ContainerType = Sequential; using ContainerType = Sequential;
...@@ -638,7 +641,7 @@ TVM_DLL Function InferType(const Function& f, ...@@ -638,7 +641,7 @@ TVM_DLL Function InferType(const Function& f,
*/ */
TVM_DLL Expr ForwardRewrite(const Expr& expr, TVM_DLL Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_attr_name, const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr, std::function<ObjectRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr); std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*! /*!
...@@ -655,7 +658,7 @@ TVM_DLL Expr ForwardRewrite(const Expr& expr, ...@@ -655,7 +658,7 @@ TVM_DLL Expr ForwardRewrite(const Expr& expr,
*/ */
TVM_DLL Expr ForwardRewrite(const Expr& expr, TVM_DLL Expr ForwardRewrite(const Expr& expr,
const FForwardRewrite& rewrite_func, const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr, std::function<ObjectRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr); std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*! /*!
......
...@@ -41,7 +41,7 @@ using Any = tvm::ir::Any; ...@@ -41,7 +41,7 @@ using Any = tvm::ir::Any;
class TypeNode : public RelayNode { class TypeNode : public RelayNode {
public: public:
static constexpr const char* _type_key = "relay.Type"; static constexpr const char* _type_key = "relay.Type";
TVM_DECLARE_BASE_NODE_INFO(TypeNode, Node); TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
}; };
/*! /*!
...@@ -55,10 +55,10 @@ class TypeNode : public RelayNode { ...@@ -55,10 +55,10 @@ class TypeNode : public RelayNode {
* There are also advanced types to support generic(polymorphic types), * There are also advanced types to support generic(polymorphic types),
* which can be ignored when first reading the code base. * which can be ignored when first reading the code base.
*/ */
class Type : public NodeRef { class Type : public ObjectRef {
public: public:
Type() {} Type() {}
explicit Type(ObjectPtr<tvm::Object> p) : NodeRef(p) {} explicit Type(ObjectPtr<tvm::Object> p) : ObjectRef(p) {}
using ContainerType = TypeNode; using ContainerType = TypeNode;
}; };
...@@ -70,10 +70,13 @@ class Type : public NodeRef { ...@@ -70,10 +70,13 @@ class Type : public NodeRef {
class BaseTensorTypeNode : public TypeNode { class BaseTensorTypeNode : public TypeNode {
public: public:
static constexpr const char* _type_key = "relay.BaseTensorType"; static constexpr const char* _type_key = "relay.BaseTensorType";
TVM_DECLARE_BASE_NODE_INFO(BaseTensorTypeNode, TypeNode); TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
}; };
RELAY_DEFINE_NODE_REF(BaseTensorType, BaseTensorTypeNode, Type); class BaseTensorType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode);
};
/*! /*!
* \brief This is the most commonly used type in relay. * \brief This is the most commonly used type in relay.
...@@ -113,10 +116,13 @@ class TensorTypeNode : public BaseTensorTypeNode { ...@@ -113,10 +116,13 @@ class TensorTypeNode : public BaseTensorTypeNode {
TVM_DLL static TensorType Scalar(DataType dtype); TVM_DLL static TensorType Scalar(DataType dtype);
static constexpr const char* _type_key = "relay.TensorType"; static constexpr const char* _type_key = "relay.TensorType";
TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
}; };
RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); class TensorType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode);
};
/*! \brief Possible kinds of Type. */ /*! \brief Possible kinds of Type. */
enum Kind : int { enum Kind : int {
...@@ -168,10 +174,13 @@ class TypeVarNode : public TypeNode { ...@@ -168,10 +174,13 @@ class TypeVarNode : public TypeNode {
TVM_DLL static TypeVar make(std::string name, Kind kind); TVM_DLL static TypeVar make(std::string name, Kind kind);
static constexpr const char* _type_key = "relay.TypeVar"; static constexpr const char* _type_key = "relay.TypeVar";
TVM_DECLARE_NODE_TYPE_INFO(TypeVarNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
}; };
RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type); class TypeVar : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode);
};
/*! /*!
* \brief A global type variable that is used for defining new types or type aliases. * \brief A global type variable that is used for defining new types or type aliases.
...@@ -197,10 +206,13 @@ class GlobalTypeVarNode : public TypeNode { ...@@ -197,10 +206,13 @@ class GlobalTypeVarNode : public TypeNode {
TVM_DLL static GlobalTypeVar make(std::string name, Kind kind); TVM_DLL static GlobalTypeVar make(std::string name, Kind kind);
static constexpr const char* _type_key = "relay.GlobalTypeVar"; static constexpr const char* _type_key = "relay.GlobalTypeVar";
TVM_DECLARE_NODE_TYPE_INFO(GlobalTypeVarNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
}; };
RELAY_DEFINE_NODE_REF(GlobalTypeVar, GlobalTypeVarNode, Type); class GlobalTypeVar : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode);
};
/*! /*!
* \brief Type application. * \brief Type application.
...@@ -225,10 +237,13 @@ class TypeCallNode : public TypeNode { ...@@ -225,10 +237,13 @@ class TypeCallNode : public TypeNode {
TVM_DLL static TypeCall make(Type func, tvm::Array<Type> args); TVM_DLL static TypeCall make(Type func, tvm::Array<Type> args);
static constexpr const char* _type_key = "relay.TypeCall"; static constexpr const char* _type_key = "relay.TypeCall";
TVM_DECLARE_NODE_TYPE_INFO(TypeCallNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
}; };
RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); class TypeCall : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeCall, Type, TypeCallNode);
};
/*! /*!
* \brief IncompleteType. * \brief IncompleteType.
...@@ -253,10 +268,13 @@ class IncompleteTypeNode : public TypeNode { ...@@ -253,10 +268,13 @@ class IncompleteTypeNode : public TypeNode {
TVM_DLL static IncompleteType make(Kind kind); TVM_DLL static IncompleteType make(Kind kind);
static constexpr const char* _type_key = "relay.IncompleteType"; static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
}; };
RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type); class IncompleteType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode);
};
/*! /*!
* \brief Potential Constraints in the type. * \brief Potential Constraints in the type.
...@@ -267,10 +285,13 @@ class TypeConstraint; ...@@ -267,10 +285,13 @@ class TypeConstraint;
class TypeConstraintNode : public TypeNode { class TypeConstraintNode : public TypeNode {
public: public:
static constexpr const char* _type_key = "relay.TypeConstraint"; static constexpr const char* _type_key = "relay.TypeConstraint";
TVM_DECLARE_BASE_NODE_INFO(TypeConstraintNode, TypeNode); TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
}; };
RELAY_DEFINE_NODE_REF(TypeConstraint, TypeConstraintNode, Type); class TypeConstraint : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode);
};
class FuncType; class FuncType;
/*! /*!
...@@ -311,10 +332,13 @@ class FuncTypeNode : public TypeNode { ...@@ -311,10 +332,13 @@ class FuncTypeNode : public TypeNode {
tvm::Array<TypeConstraint> type_constraints); tvm::Array<TypeConstraint> type_constraints);
static constexpr const char* _type_key = "relay.FuncType"; static constexpr const char* _type_key = "relay.FuncType";
TVM_DECLARE_NODE_TYPE_INFO(FuncTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
}; };
RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); class FuncType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
};
/*! /*!
* \brief The type of tuple values. * \brief The type of tuple values.
...@@ -338,10 +362,13 @@ class TupleTypeNode : public TypeNode { ...@@ -338,10 +362,13 @@ class TupleTypeNode : public TypeNode {
TVM_DLL static TupleType make(tvm::Array<Type> fields); TVM_DLL static TupleType make(tvm::Array<Type> fields);
static constexpr const char* _type_key = "relay.TupleType"; static constexpr const char* _type_key = "relay.TupleType";
TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
}; };
RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type); class TupleType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode);
};
/*! /*!
* \brief The type of reference values. * \brief The type of reference values.
...@@ -365,10 +392,13 @@ class RefTypeNode : public TypeNode { ...@@ -365,10 +392,13 @@ class RefTypeNode : public TypeNode {
TVM_DLL static RefType make(Type value); TVM_DLL static RefType make(Type value);
static constexpr const char* _type_key = "relay.RefType"; static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(RefTypeNode, TypeNode);
}; };
RELAY_DEFINE_NODE_REF(RefType, RefTypeNode, Type); class RefType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefType, Type, RefTypeNode);
};
class TypeReporter; class TypeReporter;
...@@ -376,7 +406,7 @@ class TypeReporter; ...@@ -376,7 +406,7 @@ class TypeReporter;
* \brief reporter that reports back to the * \brief reporter that reports back to the
* type resolution information. * type resolution information.
*/ */
class TypeReporterNode : public Node { class TypeReporterNode : public Object {
public: public:
/*! /*!
* \brief Create a type equality constraint. * \brief Create a type equality constraint.
...@@ -408,7 +438,7 @@ class TypeReporterNode : public Node { ...@@ -408,7 +438,7 @@ class TypeReporterNode : public Node {
* \brief Set the location at which to report unification errors. * \brief Set the location at which to report unification errors.
* \param ref The program node to report the error. * \param ref The program node to report the error.
*/ */
TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0; TVM_DLL virtual void SetLocation(const ObjectRef& ref) = 0;
/*! /*!
* \brief Retrieve the current global module. * \brief Retrieve the current global module.
...@@ -420,17 +450,17 @@ class TypeReporterNode : public Node { ...@@ -420,17 +450,17 @@ class TypeReporterNode : public Node {
void VisitAttrs(tvm::AttrVisitor* v) {} void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.TypeReporter"; static constexpr const char* _type_key = "relay.TypeReporter";
TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
}; };
/*! /*!
* \brief Container class of TypeReporter. * \brief Container class of TypeReporter.
* \sa TypeReporterNode * \sa TypeReporterNode
*/ */
class TypeReporter : public NodeRef { class TypeReporter : public ObjectRef {
public: public:
TypeReporter() {} TypeReporter() {}
explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : NodeRef(n) { explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {
} }
TypeReporterNode* operator->() const { TypeReporterNode* operator->() const {
return const_cast<TypeReporterNode*>( return const_cast<TypeReporterNode*>(
...@@ -502,10 +532,13 @@ class TypeRelationNode : public TypeConstraintNode { ...@@ -502,10 +532,13 @@ class TypeRelationNode : public TypeConstraintNode {
Attrs attrs); Attrs attrs);
static constexpr const char* _type_key = "relay.TypeRelation"; static constexpr const char* _type_key = "relay.TypeRelation";
TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, TypeConstraintNode); TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
}; };
RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint); class TypeRelation : public TypeConstraint {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode);
};
// The following fields contains advanced typing // The following fields contains advanced typing
// Only keep the class name and reserved for future usage. // Only keep the class name and reserved for future usage.
......
...@@ -700,7 +700,12 @@ struct ObjectEqual { ...@@ -700,7 +700,12 @@ struct ObjectEqual {
TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \ TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \
TypeName::_GetOrAllocRuntimeTypeIndex() TypeName::_GetOrAllocRuntimeTypeIndex()
/*
* \brief Define object reference methods.
* \param TypeName The object type name
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
*/
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ #define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() {} \ TypeName() {} \
explicit TypeName( \ explicit TypeName( \
...@@ -712,17 +717,54 @@ struct ObjectEqual { ...@@ -712,17 +717,54 @@ struct ObjectEqual {
operator bool() const { return data_ != nullptr; } \ operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName; using ContainerType = ObjectName;
#define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \ /*
* \brief Define object reference methods of whose content is mutable.
* \param TypeName The object type name
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
* \note We recommend making objects immutable when possible.
* This macro is only reserved for objects that stores runtime states.
*/
#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() {} \ TypeName() {} \
explicit TypeName( \ explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \ : ParentType(n) {} \
ObjectName* operator->() { \ ObjectName* operator->() const { \
return static_cast<ObjectName*>(data_.get()); \ return static_cast<ObjectName*>(data_.get()); \
} \ } \
operator bool() const { return data_ != nullptr; } \ operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName; using ContainerType = ObjectName;
/*!
* \brief Define CopyOnWrite function in an ObjectRef.
* \param ObjectName The Type of the Node.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWObjectRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \
ObjectName* CopyOnWrite() { \
CHECK(data_ != nullptr); \
if (!data_.unique()) { \
auto n = make_object<ObjectName>(*(operator->())); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<ObjectName*>(data_.get()); \
}
// Implementations details below // Implementations details below
// Object reference counting. // Object reference counting.
#if TVM_OBJECT_ATOMIC_REF_COUNTER #if TVM_OBJECT_ATOMIC_REF_COUNTER
...@@ -832,10 +874,6 @@ inline SubRef Downcast(BaseRef ref) { ...@@ -832,10 +874,6 @@ inline SubRef Downcast(BaseRef ref) {
} }
} // namespace runtime } // namespace runtime
template<typename T>
using NodePtr = runtime::ObjectPtr<T>;
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_OBJECT_H_ #endif // TVM_RUNTIME_OBJECT_H_
...@@ -53,10 +53,10 @@ enum AttachType : int { ...@@ -53,10 +53,10 @@ enum AttachType : int {
}; };
/*! \brief Stage, contains scheduling for a stage of computation. */ /*! \brief Stage, contains scheduling for a stage of computation. */
class Stage : public NodeRef { class Stage : public ObjectRef {
public: public:
Stage() {} Stage() {}
explicit Stage(ObjectPtr<Object> n) : NodeRef(n) {} explicit Stage(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief create a new schedule for op. * \brief create a new schedule for op.
* \param op The operator in the schedule * \param op The operator in the schedule
...@@ -277,10 +277,10 @@ class Stage : public NodeRef { ...@@ -277,10 +277,10 @@ class Stage : public NodeRef {
* For operations and all the operations they depend on. * For operations and all the operations they depend on.
* The schedule per Operation is named as stage. * The schedule per Operation is named as stage.
*/ */
class Schedule : public NodeRef { class Schedule : public ObjectRef {
public: public:
Schedule() {} Schedule() {}
explicit Schedule(ObjectPtr<Object> n) : NodeRef(n) {} explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief Get a copy of current schedule. * \brief Get a copy of current schedule.
* \return The copied schedule. * \return The copied schedule.
...@@ -400,10 +400,10 @@ class Schedule : public NodeRef { ...@@ -400,10 +400,10 @@ class Schedule : public NodeRef {
* \brief The schedule relation between IterVars * \brief The schedule relation between IterVars
* can be Split, Fuse. * can be Split, Fuse.
*/ */
class IterVarRelation : public NodeRef { class IterVarRelation : public ObjectRef {
public: public:
IterVarRelation() {} IterVarRelation() {}
explicit IterVarRelation(ObjectPtr<Object> n) : NodeRef(n) {} explicit IterVarRelation(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -414,10 +414,10 @@ class IterVarRelation : public NodeRef { ...@@ -414,10 +414,10 @@ class IterVarRelation : public NodeRef {
/*! /*!
* \brief Additional scheduable attributes about IterVar. * \brief Additional scheduable attributes about IterVar.
*/ */
class IterVarAttr : public NodeRef { class IterVarAttr : public ObjectRef {
public: public:
IterVarAttr() {} IterVarAttr() {}
explicit IterVarAttr(ObjectPtr<Object> n) : NodeRef(n) {} explicit IterVarAttr(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -440,7 +440,7 @@ class IterVarAttr : public NodeRef { ...@@ -440,7 +440,7 @@ class IterVarAttr : public NodeRef {
* *
* The group stage node can be attached to IterVars as in normal stage. * The group stage node can be attached to IterVars as in normal stage.
*/ */
class StageNode : public Node { class StageNode : public Object {
public: public:
/*! /*!
* \brief The operation of stage, can be different from original op. * \brief The operation of stage, can be different from original op.
...@@ -515,11 +515,11 @@ class StageNode : public Node { ...@@ -515,11 +515,11 @@ class StageNode : public Node {
} }
static constexpr const char* _type_key = "Stage"; static constexpr const char* _type_key = "Stage";
TVM_DECLARE_NODE_TYPE_INFO(StageNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object);
}; };
/*! \brief node container for schedule */ /*! \brief node container for schedule */
class ScheduleNode : public Node { class ScheduleNode : public Object {
public: public:
/*! \brief The output operations in original data flow graph */ /*! \brief The output operations in original data flow graph */
Array<Operation> outputs; Array<Operation> outputs;
...@@ -538,7 +538,7 @@ class ScheduleNode : public Node { ...@@ -538,7 +538,7 @@ class ScheduleNode : public Node {
* \brief Internal stage map to map internal ops to stages. * \brief Internal stage map to map internal ops to stages.
* This is created on demand and can be invalidated. * This is created on demand and can be invalidated.
*/ */
std::unordered_map<const Node*, Stage> op2stage_cache_; std::unordered_map<const Object*, Stage> op2stage_cache_;
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("outputs", &outputs); v->Visit("outputs", &outputs);
...@@ -576,7 +576,7 @@ class ScheduleNode : public Node { ...@@ -576,7 +576,7 @@ class ScheduleNode : public Node {
TVM_DLL static Schedule make(Array<Operation> ops); TVM_DLL static Schedule make(Array<Operation> ops);
static constexpr const char* _type_key = "Schedule"; static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object);
}; };
/*! /*!
...@@ -589,7 +589,7 @@ inline Schedule create_schedule(Array<Operation> ops) { ...@@ -589,7 +589,7 @@ inline Schedule create_schedule(Array<Operation> ops) {
} }
/*! \brief node container for IterVar attr */ /*! \brief node container for IterVar attr */
class IterVarAttrNode : public Node { class IterVarAttrNode : public Object {
public: public:
/*! \brief The iteration type. */ /*! \brief The iteration type. */
IterVarType iter_type{kDataPar}; IterVarType iter_type{kDataPar};
...@@ -630,14 +630,14 @@ class IterVarAttrNode : public Node { ...@@ -630,14 +630,14 @@ class IterVarAttrNode : public Node {
} }
static constexpr const char* _type_key = "IterVarAttr"; static constexpr const char* _type_key = "IterVarAttr";
TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(IterVarAttrNode, Object);
}; };
/*! \brief base node of iteration var */ /*! \brief base node of iteration var */
class IterVarRelationNode : public Node { class IterVarRelationNode : public Object {
public: public:
static constexpr const char* _type_key = "IterVarRelation"; static constexpr const char* _type_key = "IterVarRelation";
TVM_DECLARE_BASE_NODE_INFO(IterVarRelationNode, Node); TVM_DECLARE_BASE_OBJECT_INFO(IterVarRelationNode, Object);
}; };
/*! /*!
...@@ -672,7 +672,7 @@ class SplitNode : public IterVarRelationNode { ...@@ -672,7 +672,7 @@ class SplitNode : public IterVarRelationNode {
Expr nparts); Expr nparts);
static constexpr const char* _type_key = "Split"; static constexpr const char* _type_key = "Split";
TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode); TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
}; };
/*! /*!
...@@ -697,7 +697,7 @@ class FuseNode : public IterVarRelationNode { ...@@ -697,7 +697,7 @@ class FuseNode : public IterVarRelationNode {
IterVar outer, IterVar inner, IterVar fused); IterVar outer, IterVar inner, IterVar fused);
static constexpr const char* _type_key = "Fuse"; static constexpr const char* _type_key = "Fuse";
TVM_DECLARE_NODE_TYPE_INFO(FuseNode, IterVarRelationNode); TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode);
}; };
/*! /*!
...@@ -720,7 +720,7 @@ class RebaseNode : public IterVarRelationNode { ...@@ -720,7 +720,7 @@ class RebaseNode : public IterVarRelationNode {
static IterVarRelation make(IterVar parent, IterVar rebased); static IterVarRelation make(IterVar parent, IterVar rebased);
static constexpr const char* _type_key = "Rebase"; static constexpr const char* _type_key = "Rebase";
TVM_DECLARE_NODE_TYPE_INFO(RebaseNode, IterVarRelationNode); TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode);
}; };
...@@ -739,7 +739,7 @@ class SingletonNode : public IterVarRelationNode { ...@@ -739,7 +739,7 @@ class SingletonNode : public IterVarRelationNode {
static IterVarRelation make(IterVar iter); static IterVarRelation make(IterVar iter);
static constexpr const char* _type_key = "Singleton"; static constexpr const char* _type_key = "Singleton";
TVM_DECLARE_NODE_TYPE_INFO(SingletonNode, IterVarRelationNode); TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
}; };
......
...@@ -34,7 +34,7 @@ namespace tvm { ...@@ -34,7 +34,7 @@ namespace tvm {
* \brief Memory information of special memory region. * \brief Memory information of special memory region.
* Use MemoryInfo as its container type * Use MemoryInfo as its container type
*/ */
struct MemoryInfoNode : public Node { struct MemoryInfoNode : public Object {
/*! \brief The addressable unit */ /*! \brief The addressable unit */
int unit_bits; int unit_bits;
/*! \brief Maximum number of bits supported in the memory */ /*! \brief Maximum number of bits supported in the memory */
...@@ -55,11 +55,14 @@ struct MemoryInfoNode : public Node { ...@@ -55,11 +55,14 @@ struct MemoryInfoNode : public Node {
} }
static constexpr const char* _type_key = "MemoryInfo"; static constexpr const char* _type_key = "MemoryInfo";
TVM_DECLARE_NODE_TYPE_INFO(MemoryInfoNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(MemoryInfoNode, Object);
}; };
/*! \brief Defines memory info */ /*! \brief Defines memory info */
TVM_DEFINE_NODE_REF(MemoryInfo, MemoryInfoNode); class MemoryInfo : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(MemoryInfo, ObjectRef, MemoryInfoNode);
};
/*! /*!
* \brief get memory info given scope * \brief get memory info given scope
......
...@@ -46,11 +46,11 @@ class OperationNode; ...@@ -46,11 +46,11 @@ class OperationNode;
* \brief Tensor structure representing a possible input, * \brief Tensor structure representing a possible input,
* or intermediate computation result. * or intermediate computation result.
*/ */
class Tensor : public NodeRef { class Tensor : public ObjectRef {
public: public:
/*! \brief default constructor, used internally */ /*! \brief default constructor, used internally */
Tensor() {} Tensor() {}
explicit Tensor(ObjectPtr<Object> n) : NodeRef(n) {} explicit Tensor(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -158,7 +158,7 @@ class Operation : public ir::FunctionRef { ...@@ -158,7 +158,7 @@ class Operation : public ir::FunctionRef {
}; };
/*! \brief Node to represent a tensor */ /*! \brief Node to represent a tensor */
class TensorNode : public Node { class TensorNode : public Object {
public: public:
/*! \brief The shape of the tensor */ /*! \brief The shape of the tensor */
Array<Expr> shape; Array<Expr> shape;
...@@ -183,7 +183,7 @@ class TensorNode : public Node { ...@@ -183,7 +183,7 @@ class TensorNode : public Node {
int value_index); int value_index);
static constexpr const char* _type_key = "Tensor"; static constexpr const char* _type_key = "Tensor";
TVM_DECLARE_NODE_TYPE_INFO(TensorNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object);
}; };
...@@ -250,13 +250,13 @@ DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*) ...@@ -250,13 +250,13 @@ DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*)
namespace std { namespace std {
template <> template <>
struct hash<::tvm::Operation> : public ::tvm::NodeHash { struct hash<::tvm::Operation> : public ::tvm::ObjectHash {
}; };
template <> template <>
struct hash<::tvm::Tensor> { struct hash<::tvm::Tensor> {
std::size_t operator()(const ::tvm::Tensor& k) const { std::size_t operator()(const ::tvm::Tensor& k) const {
::tvm::NodeHash hasher; ::tvm::ObjectHash hasher;
if (k.defined() && k->op.defined()) { if (k.defined() && k->op.defined()) {
return hasher(k->op); return hasher(k->op);
} else{ } else{
......
...@@ -34,10 +34,10 @@ namespace tvm { ...@@ -34,10 +34,10 @@ namespace tvm {
class TensorIntrinNode; class TensorIntrinNode;
/*! \brief Tensor intrinsic node. */ /*! \brief Tensor intrinsic node. */
class TensorIntrin : public NodeRef { class TensorIntrin : public ObjectRef {
public: public:
TensorIntrin() {} TensorIntrin() {}
explicit TensorIntrin(NodePtr<Node> n) : NodeRef(n) {} explicit TensorIntrin(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -49,7 +49,7 @@ class TensorIntrin : public NodeRef { ...@@ -49,7 +49,7 @@ class TensorIntrin : public NodeRef {
}; };
/*! \brief Node to represent a Tensor intrinsic operator */ /*! \brief Node to represent a Tensor intrinsic operator */
class TensorIntrinNode : public Node { class TensorIntrinNode : public Object {
public: public:
/*! \brief The name of the intrinsic */ /*! \brief The name of the intrinsic */
std::string name; std::string name;
...@@ -108,7 +108,7 @@ class TensorIntrinNode : public Node { ...@@ -108,7 +108,7 @@ class TensorIntrinNode : public Node {
Stmt reduce_update); Stmt reduce_update);
static constexpr const char* _type_key = "TensorIntrin"; static constexpr const char* _type_key = "TensorIntrin";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object);
}; };
inline const TensorIntrinNode* TensorIntrin::operator->() const { inline const TensorIntrinNode* TensorIntrin::operator->() const {
...@@ -119,10 +119,10 @@ inline const TensorIntrinNode* TensorIntrin::operator->() const { ...@@ -119,10 +119,10 @@ inline const TensorIntrinNode* TensorIntrin::operator->() const {
class TensorIntrinCallNode; class TensorIntrinCallNode;
/*! \brief Tensor intrinsic calling node. */ /*! \brief Tensor intrinsic calling node. */
class TensorIntrinCall : public NodeRef { class TensorIntrinCall : public ObjectRef {
public: public:
TensorIntrinCall() {} TensorIntrinCall() {}
explicit TensorIntrinCall(NodePtr<Node> n) : NodeRef(n) {} explicit TensorIntrinCall(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -133,7 +133,7 @@ class TensorIntrinCall : public NodeRef { ...@@ -133,7 +133,7 @@ class TensorIntrinCall : public NodeRef {
using ContainerType = TensorIntrinCallNode; using ContainerType = TensorIntrinCallNode;
}; };
class TensorIntrinCallNode : public Node { class TensorIntrinCallNode : public Object {
public: public:
/*! \brief the tensor intrinsic */ /*! \brief the tensor intrinsic */
TensorIntrin intrin; TensorIntrin intrin;
...@@ -166,7 +166,7 @@ class TensorIntrinCallNode : public Node { ...@@ -166,7 +166,7 @@ class TensorIntrinCallNode : public Node {
Array<Expr> scalar_inputs); Array<Expr> scalar_inputs);
static constexpr const char* _type_key = "TensorIntrinCall"; static constexpr const char* _type_key = "TensorIntrinCall";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node); TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object);
}; };
inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const { inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -306,7 +306,7 @@ void PostOrderDFSVisit(const std::vector<GNode>& heads, ...@@ -306,7 +306,7 @@ void PostOrderDFSVisit(const std::vector<GNode>& heads,
template<typename FVisit> template<typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads, inline void DFSVisit(const std::vector<NodeEntry>& heads,
FVisit fvisit) { FVisit fvisit) {
typedef const NodePtr* GNode; typedef const ObjectPtr* GNode;
std::vector<GNode> head_nodes(heads.size()); std::vector<GNode> head_nodes(heads.size());
std::transform(heads.begin(), heads.end(), head_nodes.begin(), std::transform(heads.begin(), heads.end(), head_nodes.begin(),
[](const NodeEntry& e)->GNode { [](const NodeEntry& e)->GNode {
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -40,22 +40,22 @@ class Node; ...@@ -40,22 +40,22 @@ class Node;
class Symbol; class Symbol;
/*! /*!
* \brief we always used NodePtr for a reference pointer * \brief we always used ObjectPtr for a reference pointer
* to the node, so this alias can be changed in case. * to the node, so this alias can be changed in case.
* *
* By default, NodePtr is a std::shared_ptr of node * By default, ObjectPtr is a std::shared_ptr of node
*/ */
using NodePtr = std::shared_ptr<Node>; using ObjectPtr = std::shared_ptr<Node>;
/*! \brief an entry that represents output data from a node */ /*! \brief an entry that represents output data from a node */
struct NodeEntry { struct NodeEntry {
NodeEntry(NodePtr node, uint32_t index, uint32_t version): NodeEntry(ObjectPtr node, uint32_t index, uint32_t version):
node(std::move(node)), node(std::move(node)),
index(index), index(index),
version(version) version(version)
{} {}
explicit NodeEntry(NodePtr node): explicit NodeEntry(ObjectPtr node):
node(std::move(node)), node(std::move(node)),
index(), index(),
version() version()
...@@ -72,7 +72,7 @@ struct NodeEntry { ...@@ -72,7 +72,7 @@ struct NodeEntry {
{} {}
/*! \brief the source node of this data */ /*! \brief the source node of this data */
NodePtr node; ObjectPtr node;
/*! \brief index of output from the source. */ /*! \brief index of output from the source. */
uint32_t index; uint32_t index;
/*! /*!
...@@ -167,7 +167,7 @@ class NNVM_DLL Node { ...@@ -167,7 +167,7 @@ class NNVM_DLL Node {
* \brief Optional control flow dependencies * \brief Optional control flow dependencies
* Gives operation must be performed before this operation. * Gives operation must be performed before this operation.
*/ */
std::vector<NodePtr> control_deps; std::vector<ObjectPtr> control_deps;
/*! \brief additional fields for this node */ /*! \brief additional fields for this node */
any info; any info;
/*! \brief destructor of node */ /*! \brief destructor of node */
...@@ -189,7 +189,7 @@ class NNVM_DLL Node { ...@@ -189,7 +189,7 @@ class NNVM_DLL Node {
* \return a created empty node. * \return a created empty node.
*/ */
template<class ...Args> template<class ...Args>
static NodePtr Create(Args&&... args) { static ObjectPtr Create(Args&&... args) {
return std::make_shared<Node>(std::forward<Args>(args)...); return std::make_shared<Node>(std::forward<Args>(args)...);
} }
}; };
...@@ -208,7 +208,7 @@ inline NodeEntry MakeNode( ...@@ -208,7 +208,7 @@ inline NodeEntry MakeNode(
std::vector<NodeEntry> inputs, std::vector<NodeEntry> inputs,
std::unordered_map<std::string, std::string> attrs = std::unordered_map<std::string, std::string> attrs =
std::unordered_map<std::string, std::string>()) { std::unordered_map<std::string, std::string>()) {
NodePtr p = Node::Create(); ObjectPtr p = Node::Create();
p->attrs.op = nnvm::Op::Get(op_name); p->attrs.op = nnvm::Op::Get(op_name);
p->attrs.name = std::move(node_name); p->attrs.name = std::move(node_name);
p->attrs.dict = attrs; p->attrs.dict = attrs;
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -192,7 +192,7 @@ using FIgnoreInputs = std::function< ...@@ -192,7 +192,7 @@ using FIgnoreInputs = std::function<
* \note Register under "FGradient" * \note Register under "FGradient"
*/ */
using FGradient = std::function<std::vector<NodeEntry>( using FGradient = std::function<std::vector<NodeEntry>(
const NodePtr& nodeptr, const ObjectPtr& nodeptr,
const std::vector<NodeEntry>& out_grads)>; const std::vector<NodeEntry>& out_grads)>;
/*! /*!
...@@ -204,7 +204,7 @@ using FGradient = std::function<std::vector<NodeEntry>( ...@@ -204,7 +204,7 @@ using FGradient = std::function<std::vector<NodeEntry>(
*/ */
using FSetInputVarAttrOnCompose = std::function<void( using FSetInputVarAttrOnCompose = std::function<void(
const NodeAttrs& attrs, const NodeAttrs& attrs,
NodePtr var, ObjectPtr var,
const int index)>; const int index)>;
/*! /*!
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -97,7 +97,7 @@ class NNVM_DLL Symbol { ...@@ -97,7 +97,7 @@ class NNVM_DLL Symbol {
* \return The arguments list of this symbol, they can be either named or unnamed (empty string). * \return The arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption * \sa ListInputOption
*/ */
std::vector<NodePtr> ListInputs(ListInputOption option) const; std::vector<ObjectPtr> ListInputs(ListInputOption option) const;
/*! /*!
* \brief List the input names. * \brief List the input names.
* *
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -259,7 +259,7 @@ int NNSymbolListInputVariables(SymbolHandle symbol, ...@@ -259,7 +259,7 @@ int NNSymbolListInputVariables(SymbolHandle symbol,
Symbol *s = static_cast<Symbol*>(symbol); Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
std::vector<NodePtr> vs = s->ListInputs(Symbol::ListInputOption(option)); std::vector<ObjectPtr> vs = s->ListInputs(Symbol::ListInputOption(option));
ret->ret_handles.resize(0); ret->ret_handles.resize(0);
ret->ret_handles.reserve(vs.size()); ret->ret_handles.reserve(vs.size());
for (size_t i = 0; i < vs.size(); ++i) { for (size_t i = 0; i < vs.size(); ++i) {
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -50,7 +50,7 @@ static void SubgraphSanityCheck(const std::vector<std::shared_ptr<Symbol>> &subg ...@@ -50,7 +50,7 @@ static void SubgraphSanityCheck(const std::vector<std::shared_ptr<Symbol>> &subg
next_level.clear(); next_level.clear();
for (const std::vector<NodeEntry> *graph_ptr : curr_level) { for (const std::vector<NodeEntry> *graph_ptr : curr_level) {
const std::vector<NodeEntry> &graph = *graph_ptr; const std::vector<NodeEntry> &graph = *graph_ptr;
DFSVisit(graph, [&next_level, &node2level, level](const NodePtr& n) { DFSVisit(graph, [&next_level, &node2level, level](const ObjectPtr& n) {
nnvm::Node *node = n.get(); nnvm::Node *node = n.get();
// if the node is visited, but on a different level, then check failed // if the node is visited, but on a different level, then check failed
// if check failed here or before, we stop doing anything, but raise an error // if check failed here or before, we stop doing anything, but raise an error
...@@ -74,7 +74,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { ...@@ -74,7 +74,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
std::vector<std::shared_ptr<Symbol>> subgraphs; std::vector<std::shared_ptr<Symbol>> subgraphs;
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs] DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs]
(const NodePtr& n) { (const ObjectPtr& n) {
const auto& is_ghost = Op::GetAttr<TIsGhost>("TIsGhost"); const auto& is_ghost = Op::GetAttr<TIsGhost>("TIsGhost");
if (!n->is_variable() && is_ghost.get(n->op(), false)) return; if (!n->is_variable() && is_ghost.get(n->op(), false)) return;
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max()); CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -30,7 +30,7 @@ Node::~Node() { ...@@ -30,7 +30,7 @@ Node::~Node() {
// explicit deletion via DFS // explicit deletion via DFS
// this is used to avoid stackoverflow caused by chain of deletions // this is used to avoid stackoverflow caused by chain of deletions
std::vector<Node*> stack{this}; std::vector<Node*> stack{this};
std::vector<NodePtr> to_delete; std::vector<ObjectPtr> to_delete;
while (!stack.empty()) { while (!stack.empty()) {
Node* n = stack.back(); Node* n = stack.back();
stack.pop_back(); stack.pop_back();
...@@ -42,7 +42,7 @@ Node::~Node() { ...@@ -42,7 +42,7 @@ Node::~Node() {
e.node.reset(); e.node.reset();
} }
} }
for (NodePtr& sp : n->control_deps) { for (ObjectPtr& sp : n->control_deps) {
if (sp.unique()) { if (sp.unique()) {
stack.push_back(sp.get()); stack.push_back(sp.get());
to_delete.emplace_back(std::move(sp)); to_delete.emplace_back(std::move(sp));
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -36,8 +36,8 @@ struct VariableParam { ...@@ -36,8 +36,8 @@ struct VariableParam {
uint32_t version{0}; uint32_t version{0};
}; };
NodePtr CreateVariableNode(const std::string& name) { ObjectPtr CreateVariableNode(const std::string& name) {
NodePtr n = Node::Create(); ObjectPtr n = Node::Create();
n->attrs.op = nullptr; n->attrs.op = nullptr;
n->attrs.name = name; n->attrs.name = name;
n->attrs.parsed = VariableParam(); n->attrs.parsed = VariableParam();
...@@ -114,10 +114,10 @@ inline bool IsAtomic(const std::vector<NodeEntry>& outputs) { ...@@ -114,10 +114,10 @@ inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
// public functions // public functions
Symbol Symbol::Copy() const { Symbol Symbol::Copy() const {
std::unordered_map<Node*, NodePtr> old_new; std::unordered_map<Node*, ObjectPtr> old_new;
// use DFSVisit to copy all the nodes // use DFSVisit to copy all the nodes
DFSVisit(this->outputs, [&old_new](const NodePtr& node) { DFSVisit(this->outputs, [&old_new](const ObjectPtr& node) {
NodePtr np = Node::Create(); ObjectPtr np = Node::Create();
np->attrs = node->attrs; np->attrs = node->attrs;
old_new[node.get()] = std::move(np); old_new[node.get()] = std::move(np);
}); });
...@@ -127,7 +127,7 @@ Symbol Symbol::Copy() const { ...@@ -127,7 +127,7 @@ Symbol Symbol::Copy() const {
Node *ptr = e.node.get(); Node *ptr = e.node.get();
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version}); kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version});
} }
for (const NodePtr& p : kv.first->control_deps) { for (const ObjectPtr& p : kv.first->control_deps) {
kv.second->control_deps.emplace_back(old_new[p.get()]); kv.second->control_deps.emplace_back(old_new[p.get()]);
} }
} }
...@@ -155,7 +155,7 @@ void Symbol::Print(std::ostream &os) const { ...@@ -155,7 +155,7 @@ void Symbol::Print(std::ostream &os) const {
os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name
<< '(' << outputs[i].index << ")\n"; << '(' << outputs[i].index << ")\n";
} }
DFSVisit(this->outputs, [&os](const NodePtr& node) { DFSVisit(this->outputs, [&os](const ObjectPtr& node) {
if (node->is_variable()) { if (node->is_variable()) {
os << "Variable:" << node->attrs.name << '\n'; os << "Variable:" << node->attrs.name << '\n';
} else { } else {
...@@ -204,21 +204,21 @@ Symbol Symbol::operator[] (size_t index) const { ...@@ -204,21 +204,21 @@ Symbol Symbol::operator[] (size_t index) const {
} }
} }
std::vector<NodePtr> Symbol::ListInputs(ListInputOption option) const { std::vector<ObjectPtr> Symbol::ListInputs(ListInputOption option) const {
std::vector<NodePtr> ret; std::vector<ObjectPtr> ret;
if (option == kAll) { if (option == kAll) {
ret.reserve(this->outputs.size()); ret.reserve(this->outputs.size());
DFSVisit(this->outputs, [&ret](const NodePtr &node) { DFSVisit(this->outputs, [&ret](const ObjectPtr &node) {
if (node->is_variable()) { if (node->is_variable()) {
ret.push_back(node); ret.push_back(node);
} }
}); });
} else { } else {
std::unordered_set<Node*> mutable_set; std::unordered_set<Node*> mutable_set;
std::vector<NodePtr> vlist; std::vector<ObjectPtr> vlist;
vlist.reserve(this->outputs.size()); vlist.reserve(this->outputs.size());
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
DFSVisit(this->outputs, [&mutable_set, &vlist](const NodePtr &node) { DFSVisit(this->outputs, [&mutable_set, &vlist](const ObjectPtr &node) {
if (node->is_variable()) { if (node->is_variable()) {
vlist.push_back(node); vlist.push_back(node);
} else if (fmutate_inputs.count(node->op())) { } else if (fmutate_inputs.count(node->op())) {
...@@ -228,7 +228,7 @@ std::vector<NodePtr> Symbol::ListInputs(ListInputOption option) const { ...@@ -228,7 +228,7 @@ std::vector<NodePtr> Symbol::ListInputs(ListInputOption option) const {
} }
}); });
ret.reserve(vlist.size()); ret.reserve(vlist.size());
for (const NodePtr& node : vlist) { for (const ObjectPtr& node : vlist) {
if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) || if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) ||
(option == kAuxiliaryStates && mutable_set.count(node.get()) != 0)) { (option == kAuxiliaryStates && mutable_set.count(node.get()) != 0)) {
ret.emplace_back(node); ret.emplace_back(node);
...@@ -239,7 +239,7 @@ std::vector<NodePtr> Symbol::ListInputs(ListInputOption option) const { ...@@ -239,7 +239,7 @@ std::vector<NodePtr> Symbol::ListInputs(ListInputOption option) const {
} }
std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const { std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
std::vector<NodePtr> inputs = ListInputs(option); std::vector<ObjectPtr> inputs = ListInputs(option);
std::vector<std::string> ret(inputs.size()); std::vector<std::string> ret(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
ret[i] = inputs[i]->attrs.name; ret[i] = inputs[i]->attrs.name;
...@@ -416,7 +416,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args, ...@@ -416,7 +416,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
std::unordered_map<Node *, const NodeEntry*> replace_map; std::unordered_map<Node *, const NodeEntry*> replace_map;
// replace map stores the existing replacement plan for arguments node // replace map stores the existing replacement plan for arguments node
auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map] auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map]
(const NodePtr &node) { (const ObjectPtr &node) {
if (node->is_variable()) { if (node->is_variable()) {
if (arg_counter < args.size()) { if (arg_counter < args.size()) {
replace_map[node.get()] = &(args[arg_counter]->outputs[0]); replace_map[node.get()] = &(args[arg_counter]->outputs[0]);
...@@ -437,7 +437,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args, ...@@ -437,7 +437,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
std::vector<Node*> update_nodes; std::vector<Node*> update_nodes;
std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan; std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes] auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes]
(const NodePtr &node) { (const ObjectPtr &node) {
// visit all the childs, find possible replacement // visit all the childs, find possible replacement
bool repl = false; bool repl = false;
for (size_t i = 0; i < node->inputs.size(); ++i) { for (size_t i = 0; i < node->inputs.size(); ++i) {
...@@ -499,7 +499,7 @@ void Symbol::AddControlDeps(const Symbol& src) { ...@@ -499,7 +499,7 @@ void Symbol::AddControlDeps(const Symbol& src) {
Symbol Symbol::GetInternals() const { Symbol Symbol::GetInternals() const {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs"); static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol ret; Symbol ret;
DFSVisit(this->outputs, [&ret](const NodePtr& node) { DFSVisit(this->outputs, [&ret](const ObjectPtr& node) {
Node* n = node.get(); Node* n = node.get();
if (n->is_variable()) { if (n->is_variable()) {
// grab version from variable. // grab version from variable.
...@@ -582,7 +582,7 @@ bool Symbol::GetAttr(const std::string& key, std::string* out) const { ...@@ -582,7 +582,7 @@ bool Symbol::GetAttr(const std::string& key, std::string* out) const {
std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption option) const { std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption option) const {
if (option == kRecursive) { if (option == kRecursive) {
std::unordered_map<std::string, std::string> ret; std::unordered_map<std::string, std::string> ret;
DFSVisit(this->outputs, [&ret](const NodePtr& n) { DFSVisit(this->outputs, [&ret](const ObjectPtr& n) {
for (const auto& it : n->attrs.dict) { for (const auto& it : n->attrs.dict) {
ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second; ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second;
} }
...@@ -596,7 +596,7 @@ std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption op ...@@ -596,7 +596,7 @@ std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption op
std::vector<std::tuple<std::string, std::string, std::string> > std::vector<std::tuple<std::string, std::string, std::string> >
Symbol::ListAttrsRecursive() const { Symbol::ListAttrsRecursive() const {
std::vector<std::tuple<std::string, std::string, std::string> > ret; std::vector<std::tuple<std::string, std::string, std::string> > ret;
DFSVisit(this->outputs, [&ret](const NodePtr& n) { DFSVisit(this->outputs, [&ret](const ObjectPtr& n) {
for (const auto& it : n->attrs.dict) { for (const auto& it : n->attrs.dict) {
ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second)); ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second));
} }
...@@ -608,7 +608,7 @@ Symbol Symbol::CreateFunctor(const Op* op, ...@@ -608,7 +608,7 @@ Symbol Symbol::CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string> attrs) { std::unordered_map<std::string, std::string> attrs) {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs"); static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol s; Symbol s;
NodePtr n = Node::Create(); ObjectPtr n = Node::Create();
n->attrs.op = op; n->attrs.op = op;
n->attrs.dict = std::move(attrs); n->attrs.dict = std::move(attrs);
if (n->op()->attr_parser != nullptr) { if (n->op()->attr_parser != nullptr) {
...@@ -628,7 +628,7 @@ Symbol Symbol::CreateFunctor(const Op* op, ...@@ -628,7 +628,7 @@ Symbol Symbol::CreateFunctor(const Op* op,
Symbol Symbol::CreateFunctor(const NodeAttrs& attrs) { Symbol Symbol::CreateFunctor(const NodeAttrs& attrs) {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs"); static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol s; Symbol s;
NodePtr n = Node::Create(); ObjectPtr n = Node::Create();
n->attrs = attrs; n->attrs = attrs;
uint32_t nout = n->num_outputs(); uint32_t nout = n->num_outputs();
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -30,11 +30,11 @@ ...@@ -30,11 +30,11 @@
namespace nnvm { namespace nnvm {
namespace pass { namespace pass {
nnvm::NodePtr CreateLayoutTransformNode(const Layout& src, nnvm::ObjectPtr CreateLayoutTransformNode(const Layout& src,
const Layout& dst) { const Layout& dst) {
static const nnvm::Op* trans_op = nnvm::Op::Get("__layout_transform__"); static const nnvm::Op* trans_op = nnvm::Op::Get("__layout_transform__");
static int count = 0; static int count = 0;
nnvm::NodePtr n = nnvm::Node::Create(); nnvm::ObjectPtr n = nnvm::Node::Create();
n->attrs.op = trans_op; n->attrs.op = trans_op;
n->attrs.name = src.name() + "_to_" + dst.name() + std::to_string(count++); n->attrs.name = src.name() + "_to_" + dst.name() + std::to_string(count++);
n->attrs.dict["src_layout"] = src.name(); n->attrs.dict["src_layout"] = src.name();
...@@ -54,14 +54,14 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { ...@@ -54,14 +54,14 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) {
nnvm::Op::GetAttr<FCorrectLayout>("FCorrectLayout"); nnvm::Op::GetAttr<FCorrectLayout>("FCorrectLayout");
const IndexedGraph& idx = src.indexed_graph(); const IndexedGraph& idx = src.indexed_graph();
std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr); std::vector<nnvm::ObjectPtr> mirror_vec(idx.num_nodes(), nullptr);
// (new) NodePtr -> output_layouts // (new) ObjectPtr -> output_layouts
LayoutAttrDict new_layouts; LayoutAttrDict new_layouts;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
nnvm::NodePtr new_node = nnvm::Node::Create(); nnvm::ObjectPtr new_node = nnvm::Node::Create();
*new_node = *(inode.source); *new_node = *(inode.source);
if (new_node->is_variable()) { if (new_node->is_variable()) {
// Variable node. No operator. Only one output entry. // Variable node. No operator. Only one output entry.
...@@ -85,7 +85,7 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { ...@@ -85,7 +85,7 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) {
std::vector<Layout> request_ilayouts(num_inputs, Layout::Undef()); std::vector<Layout> request_ilayouts(num_inputs, Layout::Undef());
for (size_t i = 0; i < num_inputs; ++i) { for (size_t i = 0; i < num_inputs; ++i) {
const IndexedGraph::NodeEntry& input_entry = inode.inputs[i]; const IndexedGraph::NodeEntry& input_entry = inode.inputs[i];
const NodePtr& new_input_node = mirror_vec[input_entry.node_id]; const ObjectPtr& new_input_node = mirror_vec[input_entry.node_id];
CHECK(new_input_node != nullptr); CHECK(new_input_node != nullptr);
// fill inputs by previous node (DFS order) inferred layouts. // fill inputs by previous node (DFS order) inferred layouts.
...@@ -122,14 +122,14 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { ...@@ -122,14 +122,14 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) {
for (uint32_t i = 0; i < inode.inputs.size(); ++i) { for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i]; const auto& e = inode.inputs[i];
const nnvm::NodePtr& in = mirror_vec[e.node_id]; const nnvm::ObjectPtr& in = mirror_vec[e.node_id];
new_node->inputs[i] = nnvm::NodeEntry{in, e.index, e.version}; new_node->inputs[i] = nnvm::NodeEntry{in, e.index, e.version};
// insert layout_transform if necessary // insert layout_transform if necessary
const Layout& produce = produce_ilayouts[i]; const Layout& produce = produce_ilayouts[i];
const Layout& request = request_ilayouts[i]; const Layout& request = request_ilayouts[i];
if (produce != request && produce.defined()) { if (produce != request && produce.defined()) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request); nnvm::ObjectPtr tnode = CreateLayoutTransformNode(produce, request);
tnode->attrs.name = idx[e.node_id].source->attrs.name + "_" + request.name(); tnode->attrs.name = idx[e.node_id].source->attrs.name + "_" + request.name();
tnode->inputs.emplace_back(new_node->inputs[i]); tnode->inputs.emplace_back(new_node->inputs[i]);
nnvm::NodeEntry tnode_output(std::move(tnode), 0, 0); nnvm::NodeEntry tnode_output(std::move(tnode), 0, 0);
......
...@@ -37,13 +37,13 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) { ...@@ -37,13 +37,13 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
if (v.size() == 1) { if (v.size() == 1) {
return std::move(v[0]); return std::move(v[0]);
} else if (v.size() == 0) { } else if (v.size() == 0) {
NodePtr zero_node = Node::Create(); ObjectPtr zero_node = Node::Create();
zero_node->attrs.op = Op::Get("zeros"); zero_node->attrs.op = Op::Get("zeros");
zero_node->attrs.name = "zero_grad"; zero_node->attrs.name = "zero_grad";
zero_node->attrs.op->attr_parser(&(zero_node->attrs)); zero_node->attrs.op->attr_parser(&(zero_node->attrs));
return NodeEntry{zero_node, 0, 0}; return NodeEntry{zero_node, 0, 0};
} else { } else {
NodePtr sum_node = Node::Create(); ObjectPtr sum_node = Node::Create();
sum_node->attrs.op = Op::Get("elemwise_sum"); sum_node->attrs.op = Op::Get("elemwise_sum");
sum_node->inputs = std::move(v); sum_node->inputs = std::move(v);
sum_node->attrs.name = "grad_sum"; sum_node->attrs.name = "grad_sum";
...@@ -119,10 +119,10 @@ Graph Gradient(Graph src) { ...@@ -119,10 +119,10 @@ Graph Gradient(Graph src) {
nullptr; nullptr;
// topo sort // topo sort
std::vector<NodePtr> topo_order; std::vector<ObjectPtr> topo_order;
std::unordered_map<Node*, std::vector<GradEntry> > output_grads; std::unordered_map<Node*, std::vector<GradEntry> > output_grads;
DFSVisit(ys, [&](const NodePtr& node) { DFSVisit(ys, [&](const ObjectPtr& node) {
if (output_grads.count(node.get()) == 0) { if (output_grads.count(node.get()) == 0) {
output_grads[node.get()].resize(node->num_outputs()); output_grads[node.get()].resize(node->num_outputs());
} }
...@@ -143,11 +143,11 @@ Graph Gradient(Graph src) { ...@@ -143,11 +143,11 @@ Graph Gradient(Graph src) {
} }
// construct mirror as memory reduction strategy if needed // construct mirror as memory reduction strategy if needed
std::unordered_map<Node*, NodePtr> mirror_map; std::unordered_map<Node*, ObjectPtr> mirror_map;
if (mirror_fun != nullptr) { if (mirror_fun != nullptr) {
for (const NodePtr& node_ptr : topo_order) { for (const ObjectPtr& node_ptr : topo_order) {
if (mirror_fun(*node_ptr)) { if (mirror_fun(*node_ptr)) {
NodePtr new_node = Node::Create(); ObjectPtr new_node = Node::Create();
*new_node = *node_ptr; *new_node = *node_ptr;
new_node->attrs.name += "_mirror"; new_node->attrs.name += "_mirror";
for (auto& e : new_node->inputs) { for (auto& e : new_node->inputs) {
...@@ -169,7 +169,7 @@ Graph Gradient(Graph src) { ...@@ -169,7 +169,7 @@ Graph Gradient(Graph src) {
std::vector<NodeEntry> out_agg_grads; std::vector<NodeEntry> out_agg_grads;
for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) {
const NodePtr& ptr = *rit; const ObjectPtr& ptr = *rit;
if (ptr->is_variable()) continue; if (ptr->is_variable()) continue;
out_agg_grads.clear(); out_agg_grads.clear();
auto& out_grad_vec = output_grads.at(ptr.get()); auto& out_grad_vec = output_grads.at(ptr.get());
...@@ -182,7 +182,7 @@ Graph Gradient(Graph src) { ...@@ -182,7 +182,7 @@ Graph Gradient(Graph src) {
out_agg_grads.push_back(e.sum); out_agg_grads.push_back(e.sum);
} }
if ((*rit)->inputs.size() != 0) { if ((*rit)->inputs.size() != 0) {
NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get())); ObjectPtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
std::vector<NodeEntry> input_grads; std::vector<NodeEntry> input_grads;
// Check for FGradient // Check for FGradient
if (grad_fun_map.contains(ptr->op())) { if (grad_fun_map.contains(ptr->op())) {
...@@ -244,7 +244,7 @@ Graph Gradient(Graph src) { ...@@ -244,7 +244,7 @@ Graph Gradient(Graph src) {
if (kv == unique_grads.end()) { if (kv == unique_grads.end()) {
unique_grads.emplace(std::move(entry.sum), std::make_pair(1, counter)); unique_grads.emplace(std::move(entry.sum), std::make_pair(1, counter));
} else { } else {
NodePtr copy_node = Node::Create(); ObjectPtr copy_node = Node::Create();
std::ostringstream os; std::ostringstream os;
os << entry.sum.node->attrs.name << "_" << kv->second.first << "_copy"; os << entry.sum.node->attrs.name << "_" << kv->second.first << "_copy";
kv->second.first++; kv->second.first++;
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -112,7 +112,7 @@ Graph InferAttr(Graph &&ret, ...@@ -112,7 +112,7 @@ Graph InferAttr(Graph &&ret,
CHECK_GE(inode.control_deps.size(), 1U) CHECK_GE(inode.control_deps.size(), 1U)
<< "BackwardOp need to have control_deps to its forward op"; << "BackwardOp need to have control_deps to its forward op";
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
NodePtr fwd_ptr = inode.source->control_deps[0]; ObjectPtr fwd_ptr = inode.source->control_deps[0];
CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable";
// use gradient function to find out the correspondence. // use gradient function to find out the correspondence.
std::vector<NodeEntry> ograd(fwd_ptr->num_outputs()); std::vector<NodeEntry> ograd(fwd_ptr->num_outputs());
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -45,7 +45,7 @@ inline bool IsMutate(const std::vector<uint32_t>& mutate_inputs, uint32_t i) { ...@@ -45,7 +45,7 @@ inline bool IsMutate(const std::vector<uint32_t>& mutate_inputs, uint32_t i) {
Graph OrderMutation(const Graph& src) { Graph OrderMutation(const Graph& src) {
std::unordered_map<Node*, std::vector<NodeEntry> > version_hist; std::unordered_map<Node*, std::vector<NodeEntry> > version_hist;
DFSVisit(src.outputs, [&version_hist](const NodePtr& n) { DFSVisit(src.outputs, [&version_hist](const ObjectPtr& n) {
for (const NodeEntry& e : n->inputs) { for (const NodeEntry& e : n->inputs) {
if (e.node->is_variable()) { if (e.node->is_variable()) {
if (e.version != 0 && version_hist.count(e.node.get()) == 0) { if (e.version != 0 && version_hist.count(e.node.get()) == 0) {
...@@ -57,8 +57,8 @@ Graph OrderMutation(const Graph& src) { ...@@ -57,8 +57,8 @@ Graph OrderMutation(const Graph& src) {
// no mutation happens, everything if fine. // no mutation happens, everything if fine.
if (version_hist.size() == 0) return src; if (version_hist.size() == 0) return src;
// start preparing for remapping the nodes. // start preparing for remapping the nodes.
std::unordered_map<Node*, NodePtr> old_new; std::unordered_map<Node*, ObjectPtr> old_new;
auto prepare = [&version_hist, &old_new] (const NodePtr& n) { auto prepare = [&version_hist, &old_new] (const ObjectPtr& n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs; std::vector<uint32_t> mutate_inputs;
if (!n->is_variable() && fmutate_inputs.count(n->op())) { if (!n->is_variable() && fmutate_inputs.count(n->op())) {
...@@ -80,11 +80,11 @@ Graph OrderMutation(const Graph& src) { ...@@ -80,11 +80,11 @@ Graph OrderMutation(const Graph& src) {
if (old_new.count(e.node.get()) != 0) need_repl = true; if (old_new.count(e.node.get()) != 0) need_repl = true;
} }
} }
for (const NodePtr& p : n->control_deps) { for (const ObjectPtr& p : n->control_deps) {
if (old_new.count(p.get()) != 0) need_repl = true; if (old_new.count(p.get()) != 0) need_repl = true;
} }
if (need_repl) { if (need_repl) {
NodePtr np = Node::Create(); ObjectPtr np = Node::Create();
np->attrs = n->attrs; np->attrs = n->attrs;
old_new[n.get()] = std::move(np); old_new[n.get()] = std::move(np);
} }
...@@ -111,7 +111,7 @@ Graph OrderMutation(const Graph& src) { ...@@ -111,7 +111,7 @@ Graph OrderMutation(const Graph& src) {
kv.second->inputs.push_back(e); kv.second->inputs.push_back(e);
} }
} }
for (const NodePtr& p : kv.first->control_deps) { for (const ObjectPtr& p : kv.first->control_deps) {
kv.second->control_deps.emplace_back( kv.second->control_deps.emplace_back(
get_with_default(old_new, p.get(), p)); get_with_default(old_new, p.get(), p));
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -105,8 +105,8 @@ Graph PlaceDevice(Graph src) { ...@@ -105,8 +105,8 @@ Graph PlaceDevice(Graph src) {
src.attrs["device"] = std::make_shared<any>(std::move(device)); src.attrs["device"] = std::make_shared<any>(std::move(device));
return src; return src;
} }
std::map<std::tuple<uint32_t, uint32_t, int>, NodePtr> copy_map; std::map<std::tuple<uint32_t, uint32_t, int>, ObjectPtr> copy_map;
std::vector<NodePtr> new_node_map(idx.num_nodes(), nullptr); std::vector<ObjectPtr> new_node_map(idx.num_nodes(), nullptr);
std::unordered_map<const Node*, int> new_device_map; std::unordered_map<const Node*, int> new_device_map;
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
...@@ -142,7 +142,7 @@ Graph PlaceDevice(Graph src) { ...@@ -142,7 +142,7 @@ Graph PlaceDevice(Graph src) {
CHECK(!need_mutate) << "consistency check"; CHECK(!need_mutate) << "consistency check";
} }
if (need_mutate) { if (need_mutate) {
NodePtr new_node = Node::Create(); ObjectPtr new_node = Node::Create();
new_node->attrs = inode.source->attrs; new_node->attrs = inode.source->attrs;
new_node->inputs.reserve(inode.inputs.size()); new_node->inputs.reserve(inode.inputs.size());
for (size_t i = 0; i < inode.inputs.size(); ++i) { for (size_t i = 0; i < inode.inputs.size(); ++i) {
...@@ -154,7 +154,7 @@ Graph PlaceDevice(Graph src) { ...@@ -154,7 +154,7 @@ Graph PlaceDevice(Graph src) {
new_node->inputs.emplace_back( new_node->inputs.emplace_back(
NodeEntry{it->second, 0, 0}); NodeEntry{it->second, 0, 0});
} else { } else {
NodePtr copy_node = Node::Create(); ObjectPtr copy_node = Node::Create();
std::ostringstream os; std::ostringstream os;
os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy"; os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy";
copy_node->attrs.op = copy_op; copy_node->attrs.op = copy_op;
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -86,7 +86,7 @@ struct JSONNode { ...@@ -86,7 +86,7 @@ struct JSONNode {
}; };
// pointer to the graph node // pointer to the graph node
NodePtr node; ObjectPtr node;
// inputs // inputs
std::vector<Entry> inputs; std::vector<Entry> inputs;
// control flow dependencies // control flow dependencies
...@@ -190,7 +190,7 @@ struct JSONGraph { ...@@ -190,7 +190,7 @@ struct JSONGraph {
void Symbol2JSONGraph(std::shared_ptr<Symbol> src, JSONGraph *jgraph) { void Symbol2JSONGraph(std::shared_ptr<Symbol> src, JSONGraph *jgraph) {
std::unordered_map<Node*, uint32_t> node2index; std::unordered_map<Node*, uint32_t> node2index;
jgraph->node_row_ptr.push_back(0); jgraph->node_row_ptr.push_back(0);
DFSVisit(src->outputs, [&node2index, jgraph](const NodePtr& n) { DFSVisit(src->outputs, [&node2index, jgraph](const ObjectPtr& n) {
uint32_t nid = static_cast<uint32_t>(jgraph->nodes.size()); uint32_t nid = static_cast<uint32_t>(jgraph->nodes.size());
node2index[n.get()] = nid; node2index[n.get()] = nid;
if (n->is_variable()) { if (n->is_variable()) {
...@@ -202,7 +202,7 @@ void Symbol2JSONGraph(std::shared_ptr<Symbol> src, JSONGraph *jgraph) { ...@@ -202,7 +202,7 @@ void Symbol2JSONGraph(std::shared_ptr<Symbol> src, JSONGraph *jgraph) {
for (const NodeEntry& e : n->inputs) { for (const NodeEntry& e : n->inputs) {
jnode.inputs.emplace_back(node2index.at(e.node.get()), e.index, e.version); jnode.inputs.emplace_back(node2index.at(e.node.get()), e.index, e.version);
} }
for (const NodePtr& c : n->control_deps) { for (const ObjectPtr& c : n->control_deps) {
jnode.control_deps.push_back(node2index.at(c.get())); jnode.control_deps.push_back(node2index.at(c.get()));
} }
jgraph->node_row_ptr.push_back(jgraph->node_row_ptr.back() + n->num_outputs()); jgraph->node_row_ptr.push_back(jgraph->node_row_ptr.back() + n->num_outputs());
......
...@@ -32,7 +32,7 @@ TVM_REGISTER_API("_format_str") ...@@ -32,7 +32,7 @@ TVM_REGISTER_API("_format_str")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kObjectHandle); CHECK(args[0].type_code() == kObjectHandle);
std::ostringstream os; std::ostringstream os;
os << args[0].operator NodeRef(); os << args[0].operator ObjectRef();
*ret = os.str(); *ret = os.str();
}); });
......
...@@ -65,7 +65,7 @@ TVM_REGISTER_API("_Array") ...@@ -65,7 +65,7 @@ TVM_REGISTER_API("_Array")
data.push_back(ObjectRef(nullptr)); data.push_back(ObjectRef(nullptr));
} }
} }
auto node = make_node<ArrayNode>(); auto node = make_object<ArrayNode>();
node->data = std::move(data); node->data = std::move(data);
*ret = Array<ObjectRef>(node); *ret = Array<ObjectRef>(node);
}); });
...@@ -105,7 +105,7 @@ TVM_REGISTER_API("_Map") ...@@ -105,7 +105,7 @@ TVM_REGISTER_API("_Map")
data.emplace(std::make_pair(args[i].operator std::string(), data.emplace(std::make_pair(args[i].operator std::string(),
args[i + 1].operator ObjectRef())); args[i + 1].operator ObjectRef()));
} }
auto node = make_node<StrMapNode>(); auto node = make_object<StrMapNode>();
node->data = std::move(data); node->data = std::move(data);
*ret = Map<ObjectRef, ObjectRef>(node); *ret = Map<ObjectRef, ObjectRef>(node);
} else { } else {
...@@ -119,7 +119,7 @@ TVM_REGISTER_API("_Map") ...@@ -119,7 +119,7 @@ TVM_REGISTER_API("_Map")
data.emplace(std::make_pair(args[i].operator ObjectRef(), data.emplace(std::make_pair(args[i].operator ObjectRef(),
args[i + 1].operator ObjectRef())); args[i + 1].operator ObjectRef()));
} }
auto node = make_node<MapNode>(); auto node = make_object<MapNode>();
node->data = std::move(data); node->data = std::move(data);
*ret = Map<ObjectRef, ObjectRef>(node); *ret = Map<ObjectRef, ObjectRef>(node);
} }
...@@ -186,7 +186,7 @@ TVM_REGISTER_API("_MapItems") ...@@ -186,7 +186,7 @@ TVM_REGISTER_API("_MapItems")
if (ptr->IsInstance<MapNode>()) { if (ptr->IsInstance<MapNode>()) {
auto* n = static_cast<const MapNode*>(ptr); auto* n = static_cast<const MapNode*>(ptr);
auto rkvs = make_node<ArrayNode>(); auto rkvs = make_object<ArrayNode>();
for (const auto& kv : n->data) { for (const auto& kv : n->data) {
rkvs->data.push_back(kv.first); rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second); rkvs->data.push_back(kv.second);
...@@ -194,7 +194,7 @@ TVM_REGISTER_API("_MapItems") ...@@ -194,7 +194,7 @@ TVM_REGISTER_API("_MapItems")
*ret = Array<ObjectRef>(rkvs); *ret = Array<ObjectRef>(rkvs);
} else { } else {
auto* n = static_cast<const StrMapNode*>(ptr); auto* n = static_cast<const StrMapNode*>(ptr);
auto rkvs = make_node<ArrayNode>(); auto rkvs = make_object<ArrayNode>();
for (const auto& kv : n->data) { for (const auto& kv : n->data) {
rkvs->data.push_back(ir::StringImm::make(kv.first)); rkvs->data.push_back(ir::StringImm::make(kv.first));
rkvs->data.push_back(kv.second); rkvs->data.push_back(kv.second);
......
...@@ -100,12 +100,13 @@ TVM_REGISTER_API("ir_pass.RewriteForTensorCore") ...@@ -100,12 +100,13 @@ TVM_REGISTER_API("ir_pass.RewriteForTensorCore")
}); });
TVM_REGISTER_API("ir_pass.AttrsEqual") TVM_REGISTER_API("ir_pass.AttrsEqual")
.set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) { .set_body_typed<bool(const ObjectRef&, const ObjectRef&)>(
[](const ObjectRef& lhs, const ObjectRef& rhs) {
return AttrsEqual()(lhs, rhs); return AttrsEqual()(lhs, rhs);
}); });
TVM_REGISTER_API("ir_pass.AttrsHash") TVM_REGISTER_API("ir_pass.AttrsHash")
.set_body_typed<int64_t(const NodeRef&)>([](const NodeRef &node) { .set_body_typed<int64_t(const ObjectRef&)>([](const ObjectRef &node) {
return AttrsHash()(node); return AttrsHash()(node);
}); });
...@@ -118,7 +119,7 @@ TVM_REGISTER_API("ir_pass.ExprUseVar") ...@@ -118,7 +119,7 @@ TVM_REGISTER_API("ir_pass.ExprUseVar")
TVM_REGISTER_API("ir_pass.PostOrderVisit") TVM_REGISTER_API("ir_pass.PostOrderVisit")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1]; PackedFunc f = args[1];
ir::PostOrderVisit(args[0], [f](const NodeRef& n) { ir::PostOrderVisit(args[0], [f](const ObjectRef& n) {
f(n); f(n);
}); });
}); });
...@@ -126,7 +127,7 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") ...@@ -126,7 +127,7 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
TVM_REGISTER_API("ir_pass.LowerStorageAccess") TVM_REGISTER_API("ir_pass.LowerStorageAccess")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
LoweredFunc f = args[0]; LoweredFunc f = args[0];
auto n = make_node<LoweredFuncNode>(*f.operator->()); auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = LowerStorageAccessInfo(f->body); n->body = LowerStorageAccessInfo(f->body);
*ret = LoweredFunc(n); *ret = LoweredFunc(n);
}); });
......
...@@ -42,7 +42,7 @@ class VariablePathFinder: public IRVisitor { ...@@ -42,7 +42,7 @@ class VariablePathFinder: public IRVisitor {
public: public:
explicit VariablePathFinder(Expr target) : target_(target) {} explicit VariablePathFinder(Expr target) : target_(target) {}
void Visit(const NodeRef& node) final { void Visit(const ObjectRef& node) final {
if (visited_.count(node.get()) != 0) return; if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get()); visited_.insert(node.get());
...@@ -82,7 +82,7 @@ class BoundDeducer: public IRVisitor { ...@@ -82,7 +82,7 @@ class BoundDeducer: public IRVisitor {
void Deduce(); void Deduce();
void Visit(const NodeRef& e) final { void Visit(const ObjectRef& e) final {
if (!success_) return; if (!success_) return;
if (e.get() == path_[iter_++]) { if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e); IRVisitor::Visit(e);
...@@ -202,7 +202,7 @@ class BoundDeduceInputChecker: public IRVisitor { ...@@ -202,7 +202,7 @@ class BoundDeduceInputChecker: public IRVisitor {
return target_count == 1; return target_count == 1;
} }
void Visit(const NodeRef& e) final { void Visit(const ObjectRef& e) final {
if (e.same_as(deducer_->target_)) ++target_count; if (e.same_as(deducer_->target_)) ++target_count;
IRVisitor::Visit(e); IRVisitor::Visit(e);
} }
......
...@@ -56,7 +56,7 @@ class CanonicalExprNode : public BaseExprNode { ...@@ -56,7 +56,7 @@ class CanonicalExprNode : public BaseExprNode {
} }
static constexpr const char* _type_key = "arith.CanonicalExpr"; static constexpr const char* _type_key = "arith.CanonicalExpr";
TVM_DECLARE_BASE_NODE_INFO(CanonicalExprNode, BaseExprNode); TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, BaseExprNode);
}; };
enum DivMode { enum DivMode {
...@@ -147,10 +147,14 @@ class SplitExprNode : public CanonicalExprNode { ...@@ -147,10 +147,14 @@ class SplitExprNode : public CanonicalExprNode {
/*! \brief positive infty */ /*! \brief positive infty */
static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
static constexpr const char* _type_key = "arith.SplitExpr"; static constexpr const char* _type_key = "arith.SplitExpr";
TVM_DECLARE_NODE_TYPE_INFO(SplitExprNode, CanonicalExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(SplitExprNode, CanonicalExprNode);
}; };
TVM_DEFINE_COW_NODE_REF(SplitExpr, Expr, SplitExprNode); class SplitExpr : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, Expr, SplitExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitExprNode);
};
inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const { inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const {
if (index.same_as(other->index)) return true; if (index.same_as(other->index)) return true;
...@@ -272,7 +276,7 @@ class SumExprNode : public CanonicalExprNode { ...@@ -272,7 +276,7 @@ class SumExprNode : public CanonicalExprNode {
void AddToSelf(const SumExpr& other, int64_t scale); void AddToSelf(const SumExpr& other, int64_t scale);
static constexpr const char* _type_key = "arith.SumExpr"; static constexpr const char* _type_key = "arith.SumExpr";
TVM_DECLARE_NODE_TYPE_INFO(SumExprNode, CanonicalExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode);
private: private:
/*! /*!
...@@ -405,7 +409,11 @@ class SumExprNode : public CanonicalExprNode { ...@@ -405,7 +409,11 @@ class SumExprNode : public CanonicalExprNode {
} }
}; };
TVM_DEFINE_COW_NODE_REF(SumExpr, Expr, SumExprNode); class SumExpr : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, Expr, SumExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SumExprNode);
};
void SumExprNode::AddToSelf(const SumExpr& other, int64_t scale) { void SumExprNode::AddToSelf(const SumExpr& other, int64_t scale) {
// NOTE: it is rare to have a balanced long expression, // NOTE: it is rare to have a balanced long expression,
...@@ -507,7 +515,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { ...@@ -507,7 +515,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
if (const auto* op = expr.as<CanonicalExprNode>()) { if (const auto* op = expr.as<CanonicalExprNode>()) {
expr = op->Normalize(); expr = op->Normalize();
} }
NodePtr<SplitExprNode> n = make_node<SplitExprNode>(); ObjectPtr<SplitExprNode> n = make_object<SplitExprNode>();
n->dtype = expr.dtype(); n->dtype = expr.dtype();
n->index = std::move(expr); n->index = std::move(expr);
n->div_mode = kTruncDiv; n->div_mode = kTruncDiv;
...@@ -544,7 +552,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { ...@@ -544,7 +552,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
if (const auto* op = expr.as<SumExprNode>()) { if (const auto* op = expr.as<SumExprNode>()) {
return GetRef<SumExpr>(op); return GetRef<SumExpr>(op);
} }
NodePtr<SumExprNode> n = make_node<SumExprNode>(); ObjectPtr<SumExprNode> n = make_object<SumExprNode>();
n->dtype = expr.dtype(); n->dtype = expr.dtype();
if (const auto* op = expr.as<IntImm>()) { if (const auto* op = expr.as<IntImm>()) {
n->base = op->value; n->base = op->value;
...@@ -655,8 +663,8 @@ SeparateDivisibleParts(const SumExprNode* psum, ...@@ -655,8 +663,8 @@ SeparateDivisibleParts(const SumExprNode* psum,
int64_t coeff, int64_t coeff,
SumExpr* out_divisible, SumExpr* out_divisible,
SumExpr* out_non_divisible) { SumExpr* out_non_divisible) {
auto divisible = make_node<SumExprNode>(); auto divisible = make_object<SumExprNode>();
auto non_divisible = make_node<SumExprNode>(); auto non_divisible = make_object<SumExprNode>();
divisible->dtype = psum->dtype; divisible->dtype = psum->dtype;
non_divisible->dtype = psum->dtype; non_divisible->dtype = psum->dtype;
......
...@@ -35,7 +35,7 @@ TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); ...@@ -35,7 +35,7 @@ TVM_REGISTER_NODE_TYPE(ConstIntBoundNode);
ConstIntBound::ConstIntBound( ConstIntBound::ConstIntBound(
int64_t min_value, int64_t max_value) { int64_t min_value, int64_t max_value) {
auto node = make_node<ConstIntBoundNode>(); auto node = make_object<ConstIntBoundNode>();
node->min_value = min_value; node->min_value = min_value;
node->max_value = max_value; node->max_value = max_value;
data_ = std::move(node); data_ = std::move(node);
...@@ -123,7 +123,7 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -123,7 +123,7 @@ class ConstIntBoundAnalyzer::Impl :
} }
// Override visitor behaviors // Override visitor behaviors
Entry VisitExprDefault_(const Node* op) final { Entry VisitExprDefault_(const Object* op) final {
return Everything( return Everything(
static_cast<const ExprNode*>(op)->dtype); static_cast<const ExprNode*>(op)->dtype);
} }
......
...@@ -106,7 +106,7 @@ class LinearEqDetector ...@@ -106,7 +106,7 @@ class LinearEqDetector
} }
return ret; return ret;
} }
LinearEqEntry VisitExprDefault_(const Node* op, const Expr& e) final { LinearEqEntry VisitExprDefault_(const Object* op, const Expr& e) final {
if (fail_) return LinearEqEntry(); if (fail_) return LinearEqEntry();
if (ExprUseVar(e, var_)) { if (ExprUseVar(e, var_)) {
fail_ = true; fail_ = true;
...@@ -171,7 +171,7 @@ bool DetectClipBound( ...@@ -171,7 +171,7 @@ bool DetectClipBound(
std::unordered_map<const Variable*, IntervalEntry>* bmap) { std::unordered_map<const Variable*, IntervalEntry>* bmap) {
int flag = 0; int flag = 0;
Var var; Var var;
auto fvisit = [&bmap, &flag, &var](const NodeRef& n) { auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) {
if (const Variable* v = n.as<Variable>()) { if (const Variable* v = n.as<Variable>()) {
if (bmap->count(v)) { if (bmap->count(v)) {
if (flag == 0) { if (flag == 0) {
......
...@@ -37,7 +37,7 @@ Expr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); ...@@ -37,7 +37,7 @@ Expr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle());
Expr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); Expr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle());
IntervalSet::IntervalSet(Expr min_value, Expr max_value) { IntervalSet::IntervalSet(Expr min_value, Expr max_value) {
auto node = make_node<IntervalSetNode>(); auto node = make_object<IntervalSetNode>();
node->min_value = std::move(min_value); node->min_value = std::move(min_value);
node->max_value = std::move(max_value); node->max_value = std::move(max_value);
data_ = std::move(node); data_ = std::move(node);
...@@ -505,7 +505,7 @@ class IntervalSetEvaluator : ...@@ -505,7 +505,7 @@ class IntervalSetEvaluator :
return Union(analyzer_, false_set, true_set); return Union(analyzer_, false_set, true_set);
} }
IntervalSet VisitExprDefault_(const Node* op) final { IntervalSet VisitExprDefault_(const Object* op) final {
DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey(); DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey();
return IntervalSet::Everything(); return IntervalSet::Everything();
} }
......
...@@ -75,7 +75,7 @@ class IntervalSetNode : public IntSetNode { ...@@ -75,7 +75,7 @@ class IntervalSetNode : public IntSetNode {
} }
static constexpr const char* _type_key = "arith.IntervalSet"; static constexpr const char* _type_key = "arith.IntervalSet";
TVM_DECLARE_NODE_TYPE_INFO(IntervalSetNode, IntSetNode); TVM_DECLARE_FINAL_OBJECT_INFO(IntervalSetNode, IntSetNode);
}; };
/*! /*!
...@@ -116,8 +116,8 @@ class IntervalSet : public IntSet { ...@@ -116,8 +116,8 @@ class IntervalSet : public IntSet {
return IntervalSet(pos_inf(), neg_inf()); return IntervalSet(pos_inf(), neg_inf());
} }
TVM_DEFINE_NODE_REF_COW(IntervalSetNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(IntervalSetNode);
TVM_DEFINE_NODE_REF_METHODS(IntervalSet, IntSet, IntervalSetNode); TVM_DEFINE_OBJECT_REF_METHODS(IntervalSet, IntSet, IntervalSetNode);
}; };
/*! /*!
......
...@@ -37,7 +37,7 @@ using namespace ir; ...@@ -37,7 +37,7 @@ using namespace ir;
TVM_REGISTER_NODE_TYPE(ModularSetNode); TVM_REGISTER_NODE_TYPE(ModularSetNode);
ModularSet::ModularSet(int64_t coeff, int64_t base) { ModularSet::ModularSet(int64_t coeff, int64_t base) {
auto node = make_node<ModularSetNode>(); auto node = make_object<ModularSetNode>();
node->coeff = coeff; node->coeff = coeff;
node->base = base; node->base = base;
// finish construction. // finish construction.
...@@ -120,7 +120,7 @@ class ModularSetAnalyzer::Impl : ...@@ -120,7 +120,7 @@ class ModularSetAnalyzer::Impl :
} }
// Override visitor behaviors // Override visitor behaviors
Entry VisitExprDefault_(const Node* op) final { Entry VisitExprDefault_(const Object* op) final {
return Everything(); return Everything();
} }
......
...@@ -250,7 +250,7 @@ class PBinaryExpr : ...@@ -250,7 +250,7 @@ class PBinaryExpr :
b_.InitMatch_(); b_.InitMatch_();
} }
bool Match_(const NodeRef& node) const { bool Match_(const ObjectRef& node) const {
if (const NodeType* ptr = node.as<NodeType>()) { if (const NodeType* ptr = node.as<NodeType>()) {
if (!a_.Match_(ptr->a)) return false; if (!a_.Match_(ptr->a)) return false;
if (!b_.Match_(ptr->b)) return false; if (!b_.Match_(ptr->b)) return false;
...@@ -282,7 +282,7 @@ class PConstWithTypeLike : ...@@ -282,7 +282,7 @@ class PConstWithTypeLike :
void InitMatch_() const {} void InitMatch_() const {}
bool Match_(const NodeRef& node) const { bool Match_(const ObjectRef& node) const {
if (const ir::IntImm* ptr = node.as<ir::IntImm>()) { if (const ir::IntImm* ptr = node.as<ir::IntImm>()) {
return ptr->value == value_; return ptr->value == value_;
} else { } else {
...@@ -364,7 +364,7 @@ class PNotExpr : public Pattern<PNotExpr<TA> > { ...@@ -364,7 +364,7 @@ class PNotExpr : public Pattern<PNotExpr<TA> > {
value_.InitMatch_(); value_.InitMatch_();
} }
bool Match_(const NodeRef& node) const { bool Match_(const ObjectRef& node) const {
if (const ir::Not* ptr = node.as<ir::Not>()) { if (const ir::Not* ptr = node.as<ir::Not>()) {
if (!value_.Match_(ptr->a)) return false; if (!value_.Match_(ptr->a)) return false;
return true; return true;
...@@ -410,7 +410,7 @@ class PSelectExpr : ...@@ -410,7 +410,7 @@ class PSelectExpr :
false_value_.InitMatch_(); false_value_.InitMatch_();
} }
bool Match_(const NodeRef& node) const { bool Match_(const ObjectRef& node) const {
if (const ir::Select* ptr = node.as<ir::Select>()) { if (const ir::Select* ptr = node.as<ir::Select>()) {
if (!condition_.Match_(ptr->condition)) return false; if (!condition_.Match_(ptr->condition)) return false;
if (!true_value_.Match_(ptr->true_value)) return false; if (!true_value_.Match_(ptr->true_value)) return false;
...@@ -472,7 +472,7 @@ class PCastExpr : ...@@ -472,7 +472,7 @@ class PCastExpr :
value_.InitMatch_(); value_.InitMatch_();
} }
bool Match_(const NodeRef& node) const { bool Match_(const ObjectRef& node) const {
if (const ir::Cast* ptr = node.as<ir::Cast>()) { if (const ir::Cast* ptr = node.as<ir::Cast>()) {
if (!dtype_.Match_(ptr->dtype)) return false; if (!dtype_.Match_(ptr->dtype)) return false;
if (!value_.Match_(ptr->value)) return false; if (!value_.Match_(ptr->value)) return false;
...@@ -530,7 +530,7 @@ class PRampExpr : ...@@ -530,7 +530,7 @@ class PRampExpr :
lanes_.InitMatch_(); lanes_.InitMatch_();
} }
bool Match_(const NodeRef& node) const { bool Match_(const ObjectRef& node) const {
if (const ir::Ramp* ptr = node.as<ir::Ramp>()) { if (const ir::Ramp* ptr = node.as<ir::Ramp>()) {
if (!base_.Match_(ptr->base)) return false; if (!base_.Match_(ptr->base)) return false;
if (!stride_.Match_(ptr->stride)) return false; if (!stride_.Match_(ptr->stride)) return false;
...@@ -592,7 +592,7 @@ class PBroadcastExpr : ...@@ -592,7 +592,7 @@ class PBroadcastExpr :
lanes_.InitMatch_(); lanes_.InitMatch_();
} }
bool Match_(const NodeRef& node) const { bool Match_(const ObjectRef& node) const {
if (const ir::Broadcast* ptr = node.as<ir::Broadcast>()) { if (const ir::Broadcast* ptr = node.as<ir::Broadcast>()) {
if (!value_.Match_(ptr->value)) return false; if (!value_.Match_(ptr->value)) return false;
if (!lanes_.Match_(ptr->lanes)) return false; if (!lanes_.Match_(ptr->lanes)) return false;
...@@ -704,7 +704,7 @@ class PCallExpr : ...@@ -704,7 +704,7 @@ class PCallExpr :
detail::tuple_for_each(finit, args_); detail::tuple_for_each(finit, args_);
} }
bool Match_(const NodeRef& node) const { bool Match_(const ObjectRef& node) const {
if (const ir::Call* ptr = node.as<ir::Call>()) { if (const ir::Call* ptr = node.as<ir::Call>()) {
if (ptr->args.size() != sizeof...(TArgs)) return false; if (ptr->args.size() != sizeof...(TArgs)) return false;
if (ptr->name != Op::kName) return false; if (ptr->name != Op::kName) return false;
......
...@@ -53,7 +53,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -53,7 +53,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
*/ */
Target CreateTarget(const std::string& target_name, Target CreateTarget(const std::string& target_name,
const std::vector<std::string>& options) { const std::vector<std::string>& options) {
auto t = make_node<TargetNode>(); auto t = make_object<TargetNode>();
t->target_name = target_name; t->target_name = target_name;
std::string libs_flag = "-libs="; std::string libs_flag = "-libs=";
...@@ -366,7 +366,7 @@ void GetBinds(const Array<Tensor>& args, ...@@ -366,7 +366,7 @@ void GetBinds(const Array<Tensor>& args,
bool compact, bool compact,
const std::unordered_map<Tensor, Buffer>& binds, const std::unordered_map<Tensor, Buffer>& binds,
Map<Tensor, Buffer>* out_binds, Map<Tensor, Buffer>* out_binds,
Array<NodeRef>* out_arg_list, Array<ObjectRef>* out_arg_list,
const BuildConfig& config) { const BuildConfig& config) {
*out_binds = binds; *out_binds = binds;
...@@ -396,7 +396,7 @@ Stmt BuildStmt(Schedule sch, ...@@ -396,7 +396,7 @@ Stmt BuildStmt(Schedule sch,
const Array<Tensor>& args, const Array<Tensor>& args,
const std::unordered_map<Tensor, Buffer>& binds, const std::unordered_map<Tensor, Buffer>& binds,
bool loop_partition, bool loop_partition,
Array<NodeRef> *out_arg_list, Array<ObjectRef> *out_arg_list,
const BuildConfig& config) { const BuildConfig& config) {
sch = sch.normalize(); sch = sch.normalize();
...@@ -445,7 +445,7 @@ Array<LoweredFunc> lower(Schedule sch, ...@@ -445,7 +445,7 @@ Array<LoweredFunc> lower(Schedule sch,
const std::string& name, const std::string& name,
const std::unordered_map<Tensor, Buffer>& binds, const std::unordered_map<Tensor, Buffer>& binds,
const BuildConfig& config) { const BuildConfig& config) {
Array<NodeRef> out_arg_list; Array<ObjectRef> out_arg_list;
auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config); auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) }); return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
} }
...@@ -618,7 +618,7 @@ runtime::Module build(const Array<LoweredFunc>& funcs, ...@@ -618,7 +618,7 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
} }
BuildConfig BuildConfig::Create() { BuildConfig BuildConfig::Create() {
return BuildConfig(make_node<BuildConfigNode>()); return BuildConfig(make_object<BuildConfigNode>());
} }
/*! \brief Entry to hold the BuildConfig context stack. */ /*! \brief Entry to hold the BuildConfig context stack. */
...@@ -701,7 +701,7 @@ GenericFunc GenericFunc::Get(const std::string& name) { ...@@ -701,7 +701,7 @@ GenericFunc GenericFunc::Get(const std::string& name) {
std::lock_guard<std::mutex>(m->mutex); std::lock_guard<std::mutex>(m->mutex);
auto it = m->fmap.find(name); auto it = m->fmap.find(name);
if (it == m->fmap.end()) { if (it == m->fmap.end()) {
auto f = make_node<GenericFuncNode>(); auto f = make_object<GenericFuncNode>();
f->name_ = name; f->name_ = name;
auto gf = GenericFunc(f); auto gf = GenericFunc(f);
m->fmap[name] = gf; m->fmap[name] = gf;
...@@ -825,7 +825,7 @@ TVM_REGISTER_API("_BuildConfigGetAddLowerPassInfo") ...@@ -825,7 +825,7 @@ TVM_REGISTER_API("_BuildConfigGetAddLowerPassInfo")
TVM_REGISTER_API("_GenericFuncCreate") TVM_REGISTER_API("_GenericFuncCreate")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GenericFunc(make_node<GenericFuncNode>()); *ret = GenericFunc(make_object<GenericFuncNode>());
}); });
TVM_REGISTER_API("_GenericFuncGetGlobal") TVM_REGISTER_API("_GenericFuncGetGlobal")
......
...@@ -408,7 +408,7 @@ void CodeGenHybrid::PrintIndent() { ...@@ -408,7 +408,7 @@ void CodeGenHybrid::PrintIndent() {
std::string CodeGenHybrid::GetVarID(const Variable *v) { std::string CodeGenHybrid::GetVarID(const Variable *v) {
if (binds_.count(v)) if (binds_.count(v))
return binds_[v]; return binds_[v];
auto key = std::make_pair(static_cast<const Node*>(v), 0); auto key = std::make_pair(static_cast<const Object*>(v), 0);
if (id_map_.count(key)) { if (id_map_.count(key)) {
return id_map_[key]; return id_map_[key];
} }
...@@ -472,7 +472,7 @@ void CodeGenHybrid::ReserveKeywords() { ...@@ -472,7 +472,7 @@ void CodeGenHybrid::ReserveKeywords() {
} }
void CodeGenHybrid::DumpStmt(const Stmt &stmt, void CodeGenHybrid::DumpStmt(const Stmt &stmt,
const Array<NodeRef> &inputs, const Array<ObjectRef> &inputs,
const Array<Tensor> &outputs, const Array<Tensor> &outputs,
const std::string &name) { const std::string &name) {
ReserveKeywords(); ReserveKeywords();
......
...@@ -56,7 +56,7 @@ class CodeGenHybrid : ...@@ -56,7 +56,7 @@ class CodeGenHybrid :
* \param outputs Output tensors of this schedule. * \param outputs Output tensors of this schedule.
* \param name The name of the function. * \param name The name of the function.
*/ */
void DumpStmt(const Stmt &stmt, const Array<NodeRef> &inputs, const Array<Tensor> &outputs, void DumpStmt(const Stmt &stmt, const Array<ObjectRef> &inputs, const Array<Tensor> &outputs,
const std::string &name = "hybrid_func"); const std::string &name = "hybrid_func");
/*! /*!
* \brief Finalize the compilation and return the code. * \brief Finalize the compilation and return the code.
...@@ -152,7 +152,7 @@ class CodeGenHybrid : ...@@ -152,7 +152,7 @@ class CodeGenHybrid :
/*! /*!
* \brief Keys are either (tensors, value_index) or (variables, 0). * \brief Keys are either (tensors, value_index) or (variables, 0).
* Values are the corresponding IDs.*/ * Values are the corresponding IDs.*/
std::map<std::pair<const Node *, int>, std::string> id_map_; std::map<std::pair<const Object *, int>, std::string> id_map_;
/*! \brief Variables (keys) binded to the threads (values). */ /*! \brief Variables (keys) binded to the threads (values). */
std::map<const Variable *, std::string> binds_; std::map<const Variable *, std::string> binds_;
/*! /*!
......
...@@ -33,7 +33,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -33,7 +33,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
ObjectPtr<Object> CreateEnvNode(const std::string& name) { ObjectPtr<Object> CreateEnvNode(const std::string& name) {
auto* f = runtime::Registry::Get(name); auto* f = runtime::Registry::Get(name);
CHECK(f != nullptr) << "Cannot find global function \'" << name << '\''; CHECK(f != nullptr) << "Cannot find global function \'" << name << '\'';
NodePtr<EnvFuncNode> n = make_node<EnvFuncNode>(); ObjectPtr<EnvFuncNode> n = make_object<EnvFuncNode>();
n->func = *f; n->func = *f;
n->name = name; n->name = name;
return n; return n;
......
...@@ -39,8 +39,8 @@ void DictAttrsNode::InitByPackedArgs( ...@@ -39,8 +39,8 @@ void DictAttrsNode::InitByPackedArgs(
for (int i = 0; i < args.size(); i += 2) { for (int i = 0; i < args.size(); i += 2) {
std::string key = args[i]; std::string key = args[i];
runtime::TVMArgValue val = args[i + 1]; runtime::TVMArgValue val = args[i + 1];
if (val.type_code() == kObjectHandle) { if (val.IsObjectRef<ObjectRef>()) {
dict.Set(key, val.operator NodeRef()); dict.Set(key, val.operator ObjectRef());
} else if (val.type_code() == kStr) { } else if (val.type_code() == kStr) {
dict.Set(key, Expr(val.operator std::string())); dict.Set(key, Expr(val.operator std::string()));
} else { } else {
...@@ -53,8 +53,8 @@ Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const { ...@@ -53,8 +53,8 @@ Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
return {}; return {};
} }
Attrs DictAttrsNode::make(Map<std::string, NodeRef> dict) { Attrs DictAttrsNode::make(Map<std::string, ObjectRef> dict) {
NodePtr<DictAttrsNode> n = make_node<DictAttrsNode>(); ObjectPtr<DictAttrsNode> n = make_object<DictAttrsNode>();
n->dict = std::move(dict); n->dict = std::move(dict);
return Attrs(n); return Attrs(n);
} }
......
...@@ -334,7 +334,7 @@ Buffer Buffer::MakeStrideView() const { ...@@ -334,7 +334,7 @@ Buffer Buffer::MakeStrideView() const {
if ((*this)->strides.size() != 0) return *this; if ((*this)->strides.size() != 0) return *this;
if ((*this)->shape.size() == 0) return *this; if ((*this)->shape.size() == 0) return *this;
std::vector<Expr> temp; std::vector<Expr> temp;
auto n = make_node<BufferNode>(*operator->()); auto n = make_object<BufferNode>(*operator->());
Expr acc = make_const(n->DefaultIndexType(), 1); Expr acc = make_const(n->DefaultIndexType(), 1);
for (size_t i = n->shape.size(); i != 0 ; --i) { for (size_t i = n->shape.size(); i != 0 ; --i) {
temp.push_back(acc); temp.push_back(acc);
...@@ -419,7 +419,7 @@ Buffer BufferNode::make(Var data, ...@@ -419,7 +419,7 @@ Buffer BufferNode::make(Var data,
int data_alignment, int data_alignment,
int offset_factor, int offset_factor,
BufferType buffer_type) { BufferType buffer_type) {
auto n = make_node<BufferNode>(); auto n = make_object<BufferNode>();
n->data = std::move(data); n->data = std::move(data);
n->dtype = dtype; n->dtype = dtype;
n->shape = std::move(shape); n->shape = std::move(shape);
......
...@@ -68,7 +68,7 @@ const LayoutAxis& LayoutAxis::make(const std::string& name) { ...@@ -68,7 +68,7 @@ const LayoutAxis& LayoutAxis::make(const std::string& name) {
} }
Layout::Layout(const Array<IterVar>& axes) { Layout::Layout(const Array<IterVar>& axes) {
auto node = make_node<LayoutNode>(); auto node = make_object<LayoutNode>();
node->axes = axes; node->axes = axes;
std::ostringstream repr; std::ostringstream repr;
for (const IterVar& axis : axes) { for (const IterVar& axis : axes) {
...@@ -89,7 +89,7 @@ Layout::Layout(const Array<IterVar>& axes) { ...@@ -89,7 +89,7 @@ Layout::Layout(const Array<IterVar>& axes) {
Layout::Layout(const std::string& name) { // NOLINT(*) Layout::Layout(const std::string& name) { // NOLINT(*)
if (name == "__undef__") return; if (name == "__undef__") return;
auto node = make_node<LayoutNode>(); auto node = make_object<LayoutNode>();
node->name = name; node->name = name;
if (name.empty()) return; // scalar if (name.empty()) return; // scalar
...@@ -347,7 +347,7 @@ Array<Expr> BijectiveLayout::BackwardShape(const Array<Expr>& shape) const { ...@@ -347,7 +347,7 @@ Array<Expr> BijectiveLayout::BackwardShape(const Array<Expr>& shape) const {
BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout, BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout,
const Layout& dst_layout) { const Layout& dst_layout) {
auto n = make_node<BijectiveLayoutNode>(); auto n = make_object<BijectiveLayoutNode>();
n->src_layout = src_layout; n->src_layout = src_layout;
n->dst_layout = dst_layout; n->dst_layout = dst_layout;
......
...@@ -42,14 +42,14 @@ Var::Var(std::string name_hint, DataType t) ...@@ -42,14 +42,14 @@ Var::Var(std::string name_hint, DataType t)
: Var(Variable::make(t, name_hint)) {} : Var(Variable::make(t, name_hint)) {}
Var Variable::make(DataType t, std::string name_hint) { Var Variable::make(DataType t, std::string name_hint) {
NodePtr<Variable> node = make_node<Variable>(); ObjectPtr<Variable> node = make_object<Variable>();
node->dtype = t; node->dtype = t;
node->name_hint = std::move(name_hint); node->name_hint = std::move(name_hint);
return Var(node); return Var(node);
} }
Range::Range(Expr begin, Expr end) Range::Range(Expr begin, Expr end)
: Range(make_node<RangeNode>( : Range(make_object<RangeNode>(
begin, begin,
is_zero(begin) ? end : (end - begin))) { is_zero(begin) ? end : (end - begin))) {
} }
...@@ -57,21 +57,21 @@ Range::Range(Expr begin, Expr end) ...@@ -57,21 +57,21 @@ Range::Range(Expr begin, Expr end)
Integer IntImm::make(DataType t, int64_t value) { Integer IntImm::make(DataType t, int64_t value) {
CHECK(t.is_int() && t.is_scalar()) CHECK(t.is_int() && t.is_scalar())
<< "ValueError: IntImm can only take scalar."; << "ValueError: IntImm can only take scalar.";
NodePtr<IntImm> node = make_node<IntImm>(); ObjectPtr<IntImm> node = make_object<IntImm>();
node->dtype = t; node->dtype = t;
node->value = value; node->value = value;
return Integer(node); return Integer(node);
} }
Range Range::make_by_min_extent(Expr min, Expr extent) { Range Range::make_by_min_extent(Expr min, Expr extent) {
return Range(make_node<RangeNode>(min, extent)); return Range(make_object<RangeNode>(min, extent));
} }
IterVar IterVarNode::make(Range dom, IterVar IterVarNode::make(Range dom,
Var var, Var var,
IterVarType t, IterVarType t,
std::string thread_tag) { std::string thread_tag) {
NodePtr<IterVarNode> n = make_node<IterVarNode>(); ObjectPtr<IterVarNode> n = make_object<IterVarNode>();
n->dom = dom; n->dom = dom;
n->var = var; n->var = var;
n->iter_type = t; n->iter_type = t;
...@@ -89,7 +89,7 @@ IterVar reduce_axis(Range dom, std::string name) { ...@@ -89,7 +89,7 @@ IterVar reduce_axis(Range dom, std::string name) {
dom, Var(name), kCommReduce); dom, Var(name), kCommReduce);
} }
void Dump(const NodeRef& n) { void Dump(const ObjectRef& n) {
std::cerr << n << "\n"; std::cerr << n << "\n";
} }
......
...@@ -47,7 +47,7 @@ Expr Tensor::operator()(Array<Expr> indices) const { ...@@ -47,7 +47,7 @@ Expr Tensor::operator()(Array<Expr> indices) const {
} }
Tensor Operation::output(size_t i) const { Tensor Operation::output(size_t i) const {
auto node = make_node<TensorNode>(); auto node = make_object<TensorNode>();
node->op = *this; node->op = *this;
node->value_index = i; node->value_index = i;
node->dtype = (*this)->output_dtype(i); node->dtype = (*this)->output_dtype(i);
...@@ -59,7 +59,7 @@ Tensor TensorNode::make(Array<Expr> shape, ...@@ -59,7 +59,7 @@ Tensor TensorNode::make(Array<Expr> shape,
DataType dtype, DataType dtype,
Operation op, Operation op,
int value_index) { int value_index) {
auto n = make_node<TensorNode>(); auto n = make_object<TensorNode>();
n->shape = std::move(shape); n->shape = std::move(shape);
n->dtype = dtype; n->dtype = dtype;
n->op = op; n->op = op;
...@@ -87,7 +87,7 @@ TensorIntrin TensorIntrinNode::make(std::string name, ...@@ -87,7 +87,7 @@ TensorIntrin TensorIntrinNode::make(std::string name,
Stmt body, Stmt body,
Stmt reduce_init, Stmt reduce_init,
Stmt reduce_update) { Stmt reduce_update) {
auto n = make_node<TensorIntrinNode>(); auto n = make_object<TensorIntrinNode>();
n->name = std::move(name); n->name = std::move(name);
n->op = std::move(op); n->op = std::move(op);
n->inputs = std::move(inputs); n->inputs = std::move(inputs);
...@@ -115,7 +115,7 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, ...@@ -115,7 +115,7 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
Array<Region> regions, Array<Region> regions,
Array<IterVar> reduce_axis, Array<IterVar> reduce_axis,
Array<Expr> scalar_inputs) { Array<Expr> scalar_inputs) {
auto n = make_node<TensorIntrinCallNode>(); auto n = make_object<TensorIntrinCallNode>();
n->intrin = std::move(intrin); n->intrin = std::move(intrin);
n->tensors = std::move(tensors); n->tensors = std::move(tensors);
n->regions = std::move(regions); n->regions = std::move(regions);
......
...@@ -79,7 +79,7 @@ class NodeIndexer : public AttrVisitor { ...@@ -79,7 +79,7 @@ class NodeIndexer : public AttrVisitor {
// make index of all the children of node // make index of all the children of node
void MakeIndex(Object* node) { void MakeIndex(Object* node) {
if (node == nullptr) return; if (node == nullptr) return;
CHECK(node->IsInstance<Node>()); CHECK(node->IsInstance<Object>());
if (node_index_.count(node)) return; if (node_index_.count(node)) return;
CHECK_EQ(node_index_.size(), node_list_.size()); CHECK_EQ(node_index_.size(), node_list_.size());
......
...@@ -90,8 +90,8 @@ Tensor compute(Array<Expr> shape, ...@@ -90,8 +90,8 @@ Tensor compute(Array<Expr> shape,
FCompute fcompute, FCompute fcompute,
std::string name, std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs) { Map<std::string, ObjectRef> attrs) {
auto op_node = make_node<ComputeOpNode>(); auto op_node = make_object<ComputeOpNode>();
// compute dimension. // compute dimension.
size_t ndim = shape.size(); size_t ndim = shape.size();
std::vector<IterVar> axis; std::vector<IterVar> axis;
...@@ -112,8 +112,8 @@ Array<Tensor> compute(Array<Expr> shape, ...@@ -112,8 +112,8 @@ Array<Tensor> compute(Array<Expr> shape,
FBatchCompute fcompute, FBatchCompute fcompute,
std::string name, std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs) { Map<std::string, ObjectRef> attrs) {
auto op_node = make_node<ComputeOpNode>(); auto op_node = make_object<ComputeOpNode>();
// compute dimension. // compute dimension.
size_t ndim = shape.size(); size_t ndim = shape.size();
std::vector<IterVar> axis; std::vector<IterVar> axis;
...@@ -136,13 +136,13 @@ Array<Tensor> compute(Array<Expr> shape, ...@@ -136,13 +136,13 @@ Array<Tensor> compute(Array<Expr> shape,
Operation ComputeOpNode::make(std::string name, Operation ComputeOpNode::make(std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs, Map<std::string, ObjectRef> attrs,
Array<IterVar> axis, Array<IterVar> axis,
Array<Expr> body) { Array<Expr> body) {
if (!attrs.defined()) { if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>(); attrs = Map<std::string, ObjectRef>();
} }
auto n = make_node<ComputeOpNode>(); auto n = make_object<ComputeOpNode>();
n->name = std::move(name); n->name = std::move(name);
n->tag = std::move(tag); n->tag = std::move(tag);
n->attrs = std::move(attrs); n->attrs = std::move(attrs);
...@@ -161,7 +161,7 @@ Array<Tensor> ComputeOpNode::InputTensors() const { ...@@ -161,7 +161,7 @@ Array<Tensor> ComputeOpNode::InputTensors() const {
Array<Tensor> ret; Array<Tensor> ret;
std::unordered_set<Tensor> visited; std::unordered_set<Tensor> visited;
for (auto& e : body) { for (auto& e : body) {
ir::PostOrderVisit(e, [&ret, &visited](const NodeRef& n) { ir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) {
const ir::Call *call = n.as<ir::Call>(); const ir::Call *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) { if (call != nullptr && call->func.defined()) {
Tensor t = Downcast<Operation>(call->func).output(call->value_index); Tensor t = Downcast<Operation>(call->func).output(call->value_index);
...@@ -188,7 +188,7 @@ Operation ComputeOpNode::ReplaceInputs( ...@@ -188,7 +188,7 @@ Operation ComputeOpNode::ReplaceInputs(
if (!new_reduce.same_as(this->body[0])) { if (!new_reduce.same_as(this->body[0])) {
const ir::Reduce* r = new_reduce.as<ir::Reduce>(); const ir::Reduce* r = new_reduce.as<ir::Reduce>();
for (size_t k = 0; k < this->body.size(); ++k) { for (size_t k = 0; k < this->body.size(); ++k) {
auto n = make_node<ir::Reduce>(*r); auto n = make_object<ir::Reduce>(*r);
n->value_index = static_cast<int>(k); n->value_index = static_cast<int>(k);
n->dtype = r->source[k].dtype(); n->dtype = r->source[k].dtype();
arr.push_back(Expr(n)); arr.push_back(Expr(n));
...@@ -215,7 +215,7 @@ void ComputeOpNode::PropBoundToInputs( ...@@ -215,7 +215,7 @@ void ComputeOpNode::PropBoundToInputs(
const std::unordered_map<const Variable*, IntSet>& dom_map, const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const { std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this); CHECK_EQ(self.operator->(), this);
auto fvisit = [&dom_map, out_dom_map, analyzer](const NodeRef& n) { auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) {
auto *call = n.as<ir::Call>(); auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) { if (call != nullptr && call->func.defined()) {
Tensor t = Downcast<Operation>(call->func).output(call->value_index); Tensor t = Downcast<Operation>(call->func).output(call->value_index);
...@@ -574,7 +574,7 @@ class ComputeVerifier final : protected ir::IRVisitor { ...@@ -574,7 +574,7 @@ class ComputeVerifier final : protected ir::IRVisitor {
protected: protected:
/// Visitor implementation /// Visitor implementation
//@{ //@{
void Visit(const NodeRef& n) final { void Visit(const ObjectRef& n) final {
++level_; ++level_;
ir::IRVisitor::Visit(n); ir::IRVisitor::Visit(n);
--level_; --level_;
......
...@@ -57,15 +57,15 @@ Array<Expr> ExternOpNode::output_shape(size_t i) const { ...@@ -57,15 +57,15 @@ Array<Expr> ExternOpNode::output_shape(size_t i) const {
Operation ExternOpNode::make(std::string name, Operation ExternOpNode::make(std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs, Map<std::string, ObjectRef> attrs,
Array<Tensor> inputs, Array<Tensor> inputs,
Array<Buffer> input_placeholders, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Array<Buffer> output_placeholders,
Stmt body) { Stmt body) {
if (!attrs.defined()) { if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>(); attrs = Map<std::string, ObjectRef>();
} }
auto n = make_node<ExternOpNode>(); auto n = make_object<ExternOpNode>();
n->name = std::move(name); n->name = std::move(name);
n->tag = std::move(tag); n->tag = std::move(tag);
n->attrs = std::move(attrs); n->attrs = std::move(attrs);
...@@ -93,7 +93,7 @@ Operation ExternOpNode::ReplaceInputs( ...@@ -93,7 +93,7 @@ Operation ExternOpNode::ReplaceInputs(
const Operation& self, const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const { const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this); CHECK_EQ(self.operator->(), this);
auto n = make_node<ExternOpNode>(*this); auto n = make_object<ExternOpNode>(*this);
n->body = op::ReplaceTensor(this->body, rmap); n->body = op::ReplaceTensor(this->body, rmap);
for (size_t i = 0; i < n->inputs.size(); ++i) { for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i]; Tensor t = n->inputs[i];
...@@ -161,7 +161,7 @@ Stmt ExternOpNode::BuildProvide( ...@@ -161,7 +161,7 @@ Stmt ExternOpNode::BuildProvide(
CHECK_EQ(stage->op.operator->(), this); CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmt::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body); Stmt ret = AttrStmt::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<NodeRef> bind_spec; Array<ObjectRef> bind_spec;
Array<Expr> tuple; Array<Expr> tuple;
bind_spec.push_back(buffer); bind_spec.push_back(buffer);
bind_spec.push_back(tensor); bind_spec.push_back(tensor);
......
...@@ -63,14 +63,14 @@ Array<Expr> HybridOpNode::output_shape(size_t i) const { ...@@ -63,14 +63,14 @@ Array<Expr> HybridOpNode::output_shape(size_t i) const {
Operation HybridOpNode::make(std::string name, Operation HybridOpNode::make(std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs, Map<std::string, ObjectRef> attrs,
Array<Tensor> inputs, Array<Tensor> inputs,
Array<Tensor> outputs, Array<Tensor> outputs,
Stmt body) { Stmt body) {
if (!attrs.defined()) { if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>(); attrs = Map<std::string, ObjectRef>();
} }
auto n = make_node<HybridOpNode>(); auto n = make_object<HybridOpNode>();
n->name = std::move(name); n->name = std::move(name);
n->tag = std::move(tag); n->tag = std::move(tag);
n->attrs = std::move(attrs); n->attrs = std::move(attrs);
...@@ -91,7 +91,7 @@ Array<Tensor> HybridOpNode::InputTensors() const { ...@@ -91,7 +91,7 @@ Array<Tensor> HybridOpNode::InputTensors() const {
} }
std::unordered_set<Tensor> visited; std::unordered_set<Tensor> visited;
Array<Tensor> curr_inputs; Array<Tensor> curr_inputs;
ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const NodeRef& n) { ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) {
const ir::Call *call = n.as<ir::Call>(); const ir::Call *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) { if (call != nullptr && call->func.defined()) {
Tensor t = Downcast<Operation>(call->func).output(call->value_index); Tensor t = Downcast<Operation>(call->func).output(call->value_index);
...@@ -108,7 +108,7 @@ Operation HybridOpNode::ReplaceInputs( ...@@ -108,7 +108,7 @@ Operation HybridOpNode::ReplaceInputs(
const Operation &self, const Operation &self,
const std::unordered_map<Tensor, Tensor> &rmap) const { const std::unordered_map<Tensor, Tensor> &rmap) const {
CHECK_EQ(self.operator->(), this); CHECK_EQ(self.operator->(), this);
auto n = make_node<HybridOpNode>(*this); auto n = make_object<HybridOpNode>(*this);
n->body = op::ReplaceTensor(this->body, rmap); n->body = op::ReplaceTensor(this->body, rmap);
for (size_t i = 0; i < n->inputs.size(); ++i) { for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i]; Tensor t = n->inputs[i];
...@@ -185,7 +185,7 @@ Stmt HybridOpNode::BuildProvide( ...@@ -185,7 +185,7 @@ Stmt HybridOpNode::BuildProvide(
for (int i = 0; i < this->num_outputs(); ++i) { for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i); rmap[outputs[i]] = stage->op.output(i);
} }
auto n = make_node<HybridOpNode>(*this); auto n = make_object<HybridOpNode>(*this);
/* This is a story little bit complicated. /* This is a story little bit complicated.
* The following two lines of codes replace output tensors' usage. * The following two lines of codes replace output tensors' usage.
* This is the simplest way I (@were) can come up with to glue * This is the simplest way I (@were) can come up with to glue
...@@ -369,7 +369,8 @@ Stmt ApplyLoopAnnotations(const Stage &stage, ...@@ -369,7 +369,8 @@ Stmt ApplyLoopAnnotations(const Stage &stage,
expected = IterVarTypeToForType(attr->iter_type); expected = IterVarTypeToForType(attr->iter_type);
} }
PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const NodeRef &node) { PostOrderVisit(stmt,
[&found, &var, &attr, &expected, &need_change](const ObjectRef& node) {
if (const For *op = node.as<For>()) { if (const For *op = node.as<For>()) {
if (op->loop_var.get() == var) { if (op->loop_var.get() == var) {
++found; ++found;
...@@ -390,7 +391,7 @@ Stmt ApplyLoopOrder(const Stage &stage, ...@@ -390,7 +391,7 @@ Stmt ApplyLoopOrder(const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map, const std::unordered_map<IterVar, Range> &dom_map,
const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) { const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
std::vector<const Variable*> current_order; std::vector<const Variable*> current_order;
PostOrderVisit(stmt, [&current_order](const NodeRef &node) { PostOrderVisit(stmt, [&current_order](const ObjectRef& node) {
if (const For *op = node.as<For>()) if (const For *op = node.as<For>())
current_order.push_back(op->loop_var.get()); current_order.push_back(op->loop_var.get());
}); });
...@@ -466,7 +467,7 @@ Stmt ApplySchedule(const Stage &stage, ...@@ -466,7 +467,7 @@ Stmt ApplySchedule(const Stage &stage,
std::vector<IterVar> GatherLoopVars(Stmt stmt) { std::vector<IterVar> GatherLoopVars(Stmt stmt) {
// TODO(@were): Write a comprehensive pass to analyze iter var types // TODO(@were): Write a comprehensive pass to analyze iter var types
std::vector<IterVar> res_; std::vector<IterVar> res_;
PostOrderVisit(stmt, [&res_](const NodeRef &node) { PostOrderVisit(stmt, [&res_](const ObjectRef& node) {
if (const For *op = node.as<For>()) { if (const For *op = node.as<For>()) {
Var loop_var(op->loop_var); Var loop_var(op->loop_var);
Range dom = Range::make_by_min_extent(op->min, op->extent); Range dom = Range::make_by_min_extent(op->min, op->extent);
......
...@@ -55,7 +55,7 @@ Array<Expr> PlaceholderOpNode::output_shape(size_t i) const { ...@@ -55,7 +55,7 @@ Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
Operation PlaceholderOpNode::make(std::string name, Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape, Array<Expr> shape,
DataType dtype) { DataType dtype) {
auto n = make_node<PlaceholderOpNode>(); auto n = make_object<PlaceholderOpNode>();
n->name = name; n->name = name;
n->shape = shape; n->shape = shape;
n->dtype = dtype; n->dtype = dtype;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment