Unverified Commit a8c36921 by Tianqi Chen Committed by GitHub

[REFACTOR][OBJECT] Consoldiate NodePtr/Ref/Hash/Equal to Object (#4603)

* [REFACTOR][OBJECT] Consoldiate NodePtr/Ref/Hash/Equal and macros to Object.

Historically, we have classes like NodePtr/Ref/HashEqual.
After unified object protocol, these names are just alias of the object counterpart.
Moreover, there are helper macros defined over the places for defining these object.

This PR consoldiate the terminologies into the corresponding ones
in the Object system so we have a clean and consistent API moving forward.

* Update include/tvm/attrs.h

Co-Authored-By: Wei Chen <ipondering.weic@gmail.com>

* fix compilation

Co-authored-by: Wei Chen <ipondering.weic@gmail.com>
parent 475158f6
......@@ -49,7 +49,7 @@ namespace tvm {
* \brief Node container of EnvFunc
* \sa EnvFunc
*/
class EnvFuncNode : public Node {
class EnvFuncNode : public Object {
public:
/*! \brief Unique name of the global function */
std::string name;
......@@ -63,7 +63,7 @@ class EnvFuncNode : public Node {
}
static constexpr const char* _type_key = "EnvFunc";
TVM_DECLARE_NODE_TYPE_INFO(EnvFuncNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
};
/*!
......@@ -73,10 +73,10 @@ class EnvFuncNode : public Node {
* An EnvFunc is saved by its name in the global registry
* under the assumption that the same function is registered during load.
*/
class EnvFunc : public NodeRef {
class EnvFunc : public ObjectRef {
public:
EnvFunc() {}
explicit EnvFunc(NodePtr<Node> n) : NodeRef(n) {}
explicit EnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return The internal global function pointer */
const EnvFuncNode* operator->() const {
return static_cast<const EnvFuncNode*>(get());
......@@ -119,12 +119,12 @@ class TypedEnvFunc;
* \sa EnvFunc
*/
template<typename R, typename... Args>
class TypedEnvFunc<R(Args...)> : public NodeRef {
class TypedEnvFunc<R(Args...)> : public ObjectRef {
public:
/*! \brief short hand for this function type */
using TSelf = TypedEnvFunc<R(Args...)>;
TypedEnvFunc() {}
explicit TypedEnvFunc(ObjectPtr<Object> n) : NodeRef(n) {}
explicit TypedEnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Assign global function to a TypedEnvFunc
* \param other Another global function.
......
......@@ -55,7 +55,7 @@ class Analyzer;
*
* set = [min_value, max_value]
*/
class ConstIntBoundNode : public Node {
class ConstIntBoundNode : public Object {
public:
int64_t min_value;
int64_t max_value;
......@@ -74,14 +74,14 @@ class ConstIntBoundNode : public Node {
static const constexpr int64_t kNegInf = -kPosInf;
static constexpr const char* _type_key = "arith.ConstIntBound";
TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object);
};
/*!
* \brief reference class to ConstIntBoundNode
* \sa ConstIntBoundNode
*/
class ConstIntBound : public NodeRef {
class ConstIntBound : public ObjectRef {
public:
/*!
* \brief constructor by fields.
......@@ -92,7 +92,7 @@ class ConstIntBound : public NodeRef {
static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
TVM_DEFINE_NODE_REF_METHODS(ConstIntBound, NodeRef, ConstIntBoundNode);
TVM_DEFINE_OBJECT_REF_METHODS(ConstIntBound, ObjectRef, ConstIntBoundNode);
};
/*!
......@@ -155,7 +155,7 @@ class ConstIntBoundAnalyzer {
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
class ModularSetNode : public Node {
class ModularSetNode : public Object {
public:
/*! \brief linear co-efficient */
int64_t coeff;
......@@ -168,18 +168,18 @@ class ModularSetNode : public Node {
}
static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object);
};
/*!
* \brief reference of ModularSetNode
* \sa ModularSetNode
*/
class ModularSet : public NodeRef {
class ModularSet : public ObjectRef {
public:
TVM_DLL ModularSet(int64_t coeff, int64_t base);
TVM_DEFINE_NODE_REF_METHODS(ModularSet, NodeRef, ModularSetNode);
TVM_DEFINE_OBJECT_REF_METHODS(ModularSet, ObjectRef, ModularSetNode);
};
/*!
......@@ -349,20 +349,20 @@ enum SignType {
/*!
* \brief Base class of all IntSet containers.
*/
struct IntSetNode : public Node {
struct IntSetNode : public Object {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Object);
TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object);
};
/*!
* \brief Integer set class, represent a set of integers in one dimension.
*/
class IntSet : public NodeRef {
class IntSet : public ObjectRef {
public:
/*! \brief constructor */
IntSet() {}
// constructor from not container.
explicit IntSet(ObjectPtr<Object> n) : NodeRef(n) {}
explicit IntSet(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -598,7 +598,7 @@ IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, NodeHash, NodeEqual>;
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
......
......@@ -65,7 +65,7 @@ namespace tvm {
*/
#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
static constexpr const char* _type_key = TypeKey; \
TVM_DECLARE_NODE_TYPE_INFO(ClassName, ::tvm::BaseAttrsNode) \
TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
template<typename FVisit> \
void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*)
......@@ -83,9 +83,9 @@ namespace tvm {
* \tparam TNodeRef the type to be created.
* \return A instance that will represent None.
*/
template<typename TNodeRef>
inline TNodeRef NullValue() {
return TNodeRef(NodePtr<Node>(nullptr));
template<typename TObjectRef>
inline TObjectRef NullValue() {
return TObjectRef(ObjectPtr<Object>(nullptr));
}
template<>
......@@ -106,7 +106,7 @@ struct AttrError : public dmlc::Error {
/*!
* \brief Information about attribute fields in string representations.
*/
class AttrFieldInfoNode : public Node {
class AttrFieldInfoNode : public Object {
public:
/*! \brief name of the field */
std::string name;
......@@ -121,11 +121,14 @@ class AttrFieldInfoNode : public Node {
v->Visit("description", &description);
}
static constexpr const char* _type_key = "AttrFieldInfo";
TVM_DECLARE_NODE_TYPE_INFO(AttrFieldInfoNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
};
/*! \brief AttrFieldInfo */
TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode);
class AttrFieldInfo : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode);
};
class AttrsHashHandler;
class AttrsEqualHandler;
......@@ -217,7 +220,7 @@ class AttrsHash {
* subclass AttrsNode instead.
* \sa AttrsNode
*/
class BaseAttrsNode : public Node {
class BaseAttrsNode : public Object {
public:
using TVMArgs = runtime::TVMArgs;
using TVMRetValue = runtime::TVMRetValue;
......@@ -271,16 +274,16 @@ class BaseAttrsNode : public Node {
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node);
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
};
/*! \brief Base attribute container for all attributes */
class Attrs : public NodeRef {
class Attrs : public ObjectRef {
public:
// normal constructor
Attrs() {}
// construct from shared ptr.
explicit Attrs(NodePtr<Node> n) : NodeRef(n) {}
explicit Attrs(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return The attribute node */
const BaseAttrsNode* operator->() const {
......@@ -305,13 +308,13 @@ class Attrs : public NodeRef {
class DictAttrsNode : public BaseAttrsNode {
public:
/*! \brief internal attrs map */
Map<std::string, NodeRef> dict;
Map<std::string, ObjectRef> dict;
/*!
* \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
* \return The dict attributes.
*/
TVM_DLL static Attrs make(Map<std::string, NodeRef> dict);
TVM_DLL static Attrs make(Map<std::string, ObjectRef> dict);
// implementations
void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
......@@ -321,7 +324,7 @@ class DictAttrsNode : public BaseAttrsNode {
size_t ContentHash(AttrsHash hasher) const final;
// type info
static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode);
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
};
......@@ -639,7 +642,7 @@ class AttrDocEntry {
public:
using TSelf = AttrDocEntry;
explicit AttrDocEntry(NodePtr<AttrFieldInfoNode> info)
explicit AttrDocEntry(ObjectPtr<AttrFieldInfoNode> info)
: info_(info) {
}
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
......@@ -663,15 +666,15 @@ class AttrDocEntry {
}
private:
NodePtr<AttrFieldInfoNode> info_;
ObjectPtr<AttrFieldInfoNode> info_;
};
class AttrDocVisitor {
public:
template<typename T>
AttrDocEntry operator()(const char* key, T* v) {
NodePtr<AttrFieldInfoNode> info
= make_node<AttrFieldInfoNode>();
ObjectPtr<AttrFieldInfoNode> info
= make_object<AttrFieldInfoNode>();
info->name = key;
info->type_info = TypeName<T>::value;
fields_.push_back(AttrFieldInfo(info));
......
......@@ -48,10 +48,10 @@ enum BufferType : int {
* It is a composition of primitive symbolic types,
* used to specify the memory layout of the Tensor used in program input.
*/
class Buffer : public NodeRef {
class Buffer : public ObjectRef {
public:
Buffer() {}
explicit Buffer(ObjectPtr<Object> n) : NodeRef(n) {}
explicit Buffer(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Return a new buffer that is equivalent with current one
* but always add stride field.
......@@ -101,7 +101,7 @@ class Buffer : public NodeRef {
};
/*! \brief Node to represent a buffer */
class BufferNode : public Node {
class BufferNode : public Object {
public:
// Data fields.
/*!
......@@ -169,7 +169,7 @@ class BufferNode : public Node {
BufferType buffer_type);
static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
};
inline const BufferNode* Buffer::operator->() const {
......
......@@ -39,7 +39,7 @@ namespace tvm {
* \brief Container for target device information.
* Use target::llvm, target::cuda etc functions instead of constructing directly.
*/
class TargetNode : public Node {
class TargetNode : public Object {
public:
/*! \brief The name of the target device */
std::string target_name;
......@@ -82,7 +82,7 @@ class TargetNode : public Node {
TVM_DLL std::unordered_set<std::string> libs() const;
static constexpr const char* _type_key = "Target";
TVM_DECLARE_NODE_TYPE_INFO(TargetNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object);
private:
/*! \brief Internal string repr. */
......@@ -90,10 +90,10 @@ class TargetNode : public Node {
};
/*! \brief reference cpass to the target. */
class Target : public NodeRef {
class Target : public ObjectRef {
public:
Target() {}
explicit Target(ObjectPtr<Object> n) : NodeRef(n) {}
explicit Target(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Create a Target given a string
* \param target_str the string to parse
......@@ -178,7 +178,7 @@ TVM_DLL Target ext_dev(const std::vector<std::string>& options =
/*!
* \brief Container for build configuration options
*/
class BuildConfigNode : public Node {
class BuildConfigNode : public Object {
public:
/*!
* \brief The data alignment to use when constructing buffers. If this is set to
......@@ -254,16 +254,16 @@ class BuildConfigNode : public Node {
}
static constexpr const char* _type_key = "BuildConfig";
TVM_DECLARE_NODE_TYPE_INFO(BuildConfigNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(BuildConfigNode, Object);
};
/*!
* \brief Build configuration for compilations.
*/
class BuildConfig : public ::tvm::NodeRef {
class BuildConfig : public ::tvm::ObjectRef {
public:
BuildConfig() {}
explicit BuildConfig(ObjectPtr<Object> n) : NodeRef(n) {}
explicit BuildConfig(ObjectPtr<Object> n) : ObjectRef(n) {}
const BuildConfigNode* operator->() const {
return static_cast<const BuildConfigNode*>(get());
}
......@@ -375,10 +375,10 @@ class GenericFuncNode;
/*!
* \brief Generic function that can be specialized on a per-target basis.
*/
class GenericFunc : public NodeRef {
class GenericFunc : public ObjectRef {
public:
GenericFunc() {}
explicit GenericFunc(ObjectPtr<Object> n) : NodeRef(n) {}
explicit GenericFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Set the default function implementaiton.
......@@ -471,7 +471,7 @@ inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const {
/*!
* \brief Represents a generic function that can be specialized on a per-target basis.
*/
class GenericFuncNode : public Node {
class GenericFuncNode : public Object {
public:
/*! \brief name of the function */
std::string name_;
......@@ -483,7 +483,7 @@ class GenericFuncNode : public Node {
void VisitAttrs(AttrVisitor* v) {}
static constexpr const char* _type_key = "GenericFunc";
TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(GenericFuncNode, Object);
};
inline GenericFuncNode* GenericFunc::operator->() {
......
......@@ -92,7 +92,7 @@ class LayoutAxis {
class Layout;
// Internal node container Buffer
class LayoutNode : public Node {
class LayoutNode : public Object {
public:
/*! \brief string representation of layout, "" for scalar. */
std::string name;
......@@ -112,7 +112,7 @@ class LayoutNode : public Node {
TVM_DLL static Layout make(const std::string& layout);
static constexpr const char* _type_key = "Layout";
TVM_DECLARE_NODE_TYPE_INFO(LayoutNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object);
};
/*!
......@@ -125,9 +125,9 @@ class LayoutNode : public Node {
* Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
* Layout for scalar is defined, while both its name and axes have size 0.
*/
class Layout : public NodeRef {
class Layout : public ObjectRef {
public:
explicit Layout(ObjectPtr<Object> n) : NodeRef(n) {}
explicit Layout(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \brief default constructor */
Layout() = default;
......@@ -311,7 +311,7 @@ class Layout : public NodeRef {
class BijectiveLayout;
// Internal node container BijectiveLayout
class BijectiveLayoutNode : public Node {
class BijectiveLayoutNode : public Object {
public:
/*! \brief Describes how source axes can be mapped to the destination axes,
* e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n
......@@ -333,7 +333,7 @@ class BijectiveLayoutNode : public Node {
}
static constexpr const char* _type_key = "BijectiveLayout";
TVM_DECLARE_NODE_TYPE_INFO(BijectiveLayoutNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object);
TVM_DLL static BijectiveLayout make(const Layout& src_layout,
const Layout& dst_layout);
......@@ -344,10 +344,10 @@ class BijectiveLayoutNode : public Node {
* provides API to transform N-dimention tensor from the source indices (i0, i1, …, im)
* to the destination indices (j0, j1, … jm).
*/
class BijectiveLayout : public NodeRef {
class BijectiveLayout : public ObjectRef {
public:
BijectiveLayout() = default;
explicit BijectiveLayout(NodePtr<Node> n) : NodeRef(n) {}
explicit BijectiveLayout(ObjectPtr<Object> n) : ObjectRef(n) {}
// Given the source shape, infer the destination shape.
TVM_DLL Array<Expr> ForwardShape(const Array<Expr>& shape) const;
......
......@@ -38,20 +38,20 @@
namespace tvm {
/*! \brief Base node of all expressions. */
class ExprNode : public Node {
class ExprNode : public Object {
public:
/*! \brief The data type of the expression. */
DataType dtype;
static constexpr const char* _type_key = "Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node);
TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, Object);
};
/*! \brief Container of all expressions. */
class Expr : public NodeRef {
class Expr : public ObjectRef {
public:
Expr() {}
explicit Expr(ObjectPtr<Object> ptr) : NodeRef(ptr) {}
explicit Expr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
......@@ -78,16 +78,16 @@ class Expr : public NodeRef {
};
/*! \brief Base node of all statements. */
class StmtNode : public Node {
class StmtNode : public Object {
public:
static constexpr const char* _type_key = "Stmt";
TVM_DECLARE_BASE_NODE_INFO(StmtNode, Node);
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
};
/*! \brief Container of all statements */
class Stmt : public NodeRef {
class Stmt : public ObjectRef {
public:
TVM_DEFINE_NODE_REF_METHODS(Stmt, NodeRef, StmtNode);
TVM_DEFINE_OBJECT_REF_METHODS(Stmt, ObjectRef, StmtNode);
};
class Var;
......@@ -118,7 +118,7 @@ class Variable : public ExprNode {
}
static constexpr const char* _type_key = "Variable";
TVM_DECLARE_NODE_TYPE_INFO(Variable, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(Variable, ExprNode);
};
/*! \brief a named variable in TVM */
......@@ -156,8 +156,8 @@ class Var : public Expr {
// Backward compatibility, will be removed later.
using VarExpr = Var;
using BaseExprNode = ExprNode;
using ExprHash = NodeHash;
using ExprEqual = NodeEqual;
using ExprHash = ObjectHash;
using ExprEqual = ObjectEqual;
class Integer;
/*! \brief ExprNode: constant integer. */
......@@ -174,7 +174,7 @@ class IntImm : public ExprNode {
TVM_DLL static Integer make(DataType t, int64_t value);
static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_NODE_TYPE_INFO(IntImm, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(IntImm, ExprNode);
};
/*!
......@@ -222,7 +222,7 @@ class Integer : public Expr {
};
/*! \brief range over one dimension */
class RangeNode : public Node {
class RangeNode : public Object {
public:
/*! \brief beginning of the node */
Expr min;
......@@ -238,11 +238,11 @@ class RangeNode : public Node {
}
static constexpr const char* _type_key = "Range";
TVM_DECLARE_NODE_TYPE_INFO(RangeNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
};
/*! \brief Range constainer */
class Range : public NodeRef {
class Range : public ObjectRef {
public:
/*!
* \brief constructor by begin and end
......@@ -261,7 +261,7 @@ class Range : public NodeRef {
*/
static Range make_by_min_extent(Expr min, Expr extent);
// declare range.
TVM_DEFINE_NODE_REF_METHODS(Range, NodeRef, RangeNode);
TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
};
/*! \brief container class of iteration variable. */
......@@ -343,12 +343,12 @@ enum IterVarType : int {
* \brief Iteration Variable,
* represents an iteration over an integer interval.
*/
class IterVar : public NodeRef {
class IterVar : public ObjectRef {
public:
// construct a new iter var without a domain
IterVar() {}
// construct from shared ptr.
explicit IterVar(ObjectPtr<Object> n) : NodeRef(n) {}
explicit IterVar(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -384,14 +384,14 @@ using Domain = Array<Range>;
* \brief Dump the node to stderr, used for debug purposes.
* \param node The input node
*/
TVM_DLL void Dump(const NodeRef& node);
TVM_DLL void Dump(const ObjectRef& node);
// definition of Node.
/*!
* \brief An iteration variable representing an iteration
* over a one dimensional interval.
*/
class IterVarNode : public Node {
class IterVarNode : public Object {
public:
/*!
* \brief the domain of iteration, if known, can be None
......@@ -420,7 +420,7 @@ class IterVarNode : public Node {
std::string thread_tag = "");
static constexpr const char* _type_key = "IterVar";
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object);
};
// inline implementations
......@@ -490,17 +490,22 @@ class IRPrinter {
using FType = NodeFunctor<void(const ObjectRef&, IRPrinter *)>;
TVM_DLL static FType& vtable();
};
} // namespace tvm
// default print function for all nodes
namespace tvm {
namespace runtime {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
IRPrinter(os).Print(n);
return os;
}
} // namespace runtime
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::IterVar> : public ::tvm::NodeHash {
struct hash<::tvm::IterVar> : public ::tvm::ObjectHash {
};
}
#endif // TVM_EXPR_H_
......@@ -164,7 +164,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args ...) {
virtual R VisitExprDefault_(const Object* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
return R();
}
......@@ -255,7 +255,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Node* op, Args ...) {
virtual R VisitStmtDefault_(const Object* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
return R();
}
......
......@@ -418,7 +418,7 @@ Stmt HoistIfThenElse(Stmt stmt);
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
Array<ObjectRef> api_args,
int num_unpacked_args,
bool is_restricted);
......
......@@ -87,7 +87,7 @@ class TVM_DLL IRVisitor {
/*!
* \brief recursively visit an IR node
*/
virtual void Visit(const NodeRef& node) {
virtual void Visit(const ObjectRef& node) {
static const FVisit& f = vtable();
if (node.defined()) f(node, this);
}
......@@ -152,7 +152,7 @@ class TVM_DLL IRVisitor {
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
TVM_DLL void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit);
TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
} // namespace ir
} // namespace tvm
......
......@@ -131,7 +131,7 @@ class LoweredFuncNode : public ir::FunctionBaseNode {
}
static constexpr const char* _type_key = "LoweredFunc";
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(LoweredFuncNode, Object);
};
// Implementations of inline functions
......@@ -143,7 +143,7 @@ inline const LoweredFuncNode* LoweredFunc::operator->() const {
namespace std {
template <>
struct hash<::tvm::LoweredFunc> : public tvm::NodeHash {
struct hash<::tvm::LoweredFunc> : public tvm::ObjectHash {
};
}
......
......@@ -35,7 +35,7 @@
namespace tvm {
/*! \brief array node content in array */
class ArrayNode : public Node {
class ArrayNode : public Object {
public:
/*! \brief the data content */
std::vector<ObjectRef> data;
......@@ -44,11 +44,11 @@ class ArrayNode : public Node {
}
static constexpr const char* _type_key = "Array";
TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object);
};
/*! \brief map node content */
class MapNode : public Node {
class MapNode : public Object {
public:
void VisitAttrs(AttrVisitor* visitor) {
}
......@@ -63,12 +63,12 @@ class MapNode : public Node {
ContainerType data;
static constexpr const char* _type_key = "Map";
TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object);
};
/*! \brief specialized map node with string as key */
class StrMapNode : public Node {
class StrMapNode : public Object {
public:
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<std::string, ObjectRef>;
......@@ -80,7 +80,7 @@ class StrMapNode : public Node {
ContainerType data;
static constexpr const char* _type_key = "StrMap";
TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Object);
};
/*!
......@@ -138,13 +138,13 @@ class IterAdapter {
*/
template<typename T,
typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type >
class Array : public NodeRef {
class Array : public ObjectRef {
public:
/*!
* \brief default constructor
*/
Array() {
data_ = make_node<ArrayNode>();
data_ = make_object<ArrayNode>();
}
/*!
* \brief move constructor
......@@ -164,7 +164,7 @@ class Array : public NodeRef {
* \brief constructor from pointer
* \param n the container pointer
*/
explicit Array(ObjectPtr<Object> n) : NodeRef(n) {}
explicit Array(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief constructor from iterator
* \param begin begin of iterator
......@@ -195,7 +195,7 @@ class Array : public NodeRef {
* \param val The init value
*/
explicit Array(size_t n, const T& val) {
auto tmp_node = make_node<ArrayNode>();
auto tmp_node = make_object<ArrayNode>();
for (size_t i = 0; i < n; ++i) {
tmp_node->data.push_back(val);
}
......@@ -227,7 +227,7 @@ class Array : public NodeRef {
*/
template<typename IterType>
void assign(IterType begin, IterType end) {
auto n = make_node<ArrayNode>();
auto n = make_object<ArrayNode>();
for (IterType it = begin; it != end; ++it) {
n->data.push_back(T(*it));
}
......@@ -257,7 +257,7 @@ class Array : public NodeRef {
*/
inline ArrayNode* CopyOnWrite() {
if (data_.get() == nullptr || !data_.unique()) {
NodePtr<ArrayNode> n = make_node<ArrayNode>();
ObjectPtr<ArrayNode> n = make_object<ArrayNode>();
n->data = static_cast<ArrayNode*>(data_.get())->data;
ObjectPtr<Object>(std::move(n)).swap(data_);
}
......@@ -333,13 +333,13 @@ template<typename K,
std::is_base_of<ObjectRef, K>::value ||
std::is_base_of<std::string, K>::value >::type,
typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
class Map : public NodeRef {
class Map : public ObjectRef {
public:
/*!
* \brief default constructor
*/
Map() {
data_ = make_node<MapNode>();
data_ = make_object<MapNode>();
}
/*!
* \brief move constructor
......@@ -352,13 +352,13 @@ class Map : public NodeRef {
* \brief copy constructor
* \param other source
*/
Map(const Map<K, V> &other) : NodeRef(other.data_) { // NOLINT(*)
Map(const Map<K, V> &other) : ObjectRef(other.data_) { // NOLINT(*)
}
/*!
* \brief constructor from pointer
* \param n the container pointer
*/
explicit Map(ObjectPtr<Object> n) : NodeRef(n) {}
explicit Map(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief constructor from iterator
* \param begin begin of iterator
......@@ -410,7 +410,7 @@ class Map : public NodeRef {
*/
template<typename IterType>
void assign(IterType begin, IterType end) {
NodePtr<MapNode> n = make_node<MapNode>();
ObjectPtr<MapNode> n = make_object<MapNode>();
for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first, i->second));
}
......@@ -454,7 +454,7 @@ class Map : public NodeRef {
*/
inline MapNode* CopyOnWrite() {
if (data_.get() == nullptr || !data_.unique()) {
NodePtr<MapNode> n = make_node<MapNode>();
ObjectPtr<MapNode> n = make_object<MapNode>();
n->data = static_cast<const MapNode*>(data_.get())->data;
ObjectPtr<Object>(std::move(n)).swap(data_);
}
......@@ -507,18 +507,18 @@ class Map : public NodeRef {
// specialize of string map
template<typename V, typename T1, typename T2>
class Map<std::string, V, T1, T2> : public NodeRef {
class Map<std::string, V, T1, T2> : public ObjectRef {
public:
// for code reuse
Map() {
data_ = make_node<StrMapNode>();
data_ = make_object<StrMapNode>();
}
Map(Map<std::string, V> && other) { // NOLINT(*)
data_ = std::move(other.data_);
}
Map(const Map<std::string, V> &other) : NodeRef(other.data_) { // NOLINT(*)
Map(const Map<std::string, V> &other) : ObjectRef(other.data_) { // NOLINT(*)
}
explicit Map(ObjectPtr<Object> n) : NodeRef(n) {}
explicit Map(ObjectPtr<Object> n) : ObjectRef(n) {}
template<typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
......@@ -541,7 +541,7 @@ class Map<std::string, V, T1, T2> : public NodeRef {
}
template<typename IterType>
void assign(IterType begin, IterType end) {
auto n = make_node<StrMapNode>();
auto n = make_object<StrMapNode>();
for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first, i->second));
}
......@@ -565,7 +565,7 @@ class Map<std::string, V, T1, T2> : public NodeRef {
}
inline StrMapNode* CopyOnWrite() {
if (data_.get() == nullptr || !data_.unique()) {
NodePtr<StrMapNode> n = make_node<StrMapNode>();
ObjectPtr<StrMapNode> n = make_object<StrMapNode>();
n->data = static_cast<const StrMapNode*>(data_.get())->data;
ObjectPtr<Object>(std::move(n)).swap(data_);
}
......
......@@ -56,105 +56,5 @@ using runtime::ObjectHash;
using runtime::ObjectEqual;
using runtime::make_object;
using NodeHash = ObjectHash;
using NodeEqual = ObjectEqual;
using Node = Object;
/*!
* \brief Base class of all references to AST/IR nodes.
*/
class NodeRef : public ObjectRef {
public:
NodeRef() {}
explicit NodeRef(ObjectPtr<Object> n) : ObjectRef(n) {}
};
/*!
* \brief Allocate a node object.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
* \note This function is an alias of make_object.
*/
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args) {
return runtime::make_object<T>(std::forward<Args>(args)...);
}
/*!
* \brief helper macro to declare type information in a base node.
*/
#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \
TVM_DECLARE_BASE_OBJECT_INFO(TypeName, Parent)
/*!
* \brief helper macro to declare type information in a terminal node
*/
#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \
TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, Parent);
/*!
* \brief Macro to define common node ref methods.
* \param TypeName The name of the NodeRef.
* \param BaseTypeName The Base type.
* \param NodeName The node container type.
*/
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
TypeName() {} \
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
: BaseTypeName(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(data_.get()); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = NodeName;
/*!
* \brief Macro to define CopyOnWrite function in a NodeRef.
* \param NodeName The Type of the Node.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWNodeRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
NodeName* CopyOnWrite() { \
CHECK(data_ != nullptr); \
if (!data_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<NodeName*>(data_.get()); \
}
/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public ::tvm::NodeRef { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
}; \
/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
TVM_DEFINE_NODE_REF_COW(NodeName); \
};
} // namespace tvm
#endif // TVM_NODE_NODE_H_
......@@ -60,7 +60,7 @@ class OperationNode : public ir::FunctionBaseNode {
/*! \brief optional tag of the operation */
std::string tag;
/*! \brief additional attributes of the operation*/
Map<std::string, NodeRef> attrs;
Map<std::string, ObjectRef> attrs;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
......@@ -149,7 +149,7 @@ class OperationNode : public ir::FunctionBaseNode {
static constexpr const char* _type_key = "Operation";
TVM_DECLARE_BASE_NODE_INFO(OperationNode, Node);
TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object);
};
/*!
......@@ -200,7 +200,7 @@ class PlaceholderOpNode : public OperationNode {
DataType dtype);
static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode);
TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
};
/*!
......@@ -228,7 +228,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode {
virtual size_t num_schedulable_dims() const = 0;
static constexpr const char* _type_key = "BaseComputeOp";
TVM_DECLARE_BASE_NODE_INFO(BaseComputeOpNode, OperationNode);
TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode);
};
......@@ -269,12 +269,12 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
}
static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Map<std::string, ObjectRef> attrs,
Array<IterVar> axis,
Array<Expr> body);
static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, BaseComputeOpNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
};
/*!
......@@ -334,7 +334,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
Array<Expr> scalar_inputs);
static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, BaseComputeOpNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
};
/*!
......@@ -407,7 +407,7 @@ class ScanOpNode : public OperationNode {
}
static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Map<std::string, ObjectRef> attrs,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
......@@ -415,7 +415,7 @@ class ScanOpNode : public OperationNode {
Array<Tensor> input);
static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
};
/*!
......@@ -472,14 +472,14 @@ class ExternOpNode : public OperationNode {
}
TVM_DLL static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Map<std::string, ObjectRef> attrs,
Array<Tensor> inputs,
Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders,
Stmt body);
static constexpr const char* _type_key = "ExternOp";
TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
};
/*!
......@@ -540,13 +540,13 @@ class HybridOpNode : public OperationNode {
}
TVM_DLL static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Map<std::string, ObjectRef> attrs,
Array<Tensor> inputs,
Array<Tensor> outputs,
Stmt body);
static constexpr const char* _type_key = "HybridOp";
TVM_DECLARE_NODE_TYPE_INFO(HybridOpNode, OperationNode);
TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
};
/*! \brief The compute function to specify the input source of a Tensor */
......@@ -578,7 +578,7 @@ TVM_DLL Tensor compute(Array<Expr> shape,
FCompute fcompute,
std::string name = "tensor",
std::string tag = "",
Map<std::string, NodeRef> attrs = {});
Map<std::string, ObjectRef> attrs = {});
/*!
* \brief Construct a new tensor by computing over shape,
......@@ -593,7 +593,7 @@ TVM_DLL Array<Tensor> compute(Array<Expr> shape,
FBatchCompute fcompute,
std::string name = "tensor",
std::string tag = "",
Map<std::string, NodeRef> attrs = {});
Map<std::string, ObjectRef> attrs = {});
/*!
* \brief Construct new tensors by scan.
......@@ -613,14 +613,14 @@ TVM_DLL Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> inputs = Array<Tensor>(),
std::string name = "scan",
std::string tag = "",
Map<std::string, NodeRef> attrs = {});
Map<std::string, ObjectRef> attrs = {});
// same as compute, specialized for different fcompute function
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var)> f,
std::string name = "tensor",
std::string tag = "",
Map<std::string, NodeRef> attrs = {}) {
Map<std::string, ObjectRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return compute(shape, fc, name, tag, attrs);
}
......@@ -628,7 +628,7 @@ inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f,
std::string name = "tensor",
std::string tag = "",
Map<std::string, NodeRef> attrs = {}) {
Map<std::string, ObjectRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return compute(shape, fc, name, tag, attrs);
}
......@@ -636,7 +636,7 @@ inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor",
std::string tag = "",
Map<std::string, NodeRef> attrs = {}) {
Map<std::string, ObjectRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return compute(shape, fc, name, tag, attrs);
}
......@@ -644,7 +644,7 @@ inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor",
std::string tag = "",
Map<std::string, NodeRef> attrs = {}) {
Map<std::string, ObjectRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return compute(shape, fc, name, tag, attrs);
}
......
......@@ -115,15 +115,15 @@ inline TVMPODValue_::operator tvm::Expr() const {
Object* ptr = static_cast<Object*>(value_.v_handle);
if (ptr->IsInstance<IterVarNode>()) {
return IterVar(ObjectPtr<Node>(ptr))->var;
return IterVar(ObjectPtr<Object>(ptr))->var;
}
if (ptr->IsInstance<TensorNode>()) {
return Tensor(ObjectPtr<Node>(ptr))();
return Tensor(ObjectPtr<Object>(ptr))();
}
CHECK(ObjectTypeChecker<Expr>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Expr(ObjectPtr<Node>(ptr));
return Expr(ObjectPtr<Object>(ptr));
}
inline TVMPODValue_::operator tvm::Integer() const {
......@@ -138,7 +138,7 @@ inline TVMPODValue_::operator tvm::Integer() const {
CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Integer(ObjectPtr<Node>(ptr));
return Integer(ObjectPtr<Object>(ptr));
}
} // namespace runtime
} // namespace tvm
......
......@@ -38,7 +38,7 @@ namespace relay {
class PatternNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Pattern";
TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node);
TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object);
};
/*!
......@@ -49,10 +49,10 @@ class PatternNode : public RelayNode {
*
* ADT pattern matching thus takes a list of values and binds to the first that accepts the value.
*/
class Pattern : public NodeRef {
class Pattern : public ObjectRef {
public:
Pattern() {}
explicit Pattern(ObjectPtr<tvm::Object> p) : NodeRef(p) {}
explicit Pattern(ObjectPtr<tvm::Object> p) : ObjectRef(p) {}
using ContainerType = PatternNode;
};
......@@ -71,10 +71,13 @@ class PatternWildcardNode : public PatternNode {
}
static constexpr const char* _type_key = "relay.PatternWildcard";
TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode);
TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode);
};
RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern);
class PatternWildcard : public Pattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PatternWildcard, Pattern, PatternWildcardNode);
};
/*! \brief A var pattern. Accept all input and bind to a var. */
class PatternVar;
......@@ -94,10 +97,13 @@ class PatternVarNode : public PatternNode {
}
static constexpr const char* _type_key = "relay.PatternVar";
TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode);
TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode);
};
RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern);
class PatternVar : public Pattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode);
};
/*!
* \brief ADT constructor.
......@@ -132,10 +138,13 @@ class ConstructorNode : public ExprNode {
}
static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr);
class Constructor : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Constructor, Expr, ConstructorNode);
};
/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
class PatternConstructor;
......@@ -158,10 +167,13 @@ class PatternConstructorNode : public PatternNode {
}
static constexpr const char* _type_key = "relay.PatternConstructor";
TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode);
TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode);
};
RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern);
class PatternConstructor : public Pattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PatternConstructor, Pattern, PatternConstructorNode);
};
/*! \brief A tuple pattern. Matches a tuple, binds recursively. */
class PatternTuple;
......@@ -181,10 +193,13 @@ class PatternTupleNode : public PatternNode {
}
static constexpr const char* _type_key = "relay.PatternTuple";
TVM_DECLARE_NODE_TYPE_INFO(PatternTupleNode, PatternNode);
TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode);
};
RELAY_DEFINE_NODE_REF(PatternTuple, PatternTupleNode, Pattern);
class PatternTuple : public Pattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode);
};
/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
......@@ -225,15 +240,18 @@ class TypeDataNode : public TypeNode {
tvm::Array<Constructor> constructors);
static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type);
class TypeData : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode);
};
/*! \brief A clause in a match expression. */
class Clause;
/*! \brief Clause container node. */
class ClauseNode : public Node {
class ClauseNode : public Object {
public:
/*! \brief The pattern the clause matches. */
Pattern lhs;
......@@ -248,10 +266,13 @@ class ClauseNode : public Node {
TVM_DLL static Clause make(Pattern lhs, Expr rhs);
static constexpr const char* _type_key = "relay.Clause";
TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object);
};
RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef);
class Clause : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode);
};
/*! \brief ADT pattern matching exression. */
class Match;
......@@ -280,10 +301,13 @@ class MatchNode : public ExprNode {
TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern, bool complete = true);
static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr);
class Match : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Match, Expr, MatchNode);
};
} // namespace relay
} // namespace tvm
......
......@@ -196,7 +196,7 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
}; // struct SqueezeAttrs
struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
NodeRef indices_or_sections;
ObjectRef indices_or_sections;
int axis;
TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
......
......@@ -54,53 +54,11 @@ namespace relay {
}
/*!
* \brief We always used NodeRef for referencing nodes.
*
* By default, NodeRef is a std::shared_ptr of node
*/
using NodeRef = tvm::NodeRef;
/*!
* \brief Content data type.
*/
using DataType = ::tvm::DataType;
/*!
* \brief Symbolic expression for tensor shape.
*/
using IndexExpr = ::tvm::Expr;
/*!
* \brief Hash function for nodes.
* e.g. std::unordered_map<Expr, Value, NodeHash, NodeEqual>
*/
using NodeHash = ::tvm::NodeHash;
/*!
* \brief Equality check function for nodes.
*/
using NodeEqual = ::tvm::NodeEqual;
/*!
* \brief Macro to make it easy to define node ref type given node
* \param TypeName The name of the reference type.
* \param NodeName The internal container name.
* \param NodeRefBase The base type.
*/
#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \
class TypeName : public NodeRefBase { \
public: \
TypeName() {} \
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
: NodeRefBase(n) { \
} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(get()); \
} \
operator bool() { return this->defined(); } \
using ContainerType = NodeName; \
};
/*!
* \brief The source name in the Span
* \sa SourceNameNode, Span
*/
......@@ -108,7 +66,7 @@ class SourceName;
/*!
* \brief The name of a source fragment.
*/
class SourceNameNode : public Node {
class SourceNameNode : public Object {
public:
/*! \brief The source name. */
std::string name;
......@@ -116,20 +74,20 @@ class SourceNameNode : public Node {
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
static constexpr const char* _type_key = "relay.SourceName";
TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
};
/*!
* \brief The source name of a file span.
* \sa SourceNameNode, Span
*/
class SourceName : public NodeRef {
class SourceName : public ObjectRef {
public:
/*! \brief default constructor */
SourceName() {}
/*! \brief constructor from node pointer */
explicit SourceName(NodePtr<Node> n) : NodeRef(n) {}
explicit SourceName(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -157,7 +115,7 @@ class Span;
/*!
* \brief Stores locations in frontend source that generated a node.
*/
class SpanNode : public Node {
class SpanNode : public Object {
public:
/*! \brief The source name */
SourceName source;
......@@ -175,22 +133,25 @@ class SpanNode : public Node {
TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
static constexpr const char* _type_key = "relay.Span";
TVM_DECLARE_NODE_TYPE_INFO(SpanNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
};
RELAY_DEFINE_NODE_REF(Span, SpanNode, NodeRef);
class Span : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
};
/*!
* \brief This is the base node container of all relay structures.
*/
class RelayNode : public Node {
class RelayNode : public Object {
public:
/*! \brief The location of the program in a SourceFragment can be null,
* check with span.defined() */
mutable Span span;
static constexpr const char* _type_key = "relay.Node";
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
TVM_DECLARE_BASE_OBJECT_INFO(RelayNode, Object);
};
/*!
......@@ -201,7 +162,7 @@ class RelayNode : public Node {
*
* \note Do not create Id directly, they are created in Var.
*/
class IdNode : public Node {
class IdNode : public Object {
public:
/*!
* \brief The name of the variable,
......@@ -215,10 +176,13 @@ class IdNode : public Node {
}
static constexpr const char* _type_key = "relay.Id";
TVM_DECLARE_NODE_TYPE_INFO(IdNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object);
};
RELAY_DEFINE_NODE_REF(Id, IdNode, NodeRef);
class Id : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode);
};
struct Module;
......
......@@ -118,7 +118,7 @@ class ErrorReporter {
* \param node The expression or type to report the error at.
* \param err The error message to report.
*/
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
inline void ReportAt(const GlobalVar& global, const ObjectRef& node, std::stringstream& err) {
std::string err_msg = err.str();
this->ReportAt(global, node, Error(err_msg));
}
......@@ -134,7 +134,7 @@ class ErrorReporter {
* \param node The expression or type to report the error at.
* \param err The error to report.
*/
void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err);
void ReportAt(const GlobalVar& global, const ObjectRef& node, const Error& err);
/*! \brief Render all reported errors and exit the program.
*
......@@ -154,8 +154,8 @@ class ErrorReporter {
private:
std::vector<Error> errors_;
std::unordered_map<NodeRef, std::vector<size_t>, NodeHash, NodeEqual> node_to_error_;
std::unordered_map<NodeRef, GlobalVar, NodeHash, NodeEqual> node_to_gv_;
std::unordered_map<ObjectRef, std::vector<size_t>, ObjectHash, ObjectEqual> node_to_error_;
std::unordered_map<ObjectRef, GlobalVar, ObjectHash, ObjectEqual> node_to_gv_;
};
} // namespace relay
......
......@@ -67,10 +67,13 @@ class ExprNode : public RelayNode {
inline const TTypeNode* type_as() const;
static constexpr const char* _type_key = "relay.Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode);
TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, RelayNode);
};
RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef);
class Expr : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Expr, ObjectRef, ExprNode);
};
/*!
* \brief Constant tensor, backed by an NDArray on the cpu(0) device.
......@@ -104,10 +107,13 @@ class ConstantNode : public ExprNode {
TVM_DLL static Constant make(runtime::NDArray data);
static constexpr const char* _type_key = "relay.Constant";
TVM_DECLARE_NODE_TYPE_INFO(ConstantNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Constant, ConstantNode, Expr);
class Constant : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Constant, Expr, ConstantNode);
};
/*! \brief Tuple of multiple Exprs */
class Tuple;
......@@ -126,10 +132,13 @@ class TupleNode : public ExprNode {
TVM_DLL static Tuple make(tvm::Array<relay::Expr> fields);
static constexpr const char* _type_key = "relay.Tuple";
TVM_DECLARE_NODE_TYPE_INFO(TupleNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr);
class Tuple : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode);
};
/*!
* \brief Local variables used in the let expression.
......@@ -179,10 +188,13 @@ class VarNode : public ExprNode {
Type type_annotation);
static constexpr const char* _type_key = "relay.Var";
TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Var, VarNode, Expr);
class Var : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode);
};
/*!
* \brief Global variable that leaves in the top-level module.
......@@ -206,10 +218,13 @@ class GlobalVarNode : public ExprNode {
TVM_DLL static GlobalVar make(std::string name_hint);
static constexpr const char* _type_key = "relay.GlobalVar";
TVM_DECLARE_NODE_TYPE_INFO(GlobalVarNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr);
class GlobalVar : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, Expr, GlobalVarNode);
};
/*!
* \brief Function (subgraph in computational graph)
......@@ -297,14 +312,19 @@ class FunctionNode : public ExprNode {
tvm::Map<Var, Constant> GetParams() const;
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr);
class Function : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode);
};
TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data);
TVM_DLL ObjectRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func,
const std::string& key,
const ObjectRef& data);
/*!
* \brief Call corresponds to operator invocation.
......@@ -363,10 +383,13 @@ class CallNode : public ExprNode {
Array<Type> type_args = Array<Type>());
static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Call, CallNode, Expr);
class Call : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode);
};
/*!
* \brief Let binding that binds a local var and optionally a type annotation.
......@@ -401,10 +424,13 @@ class LetNode : public ExprNode {
TVM_DLL static Let make(Var var, Expr value, Expr body);
static constexpr const char* _type_key = "relay.Let";
TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Let, LetNode, Expr);
class Let : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Let, Expr, LetNode);
};
/*!
* \brief Condition expression
......@@ -439,10 +465,13 @@ class IfNode : public ExprNode {
TVM_DLL static If make(Expr cond, Expr true_branch, Expr false_branch);
static constexpr const char* _type_key = "relay.If";
TVM_DECLARE_NODE_TYPE_INFO(IfNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(If, IfNode, Expr);
class If : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode);
};
/*! \brief Get index-th field out of a tuple. */
class TupleGetItem;
......@@ -463,10 +492,13 @@ class TupleGetItemNode : public ExprNode {
TVM_DLL static TupleGetItem make(Expr tuple, int index);
static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);
class TupleGetItem : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode);
};
/*! \brief Create a new Reference out of initial value. */
class RefCreate;
......@@ -484,10 +516,13 @@ class RefCreateNode : public ExprNode {
TVM_DLL static RefCreate make(Expr value);
static constexpr const char* _type_key = "relay.RefCreate";
TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr);
class RefCreate : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, Expr, RefCreateNode);
};
/*! \brief Get value out of Reference. */
class RefRead;
......@@ -505,10 +540,13 @@ class RefReadNode : public ExprNode {
TVM_DLL static RefRead make(Expr ref);
static constexpr const char* _type_key = "relay.RefRead";
TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr);
class RefRead : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefRead, Expr, RefReadNode);
};
/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
class RefWrite;
class RefWriteNode : public ExprNode {
......@@ -528,10 +566,13 @@ class RefWriteNode : public ExprNode {
TVM_DLL static RefWrite make(Expr ref, Expr value);
static constexpr const char* _type_key = "relay.RefWrite";
TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RefWriteNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr);
class RefWrite : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, Expr, RefWriteNode);
};
/*!
* \brief Base class of the temporary expression.
......@@ -554,10 +595,13 @@ class TempExprNode : public ExprNode {
virtual Expr Realize() const = 0;
static constexpr const char* _type_key = "relay.TempExpr";
TVM_DECLARE_BASE_NODE_INFO(TempExprNode, ExprNode);
TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr);
class TempExpr : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, Expr, TempExprNode);
};
// implementataions
inline const Type& ExprNode::checked_type() const {
......@@ -583,7 +627,7 @@ inline const TTypeNode* ExprNode::type_as() const {
}
/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
std::string PrettyPrint(const NodeRef& node);
std::string PrettyPrint(const ObjectRef& node);
/*!
* \brief Render the node as a string in the Relay text format.
......@@ -593,7 +637,7 @@ std::string PrettyPrint(const NodeRef& node);
* additional comment block to an expr.
* \return The text representation.
*/
std::string AsText(const NodeRef& node,
std::string AsText(const ObjectRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
......
......@@ -116,7 +116,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
virtual R VisitExprDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
throw;
}
......@@ -177,7 +177,7 @@ class ExprVisitor
protected:
// Internal visiting counter
std::unordered_map<const Node*, size_t> visit_counter_;
std::unordered_map<const Object*, size_t> visit_counter_;
};
/*!
......@@ -227,7 +227,7 @@ class ExprMutator
protected:
/*! \brief Internal map used for memoization. */
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> memo_;
};
/*!
......
......@@ -72,13 +72,13 @@ CreateInterpreter(Module mod, DLContext context, Target target);
class ValueNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Value";
TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode);
TVM_DECLARE_BASE_OBJECT_INFO(ValueNode, RelayNode);
};
class Value : public NodeRef {
class Value : public ObjectRef {
public:
Value() {}
explicit Value(ObjectPtr<Object> n) : NodeRef(n) {}
explicit Value(ObjectPtr<Object> n) : ObjectRef(n) {}
const ValueNode* operator->() const {
return static_cast<const ValueNode*>(get());
}
......@@ -114,10 +114,13 @@ class ClosureNode : public ValueNode {
TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func);
static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value);
class Closure : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, Value, ClosureNode);
};
/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;
......@@ -140,10 +143,13 @@ class RecClosureNode : public ValueNode {
TVM_DLL static RecClosure make(Closure clos, Var bind);
static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_NODE_TYPE_INFO(RecClosureNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(RecClosure, RecClosureNode, Value);
class RecClosure : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, Value, RecClosureNode);
};
/*! \brief A tuple value. */
class TupleValue;
......@@ -159,10 +165,13 @@ struct TupleValueNode : ValueNode {
TVM_DLL static TupleValue make(tvm::Array<Value> value);
static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_NODE_TYPE_INFO(TupleValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(TupleValue, TupleValueNode, Value);
class TupleValue : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, Value, TupleValueNode);
};
/*! \brief A tensor value. */
class TensorValue;
......@@ -179,10 +188,13 @@ struct TensorValueNode : ValueNode {
TVM_DLL static TensorValue make(runtime::NDArray data);
static constexpr const char* _type_key = "relay.TensorValue";
TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TensorValueNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);
class TensorValue : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TensorValue, Value, TensorValueNode);
};
/*! \brief A reference value. */
class RefValue;
......@@ -199,10 +211,13 @@ struct RefValueNode : ValueNode {
TVM_DLL static RefValue make(Value val);
static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_NODE_TYPE_INFO(RefValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
class RefValue : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, Value, RefValueNode);
};
/*! \brief An ADT constructor value. */
class ConstructorValue;
......@@ -226,10 +241,13 @@ struct ConstructorValueNode : ValueNode {
Constructor construtor = {});
static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(ConstructorValue, ConstructorValueNode, Value);
class ConstructorValue : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, Value, ConstructorValueNode);
};
} // namespace relay
} // namespace tvm
......
......@@ -258,7 +258,7 @@ class ModuleNode : public RelayNode {
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object);
private:
/*! \brief Helper function for registering a typedef's constructors */
......@@ -285,9 +285,9 @@ class ModuleNode : public RelayNode {
std::unordered_set<std::string> import_set_;
};
struct Module : public NodeRef {
struct Module : public ObjectRef {
Module() {}
explicit Module(ObjectPtr<::tvm::Object> p) : NodeRef(p) {}
explicit Module(ObjectPtr<::tvm::Object> p) : ObjectRef(p) {}
ModuleNode* operator->() const {
return static_cast<ModuleNode*>(get_mutable());
......
......@@ -106,7 +106,7 @@ class OpNode : public relay::ExprNode {
}
static constexpr const char* _type_key = "relay.Op";
TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, ExprNode);
private:
// friend class
......@@ -431,7 +431,7 @@ inline OpRegistry& OpRegistry::describe(
inline OpRegistry& OpRegistry::add_argument(const std::string& name,
const std::string& type,
const std::string& description) {
auto n = make_node<AttrFieldInfoNode>();
auto n = make_object<AttrFieldInfoNode>();
n->name = name;
n->type_info = type;
n->description = description;
......
......@@ -180,7 +180,7 @@ using FTVMLegalize = runtime::TypedPackedFunc<
using FForwardRewrite = runtime::TypedPackedFunc<
Expr(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx)>;
const ObjectRef& ctx)>;
/*!
* \brief Gradient for a specific op.
......
......@@ -102,7 +102,7 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPattern_(const PatternTupleNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPatternDefault_(const Node* op, Args...) {
virtual R VisitPatternDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
throw;
}
......@@ -162,7 +162,7 @@ class PatternMutator
/*! \brief Used to visit the vars inside of patterns. */
virtual Constructor VisitConstructor(const Constructor& c);
private:
std::unordered_map<Var, Var, NodeHash, NodeEqual> var_map_;
std::unordered_map<Var, Var, ObjectHash, ObjectEqual> var_map_;
};
} // namespace relay
......
......@@ -109,7 +109,7 @@ class PassContextNode : public RelayNode {
}
static constexpr const char* _type_key = "relay.PassContext";
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, RelayNode);
};
/*!
......@@ -125,10 +125,10 @@ class PassContextNode : public RelayNode {
*
* \endcode
*/
class PassContext : public NodeRef {
class PassContext : public ObjectRef {
public:
PassContext() {}
explicit PassContext(NodePtr<::tvm::Node> n) : NodeRef(n) {}
explicit PassContext(ObjectPtr<::tvm::Object> n) : ObjectRef(n) {}
/*!
* \brief const accessor.
* \return const access pointer.
......@@ -207,10 +207,13 @@ class PassInfoNode : public RelayNode {
tvm::Array<tvm::Expr> required);
static constexpr const char* _type_key = "relay.PassInfo";
TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode);
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, RelayNode);
};
TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode)
class PassInfo : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
class Pass;
......@@ -251,10 +254,10 @@ class PassNode : public RelayNode {
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, RelayNode);
};
class Pass : public NodeRef {
class Pass : public ObjectRef {
public:
/*!
* \brief Transform mod using the default PassContext in the current scope.
......@@ -283,7 +286,7 @@ class Pass : public NodeRef {
return node->operator()(mod, pass_ctx);
}
TVM_DEFINE_NODE_REF_METHODS(Pass, NodeRef, PassNode);
TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);
};
class SequentialNode;
......@@ -309,7 +312,7 @@ class Sequential : public Pass {
TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential");
Sequential() = default;
explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {}
explicit Sequential(tvm::ObjectPtr<::tvm::Object> n) : Pass(n) {}
const SequentialNode* operator->() const;
using ContainerType = Sequential;
......@@ -638,7 +641,7 @@ TVM_DLL Function InferType(const Function& f,
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<ObjectRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*!
......@@ -655,7 +658,7 @@ TVM_DLL Expr ForwardRewrite(const Expr& expr,
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr,
const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<ObjectRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*!
......
......@@ -41,7 +41,7 @@ using Any = tvm::ir::Any;
class TypeNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Type";
TVM_DECLARE_BASE_NODE_INFO(TypeNode, Node);
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};
/*!
......@@ -55,10 +55,10 @@ class TypeNode : public RelayNode {
* There are also advanced types to support generic(polymorphic types),
* which can be ignored when first reading the code base.
*/
class Type : public NodeRef {
class Type : public ObjectRef {
public:
Type() {}
explicit Type(ObjectPtr<tvm::Object> p) : NodeRef(p) {}
explicit Type(ObjectPtr<tvm::Object> p) : ObjectRef(p) {}
using ContainerType = TypeNode;
};
......@@ -70,10 +70,13 @@ class Type : public NodeRef {
class BaseTensorTypeNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.BaseTensorType";
TVM_DECLARE_BASE_NODE_INFO(BaseTensorTypeNode, TypeNode);
TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(BaseTensorType, BaseTensorTypeNode, Type);
class BaseTensorType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode);
};
/*!
* \brief This is the most commonly used type in relay.
......@@ -113,10 +116,13 @@ class TensorTypeNode : public BaseTensorTypeNode {
TVM_DLL static TensorType Scalar(DataType dtype);
static constexpr const char* _type_key = "relay.TensorType";
TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
};
RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
class TensorType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode);
};
/*! \brief Possible kinds of Type. */
enum Kind : int {
......@@ -168,10 +174,13 @@ class TypeVarNode : public TypeNode {
TVM_DLL static TypeVar make(std::string name, Kind kind);
static constexpr const char* _type_key = "relay.TypeVar";
TVM_DECLARE_NODE_TYPE_INFO(TypeVarNode, TypeNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type);
class TypeVar : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode);
};
/*!
* \brief A global type variable that is used for defining new types or type aliases.
......@@ -197,10 +206,13 @@ class GlobalTypeVarNode : public TypeNode {
TVM_DLL static GlobalTypeVar make(std::string name, Kind kind);
static constexpr const char* _type_key = "relay.GlobalTypeVar";
TVM_DECLARE_NODE_TYPE_INFO(GlobalTypeVarNode, TypeNode);
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(GlobalTypeVar, GlobalTypeVarNode, Type);
class GlobalTypeVar : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode);
};
/*!
* \brief Type application.
......@@ -225,10 +237,13 @@ class TypeCallNode : public TypeNode {
TVM_DLL static TypeCall make(Type func, tvm::Array<Type> args);
static constexpr const char* _type_key = "relay.TypeCall";
TVM_DECLARE_NODE_TYPE_INFO(TypeCallNode, TypeNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type);
class TypeCall : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeCall, Type, TypeCallNode);
};
/*!
* \brief IncompleteType.
......@@ -253,10 +268,13 @@ class IncompleteTypeNode : public TypeNode {
TVM_DLL static IncompleteType make(Kind kind);
static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode);
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type);
class IncompleteType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode);
};
/*!
* \brief Potential Constraints in the type.
......@@ -267,10 +285,13 @@ class TypeConstraint;
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.TypeConstraint";
TVM_DECLARE_BASE_NODE_INFO(TypeConstraintNode, TypeNode);
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeConstraint, TypeConstraintNode, Type);
class TypeConstraint : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode);
};
class FuncType;
/*!
......@@ -311,10 +332,13 @@ class FuncTypeNode : public TypeNode {
tvm::Array<TypeConstraint> type_constraints);
static constexpr const char* _type_key = "relay.FuncType";
TVM_DECLARE_NODE_TYPE_INFO(FuncTypeNode, TypeNode);
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type);
class FuncType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
};
/*!
* \brief The type of tuple values.
......@@ -338,10 +362,13 @@ class TupleTypeNode : public TypeNode {
TVM_DLL static TupleType make(tvm::Array<Type> fields);
static constexpr const char* _type_key = "relay.TupleType";
TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);
class TupleType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode);
};
/*!
* \brief The type of reference values.
......@@ -365,10 +392,13 @@ class RefTypeNode : public TypeNode {
TVM_DLL static RefType make(Type value);
static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RefTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(RefType, RefTypeNode, Type);
class RefType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefType, Type, RefTypeNode);
};
class TypeReporter;
......@@ -376,7 +406,7 @@ class TypeReporter;
* \brief reporter that reports back to the
* type resolution information.
*/
class TypeReporterNode : public Node {
class TypeReporterNode : public Object {
public:
/*!
* \brief Create a type equality constraint.
......@@ -408,7 +438,7 @@ class TypeReporterNode : public Node {
* \brief Set the location at which to report unification errors.
* \param ref The program node to report the error.
*/
TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;
TVM_DLL virtual void SetLocation(const ObjectRef& ref) = 0;
/*!
* \brief Retrieve the current global module.
......@@ -420,17 +450,17 @@ class TypeReporterNode : public Node {
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.TypeReporter";
TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
};
/*!
* \brief Container class of TypeReporter.
* \sa TypeReporterNode
*/
class TypeReporter : public NodeRef {
class TypeReporter : public ObjectRef {
public:
TypeReporter() {}
explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : NodeRef(n) {
explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {
}
TypeReporterNode* operator->() const {
return const_cast<TypeReporterNode*>(
......@@ -502,10 +532,13 @@ class TypeRelationNode : public TypeConstraintNode {
Attrs attrs);
static constexpr const char* _type_key = "relay.TypeRelation";
TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, TypeConstraintNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
};
RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint);
class TypeRelation : public TypeConstraint {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode);
};
// The following fields contains advanced typing
// Only keep the class name and reserved for future usage.
......
......@@ -700,7 +700,12 @@ struct ObjectEqual {
TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \
TypeName::_GetOrAllocRuntimeTypeIndex()
/*
* \brief Define object reference methods.
* \param TypeName The object type name
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
*/
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() {} \
explicit TypeName( \
......@@ -712,17 +717,54 @@ struct ObjectEqual {
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;
#define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \
/*
* \brief Define object reference methods of whose content is mutable.
* \param TypeName The object type name
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
* \note We recommend making objects immutable when possible.
* This macro is only reserved for objects that stores runtime states.
*/
#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() {} \
explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \
ObjectName* operator->() { \
ObjectName* operator->() const { \
return static_cast<ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;
/*!
* \brief Define CopyOnWrite function in an ObjectRef.
* \param ObjectName The Type of the Node.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWObjectRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \
ObjectName* CopyOnWrite() { \
CHECK(data_ != nullptr); \
if (!data_.unique()) { \
auto n = make_object<ObjectName>(*(operator->())); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<ObjectName*>(data_.get()); \
}
// Implementations details below
// Object reference counting.
#if TVM_OBJECT_ATOMIC_REF_COUNTER
......@@ -832,10 +874,6 @@ inline SubRef Downcast(BaseRef ref) {
}
} // namespace runtime
template<typename T>
using NodePtr = runtime::ObjectPtr<T>;
} // namespace tvm
#endif // TVM_RUNTIME_OBJECT_H_
......@@ -53,10 +53,10 @@ enum AttachType : int {
};
/*! \brief Stage, contains scheduling for a stage of computation. */
class Stage : public NodeRef {
class Stage : public ObjectRef {
public:
Stage() {}
explicit Stage(ObjectPtr<Object> n) : NodeRef(n) {}
explicit Stage(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief create a new schedule for op.
* \param op The operator in the schedule
......@@ -277,10 +277,10 @@ class Stage : public NodeRef {
* For operations and all the operations they depend on.
* The schedule per Operation is named as stage.
*/
class Schedule : public NodeRef {
class Schedule : public ObjectRef {
public:
Schedule() {}
explicit Schedule(ObjectPtr<Object> n) : NodeRef(n) {}
explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Get a copy of current schedule.
* \return The copied schedule.
......@@ -400,10 +400,10 @@ class Schedule : public NodeRef {
* \brief The schedule relation between IterVars
* can be Split, Fuse.
*/
class IterVarRelation : public NodeRef {
class IterVarRelation : public ObjectRef {
public:
IterVarRelation() {}
explicit IterVarRelation(ObjectPtr<Object> n) : NodeRef(n) {}
explicit IterVarRelation(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -414,10 +414,10 @@ class IterVarRelation : public NodeRef {
/*!
* \brief Additional scheduable attributes about IterVar.
*/
class IterVarAttr : public NodeRef {
class IterVarAttr : public ObjectRef {
public:
IterVarAttr() {}
explicit IterVarAttr(ObjectPtr<Object> n) : NodeRef(n) {}
explicit IterVarAttr(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -440,7 +440,7 @@ class IterVarAttr : public NodeRef {
*
* The group stage node can be attached to IterVars as in normal stage.
*/
class StageNode : public Node {
class StageNode : public Object {
public:
/*!
* \brief The operation of stage, can be different from original op.
......@@ -515,11 +515,11 @@ class StageNode : public Node {
}
static constexpr const char* _type_key = "Stage";
TVM_DECLARE_NODE_TYPE_INFO(StageNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object);
};
/*! \brief node container for schedule */
class ScheduleNode : public Node {
class ScheduleNode : public Object {
public:
/*! \brief The output operations in original data flow graph */
Array<Operation> outputs;
......@@ -538,7 +538,7 @@ class ScheduleNode : public Node {
* \brief Internal stage map to map internal ops to stages.
* This is created on demand and can be invalidated.
*/
std::unordered_map<const Node*, Stage> op2stage_cache_;
std::unordered_map<const Object*, Stage> op2stage_cache_;
void VisitAttrs(AttrVisitor* v) {
v->Visit("outputs", &outputs);
......@@ -576,7 +576,7 @@ class ScheduleNode : public Node {
TVM_DLL static Schedule make(Array<Operation> ops);
static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object);
};
/*!
......@@ -589,7 +589,7 @@ inline Schedule create_schedule(Array<Operation> ops) {
}
/*! \brief node container for IterVar attr */
class IterVarAttrNode : public Node {
class IterVarAttrNode : public Object {
public:
/*! \brief The iteration type. */
IterVarType iter_type{kDataPar};
......@@ -630,14 +630,14 @@ class IterVarAttrNode : public Node {
}
static constexpr const char* _type_key = "IterVarAttr";
TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(IterVarAttrNode, Object);
};
/*! \brief base node of iteration var */
class IterVarRelationNode : public Node {
class IterVarRelationNode : public Object {
public:
static constexpr const char* _type_key = "IterVarRelation";
TVM_DECLARE_BASE_NODE_INFO(IterVarRelationNode, Node);
TVM_DECLARE_BASE_OBJECT_INFO(IterVarRelationNode, Object);
};
/*!
......@@ -672,7 +672,7 @@ class SplitNode : public IterVarRelationNode {
Expr nparts);
static constexpr const char* _type_key = "Split";
TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode);
TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
};
/*!
......@@ -697,7 +697,7 @@ class FuseNode : public IterVarRelationNode {
IterVar outer, IterVar inner, IterVar fused);
static constexpr const char* _type_key = "Fuse";
TVM_DECLARE_NODE_TYPE_INFO(FuseNode, IterVarRelationNode);
TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode);
};
/*!
......@@ -720,7 +720,7 @@ class RebaseNode : public IterVarRelationNode {
static IterVarRelation make(IterVar parent, IterVar rebased);
static constexpr const char* _type_key = "Rebase";
TVM_DECLARE_NODE_TYPE_INFO(RebaseNode, IterVarRelationNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode);
};
......@@ -739,7 +739,7 @@ class SingletonNode : public IterVarRelationNode {
static IterVarRelation make(IterVar iter);
static constexpr const char* _type_key = "Singleton";
TVM_DECLARE_NODE_TYPE_INFO(SingletonNode, IterVarRelationNode);
TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
};
......
......@@ -34,7 +34,7 @@ namespace tvm {
* \brief Memory information of special memory region.
* Use MemoryInfo as its container type
*/
struct MemoryInfoNode : public Node {
struct MemoryInfoNode : public Object {
/*! \brief The addressable unit */
int unit_bits;
/*! \brief Maximum number of bits supported in the memory */
......@@ -55,11 +55,14 @@ struct MemoryInfoNode : public Node {
}
static constexpr const char* _type_key = "MemoryInfo";
TVM_DECLARE_NODE_TYPE_INFO(MemoryInfoNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(MemoryInfoNode, Object);
};
/*! \brief Defines memory info */
TVM_DEFINE_NODE_REF(MemoryInfo, MemoryInfoNode);
class MemoryInfo : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(MemoryInfo, ObjectRef, MemoryInfoNode);
};
/*!
* \brief get memory info given scope
......
......@@ -46,11 +46,11 @@ class OperationNode;
* \brief Tensor structure representing a possible input,
* or intermediate computation result.
*/
class Tensor : public NodeRef {
class Tensor : public ObjectRef {
public:
/*! \brief default constructor, used internally */
Tensor() {}
explicit Tensor(ObjectPtr<Object> n) : NodeRef(n) {}
explicit Tensor(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -158,7 +158,7 @@ class Operation : public ir::FunctionRef {
};
/*! \brief Node to represent a tensor */
class TensorNode : public Node {
class TensorNode : public Object {
public:
/*! \brief The shape of the tensor */
Array<Expr> shape;
......@@ -183,7 +183,7 @@ class TensorNode : public Node {
int value_index);
static constexpr const char* _type_key = "Tensor";
TVM_DECLARE_NODE_TYPE_INFO(TensorNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object);
};
......@@ -250,13 +250,13 @@ DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*)
namespace std {
template <>
struct hash<::tvm::Operation> : public ::tvm::NodeHash {
struct hash<::tvm::Operation> : public ::tvm::ObjectHash {
};
template <>
struct hash<::tvm::Tensor> {
std::size_t operator()(const ::tvm::Tensor& k) const {
::tvm::NodeHash hasher;
::tvm::ObjectHash hasher;
if (k.defined() && k->op.defined()) {
return hasher(k->op);
} else{
......
......@@ -34,10 +34,10 @@ namespace tvm {
class TensorIntrinNode;
/*! \brief Tensor intrinsic node. */
class TensorIntrin : public NodeRef {
class TensorIntrin : public ObjectRef {
public:
TensorIntrin() {}
explicit TensorIntrin(NodePtr<Node> n) : NodeRef(n) {}
explicit TensorIntrin(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -49,7 +49,7 @@ class TensorIntrin : public NodeRef {
};
/*! \brief Node to represent a Tensor intrinsic operator */
class TensorIntrinNode : public Node {
class TensorIntrinNode : public Object {
public:
/*! \brief The name of the intrinsic */
std::string name;
......@@ -108,7 +108,7 @@ class TensorIntrinNode : public Node {
Stmt reduce_update);
static constexpr const char* _type_key = "TensorIntrin";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object);
};
inline const TensorIntrinNode* TensorIntrin::operator->() const {
......@@ -119,10 +119,10 @@ inline const TensorIntrinNode* TensorIntrin::operator->() const {
class TensorIntrinCallNode;
/*! \brief Tensor intrinsic calling node. */
class TensorIntrinCall : public NodeRef {
class TensorIntrinCall : public ObjectRef {
public:
TensorIntrinCall() {}
explicit TensorIntrinCall(NodePtr<Node> n) : NodeRef(n) {}
explicit TensorIntrinCall(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -133,7 +133,7 @@ class TensorIntrinCall : public NodeRef {
using ContainerType = TensorIntrinCallNode;
};
class TensorIntrinCallNode : public Node {
class TensorIntrinCallNode : public Object {
public:
/*! \brief the tensor intrinsic */
TensorIntrin intrin;
......@@ -166,7 +166,7 @@ class TensorIntrinCallNode : public Node {
Array<Expr> scalar_inputs);
static constexpr const char* _type_key = "TensorIntrinCall";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object);
};
inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -306,7 +306,7 @@ void PostOrderDFSVisit(const std::vector<GNode>& heads,
template<typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads,
FVisit fvisit) {
typedef const NodePtr* GNode;
typedef const ObjectPtr* GNode;
std::vector<GNode> head_nodes(heads.size());
std::transform(heads.begin(), heads.end(), head_nodes.begin(),
[](const NodeEntry& e)->GNode {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -40,22 +40,22 @@ class Node;
class Symbol;
/*!
* \brief we always used NodePtr for a reference pointer
* \brief we always used ObjectPtr for a reference pointer
* to the node, so this alias can be changed in case.
*
* By default, NodePtr is a std::shared_ptr of node
* By default, ObjectPtr is a std::shared_ptr of node
*/
using NodePtr = std::shared_ptr<Node>;
using ObjectPtr = std::shared_ptr<Node>;
/*! \brief an entry that represents output data from a node */
struct NodeEntry {
NodeEntry(NodePtr node, uint32_t index, uint32_t version):
NodeEntry(ObjectPtr node, uint32_t index, uint32_t version):
node(std::move(node)),
index(index),
version(version)
{}
explicit NodeEntry(NodePtr node):
explicit NodeEntry(ObjectPtr node):
node(std::move(node)),
index(),
version()
......@@ -72,7 +72,7 @@ struct NodeEntry {
{}
/*! \brief the source node of this data */
NodePtr node;
ObjectPtr node;
/*! \brief index of output from the source. */
uint32_t index;
/*!
......@@ -167,7 +167,7 @@ class NNVM_DLL Node {
* \brief Optional control flow dependencies
* Gives operation must be performed before this operation.
*/
std::vector<NodePtr> control_deps;
std::vector<ObjectPtr> control_deps;
/*! \brief additional fields for this node */
any info;
/*! \brief destructor of node */
......@@ -189,7 +189,7 @@ class NNVM_DLL Node {
* \return a created empty node.
*/
template<class ...Args>
static NodePtr Create(Args&&... args) {
static ObjectPtr Create(Args&&... args) {
return std::make_shared<Node>(std::forward<Args>(args)...);
}
};
......@@ -208,7 +208,7 @@ inline NodeEntry MakeNode(
std::vector<NodeEntry> inputs,
std::unordered_map<std::string, std::string> attrs =
std::unordered_map<std::string, std::string>()) {
NodePtr p = Node::Create();
ObjectPtr p = Node::Create();
p->attrs.op = nnvm::Op::Get(op_name);
p->attrs.name = std::move(node_name);
p->attrs.dict = attrs;
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -192,7 +192,7 @@ using FIgnoreInputs = std::function<
* \note Register under "FGradient"
*/
using FGradient = std::function<std::vector<NodeEntry>(
const NodePtr& nodeptr,
const ObjectPtr& nodeptr,
const std::vector<NodeEntry>& out_grads)>;
/*!
......@@ -204,7 +204,7 @@ using FGradient = std::function<std::vector<NodeEntry>(
*/
using FSetInputVarAttrOnCompose = std::function<void(
const NodeAttrs& attrs,
NodePtr var,
ObjectPtr var,
const int index)>;
/*!
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -97,7 +97,7 @@ class NNVM_DLL Symbol {
* \return The arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption
*/
std::vector<NodePtr> ListInputs(ListInputOption option) const;
std::vector<ObjectPtr> ListInputs(ListInputOption option) const;
/*!
* \brief List the input names.
*
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -259,7 +259,7 @@ int NNSymbolListInputVariables(SymbolHandle symbol,
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
std::vector<NodePtr> vs = s->ListInputs(Symbol::ListInputOption(option));
std::vector<ObjectPtr> vs = s->ListInputs(Symbol::ListInputOption(option));
ret->ret_handles.resize(0);
ret->ret_handles.reserve(vs.size());
for (size_t i = 0; i < vs.size(); ++i) {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -50,7 +50,7 @@ static void SubgraphSanityCheck(const std::vector<std::shared_ptr<Symbol>> &subg
next_level.clear();
for (const std::vector<NodeEntry> *graph_ptr : curr_level) {
const std::vector<NodeEntry> &graph = *graph_ptr;
DFSVisit(graph, [&next_level, &node2level, level](const NodePtr& n) {
DFSVisit(graph, [&next_level, &node2level, level](const ObjectPtr& n) {
nnvm::Node *node = n.get();
// if the node is visited, but on a different level, then check failed
// if check failed here or before, we stop doing anything, but raise an error
......@@ -74,7 +74,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
std::vector<std::shared_ptr<Symbol>> subgraphs;
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs]
(const NodePtr& n) {
(const ObjectPtr& n) {
const auto& is_ghost = Op::GetAttr<TIsGhost>("TIsGhost");
if (!n->is_variable() && is_ghost.get(n->op(), false)) return;
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -30,7 +30,7 @@ Node::~Node() {
// explicit deletion via DFS
// this is used to avoid stackoverflow caused by chain of deletions
std::vector<Node*> stack{this};
std::vector<NodePtr> to_delete;
std::vector<ObjectPtr> to_delete;
while (!stack.empty()) {
Node* n = stack.back();
stack.pop_back();
......@@ -42,7 +42,7 @@ Node::~Node() {
e.node.reset();
}
}
for (NodePtr& sp : n->control_deps) {
for (ObjectPtr& sp : n->control_deps) {
if (sp.unique()) {
stack.push_back(sp.get());
to_delete.emplace_back(std::move(sp));
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -36,8 +36,8 @@ struct VariableParam {
uint32_t version{0};
};
NodePtr CreateVariableNode(const std::string& name) {
NodePtr n = Node::Create();
ObjectPtr CreateVariableNode(const std::string& name) {
ObjectPtr n = Node::Create();
n->attrs.op = nullptr;
n->attrs.name = name;
n->attrs.parsed = VariableParam();
......@@ -114,10 +114,10 @@ inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
// public functions
Symbol Symbol::Copy() const {
std::unordered_map<Node*, NodePtr> old_new;
std::unordered_map<Node*, ObjectPtr> old_new;
// use DFSVisit to copy all the nodes
DFSVisit(this->outputs, [&old_new](const NodePtr& node) {
NodePtr np = Node::Create();
DFSVisit(this->outputs, [&old_new](const ObjectPtr& node) {
ObjectPtr np = Node::Create();
np->attrs = node->attrs;
old_new[node.get()] = std::move(np);
});
......@@ -127,7 +127,7 @@ Symbol Symbol::Copy() const {
Node *ptr = e.node.get();
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version});
}
for (const NodePtr& p : kv.first->control_deps) {
for (const ObjectPtr& p : kv.first->control_deps) {
kv.second->control_deps.emplace_back(old_new[p.get()]);
}
}
......@@ -155,7 +155,7 @@ void Symbol::Print(std::ostream &os) const {
os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name
<< '(' << outputs[i].index << ")\n";
}
DFSVisit(this->outputs, [&os](const NodePtr& node) {
DFSVisit(this->outputs, [&os](const ObjectPtr& node) {
if (node->is_variable()) {
os << "Variable:" << node->attrs.name << '\n';
} else {
......@@ -204,21 +204,21 @@ Symbol Symbol::operator[] (size_t index) const {
}
}
std::vector<NodePtr> Symbol::ListInputs(ListInputOption option) const {
std::vector<NodePtr> ret;
std::vector<ObjectPtr> Symbol::ListInputs(ListInputOption option) const {
std::vector<ObjectPtr> ret;
if (option == kAll) {
ret.reserve(this->outputs.size());
DFSVisit(this->outputs, [&ret](const NodePtr &node) {
DFSVisit(this->outputs, [&ret](const ObjectPtr &node) {
if (node->is_variable()) {
ret.push_back(node);
}
});
} else {
std::unordered_set<Node*> mutable_set;
std::vector<NodePtr> vlist;
std::vector<ObjectPtr> vlist;
vlist.reserve(this->outputs.size());
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
DFSVisit(this->outputs, [&mutable_set, &vlist](const NodePtr &node) {
DFSVisit(this->outputs, [&mutable_set, &vlist](const ObjectPtr &node) {
if (node->is_variable()) {
vlist.push_back(node);
} else if (fmutate_inputs.count(node->op())) {
......@@ -228,7 +228,7 @@ std::vector<NodePtr> Symbol::ListInputs(ListInputOption option) const {
}
});
ret.reserve(vlist.size());
for (const NodePtr& node : vlist) {
for (const ObjectPtr& node : vlist) {
if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) ||
(option == kAuxiliaryStates && mutable_set.count(node.get()) != 0)) {
ret.emplace_back(node);
......@@ -239,7 +239,7 @@ std::vector<NodePtr> Symbol::ListInputs(ListInputOption option) const {
}
std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
std::vector<NodePtr> inputs = ListInputs(option);
std::vector<ObjectPtr> inputs = ListInputs(option);
std::vector<std::string> ret(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
ret[i] = inputs[i]->attrs.name;
......@@ -416,7 +416,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
std::unordered_map<Node *, const NodeEntry*> replace_map;
// replace map stores the existing replacement plan for arguments node
auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map]
(const NodePtr &node) {
(const ObjectPtr &node) {
if (node->is_variable()) {
if (arg_counter < args.size()) {
replace_map[node.get()] = &(args[arg_counter]->outputs[0]);
......@@ -437,7 +437,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
std::vector<Node*> update_nodes;
std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes]
(const NodePtr &node) {
(const ObjectPtr &node) {
// visit all the childs, find possible replacement
bool repl = false;
for (size_t i = 0; i < node->inputs.size(); ++i) {
......@@ -499,7 +499,7 @@ void Symbol::AddControlDeps(const Symbol& src) {
Symbol Symbol::GetInternals() const {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol ret;
DFSVisit(this->outputs, [&ret](const NodePtr& node) {
DFSVisit(this->outputs, [&ret](const ObjectPtr& node) {
Node* n = node.get();
if (n->is_variable()) {
// grab version from variable.
......@@ -582,7 +582,7 @@ bool Symbol::GetAttr(const std::string& key, std::string* out) const {
std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption option) const {
if (option == kRecursive) {
std::unordered_map<std::string, std::string> ret;
DFSVisit(this->outputs, [&ret](const NodePtr& n) {
DFSVisit(this->outputs, [&ret](const ObjectPtr& n) {
for (const auto& it : n->attrs.dict) {
ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second;
}
......@@ -596,7 +596,7 @@ std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption op
std::vector<std::tuple<std::string, std::string, std::string> >
Symbol::ListAttrsRecursive() const {
std::vector<std::tuple<std::string, std::string, std::string> > ret;
DFSVisit(this->outputs, [&ret](const NodePtr& n) {
DFSVisit(this->outputs, [&ret](const ObjectPtr& n) {
for (const auto& it : n->attrs.dict) {
ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second));
}
......@@ -608,7 +608,7 @@ Symbol Symbol::CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string> attrs) {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol s;
NodePtr n = Node::Create();
ObjectPtr n = Node::Create();
n->attrs.op = op;
n->attrs.dict = std::move(attrs);
if (n->op()->attr_parser != nullptr) {
......@@ -628,7 +628,7 @@ Symbol Symbol::CreateFunctor(const Op* op,
Symbol Symbol::CreateFunctor(const NodeAttrs& attrs) {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol s;
NodePtr n = Node::Create();
ObjectPtr n = Node::Create();
n->attrs = attrs;
uint32_t nout = n->num_outputs();
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -30,11 +30,11 @@
namespace nnvm {
namespace pass {
nnvm::NodePtr CreateLayoutTransformNode(const Layout& src,
nnvm::ObjectPtr CreateLayoutTransformNode(const Layout& src,
const Layout& dst) {
static const nnvm::Op* trans_op = nnvm::Op::Get("__layout_transform__");
static int count = 0;
nnvm::NodePtr n = nnvm::Node::Create();
nnvm::ObjectPtr n = nnvm::Node::Create();
n->attrs.op = trans_op;
n->attrs.name = src.name() + "_to_" + dst.name() + std::to_string(count++);
n->attrs.dict["src_layout"] = src.name();
......@@ -54,14 +54,14 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) {
nnvm::Op::GetAttr<FCorrectLayout>("FCorrectLayout");
const IndexedGraph& idx = src.indexed_graph();
std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr);
std::vector<nnvm::ObjectPtr> mirror_vec(idx.num_nodes(), nullptr);
// (new) NodePtr -> output_layouts
// (new) ObjectPtr -> output_layouts
LayoutAttrDict new_layouts;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
nnvm::NodePtr new_node = nnvm::Node::Create();
nnvm::ObjectPtr new_node = nnvm::Node::Create();
*new_node = *(inode.source);
if (new_node->is_variable()) {
// Variable node. No operator. Only one output entry.
......@@ -85,7 +85,7 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) {
std::vector<Layout> request_ilayouts(num_inputs, Layout::Undef());
for (size_t i = 0; i < num_inputs; ++i) {
const IndexedGraph::NodeEntry& input_entry = inode.inputs[i];
const NodePtr& new_input_node = mirror_vec[input_entry.node_id];
const ObjectPtr& new_input_node = mirror_vec[input_entry.node_id];
CHECK(new_input_node != nullptr);
// fill inputs by previous node (DFS order) inferred layouts.
......@@ -122,14 +122,14 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) {
for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
const nnvm::NodePtr& in = mirror_vec[e.node_id];
const nnvm::ObjectPtr& in = mirror_vec[e.node_id];
new_node->inputs[i] = nnvm::NodeEntry{in, e.index, e.version};
// insert layout_transform if necessary
const Layout& produce = produce_ilayouts[i];
const Layout& request = request_ilayouts[i];
if (produce != request && produce.defined()) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request);
nnvm::ObjectPtr tnode = CreateLayoutTransformNode(produce, request);
tnode->attrs.name = idx[e.node_id].source->attrs.name + "_" + request.name();
tnode->inputs.emplace_back(new_node->inputs[i]);
nnvm::NodeEntry tnode_output(std::move(tnode), 0, 0);
......
......@@ -37,13 +37,13 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
if (v.size() == 1) {
return std::move(v[0]);
} else if (v.size() == 0) {
NodePtr zero_node = Node::Create();
ObjectPtr zero_node = Node::Create();
zero_node->attrs.op = Op::Get("zeros");
zero_node->attrs.name = "zero_grad";
zero_node->attrs.op->attr_parser(&(zero_node->attrs));
return NodeEntry{zero_node, 0, 0};
} else {
NodePtr sum_node = Node::Create();
ObjectPtr sum_node = Node::Create();
sum_node->attrs.op = Op::Get("elemwise_sum");
sum_node->inputs = std::move(v);
sum_node->attrs.name = "grad_sum";
......@@ -119,10 +119,10 @@ Graph Gradient(Graph src) {
nullptr;
// topo sort
std::vector<NodePtr> topo_order;
std::vector<ObjectPtr> topo_order;
std::unordered_map<Node*, std::vector<GradEntry> > output_grads;
DFSVisit(ys, [&](const NodePtr& node) {
DFSVisit(ys, [&](const ObjectPtr& node) {
if (output_grads.count(node.get()) == 0) {
output_grads[node.get()].resize(node->num_outputs());
}
......@@ -143,11 +143,11 @@ Graph Gradient(Graph src) {
}
// construct mirror as memory reduction strategy if needed
std::unordered_map<Node*, NodePtr> mirror_map;
std::unordered_map<Node*, ObjectPtr> mirror_map;
if (mirror_fun != nullptr) {
for (const NodePtr& node_ptr : topo_order) {
for (const ObjectPtr& node_ptr : topo_order) {
if (mirror_fun(*node_ptr)) {
NodePtr new_node = Node::Create();
ObjectPtr new_node = Node::Create();
*new_node = *node_ptr;
new_node->attrs.name += "_mirror";
for (auto& e : new_node->inputs) {
......@@ -169,7 +169,7 @@ Graph Gradient(Graph src) {
std::vector<NodeEntry> out_agg_grads;
for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) {
const NodePtr& ptr = *rit;
const ObjectPtr& ptr = *rit;
if (ptr->is_variable()) continue;
out_agg_grads.clear();
auto& out_grad_vec = output_grads.at(ptr.get());
......@@ -182,7 +182,7 @@ Graph Gradient(Graph src) {
out_agg_grads.push_back(e.sum);
}
if ((*rit)->inputs.size() != 0) {
NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
ObjectPtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
std::vector<NodeEntry> input_grads;
// Check for FGradient
if (grad_fun_map.contains(ptr->op())) {
......@@ -244,7 +244,7 @@ Graph Gradient(Graph src) {
if (kv == unique_grads.end()) {
unique_grads.emplace(std::move(entry.sum), std::make_pair(1, counter));
} else {
NodePtr copy_node = Node::Create();
ObjectPtr copy_node = Node::Create();
std::ostringstream os;
os << entry.sum.node->attrs.name << "_" << kv->second.first << "_copy";
kv->second.first++;
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -112,7 +112,7 @@ Graph InferAttr(Graph &&ret,
CHECK_GE(inode.control_deps.size(), 1U)
<< "BackwardOp need to have control_deps to its forward op";
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
NodePtr fwd_ptr = inode.source->control_deps[0];
ObjectPtr fwd_ptr = inode.source->control_deps[0];
CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable";
// use gradient function to find out the correspondence.
std::vector<NodeEntry> ograd(fwd_ptr->num_outputs());
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -45,7 +45,7 @@ inline bool IsMutate(const std::vector<uint32_t>& mutate_inputs, uint32_t i) {
Graph OrderMutation(const Graph& src) {
std::unordered_map<Node*, std::vector<NodeEntry> > version_hist;
DFSVisit(src.outputs, [&version_hist](const NodePtr& n) {
DFSVisit(src.outputs, [&version_hist](const ObjectPtr& n) {
for (const NodeEntry& e : n->inputs) {
if (e.node->is_variable()) {
if (e.version != 0 && version_hist.count(e.node.get()) == 0) {
......@@ -57,8 +57,8 @@ Graph OrderMutation(const Graph& src) {
// no mutation happens, everything if fine.
if (version_hist.size() == 0) return src;
// start preparing for remapping the nodes.
std::unordered_map<Node*, NodePtr> old_new;
auto prepare = [&version_hist, &old_new] (const NodePtr& n) {
std::unordered_map<Node*, ObjectPtr> old_new;
auto prepare = [&version_hist, &old_new] (const ObjectPtr& n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs;
if (!n->is_variable() && fmutate_inputs.count(n->op())) {
......@@ -80,11 +80,11 @@ Graph OrderMutation(const Graph& src) {
if (old_new.count(e.node.get()) != 0) need_repl = true;
}
}
for (const NodePtr& p : n->control_deps) {
for (const ObjectPtr& p : n->control_deps) {
if (old_new.count(p.get()) != 0) need_repl = true;
}
if (need_repl) {
NodePtr np = Node::Create();
ObjectPtr np = Node::Create();
np->attrs = n->attrs;
old_new[n.get()] = std::move(np);
}
......@@ -111,7 +111,7 @@ Graph OrderMutation(const Graph& src) {
kv.second->inputs.push_back(e);
}
}
for (const NodePtr& p : kv.first->control_deps) {
for (const ObjectPtr& p : kv.first->control_deps) {
kv.second->control_deps.emplace_back(
get_with_default(old_new, p.get(), p));
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -105,8 +105,8 @@ Graph PlaceDevice(Graph src) {
src.attrs["device"] = std::make_shared<any>(std::move(device));
return src;
}
std::map<std::tuple<uint32_t, uint32_t, int>, NodePtr> copy_map;
std::vector<NodePtr> new_node_map(idx.num_nodes(), nullptr);
std::map<std::tuple<uint32_t, uint32_t, int>, ObjectPtr> copy_map;
std::vector<ObjectPtr> new_node_map(idx.num_nodes(), nullptr);
std::unordered_map<const Node*, int> new_device_map;
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
......@@ -142,7 +142,7 @@ Graph PlaceDevice(Graph src) {
CHECK(!need_mutate) << "consistency check";
}
if (need_mutate) {
NodePtr new_node = Node::Create();
ObjectPtr new_node = Node::Create();
new_node->attrs = inode.source->attrs;
new_node->inputs.reserve(inode.inputs.size());
for (size_t i = 0; i < inode.inputs.size(); ++i) {
......@@ -154,7 +154,7 @@ Graph PlaceDevice(Graph src) {
new_node->inputs.emplace_back(
NodeEntry{it->second, 0, 0});
} else {
NodePtr copy_node = Node::Create();
ObjectPtr copy_node = Node::Create();
std::ostringstream os;
os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy";
copy_node->attrs.op = copy_op;
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -86,7 +86,7 @@ struct JSONNode {
};
// pointer to the graph node
NodePtr node;
ObjectPtr node;
// inputs
std::vector<Entry> inputs;
// control flow dependencies
......@@ -190,7 +190,7 @@ struct JSONGraph {
void Symbol2JSONGraph(std::shared_ptr<Symbol> src, JSONGraph *jgraph) {
std::unordered_map<Node*, uint32_t> node2index;
jgraph->node_row_ptr.push_back(0);
DFSVisit(src->outputs, [&node2index, jgraph](const NodePtr& n) {
DFSVisit(src->outputs, [&node2index, jgraph](const ObjectPtr& n) {
uint32_t nid = static_cast<uint32_t>(jgraph->nodes.size());
node2index[n.get()] = nid;
if (n->is_variable()) {
......@@ -202,7 +202,7 @@ void Symbol2JSONGraph(std::shared_ptr<Symbol> src, JSONGraph *jgraph) {
for (const NodeEntry& e : n->inputs) {
jnode.inputs.emplace_back(node2index.at(e.node.get()), e.index, e.version);
}
for (const NodePtr& c : n->control_deps) {
for (const ObjectPtr& c : n->control_deps) {
jnode.control_deps.push_back(node2index.at(c.get()));
}
jgraph->node_row_ptr.push_back(jgraph->node_row_ptr.back() + n->num_outputs());
......
......@@ -32,7 +32,7 @@ TVM_REGISTER_API("_format_str")
.set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kObjectHandle);
std::ostringstream os;
os << args[0].operator NodeRef();
os << args[0].operator ObjectRef();
*ret = os.str();
});
......
......@@ -65,7 +65,7 @@ TVM_REGISTER_API("_Array")
data.push_back(ObjectRef(nullptr));
}
}
auto node = make_node<ArrayNode>();
auto node = make_object<ArrayNode>();
node->data = std::move(data);
*ret = Array<ObjectRef>(node);
});
......@@ -105,7 +105,7 @@ TVM_REGISTER_API("_Map")
data.emplace(std::make_pair(args[i].operator std::string(),
args[i + 1].operator ObjectRef()));
}
auto node = make_node<StrMapNode>();
auto node = make_object<StrMapNode>();
node->data = std::move(data);
*ret = Map<ObjectRef, ObjectRef>(node);
} else {
......@@ -119,7 +119,7 @@ TVM_REGISTER_API("_Map")
data.emplace(std::make_pair(args[i].operator ObjectRef(),
args[i + 1].operator ObjectRef()));
}
auto node = make_node<MapNode>();
auto node = make_object<MapNode>();
node->data = std::move(data);
*ret = Map<ObjectRef, ObjectRef>(node);
}
......@@ -186,7 +186,7 @@ TVM_REGISTER_API("_MapItems")
if (ptr->IsInstance<MapNode>()) {
auto* n = static_cast<const MapNode*>(ptr);
auto rkvs = make_node<ArrayNode>();
auto rkvs = make_object<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second);
......@@ -194,7 +194,7 @@ TVM_REGISTER_API("_MapItems")
*ret = Array<ObjectRef>(rkvs);
} else {
auto* n = static_cast<const StrMapNode*>(ptr);
auto rkvs = make_node<ArrayNode>();
auto rkvs = make_object<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(ir::StringImm::make(kv.first));
rkvs->data.push_back(kv.second);
......
......@@ -100,12 +100,13 @@ TVM_REGISTER_API("ir_pass.RewriteForTensorCore")
});
TVM_REGISTER_API("ir_pass.AttrsEqual")
.set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
.set_body_typed<bool(const ObjectRef&, const ObjectRef&)>(
[](const ObjectRef& lhs, const ObjectRef& rhs) {
return AttrsEqual()(lhs, rhs);
});
TVM_REGISTER_API("ir_pass.AttrsHash")
.set_body_typed<int64_t(const NodeRef&)>([](const NodeRef &node) {
.set_body_typed<int64_t(const ObjectRef&)>([](const ObjectRef &node) {
return AttrsHash()(node);
});
......@@ -118,7 +119,7 @@ TVM_REGISTER_API("ir_pass.ExprUseVar")
TVM_REGISTER_API("ir_pass.PostOrderVisit")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
ir::PostOrderVisit(args[0], [f](const ObjectRef& n) {
f(n);
});
});
......@@ -126,7 +127,7 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
TVM_REGISTER_API("ir_pass.LowerStorageAccess")
.set_body([](TVMArgs args, TVMRetValue *ret) {
LoweredFunc f = args[0];
auto n = make_node<LoweredFuncNode>(*f.operator->());
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = LowerStorageAccessInfo(f->body);
*ret = LoweredFunc(n);
});
......
......@@ -42,7 +42,7 @@ class VariablePathFinder: public IRVisitor {
public:
explicit VariablePathFinder(Expr target) : target_(target) {}
void Visit(const NodeRef& node) final {
void Visit(const ObjectRef& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
......@@ -82,7 +82,7 @@ class BoundDeducer: public IRVisitor {
void Deduce();
void Visit(const NodeRef& e) final {
void Visit(const ObjectRef& e) final {
if (!success_) return;
if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e);
......@@ -202,7 +202,7 @@ class BoundDeduceInputChecker: public IRVisitor {
return target_count == 1;
}
void Visit(const NodeRef& e) final {
void Visit(const ObjectRef& e) final {
if (e.same_as(deducer_->target_)) ++target_count;
IRVisitor::Visit(e);
}
......
......@@ -56,7 +56,7 @@ class CanonicalExprNode : public BaseExprNode {
}
static constexpr const char* _type_key = "arith.CanonicalExpr";
TVM_DECLARE_BASE_NODE_INFO(CanonicalExprNode, BaseExprNode);
TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, BaseExprNode);
};
enum DivMode {
......@@ -147,10 +147,14 @@ class SplitExprNode : public CanonicalExprNode {
/*! \brief positive infty */
static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
static constexpr const char* _type_key = "arith.SplitExpr";
TVM_DECLARE_NODE_TYPE_INFO(SplitExprNode, CanonicalExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(SplitExprNode, CanonicalExprNode);
};
TVM_DEFINE_COW_NODE_REF(SplitExpr, Expr, SplitExprNode);
class SplitExpr : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, Expr, SplitExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitExprNode);
};
inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const {
if (index.same_as(other->index)) return true;
......@@ -272,7 +276,7 @@ class SumExprNode : public CanonicalExprNode {
void AddToSelf(const SumExpr& other, int64_t scale);
static constexpr const char* _type_key = "arith.SumExpr";
TVM_DECLARE_NODE_TYPE_INFO(SumExprNode, CanonicalExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode);
private:
/*!
......@@ -405,7 +409,11 @@ class SumExprNode : public CanonicalExprNode {
}
};
TVM_DEFINE_COW_NODE_REF(SumExpr, Expr, SumExprNode);
class SumExpr : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, Expr, SumExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SumExprNode);
};
void SumExprNode::AddToSelf(const SumExpr& other, int64_t scale) {
// NOTE: it is rare to have a balanced long expression,
......@@ -507,7 +515,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
if (const auto* op = expr.as<CanonicalExprNode>()) {
expr = op->Normalize();
}
NodePtr<SplitExprNode> n = make_node<SplitExprNode>();
ObjectPtr<SplitExprNode> n = make_object<SplitExprNode>();
n->dtype = expr.dtype();
n->index = std::move(expr);
n->div_mode = kTruncDiv;
......@@ -544,7 +552,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
if (const auto* op = expr.as<SumExprNode>()) {
return GetRef<SumExpr>(op);
}
NodePtr<SumExprNode> n = make_node<SumExprNode>();
ObjectPtr<SumExprNode> n = make_object<SumExprNode>();
n->dtype = expr.dtype();
if (const auto* op = expr.as<IntImm>()) {
n->base = op->value;
......@@ -655,8 +663,8 @@ SeparateDivisibleParts(const SumExprNode* psum,
int64_t coeff,
SumExpr* out_divisible,
SumExpr* out_non_divisible) {
auto divisible = make_node<SumExprNode>();
auto non_divisible = make_node<SumExprNode>();
auto divisible = make_object<SumExprNode>();
auto non_divisible = make_object<SumExprNode>();
divisible->dtype = psum->dtype;
non_divisible->dtype = psum->dtype;
......
......@@ -35,7 +35,7 @@ TVM_REGISTER_NODE_TYPE(ConstIntBoundNode);
ConstIntBound::ConstIntBound(
int64_t min_value, int64_t max_value) {
auto node = make_node<ConstIntBoundNode>();
auto node = make_object<ConstIntBoundNode>();
node->min_value = min_value;
node->max_value = max_value;
data_ = std::move(node);
......@@ -123,7 +123,7 @@ class ConstIntBoundAnalyzer::Impl :
}
// Override visitor behaviors
Entry VisitExprDefault_(const Node* op) final {
Entry VisitExprDefault_(const Object* op) final {
return Everything(
static_cast<const ExprNode*>(op)->dtype);
}
......
......@@ -106,7 +106,7 @@ class LinearEqDetector
}
return ret;
}
LinearEqEntry VisitExprDefault_(const Node* op, const Expr& e) final {
LinearEqEntry VisitExprDefault_(const Object* op, const Expr& e) final {
if (fail_) return LinearEqEntry();
if (ExprUseVar(e, var_)) {
fail_ = true;
......@@ -171,7 +171,7 @@ bool DetectClipBound(
std::unordered_map<const Variable*, IntervalEntry>* bmap) {
int flag = 0;
Var var;
auto fvisit = [&bmap, &flag, &var](const NodeRef& n) {
auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) {
if (const Variable* v = n.as<Variable>()) {
if (bmap->count(v)) {
if (flag == 0) {
......
......@@ -37,7 +37,7 @@ Expr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle());
Expr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle());
IntervalSet::IntervalSet(Expr min_value, Expr max_value) {
auto node = make_node<IntervalSetNode>();
auto node = make_object<IntervalSetNode>();
node->min_value = std::move(min_value);
node->max_value = std::move(max_value);
data_ = std::move(node);
......@@ -505,7 +505,7 @@ class IntervalSetEvaluator :
return Union(analyzer_, false_set, true_set);
}
IntervalSet VisitExprDefault_(const Node* op) final {
IntervalSet VisitExprDefault_(const Object* op) final {
DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey();
return IntervalSet::Everything();
}
......
......@@ -75,7 +75,7 @@ class IntervalSetNode : public IntSetNode {
}
static constexpr const char* _type_key = "arith.IntervalSet";
TVM_DECLARE_NODE_TYPE_INFO(IntervalSetNode, IntSetNode);
TVM_DECLARE_FINAL_OBJECT_INFO(IntervalSetNode, IntSetNode);
};
/*!
......@@ -116,8 +116,8 @@ class IntervalSet : public IntSet {
return IntervalSet(pos_inf(), neg_inf());
}
TVM_DEFINE_NODE_REF_COW(IntervalSetNode);
TVM_DEFINE_NODE_REF_METHODS(IntervalSet, IntSet, IntervalSetNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IntervalSetNode);
TVM_DEFINE_OBJECT_REF_METHODS(IntervalSet, IntSet, IntervalSetNode);
};
/*!
......
......@@ -37,7 +37,7 @@ using namespace ir;
TVM_REGISTER_NODE_TYPE(ModularSetNode);
ModularSet::ModularSet(int64_t coeff, int64_t base) {
auto node = make_node<ModularSetNode>();
auto node = make_object<ModularSetNode>();
node->coeff = coeff;
node->base = base;
// finish construction.
......@@ -120,7 +120,7 @@ class ModularSetAnalyzer::Impl :
}
// Override visitor behaviors
Entry VisitExprDefault_(const Node* op) final {
Entry VisitExprDefault_(const Object* op) final {
return Everything();
}
......
......@@ -250,7 +250,7 @@ class PBinaryExpr :
b_.InitMatch_();
}
bool Match_(const NodeRef& node) const {
bool Match_(const ObjectRef& node) const {
if (const NodeType* ptr = node.as<NodeType>()) {
if (!a_.Match_(ptr->a)) return false;
if (!b_.Match_(ptr->b)) return false;
......@@ -282,7 +282,7 @@ class PConstWithTypeLike :
void InitMatch_() const {}
bool Match_(const NodeRef& node) const {
bool Match_(const ObjectRef& node) const {
if (const ir::IntImm* ptr = node.as<ir::IntImm>()) {
return ptr->value == value_;
} else {
......@@ -364,7 +364,7 @@ class PNotExpr : public Pattern<PNotExpr<TA> > {
value_.InitMatch_();
}
bool Match_(const NodeRef& node) const {
bool Match_(const ObjectRef& node) const {
if (const ir::Not* ptr = node.as<ir::Not>()) {
if (!value_.Match_(ptr->a)) return false;
return true;
......@@ -410,7 +410,7 @@ class PSelectExpr :
false_value_.InitMatch_();
}
bool Match_(const NodeRef& node) const {
bool Match_(const ObjectRef& node) const {
if (const ir::Select* ptr = node.as<ir::Select>()) {
if (!condition_.Match_(ptr->condition)) return false;
if (!true_value_.Match_(ptr->true_value)) return false;
......@@ -472,7 +472,7 @@ class PCastExpr :
value_.InitMatch_();
}
bool Match_(const NodeRef& node) const {
bool Match_(const ObjectRef& node) const {
if (const ir::Cast* ptr = node.as<ir::Cast>()) {
if (!dtype_.Match_(ptr->dtype)) return false;
if (!value_.Match_(ptr->value)) return false;
......@@ -530,7 +530,7 @@ class PRampExpr :
lanes_.InitMatch_();
}
bool Match_(const NodeRef& node) const {
bool Match_(const ObjectRef& node) const {
if (const ir::Ramp* ptr = node.as<ir::Ramp>()) {
if (!base_.Match_(ptr->base)) return false;
if (!stride_.Match_(ptr->stride)) return false;
......@@ -592,7 +592,7 @@ class PBroadcastExpr :
lanes_.InitMatch_();
}
bool Match_(const NodeRef& node) const {
bool Match_(const ObjectRef& node) const {
if (const ir::Broadcast* ptr = node.as<ir::Broadcast>()) {
if (!value_.Match_(ptr->value)) return false;
if (!lanes_.Match_(ptr->lanes)) return false;
......@@ -704,7 +704,7 @@ class PCallExpr :
detail::tuple_for_each(finit, args_);
}
bool Match_(const NodeRef& node) const {
bool Match_(const ObjectRef& node) const {
if (const ir::Call* ptr = node.as<ir::Call>()) {
if (ptr->args.size() != sizeof...(TArgs)) return false;
if (ptr->name != Op::kName) return false;
......
......@@ -53,7 +53,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
*/
Target CreateTarget(const std::string& target_name,
const std::vector<std::string>& options) {
auto t = make_node<TargetNode>();
auto t = make_object<TargetNode>();
t->target_name = target_name;
std::string libs_flag = "-libs=";
......@@ -366,7 +366,7 @@ void GetBinds(const Array<Tensor>& args,
bool compact,
const std::unordered_map<Tensor, Buffer>& binds,
Map<Tensor, Buffer>* out_binds,
Array<NodeRef>* out_arg_list,
Array<ObjectRef>* out_arg_list,
const BuildConfig& config) {
*out_binds = binds;
......@@ -396,7 +396,7 @@ Stmt BuildStmt(Schedule sch,
const Array<Tensor>& args,
const std::unordered_map<Tensor, Buffer>& binds,
bool loop_partition,
Array<NodeRef> *out_arg_list,
Array<ObjectRef> *out_arg_list,
const BuildConfig& config) {
sch = sch.normalize();
......@@ -445,7 +445,7 @@ Array<LoweredFunc> lower(Schedule sch,
const std::string& name,
const std::unordered_map<Tensor, Buffer>& binds,
const BuildConfig& config) {
Array<NodeRef> out_arg_list;
Array<ObjectRef> out_arg_list;
auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
}
......@@ -618,7 +618,7 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
}
BuildConfig BuildConfig::Create() {
return BuildConfig(make_node<BuildConfigNode>());
return BuildConfig(make_object<BuildConfigNode>());
}
/*! \brief Entry to hold the BuildConfig context stack. */
......@@ -701,7 +701,7 @@ GenericFunc GenericFunc::Get(const std::string& name) {
std::lock_guard<std::mutex>(m->mutex);
auto it = m->fmap.find(name);
if (it == m->fmap.end()) {
auto f = make_node<GenericFuncNode>();
auto f = make_object<GenericFuncNode>();
f->name_ = name;
auto gf = GenericFunc(f);
m->fmap[name] = gf;
......@@ -825,7 +825,7 @@ TVM_REGISTER_API("_BuildConfigGetAddLowerPassInfo")
TVM_REGISTER_API("_GenericFuncCreate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GenericFunc(make_node<GenericFuncNode>());
*ret = GenericFunc(make_object<GenericFuncNode>());
});
TVM_REGISTER_API("_GenericFuncGetGlobal")
......
......@@ -408,7 +408,7 @@ void CodeGenHybrid::PrintIndent() {
std::string CodeGenHybrid::GetVarID(const Variable *v) {
if (binds_.count(v))
return binds_[v];
auto key = std::make_pair(static_cast<const Node*>(v), 0);
auto key = std::make_pair(static_cast<const Object*>(v), 0);
if (id_map_.count(key)) {
return id_map_[key];
}
......@@ -472,7 +472,7 @@ void CodeGenHybrid::ReserveKeywords() {
}
void CodeGenHybrid::DumpStmt(const Stmt &stmt,
const Array<NodeRef> &inputs,
const Array<ObjectRef> &inputs,
const Array<Tensor> &outputs,
const std::string &name) {
ReserveKeywords();
......
......@@ -56,7 +56,7 @@ class CodeGenHybrid :
* \param outputs Output tensors of this schedule.
* \param name The name of the function.
*/
void DumpStmt(const Stmt &stmt, const Array<NodeRef> &inputs, const Array<Tensor> &outputs,
void DumpStmt(const Stmt &stmt, const Array<ObjectRef> &inputs, const Array<Tensor> &outputs,
const std::string &name = "hybrid_func");
/*!
* \brief Finalize the compilation and return the code.
......@@ -152,7 +152,7 @@ class CodeGenHybrid :
/*!
* \brief Keys are either (tensors, value_index) or (variables, 0).
* Values are the corresponding IDs.*/
std::map<std::pair<const Node *, int>, std::string> id_map_;
std::map<std::pair<const Object *, int>, std::string> id_map_;
/*! \brief Variables (keys) binded to the threads (values). */
std::map<const Variable *, std::string> binds_;
/*!
......
......@@ -33,7 +33,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
ObjectPtr<Object> CreateEnvNode(const std::string& name) {
auto* f = runtime::Registry::Get(name);
CHECK(f != nullptr) << "Cannot find global function \'" << name << '\'';
NodePtr<EnvFuncNode> n = make_node<EnvFuncNode>();
ObjectPtr<EnvFuncNode> n = make_object<EnvFuncNode>();
n->func = *f;
n->name = name;
return n;
......
......@@ -39,8 +39,8 @@ void DictAttrsNode::InitByPackedArgs(
for (int i = 0; i < args.size(); i += 2) {
std::string key = args[i];
runtime::TVMArgValue val = args[i + 1];
if (val.type_code() == kObjectHandle) {
dict.Set(key, val.operator NodeRef());
if (val.IsObjectRef<ObjectRef>()) {
dict.Set(key, val.operator ObjectRef());
} else if (val.type_code() == kStr) {
dict.Set(key, Expr(val.operator std::string()));
} else {
......@@ -53,8 +53,8 @@ Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
return {};
}
Attrs DictAttrsNode::make(Map<std::string, NodeRef> dict) {
NodePtr<DictAttrsNode> n = make_node<DictAttrsNode>();
Attrs DictAttrsNode::make(Map<std::string, ObjectRef> dict) {
ObjectPtr<DictAttrsNode> n = make_object<DictAttrsNode>();
n->dict = std::move(dict);
return Attrs(n);
}
......
......@@ -334,7 +334,7 @@ Buffer Buffer::MakeStrideView() const {
if ((*this)->strides.size() != 0) return *this;
if ((*this)->shape.size() == 0) return *this;
std::vector<Expr> temp;
auto n = make_node<BufferNode>(*operator->());
auto n = make_object<BufferNode>(*operator->());
Expr acc = make_const(n->DefaultIndexType(), 1);
for (size_t i = n->shape.size(); i != 0 ; --i) {
temp.push_back(acc);
......@@ -419,7 +419,7 @@ Buffer BufferNode::make(Var data,
int data_alignment,
int offset_factor,
BufferType buffer_type) {
auto n = make_node<BufferNode>();
auto n = make_object<BufferNode>();
n->data = std::move(data);
n->dtype = dtype;
n->shape = std::move(shape);
......
......@@ -68,7 +68,7 @@ const LayoutAxis& LayoutAxis::make(const std::string& name) {
}
Layout::Layout(const Array<IterVar>& axes) {
auto node = make_node<LayoutNode>();
auto node = make_object<LayoutNode>();
node->axes = axes;
std::ostringstream repr;
for (const IterVar& axis : axes) {
......@@ -89,7 +89,7 @@ Layout::Layout(const Array<IterVar>& axes) {
Layout::Layout(const std::string& name) { // NOLINT(*)
if (name == "__undef__") return;
auto node = make_node<LayoutNode>();
auto node = make_object<LayoutNode>();
node->name = name;
if (name.empty()) return; // scalar
......@@ -347,7 +347,7 @@ Array<Expr> BijectiveLayout::BackwardShape(const Array<Expr>& shape) const {
BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout,
const Layout& dst_layout) {
auto n = make_node<BijectiveLayoutNode>();
auto n = make_object<BijectiveLayoutNode>();
n->src_layout = src_layout;
n->dst_layout = dst_layout;
......
......@@ -42,14 +42,14 @@ Var::Var(std::string name_hint, DataType t)
: Var(Variable::make(t, name_hint)) {}
Var Variable::make(DataType t, std::string name_hint) {
NodePtr<Variable> node = make_node<Variable>();
ObjectPtr<Variable> node = make_object<Variable>();
node->dtype = t;
node->name_hint = std::move(name_hint);
return Var(node);
}
Range::Range(Expr begin, Expr end)
: Range(make_node<RangeNode>(
: Range(make_object<RangeNode>(
begin,
is_zero(begin) ? end : (end - begin))) {
}
......@@ -57,21 +57,21 @@ Range::Range(Expr begin, Expr end)
Integer IntImm::make(DataType t, int64_t value) {
CHECK(t.is_int() && t.is_scalar())
<< "ValueError: IntImm can only take scalar.";
NodePtr<IntImm> node = make_node<IntImm>();
ObjectPtr<IntImm> node = make_object<IntImm>();
node->dtype = t;
node->value = value;
return Integer(node);
}
Range Range::make_by_min_extent(Expr min, Expr extent) {
return Range(make_node<RangeNode>(min, extent));
return Range(make_object<RangeNode>(min, extent));
}
IterVar IterVarNode::make(Range dom,
Var var,
IterVarType t,
std::string thread_tag) {
NodePtr<IterVarNode> n = make_node<IterVarNode>();
ObjectPtr<IterVarNode> n = make_object<IterVarNode>();
n->dom = dom;
n->var = var;
n->iter_type = t;
......@@ -89,7 +89,7 @@ IterVar reduce_axis(Range dom, std::string name) {
dom, Var(name), kCommReduce);
}
void Dump(const NodeRef& n) {
void Dump(const ObjectRef& n) {
std::cerr << n << "\n";
}
......
......@@ -47,7 +47,7 @@ Expr Tensor::operator()(Array<Expr> indices) const {
}
Tensor Operation::output(size_t i) const {
auto node = make_node<TensorNode>();
auto node = make_object<TensorNode>();
node->op = *this;
node->value_index = i;
node->dtype = (*this)->output_dtype(i);
......@@ -59,7 +59,7 @@ Tensor TensorNode::make(Array<Expr> shape,
DataType dtype,
Operation op,
int value_index) {
auto n = make_node<TensorNode>();
auto n = make_object<TensorNode>();
n->shape = std::move(shape);
n->dtype = dtype;
n->op = op;
......@@ -87,7 +87,7 @@ TensorIntrin TensorIntrinNode::make(std::string name,
Stmt body,
Stmt reduce_init,
Stmt reduce_update) {
auto n = make_node<TensorIntrinNode>();
auto n = make_object<TensorIntrinNode>();
n->name = std::move(name);
n->op = std::move(op);
n->inputs = std::move(inputs);
......@@ -115,7 +115,7 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
Array<Region> regions,
Array<IterVar> reduce_axis,
Array<Expr> scalar_inputs) {
auto n = make_node<TensorIntrinCallNode>();
auto n = make_object<TensorIntrinCallNode>();
n->intrin = std::move(intrin);
n->tensors = std::move(tensors);
n->regions = std::move(regions);
......
......@@ -79,7 +79,7 @@ class NodeIndexer : public AttrVisitor {
// make index of all the children of node
void MakeIndex(Object* node) {
if (node == nullptr) return;
CHECK(node->IsInstance<Node>());
CHECK(node->IsInstance<Object>());
if (node_index_.count(node)) return;
CHECK_EQ(node_index_.size(), node_list_.size());
......
......@@ -90,8 +90,8 @@ Tensor compute(Array<Expr> shape,
FCompute fcompute,
std::string name,
std::string tag,
Map<std::string, NodeRef> attrs) {
auto op_node = make_node<ComputeOpNode>();
Map<std::string, ObjectRef> attrs) {
auto op_node = make_object<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
std::vector<IterVar> axis;
......@@ -112,8 +112,8 @@ Array<Tensor> compute(Array<Expr> shape,
FBatchCompute fcompute,
std::string name,
std::string tag,
Map<std::string, NodeRef> attrs) {
auto op_node = make_node<ComputeOpNode>();
Map<std::string, ObjectRef> attrs) {
auto op_node = make_object<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
std::vector<IterVar> axis;
......@@ -136,13 +136,13 @@ Array<Tensor> compute(Array<Expr> shape,
Operation ComputeOpNode::make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Map<std::string, ObjectRef> attrs,
Array<IterVar> axis,
Array<Expr> body) {
if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>();
attrs = Map<std::string, ObjectRef>();
}
auto n = make_node<ComputeOpNode>();
auto n = make_object<ComputeOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
......@@ -161,7 +161,7 @@ Array<Tensor> ComputeOpNode::InputTensors() const {
Array<Tensor> ret;
std::unordered_set<Tensor> visited;
for (auto& e : body) {
ir::PostOrderVisit(e, [&ret, &visited](const NodeRef& n) {
ir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) {
const ir::Call *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Downcast<Operation>(call->func).output(call->value_index);
......@@ -188,7 +188,7 @@ Operation ComputeOpNode::ReplaceInputs(
if (!new_reduce.same_as(this->body[0])) {
const ir::Reduce* r = new_reduce.as<ir::Reduce>();
for (size_t k = 0; k < this->body.size(); ++k) {
auto n = make_node<ir::Reduce>(*r);
auto n = make_object<ir::Reduce>(*r);
n->value_index = static_cast<int>(k);
n->dtype = r->source[k].dtype();
arr.push_back(Expr(n));
......@@ -215,7 +215,7 @@ void ComputeOpNode::PropBoundToInputs(
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
auto fvisit = [&dom_map, out_dom_map, analyzer](const NodeRef& n) {
auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Downcast<Operation>(call->func).output(call->value_index);
......@@ -574,7 +574,7 @@ class ComputeVerifier final : protected ir::IRVisitor {
protected:
/// Visitor implementation
//@{
void Visit(const NodeRef& n) final {
void Visit(const ObjectRef& n) final {
++level_;
ir::IRVisitor::Visit(n);
--level_;
......
......@@ -57,15 +57,15 @@ Array<Expr> ExternOpNode::output_shape(size_t i) const {
Operation ExternOpNode::make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Map<std::string, ObjectRef> attrs,
Array<Tensor> inputs,
Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders,
Stmt body) {
if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>();
attrs = Map<std::string, ObjectRef>();
}
auto n = make_node<ExternOpNode>();
auto n = make_object<ExternOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
......@@ -93,7 +93,7 @@ Operation ExternOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
auto n = make_node<ExternOpNode>(*this);
auto n = make_object<ExternOpNode>(*this);
n->body = op::ReplaceTensor(this->body, rmap);
for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i];
......@@ -161,7 +161,7 @@ Stmt ExternOpNode::BuildProvide(
CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmt::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<NodeRef> bind_spec;
Array<ObjectRef> bind_spec;
Array<Expr> tuple;
bind_spec.push_back(buffer);
bind_spec.push_back(tensor);
......
......@@ -63,14 +63,14 @@ Array<Expr> HybridOpNode::output_shape(size_t i) const {
Operation HybridOpNode::make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Map<std::string, ObjectRef> attrs,
Array<Tensor> inputs,
Array<Tensor> outputs,
Stmt body) {
if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>();
attrs = Map<std::string, ObjectRef>();
}
auto n = make_node<HybridOpNode>();
auto n = make_object<HybridOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
......@@ -91,7 +91,7 @@ Array<Tensor> HybridOpNode::InputTensors() const {
}
std::unordered_set<Tensor> visited;
Array<Tensor> curr_inputs;
ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const NodeRef& n) {
ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) {
const ir::Call *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Downcast<Operation>(call->func).output(call->value_index);
......@@ -108,7 +108,7 @@ Operation HybridOpNode::ReplaceInputs(
const Operation &self,
const std::unordered_map<Tensor, Tensor> &rmap) const {
CHECK_EQ(self.operator->(), this);
auto n = make_node<HybridOpNode>(*this);
auto n = make_object<HybridOpNode>(*this);
n->body = op::ReplaceTensor(this->body, rmap);
for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i];
......@@ -185,7 +185,7 @@ Stmt HybridOpNode::BuildProvide(
for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i);
}
auto n = make_node<HybridOpNode>(*this);
auto n = make_object<HybridOpNode>(*this);
/* This is a story little bit complicated.
* The following two lines of codes replace output tensors' usage.
* This is the simplest way I (@were) can come up with to glue
......@@ -369,7 +369,8 @@ Stmt ApplyLoopAnnotations(const Stage &stage,
expected = IterVarTypeToForType(attr->iter_type);
}
PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const NodeRef &node) {
PostOrderVisit(stmt,
[&found, &var, &attr, &expected, &need_change](const ObjectRef& node) {
if (const For *op = node.as<For>()) {
if (op->loop_var.get() == var) {
++found;
......@@ -390,7 +391,7 @@ Stmt ApplyLoopOrder(const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map,
const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
std::vector<const Variable*> current_order;
PostOrderVisit(stmt, [&current_order](const NodeRef &node) {
PostOrderVisit(stmt, [&current_order](const ObjectRef& node) {
if (const For *op = node.as<For>())
current_order.push_back(op->loop_var.get());
});
......@@ -466,7 +467,7 @@ Stmt ApplySchedule(const Stage &stage,
std::vector<IterVar> GatherLoopVars(Stmt stmt) {
// TODO(@were): Write a comprehensive pass to analyze iter var types
std::vector<IterVar> res_;
PostOrderVisit(stmt, [&res_](const NodeRef &node) {
PostOrderVisit(stmt, [&res_](const ObjectRef& node) {
if (const For *op = node.as<For>()) {
Var loop_var(op->loop_var);
Range dom = Range::make_by_min_extent(op->min, op->extent);
......
......@@ -55,7 +55,7 @@ Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape,
DataType dtype) {
auto n = make_node<PlaceholderOpNode>();
auto n = make_object<PlaceholderOpNode>();
n->name = name;
n->shape = shape;
n->dtype = dtype;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment