Unverified Commit ae5a28db by Tianqi Chen Committed by GitHub

[NODE] Node base system refactor (#1739)

parent 1e57ee6c
Subproject commit f519848d972c67971b4cbf8c34070d5a5e3ede0d Subproject commit cf6090aeaeb782d1daff54b0ca5c2c281d7008db
...@@ -57,7 +57,7 @@ class EnvFuncNode : public Node { ...@@ -57,7 +57,7 @@ class EnvFuncNode : public Node {
class EnvFunc : public NodeRef { class EnvFunc : public NodeRef {
public: public:
EnvFunc() {} EnvFunc() {}
explicit EnvFunc(std::shared_ptr<Node> n) : NodeRef(n) {} explicit EnvFunc(NodePtr<Node> n) : NodeRef(n) {}
/*! \return The internal global function pointer */ /*! \return The internal global function pointer */
const EnvFuncNode* operator->() const { const EnvFuncNode* operator->() const {
return static_cast<EnvFuncNode*>(node_.get()); return static_cast<EnvFuncNode*>(node_.get());
...@@ -105,7 +105,7 @@ class TypedEnvFunc<R(Args...)> : public NodeRef { ...@@ -105,7 +105,7 @@ class TypedEnvFunc<R(Args...)> : public NodeRef {
/*! \brief short hand for this function type */ /*! \brief short hand for this function type */
using TSelf = TypedEnvFunc<R(Args...)>; using TSelf = TypedEnvFunc<R(Args...)>;
TypedEnvFunc() {} TypedEnvFunc() {}
explicit TypedEnvFunc(std::shared_ptr<Node> n) : NodeRef(n) {} explicit TypedEnvFunc(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief Assign global function to a TypedEnvFunc * \brief Assign global function to a TypedEnvFunc
* \param other Another global function. * \param other Another global function.
......
...@@ -38,7 +38,7 @@ class IntSet : public NodeRef { ...@@ -38,7 +38,7 @@ class IntSet : public NodeRef {
/*! \brief constructor */ /*! \brief constructor */
IntSet() {} IntSet() {}
// constructor from not container. // constructor from not container.
explicit IntSet(std::shared_ptr<Node> n) : NodeRef(n) {} explicit IntSet(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
......
...@@ -136,7 +136,7 @@ class Attrs : public NodeRef { ...@@ -136,7 +136,7 @@ class Attrs : public NodeRef {
// normal constructor // normal constructor
Attrs() {} Attrs() {}
// construct from shared ptr. // construct from shared ptr.
explicit Attrs(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Attrs(NodePtr<Node> n) : NodeRef(n) {}
/*! \return The attribute node */ /*! \return The attribute node */
const BaseAttrsNode* operator->() const { const BaseAttrsNode* operator->() const {
...@@ -442,7 +442,7 @@ class AttrDocEntry { ...@@ -442,7 +442,7 @@ class AttrDocEntry {
public: public:
using TSelf = AttrDocEntry; using TSelf = AttrDocEntry;
explicit AttrDocEntry(std::shared_ptr<AttrFieldInfoNode> info) explicit AttrDocEntry(NodePtr<AttrFieldInfoNode> info)
: info_(info) { : info_(info) {
} }
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
...@@ -466,15 +466,15 @@ class AttrDocEntry { ...@@ -466,15 +466,15 @@ class AttrDocEntry {
} }
private: private:
std::shared_ptr<AttrFieldInfoNode> info_; NodePtr<AttrFieldInfoNode> info_;
}; };
class AttrDocVisitor { class AttrDocVisitor {
public: public:
template<typename T> template<typename T>
AttrDocEntry operator()(const char* key, T* v) { AttrDocEntry operator()(const char* key, T* v) {
std::shared_ptr<AttrFieldInfoNode> info NodePtr<AttrFieldInfoNode> info
= std::make_shared<AttrFieldInfoNode>(); = make_node<AttrFieldInfoNode>();
info->name = key; info->name = key;
info->type_info = TypeName<T>::value; info->type_info = TypeName<T>::value;
fields_.push_back(AttrFieldInfo(info)); fields_.push_back(AttrFieldInfo(info));
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include <tvm/node.h> #include <tvm/node/node.h>
#include <string> #include <string>
#include <memory> #include <memory>
#include <functional> #include <functional>
...@@ -25,7 +25,7 @@ using ::tvm::AttrVisitor; ...@@ -25,7 +25,7 @@ using ::tvm::AttrVisitor;
class TypeName : public ::tvm::NodeRef { \ class TypeName : public ::tvm::NodeRef { \
public: \ public: \
TypeName() {} \ TypeName() {} \
explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {} \ explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} \
const NodeName* operator->() const { \ const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \ return static_cast<const NodeName*>(node_.get()); \
} \ } \
...@@ -48,7 +48,7 @@ std::string SaveJSON(const NodeRef& node); ...@@ -48,7 +48,7 @@ std::string SaveJSON(const NodeRef& node);
* *
* \return The shared_ptr of the Node. * \return The shared_ptr of the Node.
*/ */
std::shared_ptr<Node> LoadJSON_(std::string json_str); NodePtr<Node> LoadJSON_(std::string json_str);
/*! /*!
* \brief Load the node from json string. * \brief Load the node from json string.
...@@ -85,7 +85,7 @@ struct NodeFactoryReg { ...@@ -85,7 +85,7 @@ struct NodeFactoryReg {
* If this is not empty then FGlobalKey * If this is not empty then FGlobalKey
* \return The created function. * \return The created function.
*/ */
using FCreate = std::function<std::shared_ptr<Node>(const std::string& global_key)>; using FCreate = std::function<NodePtr<Node>(const std::string& global_key)>;
/*! /*!
* \brief Global key function, only needed by global objects. * \brief Global key function, only needed by global objects.
* \param node The node pointer. * \param node The node pointer.
...@@ -123,7 +123,7 @@ struct NodeFactoryReg { ...@@ -123,7 +123,7 @@ struct NodeFactoryReg {
#define TVM_REGISTER_NODE_TYPE(TypeName) \ #define TVM_REGISTER_NODE_TYPE(TypeName) \
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \ static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \ ::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \
.set_creator([](const std::string&) { return std::make_shared<TypeName>(); }) .set_creator([](const std::string&) { return ::tvm::make_node<TypeName>(); })
#define TVM_STRINGIZE_DETAIL(x) #x #define TVM_STRINGIZE_DETAIL(x) #x
......
...@@ -6,11 +6,11 @@ ...@@ -6,11 +6,11 @@
#ifndef TVM_BUFFER_H_ #ifndef TVM_BUFFER_H_
#define TVM_BUFFER_H_ #define TVM_BUFFER_H_
#include <tvm/container.h>
#include <string> #include <string>
#include "base.h" #include "base.h"
#include "expr.h" #include "expr.h"
#include "node/container.h"
namespace tvm { namespace tvm {
...@@ -31,7 +31,7 @@ enum class AccessMask : int { ...@@ -31,7 +31,7 @@ enum class AccessMask : int {
class Buffer : public NodeRef { class Buffer : public NodeRef {
public: public:
Buffer() {} Buffer() {}
explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Buffer(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief Return a new buffer that is equivalent with current one * \brief Return a new buffer that is equivalent with current one
* but always add stride field. * but always add stride field.
......
...@@ -69,7 +69,7 @@ class TargetNode : public Node { ...@@ -69,7 +69,7 @@ class TargetNode : public Node {
class Target : public NodeRef { class Target : public NodeRef {
public: public:
Target() {} Target() {}
explicit Target(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Target(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief Create a Target given a string * \brief Create a Target given a string
...@@ -241,7 +241,7 @@ class BuildConfigNode : public Node { ...@@ -241,7 +241,7 @@ class BuildConfigNode : public Node {
class BuildConfig : public ::tvm::NodeRef { class BuildConfig : public ::tvm::NodeRef {
public: public:
BuildConfig() {} BuildConfig() {}
explicit BuildConfig(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {} explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {}
const BuildConfigNode* operator->() const { const BuildConfigNode* operator->() const {
return static_cast<const BuildConfigNode*>(node_.get()); return static_cast<const BuildConfigNode*>(node_.get());
...@@ -335,7 +335,7 @@ class GenericFuncNode; ...@@ -335,7 +335,7 @@ class GenericFuncNode;
class GenericFunc : public NodeRef { class GenericFunc : public NodeRef {
public: public:
GenericFunc() {} GenericFunc() {}
explicit GenericFunc(std::shared_ptr<Node> n) : NodeRef(n) {} explicit GenericFunc(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief Set the default function implementaiton. * \brief Set the default function implementaiton.
......
...@@ -17,7 +17,7 @@ class Channel : public NodeRef { ...@@ -17,7 +17,7 @@ class Channel : public NodeRef {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
Channel() {} Channel() {}
explicit Channel(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Channel(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
......
...@@ -76,7 +76,7 @@ class Var : public HalideIR::VarExpr { ...@@ -76,7 +76,7 @@ class Var : public HalideIR::VarExpr {
public: public:
EXPORT explicit Var(const std::string& name_hint = "v", EXPORT explicit Var(const std::string& name_hint = "v",
Type t = Int(32)) : VarExpr(name_hint, t) {} Type t = Int(32)) : VarExpr(name_hint, t) {}
explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {} explicit Var(NodePtr<Node> n) : VarExpr(n) {}
explicit Var(VarExpr v) : VarExpr(v) {} explicit Var(VarExpr v) : VarExpr(v) {}
/*! /*!
* \brief Make a new copy of var with same type, append suffix * \brief Make a new copy of var with same type, append suffix
...@@ -107,7 +107,7 @@ class Range : public HalideIR::IR::Range { ...@@ -107,7 +107,7 @@ class Range : public HalideIR::IR::Range {
public: public:
/*! \brief constructor */ /*! \brief constructor */
Range() {} Range() {}
explicit Range(std::shared_ptr<Node> n) : HalideIR::IR::Range(n) {} explicit Range(NodePtr<Node> n) : HalideIR::IR::Range(n) {}
/*! /*!
* \brief constructor by begin and end * \brief constructor by begin and end
* \param begin The begin of the range. * \param begin The begin of the range.
...@@ -197,7 +197,7 @@ class IterVar : public NodeRef { ...@@ -197,7 +197,7 @@ class IterVar : public NodeRef {
// construct a new iter var without a domain // construct a new iter var without a domain
IterVar() {} IterVar() {}
// construct from shared ptr. // construct from shared ptr.
explicit IterVar(std::shared_ptr<Node> n) : NodeRef(n) {} explicit IterVar(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
......
...@@ -28,7 +28,7 @@ struct CommReducerNode; ...@@ -28,7 +28,7 @@ struct CommReducerNode;
struct CommReducer : public NodeRef { struct CommReducer : public NodeRef {
CommReducer() {} CommReducer() {}
explicit CommReducer(std::shared_ptr<Node> n) : NodeRef(n) {} explicit CommReducer(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#ifndef TVM_IR_FUNCTOR_EXT_H_ #ifndef TVM_IR_FUNCTOR_EXT_H_
#define TVM_IR_FUNCTOR_EXT_H_ #define TVM_IR_FUNCTOR_EXT_H_
#include <tvm/ir_functor.h> #include "node/ir_functor.h"
#include "ir.h" #include "ir.h"
namespace tvm { namespace tvm {
......
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
#ifndef TVM_IR_MUTATOR_H_ #ifndef TVM_IR_MUTATOR_H_
#define TVM_IR_MUTATOR_H_ #define TVM_IR_MUTATOR_H_
#include <tvm/ir_functor.h>
#include <unordered_map> #include <unordered_map>
#include "expr.h" #include "expr.h"
#include "ir.h" #include "ir.h"
#include "node/ir_functor.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
......
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
#ifndef TVM_IR_PASS_H_ #ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_ #define TVM_IR_PASS_H_
#include <tvm/ir_functor.h>
#include <arithmetic/Simplify.h> #include <arithmetic/Simplify.h>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#ifndef TVM_IR_VISITOR_H_ #ifndef TVM_IR_VISITOR_H_
#define TVM_IR_VISITOR_H_ #define TVM_IR_VISITOR_H_
#include <tvm/ir_functor.h>
#include "ir.h" #include "ir.h"
#include "node/ir_functor.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
......
...@@ -7,13 +7,13 @@ ...@@ -7,13 +7,13 @@
#ifndef TVM_LOWERED_FUNC_H_ #ifndef TVM_LOWERED_FUNC_H_
#define TVM_LOWERED_FUNC_H_ #define TVM_LOWERED_FUNC_H_
#include <tvm/container.h>
#include <ir/FunctionBase.h> #include <ir/FunctionBase.h>
#include <string> #include <string>
#include "base.h" #include "base.h"
#include "expr.h" #include "expr.h"
#include "tensor.h" #include "tensor.h"
#include "node/container.h"
namespace tvm { namespace tvm {
...@@ -27,7 +27,7 @@ class LoweredFuncNode; ...@@ -27,7 +27,7 @@ class LoweredFuncNode;
class LoweredFunc : public FunctionRef { class LoweredFunc : public FunctionRef {
public: public:
LoweredFunc() {} LoweredFunc() {}
explicit LoweredFunc(std::shared_ptr<Node> n) : FunctionRef(n) {} explicit LoweredFunc(NodePtr<Node> n) : FunctionRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/node/ir_functor.h
* \brief Defines the IRFunctor data structures.
*/
#ifndef TVM_NODE_IR_FUNCTOR_H_
#define TVM_NODE_IR_FUNCTOR_H_
#include <dmlc/logging.h>
#include <string>
#include <vector>
#include <type_traits>
#include <functional>
#include "node.h"
#include "../runtime/registry.h"
namespace tvm {
/*!
* \brief A dynamical dispatched functor on NodeRef in the first argument.
*
* \code
* IRFunctor<std::string (const NodeRef& n, std::string prefix)> tostr;
* tostr.set_dispatch<Add>([](const Add* op, std::string prefix) {
* return prefix + "Add";
* });
* tostr.set_dispatch<IntImm>([](const IntImm* op) {
* return prefix + "IntImm"
* });
*
* Expr x = make_const(1);
* Expr y = x + x;
* // dispatch to IntImm, outputs "MyIntImm"
* LOG(INFO) << tostr(x, "My");
* // dispatch to IntImm, outputs "MyAdd"
* LOG(INFO) << tostr(y, "My");
* \endcode
*
* \tparam FType function signiture
* This type if only defined for FType with function signiture
*/
template<typename FType>
class IRFunctor;
template<typename R, typename ...Args>
class IRFunctor<R(const NodeRef& n, Args...)> {
private:
using Function = std::function<R (const NodeRef&n, Args...)>;
using TSelf = IRFunctor<R (const NodeRef& n, Args...)>;
/*! \brief internal function table */
std::vector<Function> func_;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*!
* \brief Whether the functor can dispatch the corresponding Node
* \param n The node to be dispatched
* \return Whether dispatching function is registered for n's type.
*/
inline bool can_dispatch(const NodeRef& n) const {
uint32_t type_index = n.type_index();
return type_index < func_.size() && func_[type_index] != nullptr;
}
/*!
* \brief invoke the functor , dispatch on type of n
* \param n The Node argument
* \param args The additional arguments
* \return The result.
*/
inline R operator()(const NodeRef& n, Args... args) const {
uint32_t type_index = n.type_index();
CHECK(type_index < func_.size() &&
func_[type_index] != nullptr)
<< "IRFunctor calls un-registered function on type "
<< Node::TypeIndex2Key(type_index);
return func_[type_index](n, std::forward<Args>(args)...);
}
/*!
* \brief set the dispacher for type TNode
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& set_dispatch(Function f) { // NOLINT(*)
uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
if (func_.size() <= tindex) {
func_.resize(tindex + 1, nullptr);
}
CHECK(func_[tindex] == nullptr)
<< "Dispatch for " << Node::TypeIndex2Key(tindex)
<< " is already set";
func_[tindex] = f;
return *this;
}
/*!
* \brief set the dispacher for type TNode
* This allows f to used detailed const Node pointer to replace NodeRef
*
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
Function fun = [f](const NodeRef& n, Args... args) {
return f(static_cast<const TNode*>(n.node_.get()),
std::forward<Args>(args)...);
};
return this->set_dispatch<TNode>(fun);
}
/*!
* \brief unset the dispacher for type TNode
*
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& clear_dispatch() { // NOLINT(*)
uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
func_[tindex] = nullptr;
return *this;
}
};
#define TVM_REGISTER_VAR_DEF(ClsName) \
static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName
/*!
* \brief Useful macro to set IRFunctor dispatch in a global static field.
*
* \code
* // Use IRFunctor to implement IRPrinter similar to Visitor Pattern.
* // vtable allows easy patch in of new Node types, without changing
* // interface of IRPrinter.
*
* class IRPrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
* void print(Expr e) {
* const static FType& f = *vtable();
* f(e, this);
* }
*
* using FType = IRFunctor<void (const NodeRef&, IRPrinter *)>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*0
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
* .set_dispatch<Add>([](const Add* n, IRPrinter* p) {
* p->print(n->a);
* p->stream << '+'
* p->print(n->b);
* });
*
*
* \endcode
*
* \param ClsName The name of the class
* \param FField The static function that returns a singleton of IRFunctor.
*/
#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
ClsName::FField()
/*!
* \brief A container for a list of callbacks. All callbacks are invoked when
* the object is destructed.
*/
class IRFunctorCleanList {
public:
~IRFunctorCleanList() {
for (auto &f : clean_items) {
f();
}
}
void append(std::function<void()> func) {
clean_items.push_back(func);
}
private:
std::vector< std::function<void()> > clean_items;
};
/*!
* \brief A wrapper around IRFunctor that will record calls to set_dispatch
* and make a corresponding call to clear_dispatch when the last copy of
* the IRFunctorStaticRegistry is destructed. When assigned to a static variable,
* this can be used by NNVM and other libraries to unregister callbacks when
* the library is unloaded. This prevents crashes when the underlying IRFunctor
* is destructed as it will no longer contain std::function instances allocated
* by a library that has been unloaded.
*/
template<typename FType>
class IRFunctorStaticRegistry;
template<typename R, typename ...Args>
class IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> {
private:
IRFunctor<R(const NodeRef& n, Args...)> *irf_;
std::shared_ptr<IRFunctorCleanList> free_list;
using TSelf = IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>;
public:
IRFunctorStaticRegistry(IRFunctor<R(const NodeRef& n, Args...)> *irf) {
irf_ = irf;
free_list = std::make_shared<IRFunctorCleanList>();
}
template<typename TNode>
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
irf_->template set_dispatch<TNode>(f);
auto irf_copy = irf_;
free_list.get()->append([irf_copy] {
irf_copy->template clear_dispatch<TNode>();
});
return *this;
}
};
/*!
* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows
* the compiler to deduce the template types.
*/
template<typename R, typename ...Args>
IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> MakeIRFunctorStaticRegistry(
IRFunctor<R(const NodeRef& n, Args...)> *irf) {
return IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>(irf);
}
#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \
static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName
/*!
* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry.
* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of
* TVM_STATIC_IR_FUNCTOR.
*/
#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField) \
TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
MakeIRFunctorStaticRegistry(&ClsName::FField())
} // namespace tvm
#endif // TVM_NODE_IR_FUNCTOR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/node/memory.h
* \brief Node memory management.
*/
#ifndef TVM_NODE_MEMORY_H_
#define TVM_NODE_MEMORY_H_
#include "node.h"
namespace tvm {
/*!
* \brief Allocate a node object.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
*/
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args);
// Detail implementations after this
//
// The current design allows swapping the
// allocator pattern when necessary.
//
// Possible future allocator optimizations:
// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr)
// - Thread-local object pools: one pool per size and alignment requirement.
// - Can specialize by type of object to give the specific allocator to each object.
//
template<typename T>
class SimpleNodeAllocator {
public:
template<typename... Args>
static T* New(Args&&... args) {
return new T(std::forward<Args>(args)...);
}
static NodeBase::FDeleter Deleter() {
return Deleter_;
}
private:
static void Deleter_(NodeBase* ptr) {
delete static_cast<T*>(ptr);
}
};
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args) {
using Allocator = SimpleNodeAllocator<T>;
static_assert(std::is_base_of<NodeBase, T>::value,
"make_node can only be used to create NodeBase");
T* node = Allocator::New(std::forward<Args>(args)...);
node->deleter_ = Allocator::Deleter();
return NodePtr<T>(node);
}
} // namespace tvm
#endif // TVM_NODE_MEMORY_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/node/node.h
* \brief Node system data structure.
*/
#ifndef TVM_NODE_NODE_H_
#define TVM_NODE_NODE_H_
#include <string>
#include <vector>
#include <type_traits>
#include "base/Type.h"
#include "../runtime/node_base.h"
#include "../runtime/c_runtime_api.h"
namespace tvm {
using HalideIR::Type;
// forward declaration
class Node;
class NodeRef;
namespace runtime {
// forward declaration
class NDArray;
} // namespace runtime
/*!
* \brief Visitor class to each node content.
* The content is going to be called for each field.
*/
class TVM_DLL AttrVisitor {
public:
//! \cond Doxygen_Suppress
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0;
virtual void Visit(const char* key, int* value) = 0;
virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, void** value) = 0;
virtual void Visit(const char* key, Type* value) = 0;
virtual void Visit(const char* key, NodeRef* value) = 0;
virtual void Visit(const char* key, runtime::NDArray* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr));
}
//! \endcond
};
/*!
* \brief base class of node container in DSL AST.
* All object's internal is stored as std::shared_ptr<Node>
*/
class TVM_DLL Node : public NodeBase {
public:
/*! \brief virtual destructor */
virtual ~Node() {}
/*! \return The unique type key of the node */
virtual const char* type_key() const = 0;
/*!
* \brief Apply visitor to each field of the Node
* Visitor could mutate the content of the node.
* override if Node contains attribute fields.
* \param visitor The visitor
*/
virtual void VisitAttrs(AttrVisitor* visitor) {}
/*! \return the type index of the node */
virtual const uint32_t type_index() const = 0;
/*!
* \brief Whether this node derives from node with type_index=tid.
* Implemented by TVM_DECLARE_NODE_TYPE_INFO
*
* \param tid The type index.
* \return the check result.
*/
virtual const bool _DerivedFrom(uint32_t tid) const;
/*!
* \brief get a runtime unique type index given a type key
* \param type_key Type key of a type.
* \return the corresponding type index.
*/
static uint32_t TypeKey2Index(const char* type_key);
/*!
* \brief get type key from type index.
* \param index The type index
* \return the corresponding type key.
*/
static const char* TypeIndex2Key(uint32_t index);
/*!
* \return whether the type is derived from
*/
template<typename T>
inline bool derived_from() const;
/*!
* \return whether the node is of type T
* \tparam The type to be checked.
*/
template<typename T>
inline bool is_type() const;
/*!
* \brief Get a NodeRef that holds reference to this Node.
* \return the NodeRef
*/
inline NodeRef GetNodeRef() const;
// node ref can see this
friend class NodeRef;
static constexpr const char* _type_key = "Node";
};
/*! \brief Base class of all node reference object */
class NodeRef {
public:
/*! \brief type indicate the container type */
using ContainerType = Node;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator==(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool same_as(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator<(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator!=(const NodeRef& other) const;
/*! \return the hash function for NodeRef */
inline size_t hash() const;
/*! \return whether the expression is null */
inline bool defined() const;
/*! \return the internal type index of IRNode */
inline uint32_t type_index() const;
/*! \return the internal node pointer */
inline const Node* get() const;
/*! \return the internal node pointer */
inline const Node* operator->() const;
/*!
* \brief Downcast this ir node to its actual type (e.g. Add, or
* Select). This returns nullptr if the node is not of the requested
* type. Example usage:
*
* if (const Add *add = node->as<Add>()) {
* // This is an add node
* }
* \tparam T the target type, must be subtype of IRNode
*/
template<typename T>
inline const T *as() const;
/*!
* \brief A more powerful version of as that also works with
* intermediate base types.
* \tparam T the target type, must be subtype of IRNode
*/
template<typename T>
inline const T *as_derived() const;
/*! \brief default constructor */
NodeRef() = default;
explicit NodeRef(NodePtr<Node> node) : node_(node) {}
/*! \brief the internal node object, do not touch */
NodePtr<Node> node_;
};
/*!
* \brief helper macro to declare type information in a base node.
*/
#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \
const bool _DerivedFrom(uint32_t tid) const override { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
}
/*!
* \brief helper macro to declare type information in a terminal node
*/
#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \
const char* type_key() const final { \
return TypeName::_type_key; \
} \
const uint32_t type_index() const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
return tidx; \
} \
const bool _DerivedFrom(uint32_t tid) const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
}
// implementations of inline functions after this
template<typename T>
inline bool Node::is_type() const {
// use static field so query only happens once.
static uint32_t type_id = Node::TypeKey2Index(T::_type_key);
return type_id == this->type_index();
}
template<typename T>
inline bool Node::derived_from() const {
// use static field so query only happens once.
static uint32_t type_id = Node::TypeKey2Index(T::_type_key);
return this->_DerivedFrom(type_id);
}
inline NodeRef Node::GetNodeRef() const {
return NodeRef(NodePtr<Node>(const_cast<Node*>(this)));
}
inline const Node* NodeRef::get() const {
return node_.get();
}
inline const Node* NodeRef::operator->() const {
return node_.get();
}
inline bool NodeRef::defined() const {
return node_.get() != nullptr;
}
inline bool NodeRef::operator==(const NodeRef& other) const {
return node_.get() == other.node_.get();
}
inline bool NodeRef::same_as(const NodeRef& other) const {
return node_.get() == other.node_.get();
}
inline bool NodeRef::operator<(const NodeRef& other) const {
return node_.get() < other.node_.get();
}
inline bool NodeRef::operator!=(const NodeRef& other) const {
return node_.get() != other.node_.get();
}
inline size_t NodeRef::hash() const {
return std::hash<Node*>()(node_.get());
}
inline uint32_t NodeRef::type_index() const {
CHECK(node_.get() != nullptr)
<< "null type";
return get()->type_index();
}
template<typename T>
inline const T* NodeRef::as() const {
const Node* ptr = static_cast<const Node*>(get());
if (ptr && ptr->is_type<T>()) {
return static_cast<const T*>(ptr);
}
return nullptr;
}
template<typename T>
inline const T* NodeRef::as_derived() const {
const Node* ptr = static_cast<const Node*>(get());
if (ptr && (ptr->is_type<T>() || ptr->derived_from<T>())) {
return static_cast<const T*>(ptr);
}
return nullptr;
}
/*! \brief The hash function for nodes */
struct NodeHash {
size_t operator()(const NodeRef& a) const {
return a.hash();
}
};
/*! \brief The equal comparator for nodes */
struct NodeEqual {
bool operator()(const NodeRef& a, const NodeRef& b) const {
return a.get() == b.get();
}
};
} // namespace tvm
#endif // TVM_NODE_NODE_H_
...@@ -116,7 +116,7 @@ inline TNodeRef TVMArgValue::AsNodeRef() const { ...@@ -116,7 +116,7 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
"Conversion only works for NodeRef"); "Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef(); if (type_code_ == kNull) return TNodeRef();
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >(); NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get())) CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TNodeRef>() << "Expected type " << NodeTypeName<TNodeRef>()
<< " but get " << sptr->type_key(); << " but get " << sptr->type_key();
...@@ -132,7 +132,7 @@ inline TVMArgValue::operator HalideIR::Expr() const { ...@@ -132,7 +132,7 @@ inline TVMArgValue::operator HalideIR::Expr() const {
return Expr(static_cast<float>(value_.v_float64)); return Expr(static_cast<float>(value_.v_float64));
} }
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >(); NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
if (sptr->is_type<IterVarNode>()) { if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var; return IterVar(sptr)->var;
} }
...@@ -145,27 +145,27 @@ inline TVMArgValue::operator HalideIR::Expr() const { ...@@ -145,27 +145,27 @@ inline TVMArgValue::operator HalideIR::Expr() const {
return Expr(sptr); return Expr(sptr);
} }
inline std::shared_ptr<Node>& TVMArgValue::node_sptr() { inline NodePtr<Node>& TVMArgValue::node_sptr() {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
return *ptr<std::shared_ptr<Node> >(); return *ptr<NodePtr<Node> >();
} }
template<typename TNodeRef, typename> template<typename TNodeRef, typename>
inline bool TVMArgValue::IsNodeType() const { inline bool TVMArgValue::IsNodeType() const {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr = NodePtr<Node>& sptr =
*ptr<std::shared_ptr<Node> >(); *ptr<NodePtr<Node> >();
return NodeTypeChecker<TNodeRef>::Check(sptr.get()); return NodeTypeChecker<TNodeRef>::Check(sptr.get());
} }
// extensions for TVMRetValue // extensions for TVMRetValue
inline TVMRetValue& TVMRetValue::operator=( inline TVMRetValue& TVMRetValue::operator=(
const std::shared_ptr<Node>& other) { const NodePtr<Node>& other) {
if (other.get() == nullptr) { if (other.get() == nullptr) {
SwitchToPOD(kNull); SwitchToPOD(kNull);
} else { } else {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other); SwitchToClass<NodePtr<Node> >(kNodeHandle, other);
} }
return *this; return *this;
} }
...@@ -174,7 +174,7 @@ inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) { ...@@ -174,7 +174,7 @@ inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
if (!other.defined()) { if (!other.defined()) {
SwitchToPOD(kNull); SwitchToPOD(kNull);
} else { } else {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_); SwitchToClass<NodePtr<Node> >(kNodeHandle, other.node_);
} }
return *this; return *this;
} }
...@@ -186,7 +186,7 @@ inline TNodeRef TVMRetValue::AsNodeRef() const { ...@@ -186,7 +186,7 @@ inline TNodeRef TVMRetValue::AsNodeRef() const {
"Conversion only works for NodeRef"); "Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef(); if (type_code_ == kNull) return TNodeRef();
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >(); NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get())) CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TNodeRef>() << "Expected type " << NodeTypeName<TNodeRef>()
<< " but get " << sptr->type_key(); << " but get " << sptr->type_key();
...@@ -195,7 +195,7 @@ inline TNodeRef TVMRetValue::AsNodeRef() const { ...@@ -195,7 +195,7 @@ inline TNodeRef TVMRetValue::AsNodeRef() const {
inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*) inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*)
if (other.defined()) { if (other.defined()) {
values_[i].v_handle = const_cast<std::shared_ptr<Node>*>(&(other.node_)); values_[i].v_handle = const_cast<NodePtr<Node>*>(&(other.node_));
type_codes_[i] = kNodeHandle; type_codes_[i] = kNodeHandle;
} else { } else {
type_codes_[i] = kNull; type_codes_[i] = kNull;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/node.h> #include <tvm/node/node.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -55,16 +55,16 @@ using NodeEqual = ::tvm::NodeEqual; ...@@ -55,16 +55,16 @@ using NodeEqual = ::tvm::NodeEqual;
* \param NodeName The internal container name. * \param NodeName The internal container name.
* \param NodeRefBase The base type. * \param NodeRefBase The base type.
*/ */
#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ #define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \
class TypeName : public NodeRefBase { \ class TypeName : public NodeRefBase { \
public: \ public: \
TypeName() {} \ TypeName() {} \
explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefBase(n) {} \ explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRefBase(n) {} \
const NodeName* operator->() const { \ const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \ return static_cast<const NodeName*>(node_.get()); \
} \ } \
operator bool() { return this->defined(); } \ operator bool() { return this->defined(); } \
using ContainerType = NodeName; \ using ContainerType = NodeName; \
}; };
/*! /*!
...@@ -82,8 +82,6 @@ class SourceNameNode : public Node { ...@@ -82,8 +82,6 @@ class SourceNameNode : public Node {
// override attr visitor // override attr visitor
void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); } void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); }
TVM_DLL static SourceName make(std::string name);
static constexpr const char* _type_key = "relay.SourceName"; static constexpr const char* _type_key = "relay.SourceName";
TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node); TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node);
}; };
...@@ -98,7 +96,7 @@ class SourceName : public NodeRef { ...@@ -98,7 +96,7 @@ class SourceName : public NodeRef {
SourceName() {} SourceName() {}
/*! \brief constructor from node pointer */ /*! \brief constructor from node pointer */
explicit SourceName(std::shared_ptr<Node> n) : NodeRef(n) {} explicit SourceName(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -109,9 +107,9 @@ class SourceName : public NodeRef { ...@@ -109,9 +107,9 @@ class SourceName : public NodeRef {
* \brief Get an SourceName for a given operator name. * \brief Get an SourceName for a given operator name.
* Will raise an error if the source name has not been registered. * Will raise an error if the source name has not been registered.
* \param name Name of the operator. * \param name Name of the operator.
* \return Reference to a SourceName valid throughout program lifetime. * \return SourceName valid throughout program lifetime.
*/ */
TVM_DLL static const SourceName& Get(const std::string& name); TVM_DLL static SourceName Get(const std::string& name);
/*! \brief specify container node */ /*! \brief specify container node */
using ContainerType = SourceNameNode; using ContainerType = SourceNameNode;
...@@ -176,7 +174,7 @@ template <typename RefType, typename NodeType> ...@@ -176,7 +174,7 @@ template <typename RefType, typename NodeType>
RefType GetRef(const NodeType* ptr) { RefType GetRef(const NodeType* ptr) {
static_assert(std::is_same<typename RefType::ContainerType, NodeType>::value, static_assert(std::is_same<typename RefType::ContainerType, NodeType>::value,
"Can only cast to the ref of same container type"); "Can only cast to the ref of same container type");
return RefType(const_cast<NodeType*>(ptr)->shared_from_this()); return RefType(std::move(ptr->GetNodeRef().node_));
} }
// TODO(@tqchen, @jroesch): can we move these semantics to HalideIR // TODO(@tqchen, @jroesch): can we move these semantics to HalideIR
......
...@@ -98,15 +98,15 @@ class EnvironmentNode : public RelayNode { ...@@ -98,15 +98,15 @@ class EnvironmentNode : public RelayNode {
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
private: private:
/*! \brief A map from string names to global variables that /*! \brief A map from string names to global variables that
* ensures global uniqueness. * ensures global uniqueness.
*/ */
tvm::Map<std::string, GlobalVar> global_map_; tvm::Map<std::string, GlobalVar> global_map_;
}; };
struct Environment : public NodeRef { struct Environment : public NodeRef {
Environment() {} Environment() {}
explicit Environment(std::shared_ptr<tvm::Node> p) : NodeRef(p) {} explicit Environment(NodePtr<tvm::Node> p) : NodeRef(p) {}
inline EnvironmentNode* operator->() const { inline EnvironmentNode* operator->() const {
return static_cast<EnvironmentNode*>(node_.get()); return static_cast<EnvironmentNode*>(node_.get());
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#ifndef TVM_RELAY_EXPR_FUNCTOR_H_ #ifndef TVM_RELAY_EXPR_FUNCTOR_H_
#define TVM_RELAY_EXPR_FUNCTOR_H_ #define TVM_RELAY_EXPR_FUNCTOR_H_
#include <tvm/ir_functor.h> #include <tvm/node/ir_functor.h>
#include <string> #include <string>
#include "./expr.h" #include "./expr.h"
#include "./op.h" #include "./op.h"
...@@ -19,7 +19,7 @@ namespace relay { ...@@ -19,7 +19,7 @@ namespace relay {
* \brief A dynamical functor that dispatches on in the first Expr argument. * \brief A dynamical functor that dispatches on in the first Expr argument.
* You can use this as a more powerful Visitor, since it allows you to * You can use this as a more powerful Visitor, since it allows you to
* define function signatures of Visit Function. * define function signatures of Visit Function.
* *
* \sa tvm/ir_functor.h * \sa tvm/ir_functor.h
* *
* \tparam FType function signiture * \tparam FType function signiture
...@@ -30,7 +30,7 @@ template <typename FType> ...@@ -30,7 +30,7 @@ template <typename FType>
class ExprFunctor; class ExprFunctor;
// functions to be overriden. // functions to be overriden.
#define EXPR_FUNCTOR_DEFAULT \ #define EXPR_FUNCTOR_DEFAULT \
{ return VisitExprDefault_(op, std::forward<Args>(args)...); } { return VisitExprDefault_(op, std::forward<Args>(args)...); }
#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ #define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \
...@@ -152,12 +152,12 @@ class ExprMutator ...@@ -152,12 +152,12 @@ class ExprMutator
Expr VisitExpr_(const CallNode* call_node, const Expr& e) override; Expr VisitExpr_(const CallNode* call_node, const Expr& e) override;
Expr VisitExpr_(const LetNode* op, const Expr& e) override; Expr VisitExpr_(const LetNode* op, const Expr& e) override;
Expr VisitExpr_(const IfNode* op, const Expr& e) override; Expr VisitExpr_(const IfNode* op, const Expr& e) override;
/*! \brief Used to visit the types inside of expressions. /*! \brief Used to visit the types inside of expressions.
* *
* Can be overloaded to transform the types in arbitrary * Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type * ways, one way would be to define a sub-class of type
* visitor for types which transform them appropriately. * visitor for types which transform them appropriately.
*/ */
virtual Type VisitType(const Type& t); virtual Type VisitType(const Type& t);
private: private:
......
...@@ -90,7 +90,7 @@ class Op : public relay::Expr { ...@@ -90,7 +90,7 @@ class Op : public relay::Expr {
/*! \brief default constructor */ /*! \brief default constructor */
Op() {} Op() {}
/*! \brief constructor from node pointer */ /*! \brief constructor from node pointer */
explicit Op(std::shared_ptr<Node> n) : Expr(n) {} explicit Op(NodePtr<Node> n) : Expr(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -149,9 +149,9 @@ class OpRegistry { ...@@ -149,9 +149,9 @@ class OpRegistry {
const std::string& description); const std::string& description);
/*! /*!
* \brief Attach the type function corresponding to the return type. * \brief Attach the type function corresponding to the return type.
* \param rel_name The type relation name to register. * \param rel_name The type relation name to register.
* \param type_rel_func The backing relation function which can solve an arbitrary * \param type_rel_func The backing relation function which can solve an arbitrary
* relation on variables. * relation on variables.
* \return reference to self. * \return reference to self.
*/ */
inline OpRegistry& add_type_rel( inline OpRegistry& add_type_rel(
...@@ -338,7 +338,7 @@ inline OpRegistry& OpRegistry::describe( ...@@ -338,7 +338,7 @@ inline OpRegistry& OpRegistry::describe(
inline OpRegistry& OpRegistry::add_argument(const std::string& name, inline OpRegistry& OpRegistry::add_argument(const std::string& name,
const std::string& type, const std::string& type,
const std::string& description) { const std::string& description) {
std::shared_ptr<AttrFieldInfoNode> n = std::make_shared<AttrFieldInfoNode>(); auto n = make_node<AttrFieldInfoNode>();
n->name = name; n->name = name;
n->type_info = type; n->type_info = type;
n->description = description; n->description = description;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/node.h> #include <tvm/node/node.h>
#include <string> #include <string>
#include "./base.h" #include "./base.h"
...@@ -37,7 +37,7 @@ class TypeNode : public RelayNode { ...@@ -37,7 +37,7 @@ class TypeNode : public RelayNode {
class Type : public NodeRef { class Type : public NodeRef {
public: public:
Type() {} Type() {}
explicit Type(std::shared_ptr<tvm::Node> p) : NodeRef(p) {} explicit Type(NodePtr<tvm::Node> p) : NodeRef(p) {}
using ContainerType = TypeNode; using ContainerType = TypeNode;
}; };
......
...@@ -263,12 +263,16 @@ struct NDArray::Container { ...@@ -263,12 +263,16 @@ struct NDArray::Container {
// the usages of functions are documented in place. // the usages of functions are documented in place.
inline NDArray::NDArray(Container* data) inline NDArray::NDArray(Container* data)
: data_(data) { : data_(data) {
data_->IncRef(); if (data != nullptr) {
data_->IncRef();
}
} }
inline NDArray::NDArray(const NDArray& other) inline NDArray::NDArray(const NDArray& other)
: data_(other.data_) { : data_(other.data_) {
data_->IncRef(); if (data_ != nullptr) {
data_->IncRef();
}
} }
inline void NDArray::reset() { inline void NDArray::reset() {
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/runtime/node_base.h
* \brief Base data structure for Node.
*
* \note Node is not a runtime feature.
* This file only exposes the signature of NodePtr for PackedFunc.
*/
#ifndef TVM_RUNTIME_NODE_BASE_H_
#define TVM_RUNTIME_NODE_BASE_H_
#include <utility>
#include <atomic>
namespace tvm {
// forward declarations
template<typename T>
class NodePtr;
class Node;
class NodeRef;
/*!
* \brief Base class of Node for runtime destructor purposes.
*
* Node is a reference counted object which is used to construct AST.
* Each node is backed by a custom deleter, which deletes the object.
* Do not call create raw Node pointer, always use tvm::make_node.
*
* \note In most cases, please inheritate tvm::Node.
* \sa Node, NodePtr, make_node
*/
class NodeBase {
public:
/*!
* \brief type of NodeBase deleter
* \param self pointer to the NodeBase.
*/
typedef void (*FDeleter)(NodeBase* self);
protected:
// default constructor and copy constructor
NodeBase() {}
// override the copy and assign constructors to do nothing.
// This is to make sure only contents, but not deleter and ref_counter
// are copied when a child class copies itself.
NodeBase(const NodeBase& other) { // NOLINT(*)
}
NodeBase(NodeBase&& other) { // NOLINT(*)
}
NodeBase& operator=(const NodeBase& other) { //NOLINT(*)
return *this;
}
NodeBase& operator=(NodeBase&& other) { //NOLINT(*)
return *this;
}
private:
/*! \brief Internal reference counter */
std::atomic<int> ref_counter_{0};
/*!
* \brief deleter of this object to enable customized allocation.
* If the deleter is nullptr, no deletion will be performed.
* The creator of the Node must always set the deleter field properly.
*/
FDeleter deleter_ = nullptr;
// reference counting functions
void IncRef() {
ref_counter_.fetch_add(1, std::memory_order_relaxed);
}
void DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
std::atomic_thread_fence(std::memory_order_acquire);
if (this->deleter_ != nullptr) {
(*this->deleter_)(this);
}
}
}
int use_count() const {
return ref_counter_.load(std::memory_order_relaxed);
}
// friend declaration
template<typename>
friend class NodePtr;
template<typename Y, typename... Args>
friend NodePtr<Y> make_node(Args&&...);
};
/*!
* \brief Smart pointer for Node containers,
* must be subclass of NodeBase
* \tparam T the content data type.
*/
template<typename T>
class NodePtr {
public:
/*! \brief default constructor */
NodePtr() {}
/*! \brief default constructor */
NodePtr(std::nullptr_t) {} // NOLINT(*)
/*!
* \brief copy constructor
* \param other The value to be moved
*/
NodePtr(const NodePtr<T>& other) // NOLINT(*)
: NodePtr(other.data_) {
}
/*!
* \brief copy constructor
* \param other The value to be moved
*/
template<typename Y>
NodePtr(const NodePtr<Y>& other) // NOLINT(*)
: NodePtr(other.data_) {
static_assert(std::is_base_of<T, Y>::value,
"can only assign of child class NodePtr to parent");
}
/*!
* \brief move constructor
* \param other The value to be moved
*/
NodePtr(NodePtr<T>&& other) // NOLINT(*)
: data_(other.data_) {
other.data_ = nullptr;
}
/*!
* \brief move constructor
* \param other The value to be moved
*/
template<typename Y>
NodePtr(NodePtr<Y>&& other) // NOLINT(*)
: data_(other.data_) {
static_assert(std::is_base_of<T, Y>::value,
"can only assign of child class NodePtr to parent");
other.data_ = nullptr;
}
/*! \brief destructor */
~NodePtr() {
this->reset();
}
/*!
* \brief Swap this array with another NDArray
* \param other The other NDArray
*/
void swap(NodePtr<T>& other) { // NOLINT(*)
std::swap(data_, other.data_);
}
/*!
* \return Get the content of the pointer
*/
T* get() const {
return static_cast<T*>(data_);
}
/*!
* \return The pointer
*/
T* operator->() const {
return get();
}
/*!
* \return The reference
*/
T& operator*() const { // NOLINT(*)
return *get();
}
/*!
* \brief copy assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
NodePtr<T>& operator=(const NodePtr<T>& other) { // NOLINT(*)
// takes in plane operator to enable copy elison.
// copy-and-swap idiom
NodePtr(other).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief move assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
NodePtr<T>& operator=(NodePtr<T>&& other) { // NOLINT(*)
// copy-and-swap idiom
NodePtr(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
/*! \brief reset the content of ptr to be nullptr */
void reset() {
if (data_ != nullptr) {
data_->DecRef();
data_ = nullptr;
}
}
/*! \return The use count of the ptr, for debug purposes */
int use_count() const {
return data_ != nullptr ? data_->use_count() : 0;
}
/*! \return whether the reference is unique */
bool unique() const {
return data_ != nullptr && data_->use_count() == 1;
}
/*! \return Whether two NodePtr do not equals each other */
bool operator==(const NodePtr<T>& other) const {
return data_ == other.data_;
}
/*! \return Whether two NodePtr equals each other */
bool operator!=(const NodePtr<T>& other) const {
return data_ != other.data_;
}
/*! \return Whether the pointer is nullptr */
bool operator==(std::nullptr_t null) const {
return data_ == nullptr;
}
/*! \return Whether the pointer is not nullptr */
bool operator!=(std::nullptr_t null) const {
return data_ != nullptr;
}
private:
/*! \brief internal pointer field */
NodeBase* data_{nullptr};
/*!
* \brief constructor from NodeBase
* \param data The node base pointer
*/
explicit NodePtr(NodeBase* data)
: data_(data) {
if (data != nullptr) {
data_->IncRef();
}
}
// friend declaration
friend class Node;
template<typename>
friend class NodePtr;
template<typename Y, typename... Args>
friend NodePtr<Y> make_node(Args&&...);
};
} // namespace tvm
#endif // TVM_RUNTIME_NODE_BASE_H_
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "c_runtime_api.h" #include "c_runtime_api.h"
#include "module.h" #include "module.h"
#include "ndarray.h" #include "ndarray.h"
#include "node_base.h"
namespace HalideIR { namespace HalideIR {
// Forward declare type for extensions // Forward declare type for extensions
...@@ -31,12 +32,6 @@ struct Expr; ...@@ -31,12 +32,6 @@ struct Expr;
#endif #endif
namespace tvm { namespace tvm {
// Forward declare NodeRef and Node for extensions.
// This header works fine without depend on NodeRef
// as long as it is not used.
class Node;
class NodeRef;
namespace runtime { namespace runtime {
// forward declarations // forward declarations
class TVMArgs; class TVMArgs;
...@@ -549,7 +544,7 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -549,7 +544,7 @@ class TVMArgValue : public TVMPODValue_ {
inline operator HalideIR::Type() const; inline operator HalideIR::Type() const;
inline operator HalideIR::Expr() const; inline operator HalideIR::Expr() const;
// get internal node ptr, if it is node // get internal node ptr, if it is node
inline std::shared_ptr<Node>& node_sptr(); inline NodePtr<Node>& node_sptr();
}; };
/*! /*!
...@@ -745,7 +740,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -745,7 +740,7 @@ class TVMRetValue : public TVMPODValue_ {
template<typename TNodeRef> template<typename TNodeRef>
inline TNodeRef AsNodeRef() const; inline TNodeRef AsNodeRef() const;
inline TVMRetValue& operator=(const NodeRef& other); inline TVMRetValue& operator=(const NodeRef& other);
inline TVMRetValue& operator=(const std::shared_ptr<Node>& other); inline TVMRetValue& operator=(const NodePtr<Node>& other);
// type related // type related
inline operator HalideIR::Type() const; inline operator HalideIR::Type() const;
inline TVMRetValue& operator=(const HalideIR::Type& other); inline TVMRetValue& operator=(const HalideIR::Type& other);
...@@ -775,8 +770,8 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -775,8 +770,8 @@ class TVMRetValue : public TVMPODValue_ {
break; break;
} }
case kNodeHandle: { case kNodeHandle: {
SwitchToClass<std::shared_ptr<Node> >( SwitchToClass<NodePtr<Node> >(
kNodeHandle, *other.template ptr<std::shared_ptr<Node> >()); kNodeHandle, *other.template ptr<NodePtr<Node> >());
break; break;
} }
default: { default: {
...@@ -821,7 +816,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -821,7 +816,7 @@ class TVMRetValue : public TVMPODValue_ {
case kStr: delete ptr<std::string>(); break; case kStr: delete ptr<std::string>(); break;
case kFuncHandle: delete ptr<PackedFunc>(); break; case kFuncHandle: delete ptr<PackedFunc>(); break;
case kModuleHandle: delete ptr<Module>(); break; case kModuleHandle: delete ptr<Module>(); break;
case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break; case kNodeHandle: delete ptr<NodePtr<Node> >(); break;
case kNDArrayContainer: { case kNDArrayContainer: {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef(); static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break; break;
......
...@@ -36,7 +36,7 @@ enum AttachType : int { ...@@ -36,7 +36,7 @@ enum AttachType : int {
class Stage : public NodeRef { class Stage : public NodeRef {
public: public:
Stage() {} Stage() {}
explicit Stage(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Stage(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief create a new schedule for op. * \brief create a new schedule for op.
* \param op The operator in the schedule * \param op The operator in the schedule
...@@ -260,7 +260,7 @@ class Stage : public NodeRef { ...@@ -260,7 +260,7 @@ class Stage : public NodeRef {
class Schedule : public NodeRef { class Schedule : public NodeRef {
public: public:
Schedule() {} Schedule() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Schedule(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief Get a copy of current schedule. * \brief Get a copy of current schedule.
* \return The copied schedule. * \return The copied schedule.
...@@ -383,7 +383,7 @@ class Schedule : public NodeRef { ...@@ -383,7 +383,7 @@ class Schedule : public NodeRef {
class IterVarRelation : public NodeRef { class IterVarRelation : public NodeRef {
public: public:
IterVarRelation() {} IterVarRelation() {}
explicit IterVarRelation(std::shared_ptr<Node> n) : NodeRef(n) {} explicit IterVarRelation(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -397,7 +397,7 @@ class IterVarRelation : public NodeRef { ...@@ -397,7 +397,7 @@ class IterVarRelation : public NodeRef {
class IterVarAttr : public NodeRef { class IterVarAttr : public NodeRef {
public: public:
IterVarAttr() {} IterVarAttr() {}
explicit IterVarAttr(std::shared_ptr<Node> n) : NodeRef(n) {} explicit IterVarAttr(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
#ifndef TVM_TENSOR_H_ #ifndef TVM_TENSOR_H_
#define TVM_TENSOR_H_ #define TVM_TENSOR_H_
#include <tvm/container.h>
#include <ir/FunctionBase.h> #include <ir/FunctionBase.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -15,6 +14,7 @@ ...@@ -15,6 +14,7 @@
#include "base.h" #include "base.h"
#include "expr.h" #include "expr.h"
#include "arithmetic.h" #include "arithmetic.h"
#include "node/container.h"
namespace tvm { namespace tvm {
...@@ -33,7 +33,7 @@ class Tensor : public NodeRef { ...@@ -33,7 +33,7 @@ class Tensor : public NodeRef {
public: public:
/*! \brief default constructor, used internally */ /*! \brief default constructor, used internally */
Tensor() {} Tensor() {}
explicit Tensor(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Tensor(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -118,7 +118,7 @@ class Operation : public FunctionRef { ...@@ -118,7 +118,7 @@ class Operation : public FunctionRef {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
Operation() {} Operation() {}
explicit Operation(std::shared_ptr<Node> n) : FunctionRef(n) {} explicit Operation(NodePtr<Node> n) : FunctionRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
......
...@@ -19,7 +19,7 @@ class TensorIntrinNode; ...@@ -19,7 +19,7 @@ class TensorIntrinNode;
class TensorIntrin : public NodeRef { class TensorIntrin : public NodeRef {
public: public:
TensorIntrin() {} TensorIntrin() {}
explicit TensorIntrin(std::shared_ptr<Node> n) : NodeRef(n) {} explicit TensorIntrin(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
......
...@@ -94,7 +94,7 @@ class CompileEngine { ...@@ -94,7 +94,7 @@ class CompileEngine {
return it->second->graph_func; return it->second->graph_func;
} }
GraphFunc f = DoLower(key->graph, key->inputs, key->target, master_idx); GraphFunc f = DoLower(key->graph, key->inputs, key->target, master_idx);
std::shared_ptr<GraphCacheEntryNode> n = std::make_shared<GraphCacheEntryNode>(); auto n = tvm::make_node<GraphCacheEntryNode>();
n->graph_func = f; n->graph_func = f;
n->use_count = 1; n->use_count = 1;
n->master_idx = master_idx; n->master_idx = master_idx;
...@@ -107,8 +107,7 @@ class CompileEngine { ...@@ -107,8 +107,7 @@ class CompileEngine {
Array<NodeRef> items; Array<NodeRef> items;
for (auto& kv : cache_) { for (auto& kv : cache_) {
items.push_back(kv.first); items.push_back(kv.first);
std::shared_ptr<GraphCacheEntryNode> n = auto n = tvm::make_node<GraphCacheEntryNode>(*(kv.second.operator->()));
std::make_shared<GraphCacheEntryNode>(*(kv.second.operator->()));
items.push_back(GraphCacheEntry(n)); items.push_back(GraphCacheEntry(n));
} }
return items; return items;
...@@ -126,7 +125,7 @@ class CompileEngine { ...@@ -126,7 +125,7 @@ class CompileEngine {
// Set the given function on given graph key. // Set the given function on given graph key.
void Set(const GraphKey& key, GraphFunc func) { void Set(const GraphKey& key, GraphFunc func) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
std::shared_ptr<GraphCacheEntryNode> n = std::make_shared<GraphCacheEntryNode>(); auto n = tvm::make_node<GraphCacheEntryNode>();
n->graph_func = func; n->graph_func = func;
n->use_count = 1; n->use_count = 1;
cache_[key] = GraphCacheEntry(n); cache_[key] = GraphCacheEntry(n);
...@@ -265,7 +264,7 @@ class CompileEngine { ...@@ -265,7 +264,7 @@ class CompileEngine {
graph, inputs, target, master_idx, graph, inputs, target, master_idx,
&readable_name, &outputs); &readable_name, &outputs);
std::shared_ptr<GraphFuncNode> gf = std::make_shared<GraphFuncNode>(); auto gf = tvm::make_node<GraphFuncNode>();
gf->target = target; gf->target = target;
gf->func_name = GetUniqeName(readable_name); gf->func_name = GetUniqeName(readable_name);
gf->inputs = inputs; gf->inputs = inputs;
......
...@@ -71,7 +71,7 @@ struct GraphCacheEntryNode : public tvm::Node { ...@@ -71,7 +71,7 @@ struct GraphCacheEntryNode : public tvm::Node {
class GraphCacheEntry : public ::tvm::NodeRef { class GraphCacheEntry : public ::tvm::NodeRef {
public: public:
GraphCacheEntry() {} GraphCacheEntry() {}
explicit GraphCacheEntry(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {} explicit GraphCacheEntry(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {}
GraphCacheEntryNode* operator->() { GraphCacheEntryNode* operator->() {
return static_cast<GraphCacheEntryNode*>(node_.get()); return static_cast<GraphCacheEntryNode*>(node_.get());
} }
......
...@@ -74,8 +74,7 @@ bool GraphKeyEqual::Equal(const GraphKey& a, ...@@ -74,8 +74,7 @@ bool GraphKeyEqual::Equal(const GraphKey& a,
GraphKey GraphKeyNode::make(Graph graph, GraphKey GraphKeyNode::make(Graph graph,
tvm::Array<Tensor> inputs, tvm::Array<Tensor> inputs,
std::string target) { std::string target) {
std::shared_ptr<GraphKeyNode> n auto n = tvm::make_node<GraphKeyNode>();
= std::make_shared<GraphKeyNode>();
n->graph = std::move(graph); n->graph = std::move(graph);
n->inputs = inputs; n->inputs = inputs;
n->target = std::move(target); n->target = std::move(target);
......
...@@ -91,8 +91,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict") ...@@ -91,8 +91,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
tvm::runtime::NDArray temp; tvm::runtime::NDArray temp;
temp.Load(strm); temp.Load(strm);
std::shared_ptr<NDArrayWrapperNode> n auto n = tvm::make_node<NDArrayWrapperNode>();
= std::make_shared<NDArrayWrapperNode>();
n->name = std::move(names[i]); n->name = std::move(names[i]);
n->array = temp; n->array = temp;
ret.push_back(NDArrayWrapper(n)); ret.push_back(NDArrayWrapper(n));
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <nnvm/graph.h> #include <nnvm/graph.h>
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/node/memory.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <vector> #include <vector>
......
...@@ -96,7 +96,7 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute") ...@@ -96,7 +96,7 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute")
const Array<Tensor>& out_info) const Array<Tensor>& out_info)
-> Array<Tensor> { -> Array<Tensor> {
TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, out_info); TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, out_info);
if ((*ret.ptr<std::shared_ptr<tvm::Node> >())->derived_from<tvm::TensorNode>()) { if ((*ret.ptr<::tvm::NodePtr<tvm::Node> >())->derived_from<tvm::TensorNode>()) {
return {ret.operator Tensor()}; return {ret.operator Tensor()};
} else { } else {
return ret; return ret;
......
...@@ -45,11 +45,11 @@ TVM_REGISTER_API("_str") ...@@ -45,11 +45,11 @@ TVM_REGISTER_API("_str")
TVM_REGISTER_API("_Array") TVM_REGISTER_API("_Array")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<std::shared_ptr<Node> > data; std::vector<NodePtr<Node> > data;
for (int i = 0; i < args.size(); ++i) { for (int i = 0; i < args.size(); ++i) {
data.push_back(args[i].node_sptr()); data.push_back(args[i].node_sptr());
} }
auto node = std::make_shared<ArrayNode>(); auto node = make_node<ArrayNode>();
node->data = std::move(data); node->data = std::move(data);
*ret = node; *ret = node;
}); });
...@@ -87,7 +87,7 @@ TVM_REGISTER_API("_Map") ...@@ -87,7 +87,7 @@ TVM_REGISTER_API("_Map")
data.emplace(std::make_pair(args[i].operator std::string(), data.emplace(std::make_pair(args[i].operator std::string(),
args[i + 1].node_sptr())); args[i + 1].node_sptr()));
} }
auto node = std::make_shared<StrMapNode>(); auto node = make_node<StrMapNode>();
node->data = std::move(data); node->data = std::move(data);
*ret = node; *ret = node;
} else { } else {
...@@ -101,7 +101,7 @@ TVM_REGISTER_API("_Map") ...@@ -101,7 +101,7 @@ TVM_REGISTER_API("_Map")
data.emplace(std::make_pair(args[i].node_sptr(), data.emplace(std::make_pair(args[i].node_sptr(),
args[i + 1].node_sptr())); args[i + 1].node_sptr()));
} }
auto node = std::make_shared<MapNode>(); auto node = make_node<MapNode>();
node->data = std::move(data); node->data = std::move(data);
*ret = node; *ret = node;
} }
...@@ -163,7 +163,7 @@ TVM_REGISTER_API("_MapItems") ...@@ -163,7 +163,7 @@ TVM_REGISTER_API("_MapItems")
auto& sptr = args[0].node_sptr(); auto& sptr = args[0].node_sptr();
if (sptr->is_type<MapNode>()) { if (sptr->is_type<MapNode>()) {
auto* n = static_cast<const MapNode*>(sptr.get()); auto* n = static_cast<const MapNode*>(sptr.get());
auto rkvs = std::make_shared<ArrayNode>(); auto rkvs = make_node<ArrayNode>();
for (const auto& kv : n->data) { for (const auto& kv : n->data) {
rkvs->data.push_back(kv.first); rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second); rkvs->data.push_back(kv.second);
...@@ -171,7 +171,7 @@ TVM_REGISTER_API("_MapItems") ...@@ -171,7 +171,7 @@ TVM_REGISTER_API("_MapItems")
*ret = rkvs; *ret = rkvs;
} else { } else {
auto* n = static_cast<const StrMapNode*>(sptr.get()); auto* n = static_cast<const StrMapNode*>(sptr.get());
auto rkvs = std::make_shared<ArrayNode>(); auto rkvs = make_node<ArrayNode>();
for (const auto& kv : n->data) { for (const auto& kv : n->data) {
rkvs->data.push_back(ir::StringImm::make(kv.first).node_); rkvs->data.push_back(ir::StringImm::make(kv.first).node_);
rkvs->data.push_back(kv.second); rkvs->data.push_back(kv.second);
......
...@@ -28,7 +28,7 @@ struct TVMAPIThreadLocalEntry { ...@@ -28,7 +28,7 @@ struct TVMAPIThreadLocalEntry {
/*! \brief Thread local store that can be used to hold return values. */ /*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore; typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore;
using TVMAPINode = std::shared_ptr<Node>; using TVMAPINode = NodePtr<Node>;
struct APIAttrGetter : public AttrVisitor { struct APIAttrGetter : public AttrVisitor {
std::string skey; std::string skey;
......
...@@ -48,7 +48,7 @@ struct ComExprEntry { ...@@ -48,7 +48,7 @@ struct ComExprEntry {
}; };
// canonical expression for communicative expression. // canonical expression for communicative expression.
struct ComExprNode { struct ComExprNode : public NodeBase {
// base constant value. // base constant value.
int64_t base{0}; int64_t base{0};
// The values to be sumed. // The values to be sumed.
...@@ -60,7 +60,7 @@ struct ComExpr { ...@@ -60,7 +60,7 @@ struct ComExpr {
public: public:
// constructor // constructor
ComExpr() {} ComExpr() {}
explicit ComExpr(std::shared_ptr<ComExprNode> ptr) : ptr_(ptr) {} explicit ComExpr(NodePtr<ComExprNode> ptr) : ptr_(ptr) {}
// get member // get member
ComExprNode* operator->() const { ComExprNode* operator->() const {
return ptr_.get(); return ptr_.get();
...@@ -106,7 +106,7 @@ struct ComExpr { ...@@ -106,7 +106,7 @@ struct ComExpr {
} }
private: private:
std::shared_ptr<ComExprNode> ptr_; NodePtr<ComExprNode> ptr_;
}; };
// binary comparison op. // binary comparison op.
...@@ -173,7 +173,7 @@ class Canonical::Internal : public IRMutator { ...@@ -173,7 +173,7 @@ class Canonical::Internal : public IRMutator {
if (sum.defined()) return sum; if (sum.defined()) return sum;
const int64_t *v1 = as_const_int(value); const int64_t *v1 = as_const_int(value);
const uint64_t *v2 = as_const_uint(value); const uint64_t *v2 = as_const_uint(value);
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>(); auto n = make_node<ComExprNode>();
if (v1) { if (v1) {
n->base = *v1; n->base = *v1;
} else if (v2) { } else if (v2) {
...@@ -471,8 +471,8 @@ class Canonical::Internal : public IRMutator { ...@@ -471,8 +471,8 @@ class Canonical::Internal : public IRMutator {
Type type = coeff.type(); Type type = coeff.type();
int64_t value = GetConstIntValue(coeff); int64_t value = GetConstIntValue(coeff);
if (value < 0) return {}; if (value < 0) return {};
std::shared_ptr<ComExprNode> xnode = std::make_shared<ComExprNode>(); auto xnode = make_node<ComExprNode>();
std::shared_ptr<ComExprNode> ynode = std::make_shared<ComExprNode>(); auto ynode = make_node<ComExprNode>();
if (a->base % value == 0) { if (a->base % value == 0) {
xnode->base = a->base; xnode->base = a->base;
} else { } else {
...@@ -507,7 +507,7 @@ class Canonical::Internal : public IRMutator { ...@@ -507,7 +507,7 @@ class Canonical::Internal : public IRMutator {
std::vector<ComExpr> pair = TryLinearEquation(a, v); std::vector<ComExpr> pair = TryLinearEquation(a, v);
if (pair.size() == 0) { if (pair.size() == 0) {
int64_t value = GetConstIntValue(v); int64_t value = GetConstIntValue(v);
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>(); auto n = make_node<ComExprNode>();
n->base = a->base % value; n->base = a->base % value;
for (auto e : a->elem) { for (auto e : a->elem) {
if (e.scale % value == 0) continue; if (e.scale % value == 0) continue;
...@@ -554,8 +554,7 @@ class Canonical::Internal : public IRMutator { ...@@ -554,8 +554,7 @@ class Canonical::Internal : public IRMutator {
if (value == 0) { if (value == 0) {
return make_zero(v.type()); return make_zero(v.type());
} }
std::shared_ptr<ComExprNode> vsum = auto vsum = make_node<ComExprNode>(*a.operator->());
std::make_shared<ComExprNode>(*a.operator->());
vsum->base *= value; vsum->base *= value;
for (auto& e : vsum->elem) { for (auto& e : vsum->elem) {
e.scale *= value; e.scale *= value;
...@@ -576,7 +575,7 @@ class Canonical::Internal : public IRMutator { ...@@ -576,7 +575,7 @@ class Canonical::Internal : public IRMutator {
ComExpr SumAdd_(const ComExpr& suma, ComExpr SumAdd_(const ComExpr& suma,
const ComExpr& sumb, const ComExpr& sumb,
int bscale) { int bscale) {
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>(); auto n = make_node<ComExprNode>();
n->base = suma->base + sumb->base * bscale; n->base = suma->base + sumb->base * bscale;
// merge of suma and sumb; // merge of suma and sumb;
size_t i = 0, j = 0; size_t i = 0, j = 0;
......
...@@ -329,7 +329,7 @@ inline IntSet AsStrideSet(IntSet a) { ...@@ -329,7 +329,7 @@ inline IntSet AsStrideSet(IntSet a) {
if (a.as<StrideSet>()) return a; if (a.as<StrideSet>()) return a;
const IntervalSet* s = a.as<IntervalSet>(); const IntervalSet* s = a.as<IntervalSet>();
CHECK(s->i.is_bounded()); CHECK(s->i.is_bounded());
std::shared_ptr<StrideSet> n = std::make_shared<StrideSet>(); NodePtr<StrideSet> n = make_node<StrideSet>();
n->base = s->i; n->base = s->i;
return IntSet(n); return IntSet(n);
} }
...@@ -348,7 +348,7 @@ inline IntSet CombineSets<Add>(IntSet a, IntSet b) { ...@@ -348,7 +348,7 @@ inline IntSet CombineSets<Add>(IntSet a, IntSet b) {
b = AsStrideSet(b); b = AsStrideSet(b);
const StrideSet* a_stride = a.as<StrideSet>(); const StrideSet* a_stride = a.as<StrideSet>();
const StrideSet* b_stride = b.as<StrideSet>(); const StrideSet* b_stride = b.as<StrideSet>();
auto n = std::make_shared<StrideSet>(*a_stride); auto n = make_node<StrideSet>(*a_stride);
for (size_t i = 0; i < b_stride->extents.size(); ++i) { for (size_t i = 0; i < b_stride->extents.size(); ++i) {
n->extents.push_back(b_stride->extents[i]); n->extents.push_back(b_stride->extents[i]);
n->strides.push_back(b_stride->strides[i]); n->strides.push_back(b_stride->strides[i]);
......
...@@ -21,14 +21,14 @@ struct IntervalSet : public IntSetNode { ...@@ -21,14 +21,14 @@ struct IntervalSet : public IntSetNode {
Interval i; Interval i;
static IntSet make(Interval i) { static IntSet make(Interval i) {
std::shared_ptr<IntervalSet> n = NodePtr<IntervalSet> n =
std::make_shared<IntervalSet>(); make_node<IntervalSet>();
n->i = i; n->i = i;
return IntSet(n); return IntSet(n);
} }
static IntSet make(Expr min, Expr max) { static IntSet make(Expr min, Expr max) {
std::shared_ptr<IntervalSet> n = NodePtr<IntervalSet> n =
std::make_shared<IntervalSet>(); make_node<IntervalSet>();
n->i.min = min; n->i.min = min;
n->i.max = max; n->i.max = max;
return IntSet(n); return IntSet(n);
......
...@@ -159,7 +159,7 @@ IntSet EvalModular(const Expr& e, ...@@ -159,7 +159,7 @@ IntSet EvalModular(const Expr& e,
CHECK(m) << "Need to pass ModularSet for Modular Analysis"; CHECK(m) << "Need to pass ModularSet for Modular Analysis";
mmap[kv.first.get()] = m->e; mmap[kv.first.get()] = m->e;
} }
std::shared_ptr<ModularSet> n = std::make_shared<ModularSet>(); NodePtr<ModularSet> n = make_node<ModularSet>();
n->e = ModularEvaluator(mmap)(e); n->e = ModularEvaluator(mmap)(e);
return IntSet(n); return IntSet(n);
} }
......
...@@ -32,7 +32,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -32,7 +32,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
*/ */
Target CreateTarget(const std::string& target_name, Target CreateTarget(const std::string& target_name,
const std::vector<std::string>& options) { const std::vector<std::string>& options) {
auto target = Target(std::make_shared<TargetNode>()); auto target = Target(make_node<TargetNode>());
auto t = static_cast<TargetNode*>(target.node_.get()); auto t = static_cast<TargetNode*>(target.node_.get());
t->target_name = target_name; t->target_name = target_name;
...@@ -475,7 +475,7 @@ runtime::Module build(const Array<LoweredFunc>& funcs, ...@@ -475,7 +475,7 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
} }
BuildConfig build_config() { BuildConfig build_config() {
return BuildConfig(std::make_shared<BuildConfigNode>()); return BuildConfig(make_node<BuildConfigNode>());
} }
/*! \brief Entry to hold the BuildConfig context stack. */ /*! \brief Entry to hold the BuildConfig context stack. */
...@@ -533,7 +533,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -533,7 +533,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
struct GenericFunc::Manager { struct GenericFunc::Manager {
std::unordered_map<std::string, std::shared_ptr<Node> > fmap; std::unordered_map<std::string, NodePtr<Node> > fmap;
// mutex // mutex
std::mutex mutex; std::mutex mutex;
...@@ -551,7 +551,7 @@ GenericFunc GenericFunc::Get(const std::string& name) { ...@@ -551,7 +551,7 @@ GenericFunc GenericFunc::Get(const std::string& name) {
std::lock_guard<std::mutex>(m->mutex); std::lock_guard<std::mutex>(m->mutex);
auto it = m->fmap.find(name); auto it = m->fmap.find(name);
if (it == m->fmap.end()) { if (it == m->fmap.end()) {
auto f = std::make_shared<GenericFuncNode>(); auto f = make_node<GenericFuncNode>();
f->name_ = name; f->name_ = name;
m->fmap[name] = f; m->fmap[name] = f;
return GenericFunc(f); return GenericFunc(f);
...@@ -669,7 +669,7 @@ TVM_REGISTER_API("_BuildConfigGetAddLowerPassInfo") ...@@ -669,7 +669,7 @@ TVM_REGISTER_API("_BuildConfigGetAddLowerPassInfo")
TVM_REGISTER_API("_GenericFuncCreate") TVM_REGISTER_API("_GenericFuncCreate")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GenericFunc(std::make_shared<GenericFuncNode>()); *ret = GenericFunc(make_node<GenericFuncNode>());
}); });
TVM_REGISTER_API("_GenericFuncGetGlobal") TVM_REGISTER_API("_GenericFuncGetGlobal")
......
...@@ -17,14 +17,14 @@ using namespace ir; ...@@ -17,14 +17,14 @@ using namespace ir;
ControlSignal ControlSignalNode::make( ControlSignal ControlSignalNode::make(
ControlSignalType type, int advance_size) { ControlSignalType type, int advance_size) {
auto n = std::make_shared<ControlSignalNode>(); auto n = make_node<ControlSignalNode>();
n->ctrl_type = type; n->ctrl_type = type;
n->advance_size = advance_size; n->advance_size = advance_size;
return ControlSignal(n); return ControlSignal(n);
} }
StageInput StageInputNode::make(Var var, StageInputType input_type) { StageInput StageInputNode::make(Var var, StageInputType input_type) {
std::shared_ptr<StageInputNode> n = std::make_shared<StageInputNode>(); NodePtr<StageInputNode> n = make_node<StageInputNode>();
n->var = var; n->var = var;
n->input_type = input_type; n->input_type = input_type;
return StageInput(n); return StageInput(n);
...@@ -81,7 +81,7 @@ class PipelineExtractor: public IRVisitor { ...@@ -81,7 +81,7 @@ class PipelineExtractor: public IRVisitor {
arg_handle_[arg.get()] = arg; arg_handle_[arg.get()] = arg;
} }
} }
pipeline_ = std::make_shared<PipelineNode>(); pipeline_ = make_node<PipelineNode>();
this->Visit(f->body); this->Visit(f->body);
// setup channels // setup channels
for (const auto &kv : cmap_) { for (const auto &kv : cmap_) {
...@@ -113,7 +113,7 @@ class PipelineExtractor: public IRVisitor { ...@@ -113,7 +113,7 @@ class PipelineExtractor: public IRVisitor {
if (cb.node != nullptr) { if (cb.node != nullptr) {
CHECK(cb.node->channel.same_as(ch)); CHECK(cb.node->channel.same_as(ch));
} else { } else {
cb.node = std::make_shared<ChannelBlockNode>(); cb.node = make_node<ChannelBlockNode>();
cb.node->channel = ch; cb.node->channel = ch;
} }
if (op->attr_key == attr::channel_read_scope) { if (op->attr_key == attr::channel_read_scope) {
...@@ -167,8 +167,8 @@ class PipelineExtractor: public IRVisitor { ...@@ -167,8 +167,8 @@ class PipelineExtractor: public IRVisitor {
// The replace logic // The replace logic
StageInputReplacer repl(var_info_); StageInputReplacer repl(var_info_);
// Setup the compute block. // Setup the compute block.
std::shared_ptr<ComputeBlockNode> compute = NodePtr<ComputeBlockNode> compute =
std::make_shared<ComputeBlockNode>(); make_node<ComputeBlockNode>();
compute->loop = Array<Stmt>(loop_); compute->loop = Array<Stmt>(loop_);
// setup the advance triggers // setup the advance triggers
for (const auto& e : trigger_) { for (const auto& e : trigger_) {
...@@ -180,8 +180,8 @@ class PipelineExtractor: public IRVisitor { ...@@ -180,8 +180,8 @@ class PipelineExtractor: public IRVisitor {
} else { } else {
ch = Channel(attr->node.node_); ch = Channel(attr->node.node_);
} }
std::shared_ptr<SignalTriggerNode> trigger NodePtr<SignalTriggerNode> trigger
= std::make_shared<SignalTriggerNode>(); = make_node<SignalTriggerNode>();
trigger->channel_var = ch->handle_var; trigger->channel_var = ch->handle_var;
// predicate for the trigger // predicate for the trigger
Expr predicate = const_true(); Expr predicate = const_true();
...@@ -249,7 +249,7 @@ class PipelineExtractor: public IRVisitor { ...@@ -249,7 +249,7 @@ class PipelineExtractor: public IRVisitor {
CHECK(!cmap_.count(var)) CHECK(!cmap_.count(var))
<< "Multiple access to the same handle"; << "Multiple access to the same handle";
ChannelEntry& cb = cmap_[var]; ChannelEntry& cb = cmap_[var];
cb.node = std::make_shared<ChannelBlockNode>(); cb.node = make_node<ChannelBlockNode>();
cb.node->channel = ChannelNode::make(arg_handle_.at(var), dtype); cb.node->channel = ChannelNode::make(arg_handle_.at(var), dtype);
return cb.node->channel; return cb.node->channel;
} }
...@@ -257,7 +257,7 @@ class PipelineExtractor: public IRVisitor { ...@@ -257,7 +257,7 @@ class PipelineExtractor: public IRVisitor {
private: private:
// The channel information. // The channel information.
struct ChannelEntry { struct ChannelEntry {
std::shared_ptr<ChannelBlockNode> node; NodePtr<ChannelBlockNode> node;
int read_ref_count{0}; int read_ref_count{0};
int write_ref_count{0}; int write_ref_count{0};
}; };
...@@ -276,7 +276,7 @@ class PipelineExtractor: public IRVisitor { ...@@ -276,7 +276,7 @@ class PipelineExtractor: public IRVisitor {
// The argument handle map // The argument handle map
std::unordered_map<const Variable*, Var> arg_handle_; std::unordered_map<const Variable*, Var> arg_handle_;
// The result block. // The result block.
std::shared_ptr<PipelineNode> pipeline_; NodePtr<PipelineNode> pipeline_;
}; };
Pipeline MakePipeline(LoweredFunc f) { Pipeline MakePipeline(LoweredFunc f) {
......
...@@ -50,7 +50,7 @@ inline VPIHandleNode* VPIHandle::get() const { ...@@ -50,7 +50,7 @@ inline VPIHandleNode* VPIHandle::get() const {
VPIHandle VPIHandleCreate( VPIHandle VPIHandleCreate(
const std::shared_ptr<VPISessionEntry>& sess, const std::shared_ptr<VPISessionEntry>& sess,
VPIRawHandle handle) { VPIRawHandle handle) {
std::shared_ptr<VPIHandleNode> n = std::make_shared<VPIHandleNode>(); auto n = make_node<VPIHandleNode>();
n->sess = sess; n->sess = sess;
n->handle = handle; n->handle = handle;
return VPIHandle(n); return VPIHandle(n);
...@@ -102,7 +102,7 @@ int VPIGetIntProp(VPIHandleNode* h, int code) { ...@@ -102,7 +102,7 @@ int VPIGetIntProp(VPIHandleNode* h, int code) {
} }
VPISession VPISession::make(int h_pipe_read, int h_pipe_write) { VPISession VPISession::make(int h_pipe_read, int h_pipe_write) {
std::shared_ptr<VPISessionNode> n = std::make_shared<VPISessionNode>(); auto n = make_node<VPISessionNode>();
n->sess = std::make_shared<VPISessionEntry>(h_pipe_read, h_pipe_write); n->sess = std::make_shared<VPISessionEntry>(h_pipe_read, h_pipe_write);
n->sess->in_control = true; n->sess->in_control = true;
VPISession sess(n); VPISession sess(n);
......
...@@ -27,7 +27,7 @@ using runtime::PackedFunc; ...@@ -27,7 +27,7 @@ using runtime::PackedFunc;
class VPISession : public NodeRef { class VPISession : public NodeRef {
public: public:
VPISession() {} VPISession() {}
explicit VPISession(std::shared_ptr<Node> n) : NodeRef(n) {} explicit VPISession(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief Get handle by name. * \brief Get handle by name.
* \param name The name of the handle. * \param name The name of the handle.
...@@ -63,7 +63,7 @@ class VPISession : public NodeRef { ...@@ -63,7 +63,7 @@ class VPISession : public NodeRef {
class VPIHandle : public NodeRef { class VPIHandle : public NodeRef {
public: public:
VPIHandle() {} VPIHandle() {}
explicit VPIHandle(std::shared_ptr<Node> n) : NodeRef(n) {} explicit VPIHandle(NodePtr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief Get handle by name. * \brief Get handle by name.
* \param name The name of the handle. * \param name The name of the handle.
......
...@@ -11,10 +11,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -11,10 +11,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "EnvFunc(" << op->name << ")"; p->stream << "EnvFunc(" << op->name << ")";
}); });
std::shared_ptr<EnvFuncNode> CreateEnvNode(const std::string& name) { NodePtr<EnvFuncNode> CreateEnvNode(const std::string& name) {
auto* f = runtime::Registry::Get(name); auto* f = runtime::Registry::Get(name);
CHECK(f != nullptr) << "Cannot find global function \'" << name << '\''; CHECK(f != nullptr) << "Cannot find global function \'" << name << '\'';
std::shared_ptr<EnvFuncNode> n = std::make_shared<EnvFuncNode>(); NodePtr<EnvFuncNode> n = make_node<EnvFuncNode>();
n->func = *f; n->func = *f;
n->name = name; n->name = name;
return n; return n;
......
...@@ -30,7 +30,7 @@ Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const { ...@@ -30,7 +30,7 @@ Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
} }
Attrs DictAttrsNode::make(Map<std::string, NodeRef> dict) { Attrs DictAttrsNode::make(Map<std::string, NodeRef> dict) {
std::shared_ptr<DictAttrsNode> n = std::make_shared<DictAttrsNode>(); NodePtr<DictAttrsNode> n = make_node<DictAttrsNode>();
n->dict = std::move(dict); n->dict = std::move(dict);
return Attrs(n); return Attrs(n);
} }
......
...@@ -289,7 +289,7 @@ Buffer Buffer::MakeStrideView() const { ...@@ -289,7 +289,7 @@ Buffer Buffer::MakeStrideView() const {
if ((*this)->strides.size() != 0) return *this; if ((*this)->strides.size() != 0) return *this;
if ((*this)->shape.size() == 0) return *this; if ((*this)->shape.size() == 0) return *this;
std::vector<Expr> temp; std::vector<Expr> temp;
auto n = std::make_shared<BufferNode>(*operator->()); auto n = make_node<BufferNode>(*operator->());
Expr acc = make_const(n->DefaultIndexType(), 1); Expr acc = make_const(n->DefaultIndexType(), 1);
for (size_t i = n->shape.size(); i != 0 ; --i) { for (size_t i = n->shape.size(); i != 0 ; --i) {
temp.push_back(acc); temp.push_back(acc);
...@@ -373,7 +373,7 @@ Buffer BufferNode::make(Var data, ...@@ -373,7 +373,7 @@ Buffer BufferNode::make(Var data,
std::string scope, std::string scope,
int data_alignment, int data_alignment,
int offset_factor) { int offset_factor) {
auto n = std::make_shared<BufferNode>(); auto n = make_node<BufferNode>();
n->data = std::move(data); n->data = std::move(data);
n->dtype = dtype; n->dtype = dtype;
n->shape = std::move(shape); n->shape = std::move(shape);
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace tvm { namespace tvm {
Channel ChannelNode::make(Var handle_var, Type dtype) { Channel ChannelNode::make(Var handle_var, Type dtype) {
auto n = std::make_shared<ChannelNode>(); auto n = make_node<ChannelNode>();
n->handle_var = handle_var; n->handle_var = handle_var;
n->dtype = dtype; n->dtype = dtype;
return Channel(n); return Channel(n);
......
...@@ -13,18 +13,18 @@ namespace tvm { ...@@ -13,18 +13,18 @@ namespace tvm {
using HalideIR::IR::RangeNode; using HalideIR::IR::RangeNode;
Range::Range(Expr begin, Expr end) Range::Range(Expr begin, Expr end)
: Range(std::make_shared<RangeNode>( : Range(make_node<RangeNode>(
begin, begin,
is_zero(begin) ? end : (end - begin))) { is_zero(begin) ? end : (end - begin))) {
} }
Range Range::make_by_min_extent(Expr min, Expr extent) { Range Range::make_by_min_extent(Expr min, Expr extent) {
return Range(std::make_shared<HalideIR::IR::RangeNode>(min, extent)); return Range(make_node<HalideIR::IR::RangeNode>(min, extent));
} }
IterVar IterVarNode::make(Range dom, Var var, IterVar IterVarNode::make(Range dom, Var var,
IterVarType t, std::string thread_tag) { IterVarType t, std::string thread_tag) {
std::shared_ptr<IterVarNode> n = std::make_shared<IterVarNode>(); NodePtr<IterVarNode> n = make_node<IterVarNode>();
n->dom = dom; n->dom = dom;
n->var = var; n->var = var;
n->iter_type = t; n->iter_type = t;
......
...@@ -52,7 +52,7 @@ CommReducer CommReducerNode::make(Array<Var> lhs, ...@@ -52,7 +52,7 @@ CommReducer CommReducerNode::make(Array<Var> lhs,
Array<Var> rhs, Array<Var> rhs,
Array<Expr> result, Array<Expr> result,
Array<Expr> identity_element) { Array<Expr> identity_element) {
auto node = std::make_shared<CommReducerNode>(); auto node = make_node<CommReducerNode>();
node->lhs = lhs; node->lhs = lhs;
node->rhs = rhs; node->rhs = rhs;
node->result = result; node->result = result;
...@@ -83,7 +83,7 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source, ...@@ -83,7 +83,7 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source,
if (!condition.defined()) { if (!condition.defined()) {
condition = const_true(); condition = const_true();
} }
auto n = std::make_shared<Reduce>(); auto n = make_node<Reduce>();
CHECK(source.defined()); CHECK(source.defined());
for (size_t i = 0; i < axis.size(); ++i) { for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined()); CHECK(axis[i].defined());
......
/*!
* Copyright (c) 2018 by Contributors
* Implementation of IR Node API
* \file node.cc
*/
#include <tvm/node/node.h>
#include <memory>
#include <atomic>
#include <mutex>
#include <unordered_map>
namespace tvm {
namespace {
// single manager of operator information.
struct TypeManager {
// mutex to avoid registration from multiple threads.
// recursive is needed for trigger(which calls UpdateAttrMap)
std::mutex mutex;
std::atomic<uint32_t> type_counter{0};
std::unordered_map<std::string, uint32_t> key2index;
std::vector<std::string> index2key;
// get singleton of the
static TypeManager* Global() {
static TypeManager inst;
return &inst;
}
};
} // namespace
const bool Node::_DerivedFrom(uint32_t tid) const {
static uint32_t tindex = TypeKey2Index(Node::_type_key);
return tid == tindex;
}
// this is slow, usually caller always hold the result in a static variable.
uint32_t Node::TypeKey2Index(const char* key) {
TypeManager *t = TypeManager::Global();
std::lock_guard<std::mutex>(t->mutex);
std::string skey = key;
auto it = t->key2index.find(skey);
if (it != t->key2index.end()) {
return it->second;
}
uint32_t tid = ++(t->type_counter);
t->key2index[skey] = tid;
t->index2key.push_back(skey);
return tid;
}
const char* Node::TypeIndex2Key(uint32_t index) {
TypeManager *t = TypeManager::Global();
std::lock_guard<std::mutex>(t->mutex);
internal_assert(index != 0);
return t->index2key.at(index - 1).c_str();
}
} // namespace tvm
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/attrs.h> #include <tvm/attrs.h>
#include <tvm/container.h> #include <tvm/node/container.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <dmlc/json.h> #include <dmlc/json.h>
...@@ -248,7 +248,7 @@ class JSONAttrGetter : public AttrVisitor { ...@@ -248,7 +248,7 @@ class JSONAttrGetter : public AttrVisitor {
class JSONAttrSetter : public AttrVisitor { class JSONAttrSetter : public AttrVisitor {
public: public:
const std::vector<std::shared_ptr<Node> >* node_list_; const std::vector<NodePtr<Node> >* node_list_;
const std::vector<runtime::NDArray>* tensor_list_; const std::vector<runtime::NDArray>* tensor_list_;
JSONNode* node_; JSONNode* node_;
...@@ -401,13 +401,13 @@ std::string SaveJSON(const NodeRef& n) { ...@@ -401,13 +401,13 @@ std::string SaveJSON(const NodeRef& n) {
return os.str(); return os.str();
} }
std::shared_ptr<Node> LoadJSON_(std::string json_str) { NodePtr<Node> LoadJSON_(std::string json_str) {
std::istringstream is(json_str); std::istringstream is(json_str);
dmlc::JSONReader reader(&is); dmlc::JSONReader reader(&is);
JSONGraph jgraph; JSONGraph jgraph;
// load in json graph. // load in json graph.
jgraph.Load(&reader); jgraph.Load(&reader);
std::vector<std::shared_ptr<Node> > nodes; std::vector<NodePtr<Node> > nodes;
std::vector<runtime::NDArray> tensors; std::vector<runtime::NDArray> tensors;
// load in tensors // load in tensors
for (const std::string& blob : jgraph.b64ndarrays) { for (const std::string& blob : jgraph.b64ndarrays) {
...@@ -427,7 +427,7 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) { ...@@ -427,7 +427,7 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) {
<< "Node type \'" << jnode.type_key << "\' is not registered in TVM"; << "Node type \'" << jnode.type_key << "\' is not registered in TVM";
nodes.emplace_back(f->fcreator(jnode.global_key)); nodes.emplace_back(f->fcreator(jnode.global_key));
} else { } else {
nodes.emplace_back(std::shared_ptr<Node>()); nodes.emplace_back(NodePtr<Node>());
} }
} }
CHECK_EQ(nodes.size(), jgraph.nodes.size()); CHECK_EQ(nodes.size(), jgraph.nodes.size());
...@@ -526,7 +526,7 @@ void MakeNode(const TVMArgs& args, TVMRetValue* rv) { ...@@ -526,7 +526,7 @@ void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1); TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1);
CHECK(f->fglobal_key == nullptr) CHECK(f->fglobal_key == nullptr)
<< "Cannot make node type \'" << type_key << "\' with global_key."; << "Cannot make node type \'" << type_key << "\' with global_key.";
std::shared_ptr<Node> n = f->fcreator(empty_str); NodePtr<Node> n = f->fcreator(empty_str);
if (n->derived_from<BaseAttrsNode>()) { if (n->derived_from<BaseAttrsNode>()) {
static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs); static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
} else { } else {
......
...@@ -30,7 +30,7 @@ Tensor TensorNode::make(Array<Expr> shape, ...@@ -30,7 +30,7 @@ Tensor TensorNode::make(Array<Expr> shape,
Type dtype, Type dtype,
Operation op, Operation op,
int value_index) { int value_index) {
auto n = std::make_shared<TensorNode>(); auto n = make_node<TensorNode>();
n->shape = std::move(shape); n->shape = std::move(shape);
n->dtype = dtype; n->dtype = dtype;
n->op = op; n->op = op;
...@@ -47,7 +47,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -47,7 +47,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(TensorNode); TVM_REGISTER_NODE_TYPE(TensorNode);
Tensor Operation::output(size_t i) const { Tensor Operation::output(size_t i) const {
auto node = std::make_shared<TensorNode>(); auto node = make_node<TensorNode>();
node->op = *this; node->op = *this;
node->value_index = i; node->value_index = i;
node->dtype = (*this)->output_dtype(i); node->dtype = (*this)->output_dtype(i);
...@@ -62,7 +62,7 @@ TensorIntrin TensorIntrinNode::make(std::string name, ...@@ -62,7 +62,7 @@ TensorIntrin TensorIntrinNode::make(std::string name,
Stmt body, Stmt body,
Stmt reduce_init, Stmt reduce_init,
Stmt reduce_update) { Stmt reduce_update) {
auto n = std::make_shared<TensorIntrinNode>(); auto n = make_node<TensorIntrinNode>();
n->name = std::move(name); n->name = std::move(name);
n->op = std::move(op); n->op = std::move(op);
n->inputs = std::move(inputs); n->inputs = std::move(inputs);
......
...@@ -69,7 +69,7 @@ Tensor compute(Array<Expr> shape, ...@@ -69,7 +69,7 @@ Tensor compute(Array<Expr> shape,
std::string name, std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs) { Map<std::string, NodeRef> attrs) {
auto op_node = std::make_shared<ComputeOpNode>(); auto op_node = make_node<ComputeOpNode>();
// compute dimension. // compute dimension.
size_t ndim = shape.size(); size_t ndim = shape.size();
std::vector<IterVar> axis; std::vector<IterVar> axis;
...@@ -91,7 +91,7 @@ Array<Tensor> compute(Array<Expr> shape, ...@@ -91,7 +91,7 @@ Array<Tensor> compute(Array<Expr> shape,
std::string name, std::string name,
std::string tag, std::string tag,
Map<std::string, NodeRef> attrs) { Map<std::string, NodeRef> attrs) {
auto op_node = std::make_shared<ComputeOpNode>(); auto op_node = make_node<ComputeOpNode>();
// compute dimension. // compute dimension.
size_t ndim = shape.size(); size_t ndim = shape.size();
std::vector<IterVar> axis; std::vector<IterVar> axis;
...@@ -117,7 +117,7 @@ Operation ComputeOpNode::make(std::string name, ...@@ -117,7 +117,7 @@ Operation ComputeOpNode::make(std::string name,
Map<std::string, NodeRef> attrs, Map<std::string, NodeRef> attrs,
Array<IterVar> axis, Array<IterVar> axis,
Array<Expr> body) { Array<Expr> body) {
auto n = std::make_shared<ComputeOpNode>(); auto n = make_node<ComputeOpNode>();
n->name = std::move(name); n->name = std::move(name);
n->tag = std::move(tag); n->tag = std::move(tag);
n->attrs = std::move(attrs); n->attrs = std::move(attrs);
...@@ -163,7 +163,7 @@ Operation ComputeOpNode::ReplaceInputs( ...@@ -163,7 +163,7 @@ Operation ComputeOpNode::ReplaceInputs(
if (!new_reduce.same_as(this->body[0])) { if (!new_reduce.same_as(this->body[0])) {
const ir::Reduce* r = new_reduce.as<ir::Reduce>(); const ir::Reduce* r = new_reduce.as<ir::Reduce>();
for (size_t k = 0; k < this->body.size(); ++k) { for (size_t k = 0; k < this->body.size(); ++k) {
std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r); auto n = make_node<ir::Reduce>(*r);
n->value_index = static_cast<int>(k); n->value_index = static_cast<int>(k);
n->type = r->source[k].type(); n->type = r->source[k].type();
arr.push_back(Expr(n)); arr.push_back(Expr(n));
......
...@@ -43,7 +43,7 @@ Operation ExternOpNode::make(std::string name, ...@@ -43,7 +43,7 @@ Operation ExternOpNode::make(std::string name,
Array<Buffer> input_placeholders, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Array<Buffer> output_placeholders,
Stmt body) { Stmt body) {
auto n = std::make_shared<ExternOpNode>(); auto n = make_node<ExternOpNode>();
n->name = std::move(name); n->name = std::move(name);
n->tag = std::move(tag); n->tag = std::move(tag);
n->attrs = std::move(attrs); n->attrs = std::move(attrs);
...@@ -68,7 +68,7 @@ Operation ExternOpNode::ReplaceInputs( ...@@ -68,7 +68,7 @@ Operation ExternOpNode::ReplaceInputs(
const Operation& self, const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const { const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this); CHECK_EQ(self.operator->(), this);
auto n = std::make_shared<ExternOpNode>(*this); auto n = make_node<ExternOpNode>(*this);
n->body = op::ReplaceTensor(this->body, rmap); n->body = op::ReplaceTensor(this->body, rmap);
for (size_t i = 0; i < n->inputs.size(); ++i) { for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i]; Tensor t = n->inputs[i];
......
...@@ -36,7 +36,7 @@ Array<Expr> PlaceholderOpNode::output_shape(size_t i) const { ...@@ -36,7 +36,7 @@ Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
Operation PlaceholderOpNode::make(std::string name, Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape, Array<Expr> shape,
Type dtype) { Type dtype) {
auto n = std::make_shared<PlaceholderOpNode>(); auto n = make_node<PlaceholderOpNode>();
n->name = name; n->name = name;
n->shape = shape; n->shape = shape;
n->dtype = dtype; n->dtype = dtype;
......
...@@ -51,7 +51,7 @@ Operation ScanOpNode::make(std::string name, ...@@ -51,7 +51,7 @@ Operation ScanOpNode::make(std::string name,
Array<Tensor> update, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> state_placeholder,
Array<Tensor> inputs) { Array<Tensor> inputs) {
auto n = std::make_shared<ScanOpNode>(); auto n = make_node<ScanOpNode>();
CHECK_EQ(init.size(), update.size()); CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size()); CHECK_EQ(init.size(), state_placeholder.size());
...@@ -135,7 +135,7 @@ Operation ScanOpNode::ReplaceInputs( ...@@ -135,7 +135,7 @@ Operation ScanOpNode::ReplaceInputs(
const Operation& self, const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const { const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this); CHECK_EQ(self.operator->(), this);
std::shared_ptr<ScanOpNode> n = std::make_shared<ScanOpNode>(*this); auto n = make_node<ScanOpNode>(*this);
for (size_t i = 0; i < n->init.size(); ++i) { for (size_t i = 0; i < n->init.size(); ++i) {
if (rmap.count(n->init[i])) { if (rmap.count(n->init[i])) {
n->init.Set(i, rmap.at(n->init[i])); n->init.Set(i, rmap.at(n->init[i]));
......
...@@ -90,7 +90,7 @@ class ContextCallCombiner final : public IRMutator { ...@@ -90,7 +90,7 @@ class ContextCallCombiner final : public IRMutator {
}; };
LoweredFunc CombineContextCall(LoweredFunc f) { LoweredFunc CombineContextCall(LoweredFunc f) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = ContextCallCombiner().Combine(n->body); n->body = ContextCallCombiner().Combine(n->body);
return LoweredFunc(n); return LoweredFunc(n);
} }
......
...@@ -13,38 +13,38 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) { ...@@ -13,38 +13,38 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
Stmt s = *ri; Stmt s = *ri;
if (s.as<For>()) { if (s.as<For>()) {
auto n = std::make_shared<For>(*s.as<For>()); auto n = make_node<For>(*s.as<For>());
CHECK(is_no_op(n->body)); CHECK(is_no_op(n->body));
n->body = body; n->body = body;
body = Stmt(n); body = Stmt(n);
} else if (s.as<LetStmt>()) { } else if (s.as<LetStmt>()) {
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>()); auto n = make_node<LetStmt>(*s.as<LetStmt>());
CHECK(is_no_op(n->body)); CHECK(is_no_op(n->body));
n->body = body; n->body = body;
body = Stmt(n); body = Stmt(n);
} else if (s.as<AttrStmt>()) { } else if (s.as<AttrStmt>()) {
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>()); auto n = make_node<AttrStmt>(*s.as<AttrStmt>());
CHECK(is_no_op(n->body)); CHECK(is_no_op(n->body));
n->body = body; n->body = body;
body = Stmt(n); body = Stmt(n);
} else if (s.as<IfThenElse>()) { } else if (s.as<IfThenElse>()) {
auto n = std::make_shared<IfThenElse>(*s.as<IfThenElse>()); auto n = make_node<IfThenElse>(*s.as<IfThenElse>());
CHECK(is_no_op(n->then_case)); CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined()); CHECK(!n->else_case.defined());
n->then_case = body; n->then_case = body;
body = Stmt(n); body = Stmt(n);
} else if (s.as<Block>()) { } else if (s.as<Block>()) {
auto n = std::make_shared<Block>(*s.as<Block>()); auto n = make_node<Block>(*s.as<Block>());
CHECK(is_no_op(n->rest)); CHECK(is_no_op(n->rest));
n->rest = body; n->rest = body;
body = Stmt(n); body = Stmt(n);
} else if (s.as<AssertStmt>()) { } else if (s.as<AssertStmt>()) {
auto n = std::make_shared<AssertStmt>(*s.as<AssertStmt>()); auto n = make_node<AssertStmt>(*s.as<AssertStmt>());
CHECK(is_no_op(n->body)); CHECK(is_no_op(n->body));
n->body = body; n->body = body;
body = Stmt(n); body = Stmt(n);
} else if (s.as<Allocate>()) { } else if (s.as<Allocate>()) {
auto n = std::make_shared<Allocate>(*s.as<Allocate>()); auto n = make_node<Allocate>(*s.as<Allocate>());
CHECK(is_no_op(n->body)); CHECK(is_no_op(n->body));
n->body = body; n->body = body;
body = Stmt(n); body = Stmt(n);
......
...@@ -104,7 +104,7 @@ class IntrinInjecter : public IRMutator { ...@@ -104,7 +104,7 @@ class IntrinInjecter : public IRMutator {
LoweredFunc LoweredFunc
LowerIntrin(LoweredFunc f, const std::string& target) { LowerIntrin(LoweredFunc f, const std::string& target) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = IntrinInjecter(target).Mutate(n->body); n->body = IntrinInjecter(target).Mutate(n->body);
return LoweredFunc(n); return LoweredFunc(n);
} }
......
...@@ -317,7 +317,7 @@ class ThreadAllreduceBuilder final : public IRMutator { ...@@ -317,7 +317,7 @@ class ThreadAllreduceBuilder final : public IRMutator {
LoweredFunc LoweredFunc
LowerThreadAllreduce(LoweredFunc f, int warp_size) { LowerThreadAllreduce(LoweredFunc f, int warp_size) {
CHECK_NE(f->func_type, kHostFunc); CHECK_NE(f->func_type, kHostFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body); n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body);
return LoweredFunc(n); return LoweredFunc(n);
} }
......
...@@ -288,7 +288,7 @@ class BuiltinLower : public IRMutator { ...@@ -288,7 +288,7 @@ class BuiltinLower : public IRMutator {
}; };
LoweredFunc LowerTVMBuiltin(LoweredFunc f) { LoweredFunc LowerTVMBuiltin(LoweredFunc f) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = BuiltinLower().Build(n->body); n->body = BuiltinLower().Build(n->body);
return LoweredFunc(n); return LoweredFunc(n);
} }
......
...@@ -93,7 +93,7 @@ class WarpStoreCoeffFinder : private IRVisitor { ...@@ -93,7 +93,7 @@ class WarpStoreCoeffFinder : private IRVisitor {
arith::DetectLinearEquation(index, {warp_index_}); arith::DetectLinearEquation(index, {warp_index_});
CHECK_EQ(m.size(), 2U) CHECK_EQ(m.size(), 2U)
<< "LowerWarpMemory failed due to store index=" << index; << "LowerWarpMemory failed due to store index=" << index;
int coeff; int coeff = 0;
Expr mcoeff = ir::Simplify(m[0]); Expr mcoeff = ir::Simplify(m[0]);
CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0) CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0)
...@@ -317,7 +317,7 @@ class WarpMemoryRewriter : private IRMutator { ...@@ -317,7 +317,7 @@ class WarpMemoryRewriter : private IRMutator {
LoweredFunc LoweredFunc
LowerWarpMemory(LoweredFunc f, int warp_size) { LowerWarpMemory(LoweredFunc f, int warp_size) {
CHECK_EQ(f->func_type, kDeviceFunc); CHECK_EQ(f->func_type, kDeviceFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = WarpMemoryRewriter(warp_size).Rewrite(n->body); n->body = WarpMemoryRewriter(warp_size).Rewrite(n->body);
return LoweredFunc(n); return LoweredFunc(n);
} }
......
...@@ -132,7 +132,7 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -132,7 +132,7 @@ LoweredFunc MakeAPI(Stmt body,
} }
} }
std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>(); NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>();
n->name = name; n->name = name;
n->args = args; n->args = args;
n->handle_data_type = binder.def_handle_dtype(); n->handle_data_type = binder.def_handle_dtype();
...@@ -197,7 +197,7 @@ class DeviceTypeBinder: public IRMutator { ...@@ -197,7 +197,7 @@ class DeviceTypeBinder: public IRMutator {
LoweredFunc BindDeviceType(LoweredFunc f, LoweredFunc BindDeviceType(LoweredFunc f,
int device_type) { int device_type) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = DeviceTypeBinder(device_type).Mutate(n->body); n->body = DeviceTypeBinder(device_type).Mutate(n->body);
return LoweredFunc(n); return LoweredFunc(n);
} }
......
...@@ -67,7 +67,7 @@ RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) { ...@@ -67,7 +67,7 @@ RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) {
} }
CHECK_EQ(f->func_type, kDeviceFunc); CHECK_EQ(f->func_type, kDeviceFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); auto n = make_node<LoweredFuncNode>(*f.operator->());
// replace the thread axis // replace the thread axis
for (size_t i = 0; i < n->thread_axis.size(); ++i) { for (size_t i = 0; i < n->thread_axis.size(); ++i) {
auto it = tmap.find(n->thread_axis[i]->thread_tag); auto it = tmap.find(n->thread_axis[i]->thread_tag);
......
...@@ -165,8 +165,8 @@ class HostDeviceSplitter : public IRMutator { ...@@ -165,8 +165,8 @@ class HostDeviceSplitter : public IRMutator {
handle_data_type_[kv.first.get()] = kv.second; handle_data_type_[kv.first.get()] = kv.second;
} }
name_ = f->name; name_ = f->name;
std::shared_ptr<LoweredFuncNode> n = NodePtr<LoweredFuncNode> n =
std::make_shared<LoweredFuncNode>(*f.operator->()); make_node<LoweredFuncNode>(*f.operator->());
n->body = this->Mutate(f->body); n->body = this->Mutate(f->body);
n->func_type = kHostFunc; n->func_type = kHostFunc;
Array<LoweredFunc> ret{LoweredFunc(n)}; Array<LoweredFunc> ret{LoweredFunc(n)};
...@@ -180,7 +180,7 @@ class HostDeviceSplitter : public IRMutator { ...@@ -180,7 +180,7 @@ class HostDeviceSplitter : public IRMutator {
Stmt SplitDeviceFunc(Stmt body) { Stmt SplitDeviceFunc(Stmt body) {
std::ostringstream os; std::ostringstream os;
os << name_ << "_kernel" << device_funcs_.size(); os << name_ << "_kernel" << device_funcs_.size();
std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>(); NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>();
// isolate the device function. // isolate the device function.
IRUseDefAnalysis m; IRUseDefAnalysis m;
m.visit_thread_extent_ = false; m.visit_thread_extent_ = false;
......
...@@ -950,8 +950,7 @@ class VectorAllocRewriter : public IRMutator { ...@@ -950,8 +950,7 @@ class VectorAllocRewriter : public IRMutator {
LoweredFunc PointerValueTypeRewrite(LoweredFunc f) { LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
std::shared_ptr<LoweredFuncNode> n = auto n = make_node<LoweredFuncNode>(*f.operator->());
std::make_shared<LoweredFuncNode>(*f.operator->());
VectorAllocRewriter rewriter; VectorAllocRewriter rewriter;
n->body = rewriter.Mutate(n->body); n->body = rewriter.Mutate(n->body);
for (Var arg : f->args) { for (Var arg : f->args) {
......
...@@ -329,7 +329,7 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) { ...@@ -329,7 +329,7 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) { LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
CHECK_NE(f->func_type, kHostFunc); CHECK_NE(f->func_type, kHostFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = ThreadSync(f->body, storage_scope); n->body = ThreadSync(f->body, storage_scope);
return LoweredFunc(n); return LoweredFunc(n);
} }
......
...@@ -12,50 +12,39 @@ namespace relay { ...@@ -12,50 +12,39 @@ namespace relay {
using tvm::IRPrinter; using tvm::IRPrinter;
using namespace tvm::runtime; using namespace tvm::runtime;
SourceName SourceNameNode::make(std::string name) { NodePtr<SourceNameNode> GetSourceNameNode(const std::string& name) {
std::shared_ptr<SourceNameNode> n = std::make_shared<SourceNameNode>(); // always return pointer as the reference can change as map re-allocate.
n->name = std::move(name); // or use another level of indirection by creating a unique_ptr
return SourceName(n); static std::unordered_map<std::string, NodePtr<SourceNameNode> > source_map;
}
std::shared_ptr<SourceNameNode> CreateSourceName(const std::string& name) {
SourceName sn = SourceName::Get(name);
CHECK(!sn.defined()) << "Cannot find source name \'" << name << '\'';
std::shared_ptr<Node> node = sn.node_;
return std::dynamic_pointer_cast<SourceNameNode>(node);
}
const SourceName& SourceName::Get(const std::string& name) {
static std::unordered_map<std::string, SourceName> source_map;
auto sn = source_map.find(name); auto sn = source_map.find(name);
if (sn == source_map.end()) { if (sn == source_map.end()) {
auto source_name = SourceNameNode::make(name); NodePtr<SourceNameNode> n = make_node<SourceNameNode>();
source_map.insert({name, source_name}); n->name = std::move(name);
return source_map.at(name); source_map[name] = n;
return n;
} else { } else {
return sn->second; return sn->second;
} }
} }
TVM_REGISTER_API("relay._make.SourceName") SourceName SourceName::Get(const std::string& name) {
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { return SourceName(GetSourceNameNode(name));
*ret = SourceNameNode::make(args[0]); }
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const SourceNameNode *node, tvm::IRPrinter *p) { .set_dispatch<SourceNameNode>([](const SourceNameNode *node, tvm::IRPrinter *p) {
p->stream << "SourceNameNode(" << node->name << ", " << node << ")"; p->stream << "SourceName(" << node->name << ", " << node << ")";
}); });
TVM_REGISTER_NODE_TYPE(SourceNameNode) TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(CreateSourceName) .set_creator(GetSourceNameNode)
.set_global_key([](const Node* n) { .set_global_key([](const Node* n) {
return static_cast<const SourceNameNode*>(n)->name; return static_cast<const SourceNameNode*>(n)->name;
}); });
Span SpanNode::make(SourceName source, int lineno, int col_offset) { Span SpanNode::make(SourceName source, int lineno, int col_offset) {
std::shared_ptr<SpanNode> n = std::make_shared<SpanNode>(); auto n = make_node<SpanNode>();
n->source = std::move(source); n->source = std::move(source);
n->lineno = lineno; n->lineno = lineno;
n->col_offset = col_offset; n->col_offset = col_offset;
......
...@@ -15,7 +15,7 @@ using tvm::IRPrinter; ...@@ -15,7 +15,7 @@ using tvm::IRPrinter;
using namespace runtime; using namespace runtime;
Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) { Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
std::shared_ptr<EnvironmentNode> n = std::make_shared<EnvironmentNode>(); auto n = make_node<EnvironmentNode>();
n->functions = std::move(global_funcs); n->functions = std::move(global_funcs);
return Environment(n); return Environment(n);
} }
...@@ -31,20 +31,22 @@ GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { ...@@ -31,20 +31,22 @@ GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) {
} }
} }
/*! \brief Add a new item to the global environment /*!
* \brief Add a new item to the global environment
* \note if the update flag is not set adding a duplicate * \note if the update flag is not set adding a duplicate
* definition will trigger an exception, otherwise we will * definition will trigger an exception, otherwise we will
* update the definition if and only if it is type compatible. * update the definition if and only if it is type compatible.
*/ */
void EnvironmentNode::Add(const GlobalVar &var, const Function &func, void EnvironmentNode::Add(const GlobalVar &var,
const Function &func,
bool update) { bool update) {
// Type check the item before we add it to the environment. // Type check the item before we add it to the environment.
auto env = GetRef<Environment>(this); auto env = relay::GetRef<Environment>(this);
Expr checked_expr = InferType(env, var, func); Expr checked_expr = InferType(env, var, func);
if (const FunctionNode *func_node = checked_expr.as<FunctionNode>()) { if (const FunctionNode *func_node = checked_expr.as<FunctionNode>()) {
auto checked_func = GetRef<Function>(func_node); auto checked_func = relay::GetRef<Function>(func_node);
auto type = checked_func->checked_type(); auto type = checked_func->checked_type();
CHECK(IsFullyResolved(type)); CHECK(IsFullyResolved(type));
...@@ -100,46 +102,46 @@ void EnvironmentNode::Merge(const Environment &env) { ...@@ -100,46 +102,46 @@ void EnvironmentNode::Merge(const Environment &env) {
} }
TVM_REGISTER_API("relay._make.Environment") TVM_REGISTER_API("relay._make.Environment")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EnvironmentNode::make(args[0]); *ret = EnvironmentNode::make(args[0]);
}); });
TVM_REGISTER_API("relay._env.Environment_Add") TVM_REGISTER_API("relay._env.Environment_Add")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0]; Environment env = args[0];
env->Add(args[1], args[2], false); env->Add(args[1], args[2], false);
}); });
TVM_REGISTER_API("relay._env.Environment_GetGlobalVar") TVM_REGISTER_API("relay._env.Environment_GetGlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0]; Environment env = args[0];
*ret = env->GetGlobalVar(args[1]); *ret = env->GetGlobalVar(args[1]);
}); });
TVM_REGISTER_API("relay._env.Environment_Lookup") TVM_REGISTER_API("relay._env.Environment_Lookup")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0]; Environment env = args[0];
GlobalVar var = args[1]; GlobalVar var = args[1];
*ret = env->Lookup(var); *ret = env->Lookup(var);
}); });
TVM_REGISTER_API("relay._env.Environment_Lookup_str") TVM_REGISTER_API("relay._env.Environment_Lookup_str")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0]; Environment env = args[0];
std::string var_name = args[1]; std::string var_name = args[1];
auto var = env->GetGlobalVar(var_name); auto var = env->GetGlobalVar(var_name);
*ret = env->Lookup(var); *ret = env->Lookup(var);
}); });
TVM_REGISTER_API("relay._env.Environment_Merge") TVM_REGISTER_API("relay._env.Environment_Merge")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0]; Environment env = args[0];
env->Merge(args[1]); env->Merge(args[1]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<EnvironmentNode>([](const EnvironmentNode *node, .set_dispatch<EnvironmentNode>(
tvm::IRPrinter *p) { [](const EnvironmentNode *node, tvm::IRPrinter *p) {
p->stream << "EnvironmentNode( " << node->functions << ")"; p->stream << "EnvironmentNode( " << node->functions << ")";
}); });
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
* \file src/tvm/ir/expr.cc * \file src/tvm/ir/expr.cc
* \brief The expression AST nodes of Relay. * \brief The expression AST nodes of Relay.
*/ */
#include <tvm/ir_functor.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
namespace tvm { namespace tvm {
...@@ -13,21 +12,20 @@ using tvm::IRPrinter; ...@@ -13,21 +12,20 @@ using tvm::IRPrinter;
using namespace tvm::runtime; using namespace tvm::runtime;
Constant ConstantNode::make(runtime::NDArray data) { Constant ConstantNode::make(runtime::NDArray data) {
std::shared_ptr<ConstantNode> n = std::make_shared<ConstantNode>(); NodePtr<ConstantNode> n = make_node<ConstantNode>();
n->data = std::move(data); n->data = std::move(data);
return Constant(n); return Constant(n);
} }
TVM_REGISTER_API("relay._make.Constant") TVM_REGISTER_API("relay._make.Constant")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ConstantNode::make(args[0]); *ret = ConstantNode::make(args[0]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstantNode>([](const ConstantNode *node, .set_dispatch<ConstantNode>([](const ConstantNode *node, tvm::IRPrinter *p) {
tvm::IRPrinter *p) { p->stream << "Constant(TODO)";
p->stream << "ConstantNode(TODO)"; });
});
TensorType ConstantNode::tensor_type() const { TensorType ConstantNode::tensor_type() const {
auto dtype = TVMType2Type(data->dtype); auto dtype = TVMType2Type(data->dtype);
...@@ -41,57 +39,55 @@ TensorType ConstantNode::tensor_type() const { ...@@ -41,57 +39,55 @@ TensorType ConstantNode::tensor_type() const {
} }
Tuple TupleNode::make(tvm::Array<relay::Expr> fields) { Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
std::shared_ptr<TupleNode> n = std::make_shared<TupleNode>(); NodePtr<TupleNode> n = make_node<TupleNode>();
n->fields = std::move(fields); n->fields = std::move(fields);
return Tuple(n); return Tuple(n);
} }
TVM_REGISTER_API("relay._make.Tuple") TVM_REGISTER_API("relay._make.Tuple")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TupleNode::make(args[0]); *ret = TupleNode::make(args[0]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleNode>([](const TupleNode *node, tvm::IRPrinter *p) { .set_dispatch<TupleNode>([](const TupleNode *node, tvm::IRPrinter *p) {
p->stream << "TupleNode(" << node->fields << ")"; p->stream << "Tuple(" << node->fields << ")";
}); });
Var VarNode::make(std::string name_hint) { Var VarNode::make(std::string name_hint) {
std::shared_ptr<VarNode> n = std::make_shared<VarNode>(); NodePtr<VarNode> n = make_node<VarNode>();
n->name_hint = std::move(name_hint); n->name_hint = std::move(name_hint);
return Var(n); return Var(n);
} }
TVM_REGISTER_API("relay._make.Var") TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = VarNode::make(args[0]); *ret = VarNode::make(args[0]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode *node, .set_dispatch<VarNode>([](const VarNode *node, tvm::IRPrinter *p) {
tvm::IRPrinter *p) { p->stream << "Var(" << node->name_hint << ")";
p->stream << "VarNode(" << node->name_hint << ")"; });
});
GlobalVar GlobalVarNode::make(std::string name_hint) { GlobalVar GlobalVarNode::make(std::string name_hint) {
std::shared_ptr<GlobalVarNode> n = std::make_shared<GlobalVarNode>(); NodePtr<GlobalVarNode> n = make_node<GlobalVarNode>();
n->name_hint = std::move(name_hint); n->name_hint = std::move(name_hint);
return GlobalVar(n); return GlobalVar(n);
} }
TVM_REGISTER_API("relay._make.GlobalVar") TVM_REGISTER_API("relay._make.GlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = GlobalVarNode::make(args[0]); *ret = GlobalVarNode::make(args[0]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const GlobalVarNode *node, .set_dispatch<GlobalVarNode>([](const GlobalVarNode *node, tvm::IRPrinter *p) {
tvm::IRPrinter *p) { p->stream << "GlobalVar(" << node->name_hint << ")";
p->stream << "GlobalVarNode(" << node->name_hint << ")"; });
});
Param ParamNode::make(Var var, Type type) { Param ParamNode::make(Var var, Type type) {
std::shared_ptr<ParamNode> n = std::make_shared<ParamNode>(); NodePtr<ParamNode> n = make_node<ParamNode>();
n->var = std::move(var); n->var = std::move(var);
n->type = std::move(type); n->type = std::move(type);
return Param(n); return Param(n);
...@@ -104,12 +100,12 @@ TVM_REGISTER_API("relay._make.Param") ...@@ -104,12 +100,12 @@ TVM_REGISTER_API("relay._make.Param")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ParamNode>([](const ParamNode *node, tvm::IRPrinter *p) { .set_dispatch<ParamNode>([](const ParamNode *node, tvm::IRPrinter *p) {
p->stream << "ParamNode(" << node->var << ", " << node->type << ")"; p->stream << "Param(" << node->var << ", " << node->type << ")";
}); });
Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body, Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body,
tvm::Array<TypeParam> type_params) { tvm::Array<TypeParam> type_params) {
std::shared_ptr<FunctionNode> n = std::make_shared<FunctionNode>(); NodePtr<FunctionNode> n = make_node<FunctionNode>();
n->params = std::move(params); n->params = std::move(params);
n->ret_type = std::move(ret_type); n->ret_type = std::move(ret_type);
n->body = std::move(body); n->body = std::move(body);
...@@ -140,7 +136,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -140,7 +136,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs, Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
Array<Type> type_args) { Array<Type> type_args) {
std::shared_ptr<CallNode> n = std::make_shared<CallNode>(); NodePtr<CallNode> n = make_node<CallNode>();
n->op = std::move(op); n->op = std::move(op);
n->args = std::move(args); n->args = std::move(args);
n->attrs = std::move(attrs); n->attrs = std::move(attrs);
...@@ -160,7 +156,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -160,7 +156,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
}); });
Let LetNode::make(Var var, Expr value, Expr body, Type value_type) { Let LetNode::make(Var var, Expr value, Expr body, Type value_type) {
std::shared_ptr<LetNode> n = std::make_shared<LetNode>(); NodePtr<LetNode> n = make_node<LetNode>();
n->var = std::move(var); n->var = std::move(var);
n->value = std::move(value); n->value = std::move(value);
n->body = std::move(body); n->body = std::move(body);
...@@ -180,7 +176,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -180,7 +176,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
}); });
If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
std::shared_ptr<IfNode> n = std::make_shared<IfNode>(); NodePtr<IfNode> n = make_node<IfNode>();
n->cond = std::move(cond); n->cond = std::move(cond);
n->true_branch = std::move(true_branch); n->true_branch = std::move(true_branch);
n->false_branch = std::move(false_branch); n->false_branch = std::move(false_branch);
......
...@@ -51,7 +51,7 @@ const Op& Op::Get(const std::string& name) { ...@@ -51,7 +51,7 @@ const Op& Op::Get(const std::string& name) {
OpRegistry::OpRegistry() { OpRegistry::OpRegistry() {
OpManager* mgr = OpManager::Global(); OpManager* mgr = OpManager::Global();
std::shared_ptr<OpNode> n = std::make_shared<OpNode>(); NodePtr<OpNode> n = make_node<OpNode>();
n->index_ = mgr->op_counter++; n->index_ = mgr->op_counter++;
op_ = Op(n); op_ = Op(n);
} }
...@@ -90,14 +90,14 @@ void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, ...@@ -90,14 +90,14 @@ void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value,
// Frontend APIs // Frontend APIs
TVM_REGISTER_API("relay.op._ListOpNames") TVM_REGISTER_API("relay.op._ListOpNames")
.set_body_typed<Array<tvm::Expr>()>([]() { .set_body_typed<Array<tvm::Expr>()>([]() {
Array<tvm::Expr> ret; Array<tvm::Expr> ret;
for (const std::string& name : for (const std::string& name :
dmlc::Registry<OpRegistry>::ListAllNames()) { dmlc::Registry<OpRegistry>::ListAllNames()) {
ret.push_back(tvm::Expr(name)); ret.push_back(tvm::Expr(name));
} }
return ret; return ret;
}); });
TVM_REGISTER_API("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get); TVM_REGISTER_API("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get);
...@@ -138,11 +138,10 @@ TVM_REGISTER_API("relay.op._Register") ...@@ -138,11 +138,10 @@ TVM_REGISTER_API("relay.op._Register")
} }
}); });
std::shared_ptr<OpNode> CreateOp(const std::string& name) { NodePtr<Node> CreateOp(const std::string& name) {
auto op = Op::Get(name); auto op = Op::Get(name);
CHECK(!op.defined()) << "Cannot find op \'" << name << '\''; CHECK(!op.defined()) << "Cannot find op \'" << name << '\'';
std::shared_ptr<Node> node = op.node_; return op.node_;
return std::dynamic_pointer_cast<OpNode>(node);
} }
TVM_REGISTER_NODE_TYPE(OpNode) TVM_REGISTER_NODE_TYPE(OpNode)
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
* \file src/tvm/ir/type.cc * \file src/tvm/ir/type.cc
* \brief The type system AST nodes of Relay. * \brief The type system AST nodes of Relay.
*/ */
#include <tvm/ir_functor.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
namespace tvm { namespace tvm {
...@@ -13,7 +12,7 @@ using tvm::IRPrinter; ...@@ -13,7 +12,7 @@ using tvm::IRPrinter;
using namespace tvm::runtime; using namespace tvm::runtime;
TensorType TensorTypeNode::make(Array<ShapeExpr> shape, DataType dtype) { TensorType TensorTypeNode::make(Array<ShapeExpr> shape, DataType dtype) {
std::shared_ptr<TensorTypeNode> n = std::make_shared<TensorTypeNode>(); NodePtr<TensorTypeNode> n = make_node<TensorTypeNode>();
n->shape = std::move(shape); n->shape = std::move(shape);
n->dtype = std::move(dtype); n->dtype = std::move(dtype);
return TensorType(n); return TensorType(n);
...@@ -36,7 +35,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -36,7 +35,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
}); });
TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) { TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) {
std::shared_ptr<TypeParamNode> n = std::make_shared<TypeParamNode>(); NodePtr<TypeParamNode> n = make_node<TypeParamNode>();
n->var = tvm::Var(name); n->var = tvm::Var(name);
n->kind = std::move(kind); n->kind = std::move(kind);
return TypeParam(n); return TypeParam(n);
...@@ -59,7 +58,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -59,7 +58,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
FuncType FuncTypeNode::make(tvm::Array<Type> arg_types, Type ret_type, FuncType FuncTypeNode::make(tvm::Array<Type> arg_types, Type ret_type,
tvm::Array<TypeParam> type_params, tvm::Array<TypeParam> type_params,
tvm::Array<TypeConstraint> type_constraints) { tvm::Array<TypeConstraint> type_constraints) {
std::shared_ptr<FuncTypeNode> n = std::make_shared<FuncTypeNode>(); NodePtr<FuncTypeNode> n = make_node<FuncTypeNode>();
n->arg_types = std::move(arg_types); n->arg_types = std::move(arg_types);
n->ret_type = std::move(ret_type); n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params); n->type_params = std::move(type_params);
...@@ -81,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -81,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
}); });
TypeRelation TypeRelationNode::make(std::string name, TypeRelationFn func, Array<Type> args) { TypeRelation TypeRelationNode::make(std::string name, TypeRelationFn func, Array<Type> args) {
std::shared_ptr<TypeRelationNode> n = std::make_shared<TypeRelationNode>(); NodePtr<TypeRelationNode> n = make_node<TypeRelationNode>();
n->name = std::move(name); n->name = std::move(name);
n->func_ = std::move(func); n->func_ = std::move(func);
n->args = std::move(args); n->args = std::move(args);
...@@ -101,7 +100,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -101,7 +100,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
}); });
TupleType TupleTypeNode::make(Array<Type> fields) { TupleType TupleTypeNode::make(Array<Type> fields) {
std::shared_ptr<TupleTypeNode> n = std::make_shared<TupleTypeNode>(); NodePtr<TupleTypeNode> n = make_node<TupleTypeNode>();
n->fields = std::move(fields); n->fields = std::move(fields);
return TupleType(n); return TupleType(n);
} }
......
...@@ -10,10 +10,9 @@ ...@@ -10,10 +10,9 @@
* *
* For example tensors are not allowed to contain functions in Relay. * For example tensors are not allowed to contain functions in Relay.
* *
* We check this by ensuring the `dtype` field of a Tensor always * We check this by ensuring the `dtype` field of a Tensor always
* contains a data type such as `int`, `float`, `uint`. * contains a data type such as `int`, `float`, `uint`.
*/ */
#include <tvm/ir_functor.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include "./type_visitor.h" #include "./type_visitor.h"
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#ifndef TVM_RELAY_PASS_TYPE_FUNCTOR_H_ #ifndef TVM_RELAY_PASS_TYPE_FUNCTOR_H_
#define TVM_RELAY_PASS_TYPE_FUNCTOR_H_ #define TVM_RELAY_PASS_TYPE_FUNCTOR_H_
#include <tvm/ir_functor.h> #include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include "./incomplete_type.h" #include "./incomplete_type.h"
......
...@@ -137,7 +137,7 @@ class TypeInferencer : private ExprFunctor<CheckedExpr(const Expr&)> { ...@@ -137,7 +137,7 @@ class TypeInferencer : private ExprFunctor<CheckedExpr(const Expr&)> {
void Solve(TypeRelationData& ty_rel); void Solve(TypeRelationData& ty_rel);
/*! \brief Attempt to solve all pending relations. /*! \brief Attempt to solve all pending relations.
* *
* If the solver * If the solver
*/ */
SolverResult Solve(std::vector<TypeRelationData>& rels); SolverResult Solve(std::vector<TypeRelationData>& rels);
...@@ -607,8 +607,7 @@ TVM_REGISTER_API("relay._ir_pass._get_checked_type") ...@@ -607,8 +607,7 @@ TVM_REGISTER_API("relay._ir_pass._get_checked_type")
/* Incomplete Type */ /* Incomplete Type */
IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) {
std::shared_ptr<IncompleteTypeNode> n = auto n = make_node<IncompleteTypeNode>();
std::make_shared<IncompleteTypeNode>();
n->kind = std::move(kind); n->kind = std::move(kind);
return IncompleteType(n); return IncompleteType(n);
} }
......
...@@ -21,7 +21,7 @@ using tvm::IRPrinter; ...@@ -21,7 +21,7 @@ using tvm::IRPrinter;
using namespace tvm::runtime; using namespace tvm::runtime;
UnionFind UnionFindNode::make(tvm::Map<IncompleteType, Type> uf_map) { UnionFind UnionFindNode::make(tvm::Map<IncompleteType, Type> uf_map) {
std::shared_ptr<UnionFindNode> n = std::make_shared<UnionFindNode>(); auto n = make_node<UnionFindNode>();
n->uf_map = uf_map; n->uf_map = uf_map;
return UnionFind(n); return UnionFind(n);
} }
...@@ -130,7 +130,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -130,7 +130,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
}); });
TypeUnifier TypeUnifierNode::make(UnionFind union_find) { TypeUnifier TypeUnifierNode::make(UnionFind union_find) {
std::shared_ptr<TypeUnifierNode> n = std::make_shared<TypeUnifierNode>(); auto n = make_node<TypeUnifierNode>();
n->union_find = union_find; n->union_find = union_find;
return TypeUnifier(n); return TypeUnifier(n);
} }
......
...@@ -67,7 +67,7 @@ class UnionFindNode : public Node { ...@@ -67,7 +67,7 @@ class UnionFindNode : public Node {
class UnionFind : public NodeRef { class UnionFind : public NodeRef {
public: public:
UnionFind() {} UnionFind() {}
explicit UnionFind(std::shared_ptr<tvm::Node> p) : NodeRef(p) {} explicit UnionFind(NodePtr<tvm::Node> p) : NodeRef(p) {}
// The union find structure is mutable so we do not use the standard macros // The union find structure is mutable so we do not use the standard macros
// and expose the pointer via `->`. // and expose the pointer via `->`.
...@@ -126,7 +126,7 @@ class TypeUnifierNode : public Node, ...@@ -126,7 +126,7 @@ class TypeUnifierNode : public Node,
class TypeUnifier : public NodeRef { class TypeUnifier : public NodeRef {
public: public:
TypeUnifier() {} TypeUnifier() {}
explicit TypeUnifier(std::shared_ptr<tvm::Node> p) : NodeRef(p) {} explicit TypeUnifier(NodePtr<tvm::Node> p) : NodeRef(p) {}
// no const so that unifier can be mutable as a member of typechecker // no const so that unifier can be mutable as a member of typechecker
inline TypeUnifierNode* operator->() const { inline TypeUnifierNode* operator->() const {
......
...@@ -46,7 +46,7 @@ Expr InjectPredicate(const Array<Expr>& predicates, ...@@ -46,7 +46,7 @@ Expr InjectPredicate(const Array<Expr>& predicates,
if (predicates.size() == 0) return body; if (predicates.size() == 0) return body;
const Reduce* reduce = body.as<Reduce>(); const Reduce* reduce = body.as<Reduce>();
if (reduce) { if (reduce) {
std::shared_ptr<Reduce> n = std::make_shared<Reduce>(*reduce); auto n = make_node<Reduce>(*reduce);
n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates, Expr()); n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates, Expr());
return Expr(n); return Expr(n);
} }
...@@ -400,7 +400,7 @@ void InjectInline(ScheduleNode* sch) { ...@@ -400,7 +400,7 @@ void InjectInline(ScheduleNode* sch) {
CHECK_EQ(new_body[j].size(), r->source.size()); CHECK_EQ(new_body[j].size(), r->source.size());
CHECK(r != nullptr); CHECK(r != nullptr);
for (size_t k = 0; k < new_body[j].size(); ++k) { for (size_t k = 0; k < new_body[j].size(); ++k) {
std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r); auto n = make_node<ir::Reduce>(*r);
n->value_index = static_cast<int>(k); n->value_index = static_cast<int>(k);
n->type = r->source[k].type(); n->type = r->source[k].type();
new_body[j].Set(k, Expr(n)); new_body[j].Set(k, Expr(n));
...@@ -520,11 +520,11 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, ...@@ -520,11 +520,11 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
const int factor_axis_pos = \ const int factor_axis_pos = \
factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis; factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
CHECK_LE(factor_axis_pos, compute_op->axis.size()); CHECK_LE(factor_axis_pos, compute_op->axis.size());
auto n = std::make_shared<ComputeOpNode>(); auto n = make_node<ComputeOpNode>();
n->name = compute_op->name + ".rf"; n->name = compute_op->name + ".rf";
{ {
// axis relacement. // axis relacement.
auto iv_node = std::make_shared<IterVarNode>(); auto iv_node = make_node<IterVarNode>();
iv_node->dom = dom_map.at(axis); iv_node->dom = dom_map.at(axis);
CHECK(is_zero(iv_node->dom->min)) CHECK(is_zero(iv_node->dom->min))
<< "Can only factor reduction domain starting from 0"; << "Can only factor reduction domain starting from 0";
...@@ -565,7 +565,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, ...@@ -565,7 +565,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
for (IterVar iv : reduce_stage->leaf_iter_vars) { for (IterVar iv : reduce_stage->leaf_iter_vars) {
if (touch_map.count(iv) && !iv.same_as(axis)) { if (touch_map.count(iv) && !iv.same_as(axis)) {
CHECK_EQ(iv->iter_type, kCommReduce); CHECK_EQ(iv->iter_type, kCommReduce);
auto ncpy = std::make_shared<IterVarNode>(*iv.operator->()); auto ncpy = make_node<IterVarNode>(*iv.operator->());
ncpy->dom = dom_map.at(iv); ncpy->dom = dom_map.at(iv);
n->reduce_axis.push_back(IterVar(ncpy)); n->reduce_axis.push_back(IterVar(ncpy));
} }
......
...@@ -70,7 +70,7 @@ void Split(StageNode* self, ...@@ -70,7 +70,7 @@ void Split(StageNode* self,
} // namespace } // namespace
Stage::Stage(Operation op) { Stage::Stage(Operation op) {
auto n = std::make_shared<StageNode>(); auto n = make_node<StageNode>();
n->op = op; n->op = op;
n->origin_op = op; n->origin_op = op;
n->all_iter_vars = op->root_iter_vars(); n->all_iter_vars = op->root_iter_vars();
...@@ -164,16 +164,16 @@ Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) ...@@ -164,16 +164,16 @@ Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*)
FindLeafVar(all_vars, leaf_vars, ivar); FindLeafVar(all_vars, leaf_vars, ivar);
auto it = self->iter_var_attrs.find(ivar); auto it = self->iter_var_attrs.find(ivar);
std::shared_ptr<IterVarAttrNode> n; NodePtr<IterVarAttrNode> n;
if (it != self->iter_var_attrs.end()) { if (it != self->iter_var_attrs.end()) {
n = std::make_shared<IterVarAttrNode>(*(*it).second.operator->()); n = make_node<IterVarAttrNode>(*(*it).second.operator->());
if (n->bind_thread.defined() && if (n->bind_thread.defined() &&
!n->bind_thread.same_as(thread_ivar)) { !n->bind_thread.same_as(thread_ivar)) {
LOG(WARNING) << "Axis " << ivar LOG(WARNING) << "Axis " << ivar
<< " is already bind to another thread " << n->bind_thread; << " is already bind to another thread " << n->bind_thread;
} }
} else { } else {
n = std::make_shared<IterVarAttrNode>(); n = make_node<IterVarAttrNode>();
} }
n->bind_thread = thread_ivar; n->bind_thread = thread_ivar;
self->iter_var_attrs.Set(ivar, IterVarAttr(n)); self->iter_var_attrs.Set(ivar, IterVarAttr(n));
...@@ -188,7 +188,7 @@ Stage& Stage::env_threads(Array<IterVar> threads) { ...@@ -188,7 +188,7 @@ Stage& Stage::env_threads(Array<IterVar> threads) {
<< "Already set env_threads"; << "Already set env_threads";
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
std::vector<std::shared_ptr<Node> > temp; std::vector<NodePtr<Node> > temp;
for (IterVar iv : threads) { for (IterVar iv : threads) {
temp.push_back(iv.node_); temp.push_back(iv.node_);
} }
...@@ -303,7 +303,7 @@ Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*) ...@@ -303,7 +303,7 @@ Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
for (size_t i = 0; i < order.size(); ++i) { for (size_t i = 0; i < order.size(); ++i) {
pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i])); pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
} }
std::vector<std::shared_ptr<Node> > temp; std::vector<NodePtr<Node> > temp;
for (size_t i = 0; i < pos.size(); ++i) { for (size_t i = 0; i < pos.size(); ++i) {
temp.emplace_back(leaf_vars->data[pos[i]]); temp.emplace_back(leaf_vars->data[pos[i]]);
} }
...@@ -335,11 +335,11 @@ inline void UpdateIterVarAttr(StageNode* self, ...@@ -335,11 +335,11 @@ inline void UpdateIterVarAttr(StageNode* self,
FindLeafVar(all_vars, leaf_vars, var); FindLeafVar(all_vars, leaf_vars, var);
} }
auto it = self->iter_var_attrs.find(var); auto it = self->iter_var_attrs.find(var);
std::shared_ptr<IterVarAttrNode> n; NodePtr<IterVarAttrNode> n;
if (it != self->iter_var_attrs.end()) { if (it != self->iter_var_attrs.end()) {
n = std::make_shared<IterVarAttrNode>(*(*it).second.operator->()); n = make_node<IterVarAttrNode>(*(*it).second.operator->());
} else { } else {
n = std::make_shared<IterVarAttrNode>(); n = make_node<IterVarAttrNode>();
} }
fupdate(n.get()); fupdate(n.get());
self->iter_var_attrs.Set(var, IterVarAttr(n)); self->iter_var_attrs.Set(var, IterVarAttr(n));
...@@ -397,11 +397,11 @@ Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) { ...@@ -397,11 +397,11 @@ Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) {
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var); FindLeafVar(all_vars, leaf_vars, var);
auto it = self->iter_var_attrs.find(var); auto it = self->iter_var_attrs.find(var);
std::shared_ptr<IterVarAttrNode> n; NodePtr<IterVarAttrNode> n;
if (it != self->iter_var_attrs.end()) { if (it != self->iter_var_attrs.end()) {
n = std::make_shared<IterVarAttrNode>(*(*it).second.operator->()); n = make_node<IterVarAttrNode>(*(*it).second.operator->());
} else { } else {
n = std::make_shared<IterVarAttrNode>(); n = make_node<IterVarAttrNode>();
} }
n->prefetch_data.push_back(tensor); n->prefetch_data.push_back(tensor);
n->prefetch_offset.push_back(offset); n->prefetch_offset.push_back(offset);
...@@ -468,8 +468,8 @@ Stage& Stage::opengl() { ...@@ -468,8 +468,8 @@ Stage& Stage::opengl() {
} }
Stage CopyStage(const Stage& s) { Stage CopyStage(const Stage& s) {
std::shared_ptr<StageNode> n = NodePtr<StageNode> n =
std::make_shared<StageNode>(*s.operator->()); make_node<StageNode>(*s.operator->());
return Stage(n); return Stage(n);
} }
...@@ -477,7 +477,7 @@ Schedule Schedule::copy() const { ...@@ -477,7 +477,7 @@ Schedule Schedule::copy() const {
// map of stages. // map of stages.
const ScheduleNode* self = operator->(); const ScheduleNode* self = operator->();
std::unordered_map<Stage, Stage, NodeHash, NodeEqual> smap; std::unordered_map<Stage, Stage, NodeHash, NodeEqual> smap;
std::shared_ptr<ScheduleNode> n = std::make_shared<ScheduleNode>(); NodePtr<ScheduleNode> n = make_node<ScheduleNode>();
n->outputs = self->outputs; n->outputs = self->outputs;
// Copy the stages. // Copy the stages.
for (Stage s : self->stages) { for (Stage s : self->stages) {
...@@ -599,7 +599,7 @@ Stage Schedule::create_group(const Array<Tensor>& outputs, ...@@ -599,7 +599,7 @@ Stage Schedule::create_group(const Array<Tensor>& outputs,
} }
} }
// Create the new group stage. // Create the new group stage.
Stage gstage(std::make_shared<StageNode>()); Stage gstage(make_node<StageNode>());
gstage->group = parent_group; gstage->group = parent_group;
if (parent_group.defined()) { if (parent_group.defined()) {
++parent_group->num_child_stages; ++parent_group->num_child_stages;
...@@ -687,7 +687,7 @@ void ScheduleNode::InitCache() { ...@@ -687,7 +687,7 @@ void ScheduleNode::InitCache() {
} }
Schedule ScheduleNode::make(Array<Operation> ops) { Schedule ScheduleNode::make(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>(); auto n = make_node<ScheduleNode>();
Schedule sch(n); Schedule sch(n);
n->outputs = ops; n->outputs = ops;
auto g = schedule::CreateReadGraph(n->outputs); auto g = schedule::CreateReadGraph(n->outputs);
...@@ -731,7 +731,7 @@ IterVarRelation SplitNode::make(IterVar parent, ...@@ -731,7 +731,7 @@ IterVarRelation SplitNode::make(IterVar parent,
IterVar inner, IterVar inner,
Expr factor, Expr factor,
Expr nparts) { Expr nparts) {
auto n = std::make_shared<SplitNode>(); auto n = make_node<SplitNode>();
n->parent = parent; n->parent = parent;
n->outer = outer; n->outer = outer;
n->inner = inner; n->inner = inner;
...@@ -742,7 +742,7 @@ IterVarRelation SplitNode::make(IterVar parent, ...@@ -742,7 +742,7 @@ IterVarRelation SplitNode::make(IterVar parent,
IterVarRelation FuseNode::make( IterVarRelation FuseNode::make(
IterVar outer, IterVar inner, IterVar fused) { IterVar outer, IterVar inner, IterVar fused) {
auto n = std::make_shared<FuseNode>(); auto n = make_node<FuseNode>();
n->outer = outer; n->outer = outer;
n->inner = inner; n->inner = inner;
n->fused = fused; n->fused = fused;
...@@ -750,14 +750,14 @@ IterVarRelation FuseNode::make( ...@@ -750,14 +750,14 @@ IterVarRelation FuseNode::make(
} }
IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) { IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
auto n = std::make_shared<RebaseNode>(); auto n = make_node<RebaseNode>();
n->parent = parent; n->parent = parent;
n->rebased = rebased; n->rebased = rebased;
return IterVarRelation(n); return IterVarRelation(n);
} }
IterVarRelation SingletonNode::make(IterVar iter) { IterVarRelation SingletonNode::make(IterVar iter) {
auto n = std::make_shared<SingletonNode>(); auto n = make_node<SingletonNode>();
n->iter = iter; n->iter = iter;
return IterVarRelation(n); return IterVarRelation(n);
} }
......
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/tvm.h> #include <tvm/tvm.h>
#include <tvm/ir_functor.h> #include <tvm/node/ir_functor.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
TEST(IRF, Basic) { TEST(IRF, Basic) {
......
...@@ -10,6 +10,7 @@ make cython3 || exit -1 ...@@ -10,6 +10,7 @@ make cython3 || exit -1
# Test extern package package # Test extern package package
cd apps/extension cd apps/extension
rm -rf lib
make || exit -1 make || exit -1
cd ../.. cd ../..
python -m nose -v apps/extension/tests || exit -1 python -m nose -v apps/extension/tests || exit -1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment