Unverified Commit ae5a28db by Tianqi Chen Committed by GitHub

[NODE] Node base system refactor (#1739)

parent 1e57ee6c
Subproject commit f519848d972c67971b4cbf8c34070d5a5e3ede0d
Subproject commit cf6090aeaeb782d1daff54b0ca5c2c281d7008db
......@@ -57,7 +57,7 @@ class EnvFuncNode : public Node {
class EnvFunc : public NodeRef {
public:
EnvFunc() {}
explicit EnvFunc(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit EnvFunc(NodePtr<Node> n) : NodeRef(n) {}
/*! \return The internal global function pointer */
const EnvFuncNode* operator->() const {
return static_cast<EnvFuncNode*>(node_.get());
......@@ -105,7 +105,7 @@ class TypedEnvFunc<R(Args...)> : public NodeRef {
/*! \brief short hand for this function type */
using TSelf = TypedEnvFunc<R(Args...)>;
TypedEnvFunc() {}
explicit TypedEnvFunc(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit TypedEnvFunc(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief Assign global function to a TypedEnvFunc
* \param other Another global function.
......
......@@ -38,7 +38,7 @@ class IntSet : public NodeRef {
/*! \brief constructor */
IntSet() {}
// constructor from not container.
explicit IntSet(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit IntSet(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......
......@@ -136,7 +136,7 @@ class Attrs : public NodeRef {
// normal constructor
Attrs() {}
// construct from shared ptr.
explicit Attrs(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit Attrs(NodePtr<Node> n) : NodeRef(n) {}
/*! \return The attribute node */
const BaseAttrsNode* operator->() const {
......@@ -442,7 +442,7 @@ class AttrDocEntry {
public:
using TSelf = AttrDocEntry;
explicit AttrDocEntry(std::shared_ptr<AttrFieldInfoNode> info)
explicit AttrDocEntry(NodePtr<AttrFieldInfoNode> info)
: info_(info) {
}
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
......@@ -466,15 +466,15 @@ class AttrDocEntry {
}
private:
std::shared_ptr<AttrFieldInfoNode> info_;
NodePtr<AttrFieldInfoNode> info_;
};
class AttrDocVisitor {
public:
template<typename T>
AttrDocEntry operator()(const char* key, T* v) {
std::shared_ptr<AttrFieldInfoNode> info
= std::make_shared<AttrFieldInfoNode>();
NodePtr<AttrFieldInfoNode> info
= make_node<AttrFieldInfoNode>();
info->name = key;
info->type_info = TypeName<T>::value;
fields_.push_back(AttrFieldInfo(info));
......
......@@ -8,7 +8,7 @@
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <tvm/node.h>
#include <tvm/node/node.h>
#include <string>
#include <memory>
#include <functional>
......@@ -25,7 +25,7 @@ using ::tvm::AttrVisitor;
class TypeName : public ::tvm::NodeRef { \
public: \
TypeName() {} \
explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
......@@ -48,7 +48,7 @@ std::string SaveJSON(const NodeRef& node);
*
* \return The shared_ptr of the Node.
*/
std::shared_ptr<Node> LoadJSON_(std::string json_str);
NodePtr<Node> LoadJSON_(std::string json_str);
/*!
* \brief Load the node from json string.
......@@ -85,7 +85,7 @@ struct NodeFactoryReg {
* If this is not empty then FGlobalKey
* \return The created function.
*/
using FCreate = std::function<std::shared_ptr<Node>(const std::string& global_key)>;
using FCreate = std::function<NodePtr<Node>(const std::string& global_key)>;
/*!
* \brief Global key function, only needed by global objects.
* \param node The node pointer.
......@@ -123,7 +123,7 @@ struct NodeFactoryReg {
#define TVM_REGISTER_NODE_TYPE(TypeName) \
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \
.set_creator([](const std::string&) { return std::make_shared<TypeName>(); })
.set_creator([](const std::string&) { return ::tvm::make_node<TypeName>(); })
#define TVM_STRINGIZE_DETAIL(x) #x
......
......@@ -6,11 +6,11 @@
#ifndef TVM_BUFFER_H_
#define TVM_BUFFER_H_
#include <tvm/container.h>
#include <string>
#include "base.h"
#include "expr.h"
#include "node/container.h"
namespace tvm {
......@@ -31,7 +31,7 @@ enum class AccessMask : int {
class Buffer : public NodeRef {
public:
Buffer() {}
explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit Buffer(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief Return a new buffer that is equivalent with current one
* but always add stride field.
......
......@@ -69,7 +69,7 @@ class TargetNode : public Node {
class Target : public NodeRef {
public:
Target() {}
explicit Target(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit Target(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief Create a Target given a string
......@@ -241,7 +241,7 @@ class BuildConfigNode : public Node {
class BuildConfig : public ::tvm::NodeRef {
public:
BuildConfig() {}
explicit BuildConfig(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {}
explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {}
const BuildConfigNode* operator->() const {
return static_cast<const BuildConfigNode*>(node_.get());
......@@ -335,7 +335,7 @@ class GenericFuncNode;
class GenericFunc : public NodeRef {
public:
GenericFunc() {}
explicit GenericFunc(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit GenericFunc(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief Set the default function implementaiton.
......
......@@ -17,7 +17,7 @@ class Channel : public NodeRef {
public:
/*! \brief default constructor */
Channel() {}
explicit Channel(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit Channel(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......
......@@ -76,7 +76,7 @@ class Var : public HalideIR::VarExpr {
public:
EXPORT explicit Var(const std::string& name_hint = "v",
Type t = Int(32)) : VarExpr(name_hint, t) {}
explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
explicit Var(NodePtr<Node> n) : VarExpr(n) {}
explicit Var(VarExpr v) : VarExpr(v) {}
/*!
* \brief Make a new copy of var with same type, append suffix
......@@ -107,7 +107,7 @@ class Range : public HalideIR::IR::Range {
public:
/*! \brief constructor */
Range() {}
explicit Range(std::shared_ptr<Node> n) : HalideIR::IR::Range(n) {}
explicit Range(NodePtr<Node> n) : HalideIR::IR::Range(n) {}
/*!
* \brief constructor by begin and end
* \param begin The begin of the range.
......@@ -197,7 +197,7 @@ class IterVar : public NodeRef {
// construct a new iter var without a domain
IterVar() {}
// construct from shared ptr.
explicit IterVar(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit IterVar(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......
......@@ -28,7 +28,7 @@ struct CommReducerNode;
struct CommReducer : public NodeRef {
CommReducer() {}
explicit CommReducer(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit CommReducer(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......
......@@ -6,7 +6,7 @@
#ifndef TVM_IR_FUNCTOR_EXT_H_
#define TVM_IR_FUNCTOR_EXT_H_
#include <tvm/ir_functor.h>
#include "node/ir_functor.h"
#include "ir.h"
namespace tvm {
......
......@@ -6,10 +6,10 @@
#ifndef TVM_IR_MUTATOR_H_
#define TVM_IR_MUTATOR_H_
#include <tvm/ir_functor.h>
#include <unordered_map>
#include "expr.h"
#include "ir.h"
#include "node/ir_functor.h"
namespace tvm {
namespace ir {
......
......@@ -9,7 +9,6 @@
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <tvm/ir_functor.h>
#include <arithmetic/Simplify.h>
#include <unordered_map>
#include <vector>
......
......@@ -6,8 +6,8 @@
#ifndef TVM_IR_VISITOR_H_
#define TVM_IR_VISITOR_H_
#include <tvm/ir_functor.h>
#include "ir.h"
#include "node/ir_functor.h"
namespace tvm {
namespace ir {
......
......@@ -7,13 +7,13 @@
#ifndef TVM_LOWERED_FUNC_H_
#define TVM_LOWERED_FUNC_H_
#include <tvm/container.h>
#include <ir/FunctionBase.h>
#include <string>
#include "base.h"
#include "expr.h"
#include "tensor.h"
#include "node/container.h"
namespace tvm {
......@@ -27,7 +27,7 @@ class LoweredFuncNode;
class LoweredFunc : public FunctionRef {
public:
LoweredFunc() {}
explicit LoweredFunc(std::shared_ptr<Node> n) : FunctionRef(n) {}
explicit LoweredFunc(NodePtr<Node> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/node/ir_functor.h
* \brief Defines the IRFunctor data structures.
*/
#ifndef TVM_NODE_IR_FUNCTOR_H_
#define TVM_NODE_IR_FUNCTOR_H_
#include <dmlc/logging.h>
#include <string>
#include <vector>
#include <type_traits>
#include <functional>
#include "node.h"
#include "../runtime/registry.h"
namespace tvm {
/*!
* \brief A dynamical dispatched functor on NodeRef in the first argument.
*
* \code
* IRFunctor<std::string (const NodeRef& n, std::string prefix)> tostr;
* tostr.set_dispatch<Add>([](const Add* op, std::string prefix) {
* return prefix + "Add";
* });
* tostr.set_dispatch<IntImm>([](const IntImm* op) {
* return prefix + "IntImm"
* });
*
* Expr x = make_const(1);
* Expr y = x + x;
* // dispatch to IntImm, outputs "MyIntImm"
* LOG(INFO) << tostr(x, "My");
* // dispatch to IntImm, outputs "MyAdd"
* LOG(INFO) << tostr(y, "My");
* \endcode
*
* \tparam FType function signiture
* This type if only defined for FType with function signiture
*/
template<typename FType>
class IRFunctor;
template<typename R, typename ...Args>
class IRFunctor<R(const NodeRef& n, Args...)> {
private:
using Function = std::function<R (const NodeRef&n, Args...)>;
using TSelf = IRFunctor<R (const NodeRef& n, Args...)>;
/*! \brief internal function table */
std::vector<Function> func_;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*!
* \brief Whether the functor can dispatch the corresponding Node
* \param n The node to be dispatched
* \return Whether dispatching function is registered for n's type.
*/
inline bool can_dispatch(const NodeRef& n) const {
uint32_t type_index = n.type_index();
return type_index < func_.size() && func_[type_index] != nullptr;
}
/*!
* \brief invoke the functor , dispatch on type of n
* \param n The Node argument
* \param args The additional arguments
* \return The result.
*/
inline R operator()(const NodeRef& n, Args... args) const {
uint32_t type_index = n.type_index();
CHECK(type_index < func_.size() &&
func_[type_index] != nullptr)
<< "IRFunctor calls un-registered function on type "
<< Node::TypeIndex2Key(type_index);
return func_[type_index](n, std::forward<Args>(args)...);
}
/*!
* \brief set the dispacher for type TNode
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& set_dispatch(Function f) { // NOLINT(*)
uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
if (func_.size() <= tindex) {
func_.resize(tindex + 1, nullptr);
}
CHECK(func_[tindex] == nullptr)
<< "Dispatch for " << Node::TypeIndex2Key(tindex)
<< " is already set";
func_[tindex] = f;
return *this;
}
/*!
* \brief set the dispacher for type TNode
* This allows f to used detailed const Node pointer to replace NodeRef
*
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
Function fun = [f](const NodeRef& n, Args... args) {
return f(static_cast<const TNode*>(n.node_.get()),
std::forward<Args>(args)...);
};
return this->set_dispatch<TNode>(fun);
}
/*!
* \brief unset the dispacher for type TNode
*
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& clear_dispatch() { // NOLINT(*)
uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
func_[tindex] = nullptr;
return *this;
}
};
#define TVM_REGISTER_VAR_DEF(ClsName) \
static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName
/*!
* \brief Useful macro to set IRFunctor dispatch in a global static field.
*
* \code
* // Use IRFunctor to implement IRPrinter similar to Visitor Pattern.
* // vtable allows easy patch in of new Node types, without changing
* // interface of IRPrinter.
*
* class IRPrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
* void print(Expr e) {
* const static FType& f = *vtable();
* f(e, this);
* }
*
* using FType = IRFunctor<void (const NodeRef&, IRPrinter *)>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*0
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
* .set_dispatch<Add>([](const Add* n, IRPrinter* p) {
* p->print(n->a);
* p->stream << '+'
* p->print(n->b);
* });
*
*
* \endcode
*
* \param ClsName The name of the class
* \param FField The static function that returns a singleton of IRFunctor.
*/
#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
ClsName::FField()
/*!
* \brief A container for a list of callbacks. All callbacks are invoked when
* the object is destructed.
*/
class IRFunctorCleanList {
public:
~IRFunctorCleanList() {
for (auto &f : clean_items) {
f();
}
}
void append(std::function<void()> func) {
clean_items.push_back(func);
}
private:
std::vector< std::function<void()> > clean_items;
};
/*!
* \brief A wrapper around IRFunctor that will record calls to set_dispatch
* and make a corresponding call to clear_dispatch when the last copy of
* the IRFunctorStaticRegistry is destructed. When assigned to a static variable,
* this can be used by NNVM and other libraries to unregister callbacks when
* the library is unloaded. This prevents crashes when the underlying IRFunctor
* is destructed as it will no longer contain std::function instances allocated
* by a library that has been unloaded.
*/
template<typename FType>
class IRFunctorStaticRegistry;
template<typename R, typename ...Args>
class IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> {
private:
IRFunctor<R(const NodeRef& n, Args...)> *irf_;
std::shared_ptr<IRFunctorCleanList> free_list;
using TSelf = IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>;
public:
IRFunctorStaticRegistry(IRFunctor<R(const NodeRef& n, Args...)> *irf) {
irf_ = irf;
free_list = std::make_shared<IRFunctorCleanList>();
}
template<typename TNode>
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
irf_->template set_dispatch<TNode>(f);
auto irf_copy = irf_;
free_list.get()->append([irf_copy] {
irf_copy->template clear_dispatch<TNode>();
});
return *this;
}
};
/*!
* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows
* the compiler to deduce the template types.
*/
template<typename R, typename ...Args>
IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> MakeIRFunctorStaticRegistry(
IRFunctor<R(const NodeRef& n, Args...)> *irf) {
return IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>(irf);
}
#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \
static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName
/*!
* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry.
* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of
* TVM_STATIC_IR_FUNCTOR.
*/
#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField) \
TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
MakeIRFunctorStaticRegistry(&ClsName::FField())
} // namespace tvm
#endif // TVM_NODE_IR_FUNCTOR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/node/memory.h
* \brief Node memory management.
*/
#ifndef TVM_NODE_MEMORY_H_
#define TVM_NODE_MEMORY_H_
#include "node.h"
namespace tvm {
/*!
* \brief Allocate a node object.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
*/
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args);
// Detail implementations after this
//
// The current design allows swapping the
// allocator pattern when necessary.
//
// Possible future allocator optimizations:
// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr)
// - Thread-local object pools: one pool per size and alignment requirement.
// - Can specialize by type of object to give the specific allocator to each object.
//
template<typename T>
class SimpleNodeAllocator {
public:
template<typename... Args>
static T* New(Args&&... args) {
return new T(std::forward<Args>(args)...);
}
static NodeBase::FDeleter Deleter() {
return Deleter_;
}
private:
static void Deleter_(NodeBase* ptr) {
delete static_cast<T*>(ptr);
}
};
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args) {
using Allocator = SimpleNodeAllocator<T>;
static_assert(std::is_base_of<NodeBase, T>::value,
"make_node can only be used to create NodeBase");
T* node = Allocator::New(std::forward<Args>(args)...);
node->deleter_ = Allocator::Deleter();
return NodePtr<T>(node);
}
} // namespace tvm
#endif // TVM_NODE_MEMORY_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/node/node.h
* \brief Node system data structure.
*/
#ifndef TVM_NODE_NODE_H_
#define TVM_NODE_NODE_H_
#include <string>
#include <vector>
#include <type_traits>
#include "base/Type.h"
#include "../runtime/node_base.h"
#include "../runtime/c_runtime_api.h"
namespace tvm {
using HalideIR::Type;
// forward declaration
class Node;
class NodeRef;
namespace runtime {
// forward declaration
class NDArray;
} // namespace runtime
/*!
* \brief Visitor class to each node content.
* The content is going to be called for each field.
*/
class TVM_DLL AttrVisitor {
public:
//! \cond Doxygen_Suppress
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0;
virtual void Visit(const char* key, int* value) = 0;
virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, void** value) = 0;
virtual void Visit(const char* key, Type* value) = 0;
virtual void Visit(const char* key, NodeRef* value) = 0;
virtual void Visit(const char* key, runtime::NDArray* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr));
}
//! \endcond
};
/*!
* \brief base class of node container in DSL AST.
* All object's internal is stored as std::shared_ptr<Node>
*/
class TVM_DLL Node : public NodeBase {
public:
/*! \brief virtual destructor */
virtual ~Node() {}
/*! \return The unique type key of the node */
virtual const char* type_key() const = 0;
/*!
* \brief Apply visitor to each field of the Node
* Visitor could mutate the content of the node.
* override if Node contains attribute fields.
* \param visitor The visitor
*/
virtual void VisitAttrs(AttrVisitor* visitor) {}
/*! \return the type index of the node */
virtual const uint32_t type_index() const = 0;
/*!
* \brief Whether this node derives from node with type_index=tid.
* Implemented by TVM_DECLARE_NODE_TYPE_INFO
*
* \param tid The type index.
* \return the check result.
*/
virtual const bool _DerivedFrom(uint32_t tid) const;
/*!
* \brief get a runtime unique type index given a type key
* \param type_key Type key of a type.
* \return the corresponding type index.
*/
static uint32_t TypeKey2Index(const char* type_key);
/*!
* \brief get type key from type index.
* \param index The type index
* \return the corresponding type key.
*/
static const char* TypeIndex2Key(uint32_t index);
/*!
* \return whether the type is derived from
*/
template<typename T>
inline bool derived_from() const;
/*!
* \return whether the node is of type T
* \tparam The type to be checked.
*/
template<typename T>
inline bool is_type() const;
/*!
* \brief Get a NodeRef that holds reference to this Node.
* \return the NodeRef
*/
inline NodeRef GetNodeRef() const;
// node ref can see this
friend class NodeRef;
static constexpr const char* _type_key = "Node";
};
/*! \brief Base class of all node reference object */
class NodeRef {
public:
/*! \brief type indicate the container type */
using ContainerType = Node;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator==(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool same_as(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator<(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator!=(const NodeRef& other) const;
/*! \return the hash function for NodeRef */
inline size_t hash() const;
/*! \return whether the expression is null */
inline bool defined() const;
/*! \return the internal type index of IRNode */
inline uint32_t type_index() const;
/*! \return the internal node pointer */
inline const Node* get() const;
/*! \return the internal node pointer */
inline const Node* operator->() const;
/*!
* \brief Downcast this ir node to its actual type (e.g. Add, or
* Select). This returns nullptr if the node is not of the requested
* type. Example usage:
*
* if (const Add *add = node->as<Add>()) {
* // This is an add node
* }
* \tparam T the target type, must be subtype of IRNode
*/
template<typename T>
inline const T *as() const;
/*!
* \brief A more powerful version of as that also works with
* intermediate base types.
* \tparam T the target type, must be subtype of IRNode
*/
template<typename T>
inline const T *as_derived() const;
/*! \brief default constructor */
NodeRef() = default;
explicit NodeRef(NodePtr<Node> node) : node_(node) {}
/*! \brief the internal node object, do not touch */
NodePtr<Node> node_;
};
/*!
* \brief helper macro to declare type information in a base node.
*/
#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \
const bool _DerivedFrom(uint32_t tid) const override { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
}
/*!
* \brief helper macro to declare type information in a terminal node
*/
#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \
const char* type_key() const final { \
return TypeName::_type_key; \
} \
const uint32_t type_index() const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
return tidx; \
} \
const bool _DerivedFrom(uint32_t tid) const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
}
// implementations of inline functions after this
template<typename T>
inline bool Node::is_type() const {
// use static field so query only happens once.
static uint32_t type_id = Node::TypeKey2Index(T::_type_key);
return type_id == this->type_index();
}
template<typename T>
inline bool Node::derived_from() const {
// use static field so query only happens once.
static uint32_t type_id = Node::TypeKey2Index(T::_type_key);
return this->_DerivedFrom(type_id);
}
inline NodeRef Node::GetNodeRef() const {
return NodeRef(NodePtr<Node>(const_cast<Node*>(this)));
}
inline const Node* NodeRef::get() const {
return node_.get();
}
inline const Node* NodeRef::operator->() const {
return node_.get();
}
inline bool NodeRef::defined() const {
return node_.get() != nullptr;
}
inline bool NodeRef::operator==(const NodeRef& other) const {
return node_.get() == other.node_.get();
}
inline bool NodeRef::same_as(const NodeRef& other) const {
return node_.get() == other.node_.get();
}
inline bool NodeRef::operator<(const NodeRef& other) const {
return node_.get() < other.node_.get();
}
inline bool NodeRef::operator!=(const NodeRef& other) const {
return node_.get() != other.node_.get();
}
inline size_t NodeRef::hash() const {
return std::hash<Node*>()(node_.get());
}
inline uint32_t NodeRef::type_index() const {
CHECK(node_.get() != nullptr)
<< "null type";
return get()->type_index();
}
template<typename T>
inline const T* NodeRef::as() const {
const Node* ptr = static_cast<const Node*>(get());
if (ptr && ptr->is_type<T>()) {
return static_cast<const T*>(ptr);
}
return nullptr;
}
template<typename T>
inline const T* NodeRef::as_derived() const {
const Node* ptr = static_cast<const Node*>(get());
if (ptr && (ptr->is_type<T>() || ptr->derived_from<T>())) {
return static_cast<const T*>(ptr);
}
return nullptr;
}
/*! \brief The hash function for nodes */
struct NodeHash {
size_t operator()(const NodeRef& a) const {
return a.hash();
}
};
/*! \brief The equal comparator for nodes */
struct NodeEqual {
bool operator()(const NodeRef& a, const NodeRef& b) const {
return a.get() == b.get();
}
};
} // namespace tvm
#endif // TVM_NODE_NODE_H_
......@@ -116,7 +116,7 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef();
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TNodeRef>()
<< " but get " << sptr->type_key();
......@@ -132,7 +132,7 @@ inline TVMArgValue::operator HalideIR::Expr() const {
return Expr(static_cast<float>(value_.v_float64));
}
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var;
}
......@@ -145,27 +145,27 @@ inline TVMArgValue::operator HalideIR::Expr() const {
return Expr(sptr);
}
inline std::shared_ptr<Node>& TVMArgValue::node_sptr() {
inline NodePtr<Node>& TVMArgValue::node_sptr() {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
return *ptr<std::shared_ptr<Node> >();
return *ptr<NodePtr<Node> >();
}
template<typename TNodeRef, typename>
inline bool TVMArgValue::IsNodeType() const {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr =
*ptr<std::shared_ptr<Node> >();
NodePtr<Node>& sptr =
*ptr<NodePtr<Node> >();
return NodeTypeChecker<TNodeRef>::Check(sptr.get());
}
// extensions for TVMRetValue
inline TVMRetValue& TVMRetValue::operator=(
const std::shared_ptr<Node>& other) {
const NodePtr<Node>& other) {
if (other.get() == nullptr) {
SwitchToPOD(kNull);
} else {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other);
SwitchToClass<NodePtr<Node> >(kNodeHandle, other);
}
return *this;
}
......@@ -174,7 +174,7 @@ inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
if (!other.defined()) {
SwitchToPOD(kNull);
} else {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_);
SwitchToClass<NodePtr<Node> >(kNodeHandle, other.node_);
}
return *this;
}
......@@ -186,7 +186,7 @@ inline TNodeRef TVMRetValue::AsNodeRef() const {
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef();
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TNodeRef>()
<< " but get " << sptr->type_key();
......@@ -195,7 +195,7 @@ inline TNodeRef TVMRetValue::AsNodeRef() const {
inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*)
if (other.defined()) {
values_[i].v_handle = const_cast<std::shared_ptr<Node>*>(&(other.node_));
values_[i].v_handle = const_cast<NodePtr<Node>*>(&(other.node_));
type_codes_[i] = kNodeHandle;
} else {
type_codes_[i] = kNull;
......
......@@ -8,7 +8,7 @@
#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/node.h>
#include <tvm/node/node.h>
#include <string>
#include <vector>
......@@ -55,16 +55,16 @@ using NodeEqual = ::tvm::NodeEqual;
* \param NodeName The internal container name.
* \param NodeRefBase The base type.
*/
#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \
class TypeName : public NodeRefBase { \
public: \
TypeName() {} \
explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefBase(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
operator bool() { return this->defined(); } \
using ContainerType = NodeName; \
#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \
class TypeName : public NodeRefBase { \
public: \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRefBase(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
operator bool() { return this->defined(); } \
using ContainerType = NodeName; \
};
/*!
......@@ -82,8 +82,6 @@ class SourceNameNode : public Node {
// override attr visitor
void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); }
TVM_DLL static SourceName make(std::string name);
static constexpr const char* _type_key = "relay.SourceName";
TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node);
};
......@@ -98,7 +96,7 @@ class SourceName : public NodeRef {
SourceName() {}
/*! \brief constructor from node pointer */
explicit SourceName(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit SourceName(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -109,9 +107,9 @@ class SourceName : public NodeRef {
* \brief Get an SourceName for a given operator name.
* Will raise an error if the source name has not been registered.
* \param name Name of the operator.
* \return Reference to a SourceName valid throughout program lifetime.
* \return SourceName valid throughout program lifetime.
*/
TVM_DLL static const SourceName& Get(const std::string& name);
TVM_DLL static SourceName Get(const std::string& name);
/*! \brief specify container node */
using ContainerType = SourceNameNode;
......@@ -176,7 +174,7 @@ template <typename RefType, typename NodeType>
RefType GetRef(const NodeType* ptr) {
static_assert(std::is_same<typename RefType::ContainerType, NodeType>::value,
"Can only cast to the ref of same container type");
return RefType(const_cast<NodeType*>(ptr)->shared_from_this());
return RefType(std::move(ptr->GetNodeRef().node_));
}
// TODO(@tqchen, @jroesch): can we move these semantics to HalideIR
......
......@@ -98,15 +98,15 @@ class EnvironmentNode : public RelayNode {
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
private:
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
tvm::Map<std::string, GlobalVar> global_map_;
};
struct Environment : public NodeRef {
Environment() {}
explicit Environment(std::shared_ptr<tvm::Node> p) : NodeRef(p) {}
explicit Environment(NodePtr<tvm::Node> p) : NodeRef(p) {}
inline EnvironmentNode* operator->() const {
return static_cast<EnvironmentNode*>(node_.get());
......
......@@ -7,7 +7,7 @@
#ifndef TVM_RELAY_EXPR_FUNCTOR_H_
#define TVM_RELAY_EXPR_FUNCTOR_H_
#include <tvm/ir_functor.h>
#include <tvm/node/ir_functor.h>
#include <string>
#include "./expr.h"
#include "./op.h"
......@@ -19,7 +19,7 @@ namespace relay {
* \brief A dynamical functor that dispatches on in the first Expr argument.
* You can use this as a more powerful Visitor, since it allows you to
* define function signatures of Visit Function.
*
*
* \sa tvm/ir_functor.h
*
* \tparam FType function signiture
......@@ -30,7 +30,7 @@ template <typename FType>
class ExprFunctor;
// functions to be overriden.
#define EXPR_FUNCTOR_DEFAULT \
#define EXPR_FUNCTOR_DEFAULT \
{ return VisitExprDefault_(op, std::forward<Args>(args)...); }
#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \
......@@ -152,12 +152,12 @@ class ExprMutator
Expr VisitExpr_(const CallNode* call_node, const Expr& e) override;
Expr VisitExpr_(const LetNode* op, const Expr& e) override;
Expr VisitExpr_(const IfNode* op, const Expr& e) override;
/*! \brief Used to visit the types inside of expressions.
*
/*! \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type
* visitor for types which transform them appropriately.
*/
*/
virtual Type VisitType(const Type& t);
private:
......
......@@ -90,7 +90,7 @@ class Op : public relay::Expr {
/*! \brief default constructor */
Op() {}
/*! \brief constructor from node pointer */
explicit Op(std::shared_ptr<Node> n) : Expr(n) {}
explicit Op(NodePtr<Node> n) : Expr(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -149,9 +149,9 @@ class OpRegistry {
const std::string& description);
/*!
* \brief Attach the type function corresponding to the return type.
* \param rel_name The type relation name to register.
* \param rel_name The type relation name to register.
* \param type_rel_func The backing relation function which can solve an arbitrary
* relation on variables.
* relation on variables.
* \return reference to self.
*/
inline OpRegistry& add_type_rel(
......@@ -338,7 +338,7 @@ inline OpRegistry& OpRegistry::describe(
inline OpRegistry& OpRegistry::add_argument(const std::string& name,
const std::string& type,
const std::string& description) {
std::shared_ptr<AttrFieldInfoNode> n = std::make_shared<AttrFieldInfoNode>();
auto n = make_node<AttrFieldInfoNode>();
n->name = name;
n->type_info = type;
n->description = description;
......
......@@ -8,7 +8,7 @@
#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/node.h>
#include <tvm/node/node.h>
#include <string>
#include "./base.h"
......@@ -37,7 +37,7 @@ class TypeNode : public RelayNode {
class Type : public NodeRef {
public:
Type() {}
explicit Type(std::shared_ptr<tvm::Node> p) : NodeRef(p) {}
explicit Type(NodePtr<tvm::Node> p) : NodeRef(p) {}
using ContainerType = TypeNode;
};
......
......@@ -263,12 +263,16 @@ struct NDArray::Container {
// the usages of functions are documented in place.
inline NDArray::NDArray(Container* data)
: data_(data) {
data_->IncRef();
if (data != nullptr) {
data_->IncRef();
}
}
inline NDArray::NDArray(const NDArray& other)
: data_(other.data_) {
data_->IncRef();
if (data_ != nullptr) {
data_->IncRef();
}
}
inline void NDArray::reset() {
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/runtime/node_base.h
* \brief Base data structure for Node.
*
* \note Node is not a runtime feature.
* This file only exposes the signature of NodePtr for PackedFunc.
*/
#ifndef TVM_RUNTIME_NODE_BASE_H_
#define TVM_RUNTIME_NODE_BASE_H_
#include <utility>
#include <atomic>
namespace tvm {
// forward declarations
template<typename T>
class NodePtr;
class Node;
class NodeRef;
/*!
* \brief Base class of Node for runtime destructor purposes.
*
* Node is a reference counted object which is used to construct AST.
* Each node is backed by a custom deleter, which deletes the object.
* Do not call create raw Node pointer, always use tvm::make_node.
*
* \note In most cases, please inheritate tvm::Node.
* \sa Node, NodePtr, make_node
*/
class NodeBase {
public:
/*!
* \brief type of NodeBase deleter
* \param self pointer to the NodeBase.
*/
typedef void (*FDeleter)(NodeBase* self);
protected:
// default constructor and copy constructor
NodeBase() {}
// override the copy and assign constructors to do nothing.
// This is to make sure only contents, but not deleter and ref_counter
// are copied when a child class copies itself.
NodeBase(const NodeBase& other) { // NOLINT(*)
}
NodeBase(NodeBase&& other) { // NOLINT(*)
}
NodeBase& operator=(const NodeBase& other) { //NOLINT(*)
return *this;
}
NodeBase& operator=(NodeBase&& other) { //NOLINT(*)
return *this;
}
private:
/*! \brief Internal reference counter */
std::atomic<int> ref_counter_{0};
/*!
* \brief deleter of this object to enable customized allocation.
* If the deleter is nullptr, no deletion will be performed.
* The creator of the Node must always set the deleter field properly.
*/
FDeleter deleter_ = nullptr;
// reference counting functions
void IncRef() {
ref_counter_.fetch_add(1, std::memory_order_relaxed);
}
void DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
std::atomic_thread_fence(std::memory_order_acquire);
if (this->deleter_ != nullptr) {
(*this->deleter_)(this);
}
}
}
int use_count() const {
return ref_counter_.load(std::memory_order_relaxed);
}
// friend declaration
template<typename>
friend class NodePtr;
template<typename Y, typename... Args>
friend NodePtr<Y> make_node(Args&&...);
};
/*!
* \brief Smart pointer for Node containers,
* must be subclass of NodeBase
* \tparam T the content data type.
*/
template<typename T>
class NodePtr {
public:
/*! \brief default constructor */
NodePtr() {}
/*! \brief default constructor */
NodePtr(std::nullptr_t) {} // NOLINT(*)
/*!
* \brief copy constructor
* \param other The value to be moved
*/
NodePtr(const NodePtr<T>& other) // NOLINT(*)
: NodePtr(other.data_) {
}
/*!
* \brief copy constructor
* \param other The value to be moved
*/
template<typename Y>
NodePtr(const NodePtr<Y>& other) // NOLINT(*)
: NodePtr(other.data_) {
static_assert(std::is_base_of<T, Y>::value,
"can only assign of child class NodePtr to parent");
}
/*!
* \brief move constructor
* \param other The value to be moved
*/
NodePtr(NodePtr<T>&& other) // NOLINT(*)
: data_(other.data_) {
other.data_ = nullptr;
}
/*!
* \brief move constructor
* \param other The value to be moved
*/
template<typename Y>
NodePtr(NodePtr<Y>&& other) // NOLINT(*)
: data_(other.data_) {
static_assert(std::is_base_of<T, Y>::value,
"can only assign of child class NodePtr to parent");
other.data_ = nullptr;
}
/*! \brief destructor */
~NodePtr() {
this->reset();
}
/*!
* \brief Swap this array with another NDArray
* \param other The other NDArray
*/
void swap(NodePtr<T>& other) { // NOLINT(*)
std::swap(data_, other.data_);
}
/*!
* \return Get the content of the pointer
*/
T* get() const {
return static_cast<T*>(data_);
}
/*!
* \return The pointer
*/
T* operator->() const {
return get();
}
/*!
* \return The reference
*/
T& operator*() const { // NOLINT(*)
return *get();
}
/*!
* \brief copy assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
NodePtr<T>& operator=(const NodePtr<T>& other) { // NOLINT(*)
// takes in plane operator to enable copy elison.
// copy-and-swap idiom
NodePtr(other).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief move assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
NodePtr<T>& operator=(NodePtr<T>&& other) { // NOLINT(*)
// copy-and-swap idiom
NodePtr(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
/*! \brief reset the content of ptr to be nullptr */
void reset() {
if (data_ != nullptr) {
data_->DecRef();
data_ = nullptr;
}
}
/*! \return The use count of the ptr, for debug purposes */
int use_count() const {
return data_ != nullptr ? data_->use_count() : 0;
}
/*! \return whether the reference is unique */
bool unique() const {
return data_ != nullptr && data_->use_count() == 1;
}
/*! \return Whether two NodePtr do not equals each other */
bool operator==(const NodePtr<T>& other) const {
return data_ == other.data_;
}
/*! \return Whether two NodePtr equals each other */
bool operator!=(const NodePtr<T>& other) const {
return data_ != other.data_;
}
/*! \return Whether the pointer is nullptr */
bool operator==(std::nullptr_t null) const {
return data_ == nullptr;
}
/*! \return Whether the pointer is not nullptr */
bool operator!=(std::nullptr_t null) const {
return data_ != nullptr;
}
private:
/*! \brief internal pointer field */
NodeBase* data_{nullptr};
/*!
* \brief constructor from NodeBase
* \param data The node base pointer
*/
explicit NodePtr(NodeBase* data)
: data_(data) {
if (data != nullptr) {
data_->IncRef();
}
}
// friend declaration
friend class Node;
template<typename>
friend class NodePtr;
template<typename Y, typename... Args>
friend NodePtr<Y> make_node(Args&&...);
};
} // namespace tvm
#endif // TVM_RUNTIME_NODE_BASE_H_
......@@ -17,6 +17,7 @@
#include "c_runtime_api.h"
#include "module.h"
#include "ndarray.h"
#include "node_base.h"
namespace HalideIR {
// Forward declare type for extensions
......@@ -31,12 +32,6 @@ struct Expr;
#endif
namespace tvm {
// Forward declare NodeRef and Node for extensions.
// This header works fine without depend on NodeRef
// as long as it is not used.
class Node;
class NodeRef;
namespace runtime {
// forward declarations
class TVMArgs;
......@@ -549,7 +544,7 @@ class TVMArgValue : public TVMPODValue_ {
inline operator HalideIR::Type() const;
inline operator HalideIR::Expr() const;
// get internal node ptr, if it is node
inline std::shared_ptr<Node>& node_sptr();
inline NodePtr<Node>& node_sptr();
};
/*!
......@@ -745,7 +740,7 @@ class TVMRetValue : public TVMPODValue_ {
template<typename TNodeRef>
inline TNodeRef AsNodeRef() const;
inline TVMRetValue& operator=(const NodeRef& other);
inline TVMRetValue& operator=(const std::shared_ptr<Node>& other);
inline TVMRetValue& operator=(const NodePtr<Node>& other);
// type related
inline operator HalideIR::Type() const;
inline TVMRetValue& operator=(const HalideIR::Type& other);
......@@ -775,8 +770,8 @@ class TVMRetValue : public TVMPODValue_ {
break;
}
case kNodeHandle: {
SwitchToClass<std::shared_ptr<Node> >(
kNodeHandle, *other.template ptr<std::shared_ptr<Node> >());
SwitchToClass<NodePtr<Node> >(
kNodeHandle, *other.template ptr<NodePtr<Node> >());
break;
}
default: {
......@@ -821,7 +816,7 @@ class TVMRetValue : public TVMPODValue_ {
case kStr: delete ptr<std::string>(); break;
case kFuncHandle: delete ptr<PackedFunc>(); break;
case kModuleHandle: delete ptr<Module>(); break;
case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break;
case kNodeHandle: delete ptr<NodePtr<Node> >(); break;
case kNDArrayContainer: {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
......
......@@ -36,7 +36,7 @@ enum AttachType : int {
class Stage : public NodeRef {
public:
Stage() {}
explicit Stage(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit Stage(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief create a new schedule for op.
* \param op The operator in the schedule
......@@ -260,7 +260,7 @@ class Stage : public NodeRef {
class Schedule : public NodeRef {
public:
Schedule() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit Schedule(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief Get a copy of current schedule.
* \return The copied schedule.
......@@ -383,7 +383,7 @@ class Schedule : public NodeRef {
class IterVarRelation : public NodeRef {
public:
IterVarRelation() {}
explicit IterVarRelation(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit IterVarRelation(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -397,7 +397,7 @@ class IterVarRelation : public NodeRef {
class IterVarAttr : public NodeRef {
public:
IterVarAttr() {}
explicit IterVarAttr(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit IterVarAttr(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......
......@@ -6,7 +6,6 @@
#ifndef TVM_TENSOR_H_
#define TVM_TENSOR_H_
#include <tvm/container.h>
#include <ir/FunctionBase.h>
#include <string>
#include <vector>
......@@ -15,6 +14,7 @@
#include "base.h"
#include "expr.h"
#include "arithmetic.h"
#include "node/container.h"
namespace tvm {
......@@ -33,7 +33,7 @@ class Tensor : public NodeRef {
public:
/*! \brief default constructor, used internally */
Tensor() {}
explicit Tensor(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit Tensor(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -118,7 +118,7 @@ class Operation : public FunctionRef {
public:
/*! \brief default constructor */
Operation() {}
explicit Operation(std::shared_ptr<Node> n) : FunctionRef(n) {}
explicit Operation(NodePtr<Node> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......
......@@ -19,7 +19,7 @@ class TensorIntrinNode;
class TensorIntrin : public NodeRef {
public:
TensorIntrin() {}
explicit TensorIntrin(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit TensorIntrin(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......
......@@ -94,7 +94,7 @@ class CompileEngine {
return it->second->graph_func;
}
GraphFunc f = DoLower(key->graph, key->inputs, key->target, master_idx);
std::shared_ptr<GraphCacheEntryNode> n = std::make_shared<GraphCacheEntryNode>();
auto n = tvm::make_node<GraphCacheEntryNode>();
n->graph_func = f;
n->use_count = 1;
n->master_idx = master_idx;
......@@ -107,8 +107,7 @@ class CompileEngine {
Array<NodeRef> items;
for (auto& kv : cache_) {
items.push_back(kv.first);
std::shared_ptr<GraphCacheEntryNode> n =
std::make_shared<GraphCacheEntryNode>(*(kv.second.operator->()));
auto n = tvm::make_node<GraphCacheEntryNode>(*(kv.second.operator->()));
items.push_back(GraphCacheEntry(n));
}
return items;
......@@ -126,7 +125,7 @@ class CompileEngine {
// Set the given function on given graph key.
void Set(const GraphKey& key, GraphFunc func) {
std::lock_guard<std::mutex> lock(mutex_);
std::shared_ptr<GraphCacheEntryNode> n = std::make_shared<GraphCacheEntryNode>();
auto n = tvm::make_node<GraphCacheEntryNode>();
n->graph_func = func;
n->use_count = 1;
cache_[key] = GraphCacheEntry(n);
......@@ -265,7 +264,7 @@ class CompileEngine {
graph, inputs, target, master_idx,
&readable_name, &outputs);
std::shared_ptr<GraphFuncNode> gf = std::make_shared<GraphFuncNode>();
auto gf = tvm::make_node<GraphFuncNode>();
gf->target = target;
gf->func_name = GetUniqeName(readable_name);
gf->inputs = inputs;
......
......@@ -71,7 +71,7 @@ struct GraphCacheEntryNode : public tvm::Node {
class GraphCacheEntry : public ::tvm::NodeRef {
public:
GraphCacheEntry() {}
explicit GraphCacheEntry(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {}
explicit GraphCacheEntry(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {}
GraphCacheEntryNode* operator->() {
return static_cast<GraphCacheEntryNode*>(node_.get());
}
......
......@@ -74,8 +74,7 @@ bool GraphKeyEqual::Equal(const GraphKey& a,
GraphKey GraphKeyNode::make(Graph graph,
tvm::Array<Tensor> inputs,
std::string target) {
std::shared_ptr<GraphKeyNode> n
= std::make_shared<GraphKeyNode>();
auto n = tvm::make_node<GraphKeyNode>();
n->graph = std::move(graph);
n->inputs = inputs;
n->target = std::move(target);
......
......@@ -91,8 +91,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
for (size_t i = 0; i < size; ++i) {
tvm::runtime::NDArray temp;
temp.Load(strm);
std::shared_ptr<NDArrayWrapperNode> n
= std::make_shared<NDArrayWrapperNode>();
auto n = tvm::make_node<NDArrayWrapperNode>();
n->name = std::move(names[i]);
n->array = temp;
ret.push_back(NDArrayWrapper(n));
......
......@@ -9,6 +9,7 @@
#include <nnvm/graph.h>
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/node/memory.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
#include <vector>
......
......@@ -96,7 +96,7 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute")
const Array<Tensor>& out_info)
-> Array<Tensor> {
TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, out_info);
if ((*ret.ptr<std::shared_ptr<tvm::Node> >())->derived_from<tvm::TensorNode>()) {
if ((*ret.ptr<::tvm::NodePtr<tvm::Node> >())->derived_from<tvm::TensorNode>()) {
return {ret.operator Tensor()};
} else {
return ret;
......
......@@ -45,11 +45,11 @@ TVM_REGISTER_API("_str")
TVM_REGISTER_API("_Array")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<std::shared_ptr<Node> > data;
std::vector<NodePtr<Node> > data;
for (int i = 0; i < args.size(); ++i) {
data.push_back(args[i].node_sptr());
}
auto node = std::make_shared<ArrayNode>();
auto node = make_node<ArrayNode>();
node->data = std::move(data);
*ret = node;
});
......@@ -87,7 +87,7 @@ TVM_REGISTER_API("_Map")
data.emplace(std::make_pair(args[i].operator std::string(),
args[i + 1].node_sptr()));
}
auto node = std::make_shared<StrMapNode>();
auto node = make_node<StrMapNode>();
node->data = std::move(data);
*ret = node;
} else {
......@@ -101,7 +101,7 @@ TVM_REGISTER_API("_Map")
data.emplace(std::make_pair(args[i].node_sptr(),
args[i + 1].node_sptr()));
}
auto node = std::make_shared<MapNode>();
auto node = make_node<MapNode>();
node->data = std::move(data);
*ret = node;
}
......@@ -163,7 +163,7 @@ TVM_REGISTER_API("_MapItems")
auto& sptr = args[0].node_sptr();
if (sptr->is_type<MapNode>()) {
auto* n = static_cast<const MapNode*>(sptr.get());
auto rkvs = std::make_shared<ArrayNode>();
auto rkvs = make_node<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second);
......@@ -171,7 +171,7 @@ TVM_REGISTER_API("_MapItems")
*ret = rkvs;
} else {
auto* n = static_cast<const StrMapNode*>(sptr.get());
auto rkvs = std::make_shared<ArrayNode>();
auto rkvs = make_node<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(ir::StringImm::make(kv.first).node_);
rkvs->data.push_back(kv.second);
......
......@@ -28,7 +28,7 @@ struct TVMAPIThreadLocalEntry {
/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore;
using TVMAPINode = std::shared_ptr<Node>;
using TVMAPINode = NodePtr<Node>;
struct APIAttrGetter : public AttrVisitor {
std::string skey;
......
......@@ -48,7 +48,7 @@ struct ComExprEntry {
};
// canonical expression for communicative expression.
struct ComExprNode {
struct ComExprNode : public NodeBase {
// base constant value.
int64_t base{0};
// The values to be sumed.
......@@ -60,7 +60,7 @@ struct ComExpr {
public:
// constructor
ComExpr() {}
explicit ComExpr(std::shared_ptr<ComExprNode> ptr) : ptr_(ptr) {}
explicit ComExpr(NodePtr<ComExprNode> ptr) : ptr_(ptr) {}
// get member
ComExprNode* operator->() const {
return ptr_.get();
......@@ -106,7 +106,7 @@ struct ComExpr {
}
private:
std::shared_ptr<ComExprNode> ptr_;
NodePtr<ComExprNode> ptr_;
};
// binary comparison op.
......@@ -173,7 +173,7 @@ class Canonical::Internal : public IRMutator {
if (sum.defined()) return sum;
const int64_t *v1 = as_const_int(value);
const uint64_t *v2 = as_const_uint(value);
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
auto n = make_node<ComExprNode>();
if (v1) {
n->base = *v1;
} else if (v2) {
......@@ -471,8 +471,8 @@ class Canonical::Internal : public IRMutator {
Type type = coeff.type();
int64_t value = GetConstIntValue(coeff);
if (value < 0) return {};
std::shared_ptr<ComExprNode> xnode = std::make_shared<ComExprNode>();
std::shared_ptr<ComExprNode> ynode = std::make_shared<ComExprNode>();
auto xnode = make_node<ComExprNode>();
auto ynode = make_node<ComExprNode>();
if (a->base % value == 0) {
xnode->base = a->base;
} else {
......@@ -507,7 +507,7 @@ class Canonical::Internal : public IRMutator {
std::vector<ComExpr> pair = TryLinearEquation(a, v);
if (pair.size() == 0) {
int64_t value = GetConstIntValue(v);
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
auto n = make_node<ComExprNode>();
n->base = a->base % value;
for (auto e : a->elem) {
if (e.scale % value == 0) continue;
......@@ -554,8 +554,7 @@ class Canonical::Internal : public IRMutator {
if (value == 0) {
return make_zero(v.type());
}
std::shared_ptr<ComExprNode> vsum =
std::make_shared<ComExprNode>(*a.operator->());
auto vsum = make_node<ComExprNode>(*a.operator->());
vsum->base *= value;
for (auto& e : vsum->elem) {
e.scale *= value;
......@@ -576,7 +575,7 @@ class Canonical::Internal : public IRMutator {
ComExpr SumAdd_(const ComExpr& suma,
const ComExpr& sumb,
int bscale) {
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
auto n = make_node<ComExprNode>();
n->base = suma->base + sumb->base * bscale;
// merge of suma and sumb;
size_t i = 0, j = 0;
......
......@@ -329,7 +329,7 @@ inline IntSet AsStrideSet(IntSet a) {
if (a.as<StrideSet>()) return a;
const IntervalSet* s = a.as<IntervalSet>();
CHECK(s->i.is_bounded());
std::shared_ptr<StrideSet> n = std::make_shared<StrideSet>();
NodePtr<StrideSet> n = make_node<StrideSet>();
n->base = s->i;
return IntSet(n);
}
......@@ -348,7 +348,7 @@ inline IntSet CombineSets<Add>(IntSet a, IntSet b) {
b = AsStrideSet(b);
const StrideSet* a_stride = a.as<StrideSet>();
const StrideSet* b_stride = b.as<StrideSet>();
auto n = std::make_shared<StrideSet>(*a_stride);
auto n = make_node<StrideSet>(*a_stride);
for (size_t i = 0; i < b_stride->extents.size(); ++i) {
n->extents.push_back(b_stride->extents[i]);
n->strides.push_back(b_stride->strides[i]);
......
......@@ -21,14 +21,14 @@ struct IntervalSet : public IntSetNode {
Interval i;
static IntSet make(Interval i) {
std::shared_ptr<IntervalSet> n =
std::make_shared<IntervalSet>();
NodePtr<IntervalSet> n =
make_node<IntervalSet>();
n->i = i;
return IntSet(n);
}
static IntSet make(Expr min, Expr max) {
std::shared_ptr<IntervalSet> n =
std::make_shared<IntervalSet>();
NodePtr<IntervalSet> n =
make_node<IntervalSet>();
n->i.min = min;
n->i.max = max;
return IntSet(n);
......
......@@ -159,7 +159,7 @@ IntSet EvalModular(const Expr& e,
CHECK(m) << "Need to pass ModularSet for Modular Analysis";
mmap[kv.first.get()] = m->e;
}
std::shared_ptr<ModularSet> n = std::make_shared<ModularSet>();
NodePtr<ModularSet> n = make_node<ModularSet>();
n->e = ModularEvaluator(mmap)(e);
return IntSet(n);
}
......
......@@ -32,7 +32,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
*/
Target CreateTarget(const std::string& target_name,
const std::vector<std::string>& options) {
auto target = Target(std::make_shared<TargetNode>());
auto target = Target(make_node<TargetNode>());
auto t = static_cast<TargetNode*>(target.node_.get());
t->target_name = target_name;
......@@ -475,7 +475,7 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
}
BuildConfig build_config() {
return BuildConfig(std::make_shared<BuildConfigNode>());
return BuildConfig(make_node<BuildConfigNode>());
}
/*! \brief Entry to hold the BuildConfig context stack. */
......@@ -533,7 +533,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
});
struct GenericFunc::Manager {
std::unordered_map<std::string, std::shared_ptr<Node> > fmap;
std::unordered_map<std::string, NodePtr<Node> > fmap;
// mutex
std::mutex mutex;
......@@ -551,7 +551,7 @@ GenericFunc GenericFunc::Get(const std::string& name) {
std::lock_guard<std::mutex>(m->mutex);
auto it = m->fmap.find(name);
if (it == m->fmap.end()) {
auto f = std::make_shared<GenericFuncNode>();
auto f = make_node<GenericFuncNode>();
f->name_ = name;
m->fmap[name] = f;
return GenericFunc(f);
......@@ -669,7 +669,7 @@ TVM_REGISTER_API("_BuildConfigGetAddLowerPassInfo")
TVM_REGISTER_API("_GenericFuncCreate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GenericFunc(std::make_shared<GenericFuncNode>());
*ret = GenericFunc(make_node<GenericFuncNode>());
});
TVM_REGISTER_API("_GenericFuncGetGlobal")
......
......@@ -17,14 +17,14 @@ using namespace ir;
ControlSignal ControlSignalNode::make(
ControlSignalType type, int advance_size) {
auto n = std::make_shared<ControlSignalNode>();
auto n = make_node<ControlSignalNode>();
n->ctrl_type = type;
n->advance_size = advance_size;
return ControlSignal(n);
}
StageInput StageInputNode::make(Var var, StageInputType input_type) {
std::shared_ptr<StageInputNode> n = std::make_shared<StageInputNode>();
NodePtr<StageInputNode> n = make_node<StageInputNode>();
n->var = var;
n->input_type = input_type;
return StageInput(n);
......@@ -81,7 +81,7 @@ class PipelineExtractor: public IRVisitor {
arg_handle_[arg.get()] = arg;
}
}
pipeline_ = std::make_shared<PipelineNode>();
pipeline_ = make_node<PipelineNode>();
this->Visit(f->body);
// setup channels
for (const auto &kv : cmap_) {
......@@ -113,7 +113,7 @@ class PipelineExtractor: public IRVisitor {
if (cb.node != nullptr) {
CHECK(cb.node->channel.same_as(ch));
} else {
cb.node = std::make_shared<ChannelBlockNode>();
cb.node = make_node<ChannelBlockNode>();
cb.node->channel = ch;
}
if (op->attr_key == attr::channel_read_scope) {
......@@ -167,8 +167,8 @@ class PipelineExtractor: public IRVisitor {
// The replace logic
StageInputReplacer repl(var_info_);
// Setup the compute block.
std::shared_ptr<ComputeBlockNode> compute =
std::make_shared<ComputeBlockNode>();
NodePtr<ComputeBlockNode> compute =
make_node<ComputeBlockNode>();
compute->loop = Array<Stmt>(loop_);
// setup the advance triggers
for (const auto& e : trigger_) {
......@@ -180,8 +180,8 @@ class PipelineExtractor: public IRVisitor {
} else {
ch = Channel(attr->node.node_);
}
std::shared_ptr<SignalTriggerNode> trigger
= std::make_shared<SignalTriggerNode>();
NodePtr<SignalTriggerNode> trigger
= make_node<SignalTriggerNode>();
trigger->channel_var = ch->handle_var;
// predicate for the trigger
Expr predicate = const_true();
......@@ -249,7 +249,7 @@ class PipelineExtractor: public IRVisitor {
CHECK(!cmap_.count(var))
<< "Multiple access to the same handle";
ChannelEntry& cb = cmap_[var];
cb.node = std::make_shared<ChannelBlockNode>();
cb.node = make_node<ChannelBlockNode>();
cb.node->channel = ChannelNode::make(arg_handle_.at(var), dtype);
return cb.node->channel;
}
......@@ -257,7 +257,7 @@ class PipelineExtractor: public IRVisitor {
private:
// The channel information.
struct ChannelEntry {
std::shared_ptr<ChannelBlockNode> node;
NodePtr<ChannelBlockNode> node;
int read_ref_count{0};
int write_ref_count{0};
};
......@@ -276,7 +276,7 @@ class PipelineExtractor: public IRVisitor {
// The argument handle map
std::unordered_map<const Variable*, Var> arg_handle_;
// The result block.
std::shared_ptr<PipelineNode> pipeline_;
NodePtr<PipelineNode> pipeline_;
};
Pipeline MakePipeline(LoweredFunc f) {
......
......@@ -50,7 +50,7 @@ inline VPIHandleNode* VPIHandle::get() const {
VPIHandle VPIHandleCreate(
const std::shared_ptr<VPISessionEntry>& sess,
VPIRawHandle handle) {
std::shared_ptr<VPIHandleNode> n = std::make_shared<VPIHandleNode>();
auto n = make_node<VPIHandleNode>();
n->sess = sess;
n->handle = handle;
return VPIHandle(n);
......@@ -102,7 +102,7 @@ int VPIGetIntProp(VPIHandleNode* h, int code) {
}
VPISession VPISession::make(int h_pipe_read, int h_pipe_write) {
std::shared_ptr<VPISessionNode> n = std::make_shared<VPISessionNode>();
auto n = make_node<VPISessionNode>();
n->sess = std::make_shared<VPISessionEntry>(h_pipe_read, h_pipe_write);
n->sess->in_control = true;
VPISession sess(n);
......
......@@ -27,7 +27,7 @@ using runtime::PackedFunc;
class VPISession : public NodeRef {
public:
VPISession() {}
explicit VPISession(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit VPISession(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief Get handle by name.
* \param name The name of the handle.
......@@ -63,7 +63,7 @@ class VPISession : public NodeRef {
class VPIHandle : public NodeRef {
public:
VPIHandle() {}
explicit VPIHandle(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit VPIHandle(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief Get handle by name.
* \param name The name of the handle.
......
......@@ -11,10 +11,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "EnvFunc(" << op->name << ")";
});
std::shared_ptr<EnvFuncNode> CreateEnvNode(const std::string& name) {
NodePtr<EnvFuncNode> CreateEnvNode(const std::string& name) {
auto* f = runtime::Registry::Get(name);
CHECK(f != nullptr) << "Cannot find global function \'" << name << '\'';
std::shared_ptr<EnvFuncNode> n = std::make_shared<EnvFuncNode>();
NodePtr<EnvFuncNode> n = make_node<EnvFuncNode>();
n->func = *f;
n->name = name;
return n;
......
......@@ -30,7 +30,7 @@ Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
}
Attrs DictAttrsNode::make(Map<std::string, NodeRef> dict) {
std::shared_ptr<DictAttrsNode> n = std::make_shared<DictAttrsNode>();
NodePtr<DictAttrsNode> n = make_node<DictAttrsNode>();
n->dict = std::move(dict);
return Attrs(n);
}
......
......@@ -289,7 +289,7 @@ Buffer Buffer::MakeStrideView() const {
if ((*this)->strides.size() != 0) return *this;
if ((*this)->shape.size() == 0) return *this;
std::vector<Expr> temp;
auto n = std::make_shared<BufferNode>(*operator->());
auto n = make_node<BufferNode>(*operator->());
Expr acc = make_const(n->DefaultIndexType(), 1);
for (size_t i = n->shape.size(); i != 0 ; --i) {
temp.push_back(acc);
......@@ -373,7 +373,7 @@ Buffer BufferNode::make(Var data,
std::string scope,
int data_alignment,
int offset_factor) {
auto n = std::make_shared<BufferNode>();
auto n = make_node<BufferNode>();
n->data = std::move(data);
n->dtype = dtype;
n->shape = std::move(shape);
......
......@@ -7,7 +7,7 @@
namespace tvm {
Channel ChannelNode::make(Var handle_var, Type dtype) {
auto n = std::make_shared<ChannelNode>();
auto n = make_node<ChannelNode>();
n->handle_var = handle_var;
n->dtype = dtype;
return Channel(n);
......
......@@ -13,18 +13,18 @@ namespace tvm {
using HalideIR::IR::RangeNode;
Range::Range(Expr begin, Expr end)
: Range(std::make_shared<RangeNode>(
: Range(make_node<RangeNode>(
begin,
is_zero(begin) ? end : (end - begin))) {
}
Range Range::make_by_min_extent(Expr min, Expr extent) {
return Range(std::make_shared<HalideIR::IR::RangeNode>(min, extent));
return Range(make_node<HalideIR::IR::RangeNode>(min, extent));
}
IterVar IterVarNode::make(Range dom, Var var,
IterVarType t, std::string thread_tag) {
std::shared_ptr<IterVarNode> n = std::make_shared<IterVarNode>();
NodePtr<IterVarNode> n = make_node<IterVarNode>();
n->dom = dom;
n->var = var;
n->iter_type = t;
......
......@@ -52,7 +52,7 @@ CommReducer CommReducerNode::make(Array<Var> lhs,
Array<Var> rhs,
Array<Expr> result,
Array<Expr> identity_element) {
auto node = std::make_shared<CommReducerNode>();
auto node = make_node<CommReducerNode>();
node->lhs = lhs;
node->rhs = rhs;
node->result = result;
......@@ -83,7 +83,7 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source,
if (!condition.defined()) {
condition = const_true();
}
auto n = std::make_shared<Reduce>();
auto n = make_node<Reduce>();
CHECK(source.defined());
for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined());
......
/*!
* Copyright (c) 2018 by Contributors
* Implementation of IR Node API
* \file node.cc
*/
#include <tvm/node/node.h>
#include <memory>
#include <atomic>
#include <mutex>
#include <unordered_map>
namespace tvm {
namespace {
// single manager of operator information.
struct TypeManager {
// mutex to avoid registration from multiple threads.
// recursive is needed for trigger(which calls UpdateAttrMap)
std::mutex mutex;
std::atomic<uint32_t> type_counter{0};
std::unordered_map<std::string, uint32_t> key2index;
std::vector<std::string> index2key;
// get singleton of the
static TypeManager* Global() {
static TypeManager inst;
return &inst;
}
};
} // namespace
const bool Node::_DerivedFrom(uint32_t tid) const {
static uint32_t tindex = TypeKey2Index(Node::_type_key);
return tid == tindex;
}
// this is slow, usually caller always hold the result in a static variable.
uint32_t Node::TypeKey2Index(const char* key) {
TypeManager *t = TypeManager::Global();
std::lock_guard<std::mutex>(t->mutex);
std::string skey = key;
auto it = t->key2index.find(skey);
if (it != t->key2index.end()) {
return it->second;
}
uint32_t tid = ++(t->type_counter);
t->key2index[skey] = tid;
t->index2key.push_back(skey);
return tid;
}
const char* Node::TypeIndex2Key(uint32_t index) {
TypeManager *t = TypeManager::Global();
std::lock_guard<std::mutex>(t->mutex);
internal_assert(index != 0);
return t->index2key.at(index - 1).c_str();
}
} // namespace tvm
......@@ -6,7 +6,7 @@
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/attrs.h>
#include <tvm/container.h>
#include <tvm/node/container.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
#include <dmlc/json.h>
......@@ -248,7 +248,7 @@ class JSONAttrGetter : public AttrVisitor {
class JSONAttrSetter : public AttrVisitor {
public:
const std::vector<std::shared_ptr<Node> >* node_list_;
const std::vector<NodePtr<Node> >* node_list_;
const std::vector<runtime::NDArray>* tensor_list_;
JSONNode* node_;
......@@ -401,13 +401,13 @@ std::string SaveJSON(const NodeRef& n) {
return os.str();
}
std::shared_ptr<Node> LoadJSON_(std::string json_str) {
NodePtr<Node> LoadJSON_(std::string json_str) {
std::istringstream is(json_str);
dmlc::JSONReader reader(&is);
JSONGraph jgraph;
// load in json graph.
jgraph.Load(&reader);
std::vector<std::shared_ptr<Node> > nodes;
std::vector<NodePtr<Node> > nodes;
std::vector<runtime::NDArray> tensors;
// load in tensors
for (const std::string& blob : jgraph.b64ndarrays) {
......@@ -427,7 +427,7 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) {
<< "Node type \'" << jnode.type_key << "\' is not registered in TVM";
nodes.emplace_back(f->fcreator(jnode.global_key));
} else {
nodes.emplace_back(std::shared_ptr<Node>());
nodes.emplace_back(NodePtr<Node>());
}
}
CHECK_EQ(nodes.size(), jgraph.nodes.size());
......@@ -526,7 +526,7 @@ void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1);
CHECK(f->fglobal_key == nullptr)
<< "Cannot make node type \'" << type_key << "\' with global_key.";
std::shared_ptr<Node> n = f->fcreator(empty_str);
NodePtr<Node> n = f->fcreator(empty_str);
if (n->derived_from<BaseAttrsNode>()) {
static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
} else {
......
......@@ -30,7 +30,7 @@ Tensor TensorNode::make(Array<Expr> shape,
Type dtype,
Operation op,
int value_index) {
auto n = std::make_shared<TensorNode>();
auto n = make_node<TensorNode>();
n->shape = std::move(shape);
n->dtype = dtype;
n->op = op;
......@@ -47,7 +47,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(TensorNode);
Tensor Operation::output(size_t i) const {
auto node = std::make_shared<TensorNode>();
auto node = make_node<TensorNode>();
node->op = *this;
node->value_index = i;
node->dtype = (*this)->output_dtype(i);
......@@ -62,7 +62,7 @@ TensorIntrin TensorIntrinNode::make(std::string name,
Stmt body,
Stmt reduce_init,
Stmt reduce_update) {
auto n = std::make_shared<TensorIntrinNode>();
auto n = make_node<TensorIntrinNode>();
n->name = std::move(name);
n->op = std::move(op);
n->inputs = std::move(inputs);
......
......@@ -69,7 +69,7 @@ Tensor compute(Array<Expr> shape,
std::string name,
std::string tag,
Map<std::string, NodeRef> attrs) {
auto op_node = std::make_shared<ComputeOpNode>();
auto op_node = make_node<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
std::vector<IterVar> axis;
......@@ -91,7 +91,7 @@ Array<Tensor> compute(Array<Expr> shape,
std::string name,
std::string tag,
Map<std::string, NodeRef> attrs) {
auto op_node = std::make_shared<ComputeOpNode>();
auto op_node = make_node<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
std::vector<IterVar> axis;
......@@ -117,7 +117,7 @@ Operation ComputeOpNode::make(std::string name,
Map<std::string, NodeRef> attrs,
Array<IterVar> axis,
Array<Expr> body) {
auto n = std::make_shared<ComputeOpNode>();
auto n = make_node<ComputeOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
......@@ -163,7 +163,7 @@ Operation ComputeOpNode::ReplaceInputs(
if (!new_reduce.same_as(this->body[0])) {
const ir::Reduce* r = new_reduce.as<ir::Reduce>();
for (size_t k = 0; k < this->body.size(); ++k) {
std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r);
auto n = make_node<ir::Reduce>(*r);
n->value_index = static_cast<int>(k);
n->type = r->source[k].type();
arr.push_back(Expr(n));
......
......@@ -43,7 +43,7 @@ Operation ExternOpNode::make(std::string name,
Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders,
Stmt body) {
auto n = std::make_shared<ExternOpNode>();
auto n = make_node<ExternOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
......@@ -68,7 +68,7 @@ Operation ExternOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
auto n = std::make_shared<ExternOpNode>(*this);
auto n = make_node<ExternOpNode>(*this);
n->body = op::ReplaceTensor(this->body, rmap);
for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i];
......
......@@ -36,7 +36,7 @@ Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape,
Type dtype) {
auto n = std::make_shared<PlaceholderOpNode>();
auto n = make_node<PlaceholderOpNode>();
n->name = name;
n->shape = shape;
n->dtype = dtype;
......
......@@ -51,7 +51,7 @@ Operation ScanOpNode::make(std::string name,
Array<Tensor> update,
Array<Tensor> state_placeholder,
Array<Tensor> inputs) {
auto n = std::make_shared<ScanOpNode>();
auto n = make_node<ScanOpNode>();
CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size());
......@@ -135,7 +135,7 @@ Operation ScanOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
std::shared_ptr<ScanOpNode> n = std::make_shared<ScanOpNode>(*this);
auto n = make_node<ScanOpNode>(*this);
for (size_t i = 0; i < n->init.size(); ++i) {
if (rmap.count(n->init[i])) {
n->init.Set(i, rmap.at(n->init[i]));
......
......@@ -90,7 +90,7 @@ class ContextCallCombiner final : public IRMutator {
};
LoweredFunc CombineContextCall(LoweredFunc f) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = ContextCallCombiner().Combine(n->body);
return LoweredFunc(n);
}
......
......@@ -13,38 +13,38 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
Stmt s = *ri;
if (s.as<For>()) {
auto n = std::make_shared<For>(*s.as<For>());
auto n = make_node<For>(*s.as<For>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
auto n = make_node<LetStmt>(*s.as<LetStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
auto n = make_node<AttrStmt>(*s.as<AttrStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<IfThenElse>()) {
auto n = std::make_shared<IfThenElse>(*s.as<IfThenElse>());
auto n = make_node<IfThenElse>(*s.as<IfThenElse>());
CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined());
n->then_case = body;
body = Stmt(n);
} else if (s.as<Block>()) {
auto n = std::make_shared<Block>(*s.as<Block>());
auto n = make_node<Block>(*s.as<Block>());
CHECK(is_no_op(n->rest));
n->rest = body;
body = Stmt(n);
} else if (s.as<AssertStmt>()) {
auto n = std::make_shared<AssertStmt>(*s.as<AssertStmt>());
auto n = make_node<AssertStmt>(*s.as<AssertStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<Allocate>()) {
auto n = std::make_shared<Allocate>(*s.as<Allocate>());
auto n = make_node<Allocate>(*s.as<Allocate>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
......
......@@ -104,7 +104,7 @@ class IntrinInjecter : public IRMutator {
LoweredFunc
LowerIntrin(LoweredFunc f, const std::string& target) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = IntrinInjecter(target).Mutate(n->body);
return LoweredFunc(n);
}
......
......@@ -317,7 +317,7 @@ class ThreadAllreduceBuilder final : public IRMutator {
LoweredFunc
LowerThreadAllreduce(LoweredFunc f, int warp_size) {
CHECK_NE(f->func_type, kHostFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body);
return LoweredFunc(n);
}
......
......@@ -288,7 +288,7 @@ class BuiltinLower : public IRMutator {
};
LoweredFunc LowerTVMBuiltin(LoweredFunc f) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = BuiltinLower().Build(n->body);
return LoweredFunc(n);
}
......
......@@ -93,7 +93,7 @@ class WarpStoreCoeffFinder : private IRVisitor {
arith::DetectLinearEquation(index, {warp_index_});
CHECK_EQ(m.size(), 2U)
<< "LowerWarpMemory failed due to store index=" << index;
int coeff;
int coeff = 0;
Expr mcoeff = ir::Simplify(m[0]);
CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0)
......@@ -317,7 +317,7 @@ class WarpMemoryRewriter : private IRMutator {
LoweredFunc
LowerWarpMemory(LoweredFunc f, int warp_size) {
CHECK_EQ(f->func_type, kDeviceFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = WarpMemoryRewriter(warp_size).Rewrite(n->body);
return LoweredFunc(n);
}
......
......@@ -132,7 +132,7 @@ LoweredFunc MakeAPI(Stmt body,
}
}
std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>();
n->name = name;
n->args = args;
n->handle_data_type = binder.def_handle_dtype();
......@@ -197,7 +197,7 @@ class DeviceTypeBinder: public IRMutator {
LoweredFunc BindDeviceType(LoweredFunc f,
int device_type) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = DeviceTypeBinder(device_type).Mutate(n->body);
return LoweredFunc(n);
}
......
......@@ -67,7 +67,7 @@ RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) {
}
CHECK_EQ(f->func_type, kDeviceFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
auto n = make_node<LoweredFuncNode>(*f.operator->());
// replace the thread axis
for (size_t i = 0; i < n->thread_axis.size(); ++i) {
auto it = tmap.find(n->thread_axis[i]->thread_tag);
......
......@@ -165,8 +165,8 @@ class HostDeviceSplitter : public IRMutator {
handle_data_type_[kv.first.get()] = kv.second;
}
name_ = f->name;
std::shared_ptr<LoweredFuncNode> n =
std::make_shared<LoweredFuncNode>(*f.operator->());
NodePtr<LoweredFuncNode> n =
make_node<LoweredFuncNode>(*f.operator->());
n->body = this->Mutate(f->body);
n->func_type = kHostFunc;
Array<LoweredFunc> ret{LoweredFunc(n)};
......@@ -180,7 +180,7 @@ class HostDeviceSplitter : public IRMutator {
Stmt SplitDeviceFunc(Stmt body) {
std::ostringstream os;
os << name_ << "_kernel" << device_funcs_.size();
std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>();
// isolate the device function.
IRUseDefAnalysis m;
m.visit_thread_extent_ = false;
......
......@@ -950,8 +950,7 @@ class VectorAllocRewriter : public IRMutator {
LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
std::shared_ptr<LoweredFuncNode> n =
std::make_shared<LoweredFuncNode>(*f.operator->());
auto n = make_node<LoweredFuncNode>(*f.operator->());
VectorAllocRewriter rewriter;
n->body = rewriter.Mutate(n->body);
for (Var arg : f->args) {
......
......@@ -329,7 +329,7 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
CHECK_NE(f->func_type, kHostFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = ThreadSync(f->body, storage_scope);
return LoweredFunc(n);
}
......
......@@ -12,50 +12,39 @@ namespace relay {
using tvm::IRPrinter;
using namespace tvm::runtime;
SourceName SourceNameNode::make(std::string name) {
std::shared_ptr<SourceNameNode> n = std::make_shared<SourceNameNode>();
n->name = std::move(name);
return SourceName(n);
}
std::shared_ptr<SourceNameNode> CreateSourceName(const std::string& name) {
SourceName sn = SourceName::Get(name);
CHECK(!sn.defined()) << "Cannot find source name \'" << name << '\'';
std::shared_ptr<Node> node = sn.node_;
return std::dynamic_pointer_cast<SourceNameNode>(node);
}
const SourceName& SourceName::Get(const std::string& name) {
static std::unordered_map<std::string, SourceName> source_map;
NodePtr<SourceNameNode> GetSourceNameNode(const std::string& name) {
// always return pointer as the reference can change as map re-allocate.
// or use another level of indirection by creating a unique_ptr
static std::unordered_map<std::string, NodePtr<SourceNameNode> > source_map;
auto sn = source_map.find(name);
if (sn == source_map.end()) {
auto source_name = SourceNameNode::make(name);
source_map.insert({name, source_name});
return source_map.at(name);
NodePtr<SourceNameNode> n = make_node<SourceNameNode>();
n->name = std::move(name);
source_map[name] = n;
return n;
} else {
return sn->second;
}
}
TVM_REGISTER_API("relay._make.SourceName")
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) {
*ret = SourceNameNode::make(args[0]);
});
SourceName SourceName::Get(const std::string& name) {
return SourceName(GetSourceNameNode(name));
}
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const SourceNameNode *node, tvm::IRPrinter *p) {
p->stream << "SourceNameNode(" << node->name << ", " << node << ")";
});
.set_dispatch<SourceNameNode>([](const SourceNameNode *node, tvm::IRPrinter *p) {
p->stream << "SourceName(" << node->name << ", " << node << ")";
});
TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(CreateSourceName)
.set_creator(GetSourceNameNode)
.set_global_key([](const Node* n) {
return static_cast<const SourceNameNode*>(n)->name;
});
Span SpanNode::make(SourceName source, int lineno, int col_offset) {
std::shared_ptr<SpanNode> n = std::make_shared<SpanNode>();
auto n = make_node<SpanNode>();
n->source = std::move(source);
n->lineno = lineno;
n->col_offset = col_offset;
......
......@@ -15,7 +15,7 @@ using tvm::IRPrinter;
using namespace runtime;
Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
std::shared_ptr<EnvironmentNode> n = std::make_shared<EnvironmentNode>();
auto n = make_node<EnvironmentNode>();
n->functions = std::move(global_funcs);
return Environment(n);
}
......@@ -31,20 +31,22 @@ GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) {
}
}
/*! \brief Add a new item to the global environment
/*!
* \brief Add a new item to the global environment
* \note if the update flag is not set adding a duplicate
* definition will trigger an exception, otherwise we will
* update the definition if and only if it is type compatible.
*/
void EnvironmentNode::Add(const GlobalVar &var, const Function &func,
void EnvironmentNode::Add(const GlobalVar &var,
const Function &func,
bool update) {
// Type check the item before we add it to the environment.
auto env = GetRef<Environment>(this);
auto env = relay::GetRef<Environment>(this);
Expr checked_expr = InferType(env, var, func);
if (const FunctionNode *func_node = checked_expr.as<FunctionNode>()) {
auto checked_func = GetRef<Function>(func_node);
auto checked_func = relay::GetRef<Function>(func_node);
auto type = checked_func->checked_type();
CHECK(IsFullyResolved(type));
......@@ -100,46 +102,46 @@ void EnvironmentNode::Merge(const Environment &env) {
}
TVM_REGISTER_API("relay._make.Environment")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EnvironmentNode::make(args[0]);
});
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EnvironmentNode::make(args[0]);
});
TVM_REGISTER_API("relay._env.Environment_Add")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
env->Add(args[1], args[2], false);
});
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
env->Add(args[1], args[2], false);
});
TVM_REGISTER_API("relay._env.Environment_GetGlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
*ret = env->GetGlobalVar(args[1]);
});
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
*ret = env->GetGlobalVar(args[1]);
});
TVM_REGISTER_API("relay._env.Environment_Lookup")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
GlobalVar var = args[1];
*ret = env->Lookup(var);
});
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
GlobalVar var = args[1];
*ret = env->Lookup(var);
});
TVM_REGISTER_API("relay._env.Environment_Lookup_str")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
std::string var_name = args[1];
auto var = env->GetGlobalVar(var_name);
*ret = env->Lookup(var);
});
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
std::string var_name = args[1];
auto var = env->GetGlobalVar(var_name);
*ret = env->Lookup(var);
});
TVM_REGISTER_API("relay._env.Environment_Merge")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
env->Merge(args[1]);
});
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
env->Merge(args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<EnvironmentNode>([](const EnvironmentNode *node,
tvm::IRPrinter *p) {
.set_dispatch<EnvironmentNode>(
[](const EnvironmentNode *node, tvm::IRPrinter *p) {
p->stream << "EnvironmentNode( " << node->functions << ")";
});
......
......@@ -3,7 +3,6 @@
* \file src/tvm/ir/expr.cc
* \brief The expression AST nodes of Relay.
*/
#include <tvm/ir_functor.h>
#include <tvm/relay/expr.h>
namespace tvm {
......@@ -13,21 +12,20 @@ using tvm::IRPrinter;
using namespace tvm::runtime;
Constant ConstantNode::make(runtime::NDArray data) {
std::shared_ptr<ConstantNode> n = std::make_shared<ConstantNode>();
NodePtr<ConstantNode> n = make_node<ConstantNode>();
n->data = std::move(data);
return Constant(n);
}
TVM_REGISTER_API("relay._make.Constant")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ConstantNode::make(args[0]);
});
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ConstantNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstantNode>([](const ConstantNode *node,
tvm::IRPrinter *p) {
p->stream << "ConstantNode(TODO)";
});
.set_dispatch<ConstantNode>([](const ConstantNode *node, tvm::IRPrinter *p) {
p->stream << "Constant(TODO)";
});
TensorType ConstantNode::tensor_type() const {
auto dtype = TVMType2Type(data->dtype);
......@@ -41,57 +39,55 @@ TensorType ConstantNode::tensor_type() const {
}
Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
std::shared_ptr<TupleNode> n = std::make_shared<TupleNode>();
NodePtr<TupleNode> n = make_node<TupleNode>();
n->fields = std::move(fields);
return Tuple(n);
}
TVM_REGISTER_API("relay._make.Tuple")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TupleNode::make(args[0]);
});
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TupleNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleNode>([](const TupleNode *node, tvm::IRPrinter *p) {
p->stream << "TupleNode(" << node->fields << ")";
});
.set_dispatch<TupleNode>([](const TupleNode *node, tvm::IRPrinter *p) {
p->stream << "Tuple(" << node->fields << ")";
});
Var VarNode::make(std::string name_hint) {
std::shared_ptr<VarNode> n = std::make_shared<VarNode>();
NodePtr<VarNode> n = make_node<VarNode>();
n->name_hint = std::move(name_hint);
return Var(n);
}
TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = VarNode::make(args[0]);
});
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = VarNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode *node,
tvm::IRPrinter *p) {
p->stream << "VarNode(" << node->name_hint << ")";
});
.set_dispatch<VarNode>([](const VarNode *node, tvm::IRPrinter *p) {
p->stream << "Var(" << node->name_hint << ")";
});
GlobalVar GlobalVarNode::make(std::string name_hint) {
std::shared_ptr<GlobalVarNode> n = std::make_shared<GlobalVarNode>();
NodePtr<GlobalVarNode> n = make_node<GlobalVarNode>();
n->name_hint = std::move(name_hint);
return GlobalVar(n);
}
TVM_REGISTER_API("relay._make.GlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = GlobalVarNode::make(args[0]);
});
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = GlobalVarNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const GlobalVarNode *node,
tvm::IRPrinter *p) {
p->stream << "GlobalVarNode(" << node->name_hint << ")";
});
.set_dispatch<GlobalVarNode>([](const GlobalVarNode *node, tvm::IRPrinter *p) {
p->stream << "GlobalVar(" << node->name_hint << ")";
});
Param ParamNode::make(Var var, Type type) {
std::shared_ptr<ParamNode> n = std::make_shared<ParamNode>();
NodePtr<ParamNode> n = make_node<ParamNode>();
n->var = std::move(var);
n->type = std::move(type);
return Param(n);
......@@ -104,12 +100,12 @@ TVM_REGISTER_API("relay._make.Param")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ParamNode>([](const ParamNode *node, tvm::IRPrinter *p) {
p->stream << "ParamNode(" << node->var << ", " << node->type << ")";
p->stream << "Param(" << node->var << ", " << node->type << ")";
});
Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body,
tvm::Array<TypeParam> type_params) {
std::shared_ptr<FunctionNode> n = std::make_shared<FunctionNode>();
NodePtr<FunctionNode> n = make_node<FunctionNode>();
n->params = std::move(params);
n->ret_type = std::move(ret_type);
n->body = std::move(body);
......@@ -140,7 +136,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
Array<Type> type_args) {
std::shared_ptr<CallNode> n = std::make_shared<CallNode>();
NodePtr<CallNode> n = make_node<CallNode>();
n->op = std::move(op);
n->args = std::move(args);
n->attrs = std::move(attrs);
......@@ -160,7 +156,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
});
Let LetNode::make(Var var, Expr value, Expr body, Type value_type) {
std::shared_ptr<LetNode> n = std::make_shared<LetNode>();
NodePtr<LetNode> n = make_node<LetNode>();
n->var = std::move(var);
n->value = std::move(value);
n->body = std::move(body);
......@@ -180,7 +176,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
});
If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
std::shared_ptr<IfNode> n = std::make_shared<IfNode>();
NodePtr<IfNode> n = make_node<IfNode>();
n->cond = std::move(cond);
n->true_branch = std::move(true_branch);
n->false_branch = std::move(false_branch);
......
......@@ -51,7 +51,7 @@ const Op& Op::Get(const std::string& name) {
OpRegistry::OpRegistry() {
OpManager* mgr = OpManager::Global();
std::shared_ptr<OpNode> n = std::make_shared<OpNode>();
NodePtr<OpNode> n = make_node<OpNode>();
n->index_ = mgr->op_counter++;
op_ = Op(n);
}
......@@ -90,14 +90,14 @@ void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value,
// Frontend APIs
TVM_REGISTER_API("relay.op._ListOpNames")
.set_body_typed<Array<tvm::Expr>()>([]() {
Array<tvm::Expr> ret;
for (const std::string& name :
dmlc::Registry<OpRegistry>::ListAllNames()) {
ret.push_back(tvm::Expr(name));
}
return ret;
});
.set_body_typed<Array<tvm::Expr>()>([]() {
Array<tvm::Expr> ret;
for (const std::string& name :
dmlc::Registry<OpRegistry>::ListAllNames()) {
ret.push_back(tvm::Expr(name));
}
return ret;
});
TVM_REGISTER_API("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get);
......@@ -138,11 +138,10 @@ TVM_REGISTER_API("relay.op._Register")
}
});
std::shared_ptr<OpNode> CreateOp(const std::string& name) {
NodePtr<Node> CreateOp(const std::string& name) {
auto op = Op::Get(name);
CHECK(!op.defined()) << "Cannot find op \'" << name << '\'';
std::shared_ptr<Node> node = op.node_;
return std::dynamic_pointer_cast<OpNode>(node);
return op.node_;
}
TVM_REGISTER_NODE_TYPE(OpNode)
......
......@@ -3,7 +3,6 @@
* \file src/tvm/ir/type.cc
* \brief The type system AST nodes of Relay.
*/
#include <tvm/ir_functor.h>
#include <tvm/relay/type.h>
namespace tvm {
......@@ -13,7 +12,7 @@ using tvm::IRPrinter;
using namespace tvm::runtime;
TensorType TensorTypeNode::make(Array<ShapeExpr> shape, DataType dtype) {
std::shared_ptr<TensorTypeNode> n = std::make_shared<TensorTypeNode>();
NodePtr<TensorTypeNode> n = make_node<TensorTypeNode>();
n->shape = std::move(shape);
n->dtype = std::move(dtype);
return TensorType(n);
......@@ -36,7 +35,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
});
TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) {
std::shared_ptr<TypeParamNode> n = std::make_shared<TypeParamNode>();
NodePtr<TypeParamNode> n = make_node<TypeParamNode>();
n->var = tvm::Var(name);
n->kind = std::move(kind);
return TypeParam(n);
......@@ -59,7 +58,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
FuncType FuncTypeNode::make(tvm::Array<Type> arg_types, Type ret_type,
tvm::Array<TypeParam> type_params,
tvm::Array<TypeConstraint> type_constraints) {
std::shared_ptr<FuncTypeNode> n = std::make_shared<FuncTypeNode>();
NodePtr<FuncTypeNode> n = make_node<FuncTypeNode>();
n->arg_types = std::move(arg_types);
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params);
......@@ -81,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
});
TypeRelation TypeRelationNode::make(std::string name, TypeRelationFn func, Array<Type> args) {
std::shared_ptr<TypeRelationNode> n = std::make_shared<TypeRelationNode>();
NodePtr<TypeRelationNode> n = make_node<TypeRelationNode>();
n->name = std::move(name);
n->func_ = std::move(func);
n->args = std::move(args);
......@@ -101,7 +100,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
});
TupleType TupleTypeNode::make(Array<Type> fields) {
std::shared_ptr<TupleTypeNode> n = std::make_shared<TupleTypeNode>();
NodePtr<TupleTypeNode> n = make_node<TupleTypeNode>();
n->fields = std::move(fields);
return TupleType(n);
}
......
......@@ -10,10 +10,9 @@
*
* For example tensors are not allowed to contain functions in Relay.
*
* We check this by ensuring the `dtype` field of a Tensor always
* We check this by ensuring the `dtype` field of a Tensor always
* contains a data type such as `int`, `float`, `uint`.
*/
#include <tvm/ir_functor.h>
#include <tvm/relay/pass.h>
#include "./type_visitor.h"
......
......@@ -6,7 +6,7 @@
#ifndef TVM_RELAY_PASS_TYPE_FUNCTOR_H_
#define TVM_RELAY_PASS_TYPE_FUNCTOR_H_
#include <tvm/ir_functor.h>
#include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h>
#include "./incomplete_type.h"
......
......@@ -137,7 +137,7 @@ class TypeInferencer : private ExprFunctor<CheckedExpr(const Expr&)> {
void Solve(TypeRelationData& ty_rel);
/*! \brief Attempt to solve all pending relations.
*
*
* If the solver
*/
SolverResult Solve(std::vector<TypeRelationData>& rels);
......@@ -607,8 +607,7 @@ TVM_REGISTER_API("relay._ir_pass._get_checked_type")
/* Incomplete Type */
IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) {
std::shared_ptr<IncompleteTypeNode> n =
std::make_shared<IncompleteTypeNode>();
auto n = make_node<IncompleteTypeNode>();
n->kind = std::move(kind);
return IncompleteType(n);
}
......
......@@ -21,7 +21,7 @@ using tvm::IRPrinter;
using namespace tvm::runtime;
UnionFind UnionFindNode::make(tvm::Map<IncompleteType, Type> uf_map) {
std::shared_ptr<UnionFindNode> n = std::make_shared<UnionFindNode>();
auto n = make_node<UnionFindNode>();
n->uf_map = uf_map;
return UnionFind(n);
}
......@@ -130,7 +130,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
});
TypeUnifier TypeUnifierNode::make(UnionFind union_find) {
std::shared_ptr<TypeUnifierNode> n = std::make_shared<TypeUnifierNode>();
auto n = make_node<TypeUnifierNode>();
n->union_find = union_find;
return TypeUnifier(n);
}
......
......@@ -67,7 +67,7 @@ class UnionFindNode : public Node {
class UnionFind : public NodeRef {
public:
UnionFind() {}
explicit UnionFind(std::shared_ptr<tvm::Node> p) : NodeRef(p) {}
explicit UnionFind(NodePtr<tvm::Node> p) : NodeRef(p) {}
// The union find structure is mutable so we do not use the standard macros
// and expose the pointer via `->`.
......@@ -126,7 +126,7 @@ class TypeUnifierNode : public Node,
class TypeUnifier : public NodeRef {
public:
TypeUnifier() {}
explicit TypeUnifier(std::shared_ptr<tvm::Node> p) : NodeRef(p) {}
explicit TypeUnifier(NodePtr<tvm::Node> p) : NodeRef(p) {}
// no const so that unifier can be mutable as a member of typechecker
inline TypeUnifierNode* operator->() const {
......
......@@ -46,7 +46,7 @@ Expr InjectPredicate(const Array<Expr>& predicates,
if (predicates.size() == 0) return body;
const Reduce* reduce = body.as<Reduce>();
if (reduce) {
std::shared_ptr<Reduce> n = std::make_shared<Reduce>(*reduce);
auto n = make_node<Reduce>(*reduce);
n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates, Expr());
return Expr(n);
}
......@@ -400,7 +400,7 @@ void InjectInline(ScheduleNode* sch) {
CHECK_EQ(new_body[j].size(), r->source.size());
CHECK(r != nullptr);
for (size_t k = 0; k < new_body[j].size(); ++k) {
std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r);
auto n = make_node<ir::Reduce>(*r);
n->value_index = static_cast<int>(k);
n->type = r->source[k].type();
new_body[j].Set(k, Expr(n));
......@@ -520,11 +520,11 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
const int factor_axis_pos = \
factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
CHECK_LE(factor_axis_pos, compute_op->axis.size());
auto n = std::make_shared<ComputeOpNode>();
auto n = make_node<ComputeOpNode>();
n->name = compute_op->name + ".rf";
{
// axis relacement.
auto iv_node = std::make_shared<IterVarNode>();
auto iv_node = make_node<IterVarNode>();
iv_node->dom = dom_map.at(axis);
CHECK(is_zero(iv_node->dom->min))
<< "Can only factor reduction domain starting from 0";
......@@ -565,7 +565,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
for (IterVar iv : reduce_stage->leaf_iter_vars) {
if (touch_map.count(iv) && !iv.same_as(axis)) {
CHECK_EQ(iv->iter_type, kCommReduce);
auto ncpy = std::make_shared<IterVarNode>(*iv.operator->());
auto ncpy = make_node<IterVarNode>(*iv.operator->());
ncpy->dom = dom_map.at(iv);
n->reduce_axis.push_back(IterVar(ncpy));
}
......
......@@ -70,7 +70,7 @@ void Split(StageNode* self,
} // namespace
Stage::Stage(Operation op) {
auto n = std::make_shared<StageNode>();
auto n = make_node<StageNode>();
n->op = op;
n->origin_op = op;
n->all_iter_vars = op->root_iter_vars();
......@@ -164,16 +164,16 @@ Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*)
FindLeafVar(all_vars, leaf_vars, ivar);
auto it = self->iter_var_attrs.find(ivar);
std::shared_ptr<IterVarAttrNode> n;
NodePtr<IterVarAttrNode> n;
if (it != self->iter_var_attrs.end()) {
n = std::make_shared<IterVarAttrNode>(*(*it).second.operator->());
n = make_node<IterVarAttrNode>(*(*it).second.operator->());
if (n->bind_thread.defined() &&
!n->bind_thread.same_as(thread_ivar)) {
LOG(WARNING) << "Axis " << ivar
<< " is already bind to another thread " << n->bind_thread;
}
} else {
n = std::make_shared<IterVarAttrNode>();
n = make_node<IterVarAttrNode>();
}
n->bind_thread = thread_ivar;
self->iter_var_attrs.Set(ivar, IterVarAttr(n));
......@@ -188,7 +188,7 @@ Stage& Stage::env_threads(Array<IterVar> threads) {
<< "Already set env_threads";
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
std::vector<std::shared_ptr<Node> > temp;
std::vector<NodePtr<Node> > temp;
for (IterVar iv : threads) {
temp.push_back(iv.node_);
}
......@@ -303,7 +303,7 @@ Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
for (size_t i = 0; i < order.size(); ++i) {
pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
}
std::vector<std::shared_ptr<Node> > temp;
std::vector<NodePtr<Node> > temp;
for (size_t i = 0; i < pos.size(); ++i) {
temp.emplace_back(leaf_vars->data[pos[i]]);
}
......@@ -335,11 +335,11 @@ inline void UpdateIterVarAttr(StageNode* self,
FindLeafVar(all_vars, leaf_vars, var);
}
auto it = self->iter_var_attrs.find(var);
std::shared_ptr<IterVarAttrNode> n;
NodePtr<IterVarAttrNode> n;
if (it != self->iter_var_attrs.end()) {
n = std::make_shared<IterVarAttrNode>(*(*it).second.operator->());
n = make_node<IterVarAttrNode>(*(*it).second.operator->());
} else {
n = std::make_shared<IterVarAttrNode>();
n = make_node<IterVarAttrNode>();
}
fupdate(n.get());
self->iter_var_attrs.Set(var, IterVarAttr(n));
......@@ -397,11 +397,11 @@ Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) {
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var);
auto it = self->iter_var_attrs.find(var);
std::shared_ptr<IterVarAttrNode> n;
NodePtr<IterVarAttrNode> n;
if (it != self->iter_var_attrs.end()) {
n = std::make_shared<IterVarAttrNode>(*(*it).second.operator->());
n = make_node<IterVarAttrNode>(*(*it).second.operator->());
} else {
n = std::make_shared<IterVarAttrNode>();
n = make_node<IterVarAttrNode>();
}
n->prefetch_data.push_back(tensor);
n->prefetch_offset.push_back(offset);
......@@ -468,8 +468,8 @@ Stage& Stage::opengl() {
}
Stage CopyStage(const Stage& s) {
std::shared_ptr<StageNode> n =
std::make_shared<StageNode>(*s.operator->());
NodePtr<StageNode> n =
make_node<StageNode>(*s.operator->());
return Stage(n);
}
......@@ -477,7 +477,7 @@ Schedule Schedule::copy() const {
// map of stages.
const ScheduleNode* self = operator->();
std::unordered_map<Stage, Stage, NodeHash, NodeEqual> smap;
std::shared_ptr<ScheduleNode> n = std::make_shared<ScheduleNode>();
NodePtr<ScheduleNode> n = make_node<ScheduleNode>();
n->outputs = self->outputs;
// Copy the stages.
for (Stage s : self->stages) {
......@@ -599,7 +599,7 @@ Stage Schedule::create_group(const Array<Tensor>& outputs,
}
}
// Create the new group stage.
Stage gstage(std::make_shared<StageNode>());
Stage gstage(make_node<StageNode>());
gstage->group = parent_group;
if (parent_group.defined()) {
++parent_group->num_child_stages;
......@@ -687,7 +687,7 @@ void ScheduleNode::InitCache() {
}
Schedule ScheduleNode::make(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>();
auto n = make_node<ScheduleNode>();
Schedule sch(n);
n->outputs = ops;
auto g = schedule::CreateReadGraph(n->outputs);
......@@ -731,7 +731,7 @@ IterVarRelation SplitNode::make(IterVar parent,
IterVar inner,
Expr factor,
Expr nparts) {
auto n = std::make_shared<SplitNode>();
auto n = make_node<SplitNode>();
n->parent = parent;
n->outer = outer;
n->inner = inner;
......@@ -742,7 +742,7 @@ IterVarRelation SplitNode::make(IterVar parent,
IterVarRelation FuseNode::make(
IterVar outer, IterVar inner, IterVar fused) {
auto n = std::make_shared<FuseNode>();
auto n = make_node<FuseNode>();
n->outer = outer;
n->inner = inner;
n->fused = fused;
......@@ -750,14 +750,14 @@ IterVarRelation FuseNode::make(
}
IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
auto n = std::make_shared<RebaseNode>();
auto n = make_node<RebaseNode>();
n->parent = parent;
n->rebased = rebased;
return IterVarRelation(n);
}
IterVarRelation SingletonNode::make(IterVar iter) {
auto n = std::make_shared<SingletonNode>();
auto n = make_node<SingletonNode>();
n->iter = iter;
return IterVarRelation(n);
}
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_functor.h>
#include <tvm/node/ir_functor.h>
#include <tvm/ir_functor_ext.h>
TEST(IRF, Basic) {
......
......@@ -10,6 +10,7 @@ make cython3 || exit -1
# Test extern package package
cd apps/extension
rm -rf lib
make || exit -1
cd ../..
python -m nose -v apps/extension/tests || exit -1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment