Unverified Commit 46363d0a by Tianqi Chen Committed by GitHub

[NODE] Keep base node system in HalideIR (#1793)

parent 06108bed
Subproject commit cf6090aeaeb782d1daff54b0ca5c2c281d7008db Subproject commit 2f3ecdfdedf3efa7e45a3945dca63a25856c4674
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "base.h" #include "base.h"
#include "expr.h" #include "expr.h"
#include "ir_operator.h" #include "ir_operator.h"
#include "node/container.h" #include "tvm/node/container.h"
namespace tvm { namespace tvm {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#ifndef TVM_IR_FUNCTOR_EXT_H_ #ifndef TVM_IR_FUNCTOR_EXT_H_
#define TVM_IR_FUNCTOR_EXT_H_ #define TVM_IR_FUNCTOR_EXT_H_
#include "node/ir_functor.h" #include "tvm/node/ir_functor.h"
#include "ir.h" #include "ir.h"
namespace tvm { namespace tvm {
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include <unordered_map> #include <unordered_map>
#include "expr.h" #include "expr.h"
#include "ir.h" #include "ir.h"
#include "node/ir_functor.h" #include "tvm/node/ir_functor.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#define TVM_IR_VISITOR_H_ #define TVM_IR_VISITOR_H_
#include "ir.h" #include "ir.h"
#include "node/ir_functor.h" #include "tvm/node/ir_functor.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "base.h" #include "base.h"
#include "expr.h" #include "expr.h"
#include "tensor.h" #include "tensor.h"
#include "node/container.h" #include "tvm/node/container.h"
namespace tvm { namespace tvm {
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/node/container.h
* \brief Array/Map container in the DSL graph.
*/
#ifndef TVM_NODE_CONTAINER_H_
#define TVM_NODE_CONTAINER_H_
#include <type_traits>
#include <vector>
#include <initializer_list>
#include <unordered_map>
#include <utility>
#include <string>
#include "node.h"
#include "memory.h"
namespace tvm {
/*! \brief array node content in array */
class ArrayNode : public Node {
public:
/*! \brief the data content */
std::vector<NodePtr<Node> > data;
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to array have no effect.
}
static constexpr const char* _type_key = "Array";
TVM_DECLARE_NODE_TYPE_INFO(ArrayNode, Node);
};
/*! \brief map node content */
class MapNode : public Node {
public:
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect.
}
// hash function
struct Hash {
size_t operator()(const NodePtr<Node>& n) const {
return std::hash<Node*>()(n.get());
}
};
// comparator
struct Equal {
bool operator()(
const NodePtr<Node>& a,
const NodePtr<Node>& b) const {
return a.get() == b.get();
}
};
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<
NodePtr<Node>,
NodePtr<Node>,
Hash, Equal>;
/*! \brief the data content */
ContainerType data;
static constexpr const char* _type_key = "Map";
TVM_DECLARE_NODE_TYPE_INFO(MapNode, Node);
};
/*! \brief specialized map node with string as key */
class StrMapNode : public Node {
public:
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect.
}
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<
std::string,
NodePtr<Node> >;
/*! \brief the data content */
ContainerType data;
static constexpr const char* _type_key = "StrMap";
TVM_DECLARE_NODE_TYPE_INFO(StrMapNode, Node);
};
/*!
* \brief iterator adapter that adapts TIter to return another type.
* \tparam Converter a struct that contains converting function
* \tparam TIter the content iterator type.
*/
template<typename Converter,
typename TIter>
class IterAdapter {
public:
explicit IterAdapter(TIter iter) : iter_(iter) {}
inline IterAdapter& operator++() { // NOLINT(*)
++iter_;
return *this;
}
inline IterAdapter& operator++(int) { // NOLINT(*)
++iter_;
return *this;
}
inline IterAdapter operator+(int offset) const { // NOLINT(*)
return IterAdapter(iter_ + offset);
}
inline bool operator==(IterAdapter other) const {
return iter_ == other.iter_;
}
inline bool operator!=(IterAdapter other) const {
return !(*this == other);
}
inline const typename Converter::ResultType operator*() const {
return Converter::convert(*iter_);
}
private:
TIter iter_;
};
/*!
* \brief Array container of NodeRef in DSL graph.
* Array implements copy on write semantics, which means array is mutable
* but copy will happen when array is referenced in more than two places.
*
* operator[] only provide const acces, use Set to mutate the content.
* \tparam T The content NodeRef type.
*/
template<typename T,
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type >
class Array : public NodeRef {
public:
/*!
* \brief default constructor
*/
Array() {
node_ = make_node<ArrayNode>();
}
/*!
* \brief move constructor
* \param other source
*/
Array(Array<T> && other) { // NOLINT(*)
node_ = std::move(other.node_);
}
/*!
* \brief copy constructor
* \param other source
*/
Array(const Array<T> &other) { // NOLINT(*)
node_ = other.node_;
}
/*!
* \brief constructor from pointer
* \param n the container pointer
*/
explicit Array(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief constructor from iterator
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
Array(IterType begin, IterType end) {
assign(begin, end);
}
/*!
* \brief constructor from initializer list
* \param init The initalizer list
*/
Array(std::initializer_list<T> init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
Array(const std::vector<T>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief move assign operator
* \param other The source of assignment
* \return reference to self.
*/
Array<T>& operator=(Array<T> && other) {
node_ = std::move(other.node_);
return *this;
}
/*!
* \brief copy assign operator
* \param other The source of assignment
* \return reference to self.
*/
Array<T>& operator=(const Array<T> & other) {
node_ = other.node_;
return *this;
}
/*!
* \brief reset the array to content from iterator.
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
void assign(IterType begin, IterType end) {
auto n = make_node<ArrayNode>();
for (IterType it = begin; it != end; ++it) {
n->data.push_back((*it).node_);
}
node_ = std::move(n);
}
/*!
* \brief Read i-th element from array.
* \param i The index
* \return the i-th element.
*/
inline const T operator[](size_t i) const {
return T(static_cast<const ArrayNode*>(node_.get())->data[i]);
}
/*! \return The size of the array */
inline size_t size() const {
if (node_.get() == nullptr) return 0;
return static_cast<const ArrayNode*>(node_.get())->data.size();
}
/*!
* \brief copy on write semantics
* Do nothing if current handle is the unique copy of the array.
* Otherwise make a new copy of the array to ensure the current handle
* hold a unique copy.
*
* \return Handle to the internal node container(which ganrantees to be unique)
*/
inline ArrayNode* CopyOnWrite() {
if (node_.get() == nullptr || !node_.unique()) {
NodePtr<ArrayNode> n = make_node<ArrayNode>();
n->data = static_cast<ArrayNode*>(node_.get())->data;
NodePtr<Node>(std::move(n)).swap(node_);
}
return static_cast<ArrayNode*>(node_.get());
}
/*!
* \brief push a new item to the back of the list
* \param item The item to be pushed.
*/
inline void push_back(const T& item) {
ArrayNode* n = this->CopyOnWrite();
n->data.push_back(item.node_);
}
/*!
* \brief set i-th element of the array.
* \param i The index
* \param value The value to be setted.
*/
inline void Set(size_t i, const T& value) {
ArrayNode* n = this->CopyOnWrite();
n->data[i] = value.node_;
}
/*! \return whether array is empty */
inline bool empty() const {
return size() == 0;
}
/*! \brief specify container node */
using ContainerType = ArrayNode;
struct Ptr2NodeRef {
using ResultType = T;
static inline T convert(const NodePtr<Node>& n) {
return T(n);
}
};
using iterator = IterAdapter<Ptr2NodeRef,
std::vector<NodePtr<Node> >::const_iterator>;
using reverse_iterator = IterAdapter<
Ptr2NodeRef,
std::vector<NodePtr<Node> >::const_reverse_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
return iterator(static_cast<const ArrayNode*>(node_.get())->data.begin());
}
/*! \return end iterator */
inline iterator end() const {
return iterator(static_cast<const ArrayNode*>(node_.get())->data.end());
}
/*! \return rbegin iterator */
inline reverse_iterator rbegin() const {
return reverse_iterator(static_cast<const ArrayNode*>(node_.get())->data.rbegin());
}
/*! \return rend iterator */
inline reverse_iterator rend() const {
return reverse_iterator(static_cast<const ArrayNode*>(node_.get())->data.rend());
}
};
/*!
* \brief Map container of NodeRef->NodeRef in DSL graph.
* Map implements copy on write semantics, which means map is mutable
* but copy will happen when array is referenced in more than two places.
*
* operator[] only provide const acces, use Set to mutate the content.
* \tparam K The key NodeRef type.
* \tparam V The value NodeRef type.
*/
template<typename K,
typename V,
typename = typename std::enable_if<
std::is_base_of<NodeRef, K>::value ||
std::is_base_of<std::string, K>::value >::type,
typename = typename std::enable_if<std::is_base_of<NodeRef, V>::value>::type>
class Map : public NodeRef {
public:
/*!
* \brief default constructor
*/
Map() {
node_ = make_node<MapNode>();
}
/*!
* \brief move constructor
* \param other source
*/
Map(Map<K, V> && other) { // NOLINT(*)
node_ = std::move(other.node_);
}
/*!
* \brief copy constructor
* \param other source
*/
Map(const Map<K, V> &other) { // NOLINT(*)
node_ = other.node_;
}
/*!
* \brief constructor from pointer
* \param n the container pointer
*/
explicit Map(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief constructor from iterator
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
}
/*!
* \brief constructor from initializer list
* \param init The initalizer list
*/
Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
template<typename Hash, typename Equal>
Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief move assign operator
* \param other The source of assignment
* \return reference to self.
*/
Map<K, V>& operator=(Map<K, V> && other) {
node_ = std::move(other.node_);
return *this;
}
/*!
* \brief copy assign operator
* \param other The source of assignment
* \return reference to self.
*/
Map<K, V>& operator=(const Map<K, V> & other) {
node_ = other.node_;
return *this;
}
/*!
* \brief reset the array to content from iterator.
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
void assign(IterType begin, IterType end) {
NodePtr<MapNode> n = make_node<MapNode>();
for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first.node_,
i->second.node_));
}
node_ = std::move(n);
}
/*!
* \brief Read element from map.
* \param key The key
* \return the corresonding element.
*/
inline const V operator[](const K& key) const {
return V(static_cast<const MapNode*>(node_.get())->data.at(key.node_));
}
/*!
* \brief Read element from map.
* \param key The key
* \return the corresonding element.
*/
inline const V at(const K& key) const {
return V(static_cast<const MapNode*>(node_.get())->data.at(key.node_));
}
/*! \return The size of the array */
inline size_t size() const {
if (node_.get() == nullptr) return 0;
return static_cast<const MapNode*>(node_.get())->data.size();
}
/*! \return The size of the array */
inline size_t count(const K& key) const {
if (node_.get() == nullptr) return 0;
return static_cast<const MapNode*>(node_.get())->data.count(key.node_);
}
/*!
* \brief copy on write semantics
* Do nothing if current handle is the unique copy of the array.
* Otherwise make a new copy of the array to ensure the current handle
* hold a unique copy.
*
* \return Handle to the internal node container(which ganrantees to be unique)
*/
inline MapNode* CopyOnWrite() {
if (node_.get() == nullptr || !node_.unique()) {
NodePtr<MapNode> n = make_node<MapNode>();
n->data = static_cast<const MapNode*>(node_.get())->data;
NodePtr<Node>(std::move(n)).swap(node_);
}
return static_cast<MapNode*>(node_.get());
}
/*!
* \brief set the Map.
* \param key The index key.
* \param value The value to be setted.
*/
inline void Set(const K& key, const V& value) {
MapNode* n = this->CopyOnWrite();
n->data[key.node_] = value.node_;
}
/*! \return whether array is empty */
inline bool empty() const {
return size() == 0;
}
/*! \brief specify container node */
using ContainerType = MapNode;
struct Ptr2NodeRef {
using ResultType = std::pair<K, V>;
static inline ResultType convert(const std::pair<
NodePtr<Node>,
NodePtr<Node> >& n) {
return std::make_pair(K(n.first), V(n.second));
}
};
using iterator = IterAdapter<
Ptr2NodeRef, MapNode::ContainerType::const_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
return iterator(static_cast<const MapNode*>(node_.get())->data.begin());
}
/*! \return end iterator */
inline iterator end() const {
return iterator(static_cast<const MapNode*>(node_.get())->data.end());
}
/*! \return begin iterator */
inline iterator find(const K& key) const {
return iterator(static_cast<const MapNode*>(node_.get())->data.find(key.node_));
}
};
// specialize of string map
template<typename V, typename T1, typename T2>
class Map<std::string, V, T1, T2> : public NodeRef {
public:
// for code reuse
Map() {
node_ = make_node<StrMapNode>();
}
Map(Map<std::string, V> && other) { // NOLINT(*)
node_ = std::move(other.node_);
}
Map(const Map<std::string, V> &other) { // NOLINT(*)
node_ = other.node_;
}
explicit Map(NodePtr<Node> n) : NodeRef(n) {}
template<typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
}
Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
assign(init.begin(), init.end());
}
template<typename Hash, typename Equal>
Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
Map<std::string, V>& operator=(Map<std::string, V> && other) {
node_ = std::move(other.node_);
return *this;
}
Map<std::string, V>& operator=(const Map<std::string, V> & other) {
node_ = other.node_;
return *this;
}
template<typename IterType>
void assign(IterType begin, IterType end) {
auto n = make_node<StrMapNode>();
for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first,
i->second.node_));
}
node_ = std::move(n);
}
inline const V operator[](const std::string& key) const {
return V(static_cast<const StrMapNode*>(node_.get())->data.at(key));
}
inline const V at(const std::string& key) const {
return V(static_cast<const StrMapNode*>(node_.get())->data.at(key));
}
inline size_t size() const {
if (node_.get() == nullptr) return 0;
return static_cast<const StrMapNode*>(node_.get())->data.size();
}
inline size_t count(const std::string& key) const {
if (node_.get() == nullptr) return 0;
return static_cast<const StrMapNode*>(node_.get())->data.count(key);
}
inline StrMapNode* CopyOnWrite() {
if (node_.get() == nullptr || !node_.unique()) {
NodePtr<StrMapNode> n = make_node<StrMapNode>();
n->data = static_cast<const StrMapNode*>(node_.get())->data;
NodePtr<Node>(std::move(n)).swap(node_);
}
return static_cast<StrMapNode*>(node_.get());
}
inline void Set(const std::string& key, const V& value) {
StrMapNode* n = this->CopyOnWrite();
n->data[key] = value.node_;
}
inline bool empty() const {
return size() == 0;
}
using ContainerType = StrMapNode;
struct Ptr2NodeRef {
using ResultType = std::pair<std::string, V>;
static inline ResultType convert(const std::pair<
std::string,
NodePtr<Node> >& n) {
return std::make_pair(n.first, V(n.second));
}
};
using iterator = IterAdapter<
Ptr2NodeRef, StrMapNode::ContainerType::const_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
return iterator(static_cast<const StrMapNode*>(node_.get())->data.begin());
}
/*! \return end iterator */
inline iterator end() const {
return iterator(static_cast<const StrMapNode*>(node_.get())->data.end());
}
/*! \return begin iterator */
inline iterator find(const std::string& key) const {
return iterator(static_cast<const StrMapNode*>(node_.get())->data.find(key));
}
};
} // namespace tvm
#endif // TVM_NODE_CONTAINER_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/node/ir_functor.h
* \brief Defines the IRFunctor data structures.
*/
#ifndef TVM_NODE_IR_FUNCTOR_H_
#define TVM_NODE_IR_FUNCTOR_H_
#include <dmlc/logging.h>
#include <string>
#include <vector>
#include <type_traits>
#include <functional>
#include "node.h"
#include "../runtime/registry.h"
namespace tvm {
/*!
* \brief A dynamical dispatched functor on NodeRef in the first argument.
*
* \code
* IRFunctor<std::string (const NodeRef& n, std::string prefix)> tostr;
* tostr.set_dispatch<Add>([](const Add* op, std::string prefix) {
* return prefix + "Add";
* });
* tostr.set_dispatch<IntImm>([](const IntImm* op) {
* return prefix + "IntImm"
* });
*
* Expr x = make_const(1);
* Expr y = x + x;
* // dispatch to IntImm, outputs "MyIntImm"
* LOG(INFO) << tostr(x, "My");
* // dispatch to IntImm, outputs "MyAdd"
* LOG(INFO) << tostr(y, "My");
* \endcode
*
* \tparam FType function signiture
* This type if only defined for FType with function signiture
*/
template<typename FType>
class IRFunctor;
template<typename R, typename ...Args>
class IRFunctor<R(const NodeRef& n, Args...)> {
private:
using Function = std::function<R (const NodeRef&n, Args...)>;
using TSelf = IRFunctor<R (const NodeRef& n, Args...)>;
/*! \brief internal function table */
std::vector<Function> func_;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*!
* \brief Whether the functor can dispatch the corresponding Node
* \param n The node to be dispatched
* \return Whether dispatching function is registered for n's type.
*/
inline bool can_dispatch(const NodeRef& n) const {
uint32_t type_index = n.type_index();
return type_index < func_.size() && func_[type_index] != nullptr;
}
/*!
* \brief invoke the functor , dispatch on type of n
* \param n The Node argument
* \param args The additional arguments
* \return The result.
*/
inline R operator()(const NodeRef& n, Args... args) const {
uint32_t type_index = n.type_index();
CHECK(type_index < func_.size() &&
func_[type_index] != nullptr)
<< "IRFunctor calls un-registered function on type "
<< Node::TypeIndex2Key(type_index);
return func_[type_index](n, std::forward<Args>(args)...);
}
/*!
* \brief set the dispacher for type TNode
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& set_dispatch(Function f) { // NOLINT(*)
uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
if (func_.size() <= tindex) {
func_.resize(tindex + 1, nullptr);
}
CHECK(func_[tindex] == nullptr)
<< "Dispatch for " << Node::TypeIndex2Key(tindex)
<< " is already set";
func_[tindex] = f;
return *this;
}
/*!
* \brief set the dispacher for type TNode
* This allows f to used detailed const Node pointer to replace NodeRef
*
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
Function fun = [f](const NodeRef& n, Args... args) {
return f(static_cast<const TNode*>(n.node_.get()),
std::forward<Args>(args)...);
};
return this->set_dispatch<TNode>(fun);
}
/*!
* \brief unset the dispacher for type TNode
*
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& clear_dispatch() { // NOLINT(*)
uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
func_[tindex] = nullptr;
return *this;
}
};
#define TVM_REGISTER_VAR_DEF(ClsName) \
static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName
/*!
* \brief Useful macro to set IRFunctor dispatch in a global static field.
*
* \code
* // Use IRFunctor to implement IRPrinter similar to Visitor Pattern.
* // vtable allows easy patch in of new Node types, without changing
* // interface of IRPrinter.
*
* class IRPrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
* void print(Expr e) {
* const static FType& f = *vtable();
* f(e, this);
* }
*
* using FType = IRFunctor<void (const NodeRef&, IRPrinter *)>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*0
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
* .set_dispatch<Add>([](const Add* n, IRPrinter* p) {
* p->print(n->a);
* p->stream << '+'
* p->print(n->b);
* });
*
*
* \endcode
*
* \param ClsName The name of the class
* \param FField The static function that returns a singleton of IRFunctor.
*/
#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
ClsName::FField()
/*!
* \brief A container for a list of callbacks. All callbacks are invoked when
* the object is destructed.
*/
class IRFunctorCleanList {
public:
~IRFunctorCleanList() {
for (auto &f : clean_items) {
f();
}
}
void append(std::function<void()> func) {
clean_items.push_back(func);
}
private:
std::vector< std::function<void()> > clean_items;
};
/*!
* \brief A wrapper around IRFunctor that will record calls to set_dispatch
* and make a corresponding call to clear_dispatch when the last copy of
* the IRFunctorStaticRegistry is destructed. When assigned to a static variable,
* this can be used by NNVM and other libraries to unregister callbacks when
* the library is unloaded. This prevents crashes when the underlying IRFunctor
* is destructed as it will no longer contain std::function instances allocated
* by a library that has been unloaded.
*/
template<typename FType>
class IRFunctorStaticRegistry;
template<typename R, typename ...Args>
class IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> {
private:
IRFunctor<R(const NodeRef& n, Args...)> *irf_;
std::shared_ptr<IRFunctorCleanList> free_list;
using TSelf = IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>;
public:
IRFunctorStaticRegistry(IRFunctor<R(const NodeRef& n, Args...)> *irf) {
irf_ = irf;
free_list = std::make_shared<IRFunctorCleanList>();
}
template<typename TNode>
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
irf_->template set_dispatch<TNode>(f);
auto irf_copy = irf_;
free_list.get()->append([irf_copy] {
irf_copy->template clear_dispatch<TNode>();
});
return *this;
}
};
/*!
* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows
* the compiler to deduce the template types.
*/
template<typename R, typename ...Args>
IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> MakeIRFunctorStaticRegistry(
IRFunctor<R(const NodeRef& n, Args...)> *irf) {
return IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>(irf);
}
#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \
static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName
/*!
* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry.
* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of
* TVM_STATIC_IR_FUNCTOR.
*/
#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField) \
TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
MakeIRFunctorStaticRegistry(&ClsName::FField())
} // namespace tvm
#endif // TVM_NODE_IR_FUNCTOR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/node/memory.h
* \brief Node memory management.
*/
#ifndef TVM_NODE_MEMORY_H_
#define TVM_NODE_MEMORY_H_
#include "node.h"
namespace tvm {
/*!
* \brief Allocate a node object.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
*/
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args);
// Detail implementations after this
//
// The current design allows swapping the
// allocator pattern when necessary.
//
// Possible future allocator optimizations:
// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr)
// - Thread-local object pools: one pool per size and alignment requirement.
// - Can specialize by type of object to give the specific allocator to each object.
//
template<typename T>
class SimpleNodeAllocator {
public:
template<typename... Args>
static T* New(Args&&... args) {
return new T(std::forward<Args>(args)...);
}
static NodeBase::FDeleter Deleter() {
return Deleter_;
}
private:
static void Deleter_(NodeBase* ptr) {
delete static_cast<T*>(ptr);
}
};
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args) {
using Allocator = SimpleNodeAllocator<T>;
static_assert(std::is_base_of<NodeBase, T>::value,
"make_node can only be used to create NodeBase");
T* node = Allocator::New(std::forward<Args>(args)...);
node->deleter_ = Allocator::Deleter();
return NodePtr<T>(node);
}
} // namespace tvm
#endif // TVM_NODE_MEMORY_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/node/node.h
* \brief Node system data structure.
*/
#ifndef TVM_NODE_NODE_H_
#define TVM_NODE_NODE_H_
#include <string>
#include <vector>
#include <type_traits>
#include "base/Type.h"
#include "../runtime/node_base.h"
#include "../runtime/c_runtime_api.h"
namespace tvm {
using HalideIR::Type;
// forward declaration
class Node;
class NodeRef;
namespace runtime {
// forward declaration
class NDArray;
} // namespace runtime
/*!
* \brief Visitor class to each node content.
* The content is going to be called for each field.
*/
class TVM_DLL AttrVisitor {
public:
//! \cond Doxygen_Suppress
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0;
virtual void Visit(const char* key, int* value) = 0;
virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, void** value) = 0;
virtual void Visit(const char* key, Type* value) = 0;
virtual void Visit(const char* key, NodeRef* value) = 0;
virtual void Visit(const char* key, runtime::NDArray* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr));
}
//! \endcond
};
/*!
* \brief base class of node container in DSL AST.
* All object's internal is stored as std::shared_ptr<Node>
*/
class TVM_DLL Node : public NodeBase {
public:
/*! \brief virtual destructor */
virtual ~Node() {}
/*! \return The unique type key of the node */
virtual const char* type_key() const = 0;
/*!
* \brief Apply visitor to each field of the Node
* Visitor could mutate the content of the node.
* override if Node contains attribute fields.
* \param visitor The visitor
*/
virtual void VisitAttrs(AttrVisitor* visitor) {}
/*! \return the type index of the node */
virtual const uint32_t type_index() const = 0;
/*!
* \brief Whether this node derives from node with type_index=tid.
* Implemented by TVM_DECLARE_NODE_TYPE_INFO
*
* \param tid The type index.
* \return the check result.
*/
virtual const bool _DerivedFrom(uint32_t tid) const;
/*!
* \brief get a runtime unique type index given a type key
* \param type_key Type key of a type.
* \return the corresponding type index.
*/
static uint32_t TypeKey2Index(const char* type_key);
/*!
* \brief get type key from type index.
* \param index The type index
* \return the corresponding type key.
*/
static const char* TypeIndex2Key(uint32_t index);
/*!
* \return whether the type is derived from
*/
template<typename T>
inline bool derived_from() const;
/*!
* \return whether the node is of type T
* \tparam The type to be checked.
*/
template<typename T>
inline bool is_type() const;
/*!
* \brief Get a NodePtr that holds reference to this Node.
* \return the NodePtr
*/
inline NodePtr<Node> GetNodePtr() const;
// node ref can see this
friend class NodeRef;
static constexpr const char* _type_key = "Node";
};
/*! \brief Base class of all node reference object */
class NodeRef {
public:
/*! \brief type indicate the container type */
using ContainerType = Node;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator==(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool same_as(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator<(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator!=(const NodeRef& other) const;
/*! \return the hash function for NodeRef */
inline size_t hash() const;
/*! \return whether the expression is null */
inline bool defined() const;
/*! \return the internal type index of IRNode */
inline uint32_t type_index() const;
/*! \return the internal node pointer */
inline const Node* get() const;
/*! \return the internal node pointer */
inline const Node* operator->() const;
/*!
* \brief Downcast this ir node to its actual type (e.g. Add, or
* Select). This returns nullptr if the node is not of the requested
* type. Example usage:
*
* if (const Add *add = node->as<Add>()) {
* // This is an add node
* }
* \tparam T the target type, must be subtype of IRNode
*/
template<typename T>
inline const T *as() const;
/*!
* \brief A more powerful version of as that also works with
* intermediate base types.
* \tparam T the target type, must be subtype of IRNode
*/
template<typename T>
inline const T *as_derived() const;
/*! \brief default constructor */
NodeRef() = default;
explicit NodeRef(NodePtr<Node> node) : node_(node) {}
/*! \brief the internal node object, do not touch */
NodePtr<Node> node_;
};
/*!
* \brief Get a reference type from a Node ptr type
*
* It is always important to get a reference type
* if we want to return a value as reference or keep
* the node alive beyond the scope of the function.
*
* \param ptr The node pointer
* \tparam RefType The reference type
* \tparam NodeType The node type
* \return The corresponding RefType
*/
template <typename RefType, typename NodeType>
inline RefType GetRef(const NodeType* ptr);
/*!
* \brief Downcast a base reference type to a more specific type.
*
* \param ref The inptut reference
* \return The corresponding SubRef.
* \tparam SubRef The target specific reference type.
* \tparam BaseRef the current reference type.
*/
template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref);
/*!
* \brief helper macro to declare type information in a base node.
*/
#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \
const bool _DerivedFrom(uint32_t tid) const override { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
}
/*!
* \brief helper macro to declare type information in a terminal node
*/
#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \
const char* type_key() const final { \
return TypeName::_type_key; \
} \
const uint32_t type_index() const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
return tidx; \
} \
const bool _DerivedFrom(uint32_t tid) const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
}
// implementations of inline functions after this
template<typename T>
inline bool Node::is_type() const {
// use static field so query only happens once.
static uint32_t type_id = Node::TypeKey2Index(T::_type_key);
return type_id == this->type_index();
}
template<typename T>
inline bool Node::derived_from() const {
// use static field so query only happens once.
static uint32_t type_id = Node::TypeKey2Index(T::_type_key);
return this->_DerivedFrom(type_id);
}
inline NodePtr<Node> Node::GetNodePtr() const {
return NodePtr<Node>(const_cast<Node*>(this));
}
template <typename RefType, typename NodeType>
inline RefType GetRef(const NodeType* ptr) {
static_assert(std::is_base_of<typename RefType::ContainerType, NodeType>::value,
"Can only cast to the ref of same container type");
return RefType(ptr->GetNodePtr());
}
template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) {
CHECK(ref->template is_type<typename SubRef::ContainerType>() ||
ref->template derived_from<typename SubRef::ContainerType>())
<< "Downcast from " << ref->type_key() << " to "
<< SubRef::ContainerType::_type_key << " failed.";
return SubRef(std::move(ref.node_));
}
inline const Node* NodeRef::get() const {
return node_.get();
}
inline const Node* NodeRef::operator->() const {
return node_.get();
}
inline bool NodeRef::defined() const {
return node_.get() != nullptr;
}
inline bool NodeRef::operator==(const NodeRef& other) const {
return node_.get() == other.node_.get();
}
inline bool NodeRef::same_as(const NodeRef& other) const {
return node_.get() == other.node_.get();
}
inline bool NodeRef::operator<(const NodeRef& other) const {
return node_.get() < other.node_.get();
}
inline bool NodeRef::operator!=(const NodeRef& other) const {
return node_.get() != other.node_.get();
}
inline size_t NodeRef::hash() const {
return std::hash<Node*>()(node_.get());
}
inline uint32_t NodeRef::type_index() const {
CHECK(node_.get() != nullptr)
<< "null type";
return get()->type_index();
}
template<typename T>
inline const T* NodeRef::as() const {
const Node* ptr = static_cast<const Node*>(get());
if (ptr && ptr->is_type<T>()) {
return static_cast<const T*>(ptr);
}
return nullptr;
}
template<typename T>
inline const T* NodeRef::as_derived() const {
const Node* ptr = static_cast<const Node*>(get());
if (ptr && (ptr->is_type<T>() || ptr->derived_from<T>())) {
return static_cast<const T*>(ptr);
}
return nullptr;
}
/*! \brief The hash function for nodes */
struct NodeHash {
size_t operator()(const NodeRef& a) const {
return a.hash();
}
};
/*! \brief The equal comparator for nodes */
struct NodeEqual {
bool operator()(const NodeRef& a, const NodeRef& b) const {
return a.get() == b.get();
}
};
} // namespace tvm
#endif // TVM_NODE_NODE_H_
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define TVM_TENSOR_H_ #define TVM_TENSOR_H_
#include <ir/FunctionBase.h> #include <ir/FunctionBase.h>
#include <tvm/node/container.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <type_traits> #include <type_traits>
...@@ -15,7 +16,6 @@ ...@@ -15,7 +16,6 @@
#include "expr.h" #include "expr.h"
#include "ir_operator.h" #include "ir_operator.h"
#include "arithmetic.h" #include "arithmetic.h"
#include "node/container.h"
namespace tvm { namespace tvm {
......
/*!
* Copyright (c) 2018 by Contributors
* Implementation of IR Node API
* \file node.cc
*/
#include <tvm/node/node.h>
#include <memory>
#include <atomic>
#include <mutex>
#include <unordered_map>
namespace tvm {
namespace {
// single manager of operator information.
struct TypeManager {
// mutex to avoid registration from multiple threads.
// recursive is needed for trigger(which calls UpdateAttrMap)
std::mutex mutex;
std::atomic<uint32_t> type_counter{0};
std::unordered_map<std::string, uint32_t> key2index;
std::vector<std::string> index2key;
// get singleton of the
static TypeManager* Global() {
static TypeManager inst;
return &inst;
}
};
} // namespace
const bool Node::_DerivedFrom(uint32_t tid) const {
static uint32_t tindex = TypeKey2Index(Node::_type_key);
return tid == tindex;
}
// this is slow, usually caller always hold the result in a static variable.
uint32_t Node::TypeKey2Index(const char* key) {
TypeManager *t = TypeManager::Global();
std::lock_guard<std::mutex>(t->mutex);
std::string skey = key;
auto it = t->key2index.find(skey);
if (it != t->key2index.end()) {
return it->second;
}
uint32_t tid = ++(t->type_counter);
t->key2index[skey] = tid;
t->index2key.push_back(skey);
return tid;
}
const char* Node::TypeIndex2Key(uint32_t index) {
TypeManager *t = TypeManager::Global();
std::lock_guard<std::mutex>(t->mutex);
internal_assert(index != 0);
return t->index2key.at(index - 1).c_str();
}
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment