Unverified Commit 46363d0a by Tianqi Chen Committed by GitHub

[NODE] Keep base node system in HalideIR (#1793)

parent 06108bed
Subproject commit cf6090aeaeb782d1daff54b0ca5c2c281d7008db
Subproject commit 2f3ecdfdedf3efa7e45a3945dca63a25856c4674
......@@ -11,7 +11,7 @@
#include "base.h"
#include "expr.h"
#include "ir_operator.h"
#include "node/container.h"
#include "tvm/node/container.h"
namespace tvm {
......
......@@ -6,7 +6,7 @@
#ifndef TVM_IR_FUNCTOR_EXT_H_
#define TVM_IR_FUNCTOR_EXT_H_
#include "node/ir_functor.h"
#include "tvm/node/ir_functor.h"
#include "ir.h"
namespace tvm {
......
......@@ -9,7 +9,7 @@
#include <unordered_map>
#include "expr.h"
#include "ir.h"
#include "node/ir_functor.h"
#include "tvm/node/ir_functor.h"
namespace tvm {
namespace ir {
......
......@@ -7,7 +7,7 @@
#define TVM_IR_VISITOR_H_
#include "ir.h"
#include "node/ir_functor.h"
#include "tvm/node/ir_functor.h"
namespace tvm {
namespace ir {
......
......@@ -13,7 +13,7 @@
#include "base.h"
#include "expr.h"
#include "tensor.h"
#include "node/container.h"
#include "tvm/node/container.h"
namespace tvm {
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/node/ir_functor.h
* \brief Defines the IRFunctor data structures.
*/
#ifndef TVM_NODE_IR_FUNCTOR_H_
#define TVM_NODE_IR_FUNCTOR_H_
#include <dmlc/logging.h>
#include <string>
#include <vector>
#include <type_traits>
#include <functional>
#include "node.h"
#include "../runtime/registry.h"
namespace tvm {
/*!
* \brief A dynamical dispatched functor on NodeRef in the first argument.
*
* \code
* IRFunctor<std::string (const NodeRef& n, std::string prefix)> tostr;
* tostr.set_dispatch<Add>([](const Add* op, std::string prefix) {
* return prefix + "Add";
* });
* tostr.set_dispatch<IntImm>([](const IntImm* op) {
* return prefix + "IntImm"
* });
*
* Expr x = make_const(1);
* Expr y = x + x;
* // dispatch to IntImm, outputs "MyIntImm"
* LOG(INFO) << tostr(x, "My");
* // dispatch to IntImm, outputs "MyAdd"
* LOG(INFO) << tostr(y, "My");
* \endcode
*
* \tparam FType function signiture
* This type if only defined for FType with function signiture
*/
template<typename FType>
class IRFunctor;
template<typename R, typename ...Args>
class IRFunctor<R(const NodeRef& n, Args...)> {
private:
using Function = std::function<R (const NodeRef&n, Args...)>;
using TSelf = IRFunctor<R (const NodeRef& n, Args...)>;
/*! \brief internal function table */
std::vector<Function> func_;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*!
* \brief Whether the functor can dispatch the corresponding Node
* \param n The node to be dispatched
* \return Whether dispatching function is registered for n's type.
*/
inline bool can_dispatch(const NodeRef& n) const {
uint32_t type_index = n.type_index();
return type_index < func_.size() && func_[type_index] != nullptr;
}
/*!
* \brief invoke the functor , dispatch on type of n
* \param n The Node argument
* \param args The additional arguments
* \return The result.
*/
inline R operator()(const NodeRef& n, Args... args) const {
uint32_t type_index = n.type_index();
CHECK(type_index < func_.size() &&
func_[type_index] != nullptr)
<< "IRFunctor calls un-registered function on type "
<< Node::TypeIndex2Key(type_index);
return func_[type_index](n, std::forward<Args>(args)...);
}
/*!
* \brief set the dispacher for type TNode
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& set_dispatch(Function f) { // NOLINT(*)
uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
if (func_.size() <= tindex) {
func_.resize(tindex + 1, nullptr);
}
CHECK(func_[tindex] == nullptr)
<< "Dispatch for " << Node::TypeIndex2Key(tindex)
<< " is already set";
func_[tindex] = f;
return *this;
}
/*!
* \brief set the dispacher for type TNode
* This allows f to used detailed const Node pointer to replace NodeRef
*
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
Function fun = [f](const NodeRef& n, Args... args) {
return f(static_cast<const TNode*>(n.node_.get()),
std::forward<Args>(args)...);
};
return this->set_dispatch<TNode>(fun);
}
/*!
* \brief unset the dispacher for type TNode
*
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& clear_dispatch() { // NOLINT(*)
uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
func_[tindex] = nullptr;
return *this;
}
};
#define TVM_REGISTER_VAR_DEF(ClsName) \
static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName
/*!
* \brief Useful macro to set IRFunctor dispatch in a global static field.
*
* \code
* // Use IRFunctor to implement IRPrinter similar to Visitor Pattern.
* // vtable allows easy patch in of new Node types, without changing
* // interface of IRPrinter.
*
* class IRPrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
* void print(Expr e) {
* const static FType& f = *vtable();
* f(e, this);
* }
*
* using FType = IRFunctor<void (const NodeRef&, IRPrinter *)>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*0
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
* .set_dispatch<Add>([](const Add* n, IRPrinter* p) {
* p->print(n->a);
* p->stream << '+'
* p->print(n->b);
* });
*
*
* \endcode
*
* \param ClsName The name of the class
* \param FField The static function that returns a singleton of IRFunctor.
*/
#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
ClsName::FField()
/*!
* \brief A container for a list of callbacks. All callbacks are invoked when
* the object is destructed.
*/
class IRFunctorCleanList {
public:
~IRFunctorCleanList() {
for (auto &f : clean_items) {
f();
}
}
void append(std::function<void()> func) {
clean_items.push_back(func);
}
private:
std::vector< std::function<void()> > clean_items;
};
/*!
* \brief A wrapper around IRFunctor that will record calls to set_dispatch
* and make a corresponding call to clear_dispatch when the last copy of
* the IRFunctorStaticRegistry is destructed. When assigned to a static variable,
* this can be used by NNVM and other libraries to unregister callbacks when
* the library is unloaded. This prevents crashes when the underlying IRFunctor
* is destructed as it will no longer contain std::function instances allocated
* by a library that has been unloaded.
*/
template<typename FType>
class IRFunctorStaticRegistry;
template<typename R, typename ...Args>
class IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> {
private:
IRFunctor<R(const NodeRef& n, Args...)> *irf_;
std::shared_ptr<IRFunctorCleanList> free_list;
using TSelf = IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>;
public:
IRFunctorStaticRegistry(IRFunctor<R(const NodeRef& n, Args...)> *irf) {
irf_ = irf;
free_list = std::make_shared<IRFunctorCleanList>();
}
template<typename TNode>
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
irf_->template set_dispatch<TNode>(f);
auto irf_copy = irf_;
free_list.get()->append([irf_copy] {
irf_copy->template clear_dispatch<TNode>();
});
return *this;
}
};
/*!
* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows
* the compiler to deduce the template types.
*/
template<typename R, typename ...Args>
IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> MakeIRFunctorStaticRegistry(
IRFunctor<R(const NodeRef& n, Args...)> *irf) {
return IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>(irf);
}
#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \
static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName
/*!
* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry.
* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of
* TVM_STATIC_IR_FUNCTOR.
*/
#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField) \
TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
MakeIRFunctorStaticRegistry(&ClsName::FField())
} // namespace tvm
#endif // TVM_NODE_IR_FUNCTOR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/node/memory.h
* \brief Node memory management.
*/
#ifndef TVM_NODE_MEMORY_H_
#define TVM_NODE_MEMORY_H_
#include "node.h"
namespace tvm {
/*!
* \brief Allocate a node object.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
*/
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args);
// Detail implementations after this
//
// The current design allows swapping the
// allocator pattern when necessary.
//
// Possible future allocator optimizations:
// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr)
// - Thread-local object pools: one pool per size and alignment requirement.
// - Can specialize by type of object to give the specific allocator to each object.
//
template<typename T>
class SimpleNodeAllocator {
public:
template<typename... Args>
static T* New(Args&&... args) {
return new T(std::forward<Args>(args)...);
}
static NodeBase::FDeleter Deleter() {
return Deleter_;
}
private:
static void Deleter_(NodeBase* ptr) {
delete static_cast<T*>(ptr);
}
};
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args) {
using Allocator = SimpleNodeAllocator<T>;
static_assert(std::is_base_of<NodeBase, T>::value,
"make_node can only be used to create NodeBase");
T* node = Allocator::New(std::forward<Args>(args)...);
node->deleter_ = Allocator::Deleter();
return NodePtr<T>(node);
}
} // namespace tvm
#endif // TVM_NODE_MEMORY_H_
......@@ -7,6 +7,7 @@
#define TVM_TENSOR_H_
#include <ir/FunctionBase.h>
#include <tvm/node/container.h>
#include <string>
#include <vector>
#include <type_traits>
......@@ -15,7 +16,6 @@
#include "expr.h"
#include "ir_operator.h"
#include "arithmetic.h"
#include "node/container.h"
namespace tvm {
......
/*!
* Copyright (c) 2018 by Contributors
* Implementation of IR Node API
* \file node.cc
*/
#include <tvm/node/node.h>
#include <memory>
#include <atomic>
#include <mutex>
#include <unordered_map>
namespace tvm {
namespace {
// single manager of operator information.
struct TypeManager {
// mutex to avoid registration from multiple threads.
// recursive is needed for trigger(which calls UpdateAttrMap)
std::mutex mutex;
std::atomic<uint32_t> type_counter{0};
std::unordered_map<std::string, uint32_t> key2index;
std::vector<std::string> index2key;
// get singleton of the
static TypeManager* Global() {
static TypeManager inst;
return &inst;
}
};
} // namespace
const bool Node::_DerivedFrom(uint32_t tid) const {
static uint32_t tindex = TypeKey2Index(Node::_type_key);
return tid == tindex;
}
// this is slow, usually caller always hold the result in a static variable.
uint32_t Node::TypeKey2Index(const char* key) {
TypeManager *t = TypeManager::Global();
std::lock_guard<std::mutex>(t->mutex);
std::string skey = key;
auto it = t->key2index.find(skey);
if (it != t->key2index.end()) {
return it->second;
}
uint32_t tid = ++(t->type_counter);
t->key2index[skey] = tid;
t->index2key.push_back(skey);
return tid;
}
const char* Node::TypeIndex2Key(uint32_t index) {
TypeManager *t = TypeManager::Global();
std::lock_guard<std::mutex>(t->mutex);
internal_assert(index != 0);
return t->index2key.at(index - 1).c_str();
}
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment