Unverified Commit 9a3d2ec9 by Tianqi Chen Committed by GitHub

[NODE][REFACTOR] Rename IRFunctor->NodeFunctor, use func pointer (#4247)

* [NODE][REFACTOR] Rename IRFunctor->NodeFunctor, use function pointer for dispatching.

Previously we used std::function for the functor dispatching.
It introduces additional overhead and problems during dll destruction(of std::function).

This PR changes the std::function to function pointers.
This change a bit restrictions around the set_dispatch that we can get around,
but will improve the general efficiency by reducing one level of indirection in the std::function.
We also no longer need special marcos to register functions to the Functor.
parent 2083513f
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
#include "dtype.h" #include "dtype.h"
#include "node/node.h" #include "node/node.h"
#include "node/container.h" #include "node/container.h"
#include "node/ir_functor.h" #include "node/functor.h"
#include "runtime/c_runtime_api.h" #include "runtime/c_runtime_api.h"
namespace tvm { namespace tvm {
...@@ -487,7 +487,7 @@ class IRPrinter { ...@@ -487,7 +487,7 @@ class IRPrinter {
/*! \brief Print indent to the stream */ /*! \brief Print indent to the stream */
TVM_DLL void PrintIndent(); TVM_DLL void PrintIndent();
// Allow registration to be printer. // Allow registration to be printer.
using FType = IRFunctor<void(const ObjectRef&, IRPrinter *)>; using FType = NodeFunctor<void(const ObjectRef&, IRPrinter *)>;
TVM_DLL static FType& vtable(); TVM_DLL static FType& vtable();
}; };
......
...@@ -24,8 +24,9 @@ ...@@ -24,8 +24,9 @@
#ifndef TVM_IR_FUNCTOR_EXT_H_ #ifndef TVM_IR_FUNCTOR_EXT_H_
#define TVM_IR_FUNCTOR_EXT_H_ #define TVM_IR_FUNCTOR_EXT_H_
#include "tvm/node/ir_functor.h" #include <tvm/node/functor.h>
#include "ir.h" #include <tvm/ir.h>
#include <utility> #include <utility>
namespace tvm { namespace tvm {
...@@ -104,7 +105,7 @@ template<typename R, typename ...Args> ...@@ -104,7 +105,7 @@ template<typename R, typename ...Args>
class ExprFunctor<R(const Expr& n, Args...)> { class ExprFunctor<R(const Expr& n, Args...)> {
private: private:
using TSelf = ExprFunctor<R(const Expr& n, Args...)>; using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
using FType = IRFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
public: public:
/*! \brief the result type of this functor */ /*! \brief the result type of this functor */
...@@ -213,7 +214,7 @@ template<typename R, typename ...Args> ...@@ -213,7 +214,7 @@ template<typename R, typename ...Args>
class StmtFunctor<R(const Stmt& n, Args... args)> { class StmtFunctor<R(const Stmt& n, Args... args)> {
private: private:
using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>; using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
using FType = IRFunctor<R(const ObjectRef& n, TSelf* self, Args... args)>; using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args... args)>;
public: public:
/*! \brief the result type of this functor */ /*! \brief the result type of this functor */
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include <utility> #include <utility>
#include "expr.h" #include "expr.h"
#include "ir.h" #include "ir.h"
#include "tvm/node/ir_functor.h" #include "tvm/node/functor.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -36,13 +36,13 @@ namespace ir { ...@@ -36,13 +36,13 @@ namespace ir {
* \brief a base class for mutator to iterative mutate the IR * \brief a base class for mutator to iterative mutate the IR
* *
* This IRMutator is implemented via Visitor Pattern. * This IRMutator is implemented via Visitor Pattern.
* Also you can implement via IRFunctor. * Also you can implement via NodeFunctor.
* This enables easy extensions of possible new Node. * This enables easy extensions of possible new Node.
* It also makes changing return types easier. * It also makes changing return types easier.
* *
* \note If you want to return a different type other than Expr and Stmt, * \note If you want to return a different type other than Expr and Stmt,
* Simply following the same pattern as IRMutator and create a seperate class. * Simply following the same pattern as IRMutator and create a seperate class.
* \sa IRFunctor * \sa NodeFunctor
*/ */
class TVM_DLL IRMutator { class TVM_DLL IRMutator {
public: public:
...@@ -65,9 +65,9 @@ class TVM_DLL IRMutator { ...@@ -65,9 +65,9 @@ class TVM_DLL IRMutator {
/*! \brief destructor */ /*! \brief destructor */
virtual ~IRMutator() {} virtual ~IRMutator() {}
/*! \brief functor type of expr mutation */ /*! \brief functor type of expr mutation */
using FMutateExpr = IRFunctor<Expr(const ObjectRef&, const Expr&, IRMutator*)>; using FMutateExpr = NodeFunctor<Expr(const ObjectRef&, const Expr&, IRMutator*)>;
/*! \brief functor type of stmt mutation */ /*! \brief functor type of stmt mutation */
using FMutateStmt = IRFunctor<Stmt(const ObjectRef&, const Stmt&, IRMutator*)>; using FMutateStmt = NodeFunctor<Stmt(const ObjectRef&, const Stmt&, IRMutator*)>;
/*! \return internal vtable of expr */ /*! \return internal vtable of expr */
static FMutateExpr& vtable_expr(); // NOLINT(*) static FMutateExpr& vtable_expr(); // NOLINT(*)
/*! \return internal stmt of expr */ /*! \return internal stmt of expr */
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_IR_VISITOR_H_ #define TVM_IR_VISITOR_H_
#include "ir.h" #include "ir.h"
#include "tvm/node/ir_functor.h" #include "tvm/node/functor.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -33,7 +33,7 @@ namespace ir { ...@@ -33,7 +33,7 @@ namespace ir {
/*! /*!
* \brief a base class for visitor to iterative traverse the IR * \brief a base class for visitor to iterative traverse the IR
* *
* This IRVisitor is implemented via IRFunctor * This IRVisitor is implemented via NodeFunctor
* This enables extensions of possible new Node. * This enables extensions of possible new Node.
* *
* \sa ExprFunctor, StmtFunctor, PostOrderVisit * \sa ExprFunctor, StmtFunctor, PostOrderVisit
...@@ -94,7 +94,7 @@ class TVM_DLL IRVisitor { ...@@ -94,7 +94,7 @@ class TVM_DLL IRVisitor {
/*! \brief destructor */ /*! \brief destructor */
virtual ~IRVisitor() {} virtual ~IRVisitor() {}
/*! \brief functor type of visitor */ /*! \brief functor type of visitor */
using FVisit = IRFunctor<void(const ObjectRef&, IRVisitor*)>; using FVisit = NodeFunctor<void(const ObjectRef&, IRVisitor*)>;
/*! \return internal vtable*/ /*! \return internal vtable*/
static FVisit& vtable(); static FVisit& vtable();
// overloadable visit function. // overloadable visit function.
......
...@@ -17,31 +17,33 @@ ...@@ -17,31 +17,33 @@
* under the License. * under the License.
*/ */
/*! /*!
* \file tvm/node/ir_functor.h * \file tvm/node/functor.h
* \brief Defines the IRFunctor data structures. * \brief Defines the Functor data structures.
*/ */
#ifndef TVM_NODE_IR_FUNCTOR_H_ #ifndef TVM_NODE_FUNCTOR_H_
#define TVM_NODE_IR_FUNCTOR_H_ #define TVM_NODE_FUNCTOR_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <string> #include <tvm/runtime/registry.h>
#include <tvm/node/node.h>
#include <vector> #include <vector>
#include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <functional>
#include "node.h"
namespace tvm { namespace tvm {
/*! /*!
* \brief A dynamically dispatched functor on ObjectRef in the first argument. * \brief A dynamically dispatched functor on the type of the first argument.
*
* This is a class that is useful to construct polymorphic dispatching
* base on the AST/IR node's type.
* *
* \code * \code
* IRFunctor<std::string (const ObjectRef& n, std::string prefix)> tostr; * NodeFunctor<std::string (const ObjectRef& n, std::string prefix)> tostr;
* tostr.set_dispatch<Add>([](const Add* op, std::string prefix) { * tostr.set_dispatch<Add>([](const ObjectRef& op, std::string prefix) {
* return prefix + "Add"; * return prefix + "Add";
* }); * });
* tostr.set_dispatch<IntImm>([](const IntImm* op) { * tostr.set_dispatch<IntImm>([](const ObjectRef& op, std::string prefix) {
* return prefix + "IntImm" * return prefix + "IntImm"
* }); * });
* *
...@@ -57,15 +59,17 @@ namespace tvm { ...@@ -57,15 +59,17 @@ namespace tvm {
* This type if only defined for FType with function signature * This type if only defined for FType with function signature
*/ */
template<typename FType> template<typename FType>
class IRFunctor; class NodeFunctor;
template<typename R, typename ...Args> template<typename R, typename ...Args>
class IRFunctor<R(const ObjectRef& n, Args...)> { class NodeFunctor<R(const ObjectRef& n, Args...)> {
private: private:
using Function = std::function<R (const ObjectRef&n, Args...)>; /*! \brief internal function pointer type */
using TSelf = IRFunctor<R (const ObjectRef& n, Args...)>; typedef R (*FPointer)(const ObjectRef&n, Args...);
/*! \brief refer to itself. */
using TSelf = NodeFunctor<R (const ObjectRef& n, Args...)>;
/*! \brief internal function table */ /*! \brief internal function table */
std::vector<Function> func_; std::vector<FPointer> func_;
public: public:
/*! \brief the result type of this functor */ /*! \brief the result type of this functor */
...@@ -75,23 +79,21 @@ class IRFunctor<R(const ObjectRef& n, Args...)> { ...@@ -75,23 +79,21 @@ class IRFunctor<R(const ObjectRef& n, Args...)> {
* \param n The node to be dispatched * \param n The node to be dispatched
* \return Whether dispatching function is registered for n's type. * \return Whether dispatching function is registered for n's type.
*/ */
inline bool can_dispatch(const ObjectRef& n) const { bool can_dispatch(const ObjectRef& n) const {
uint32_t type_index = n->type_index(); uint32_t type_index = n->type_index();
return type_index < func_.size() && func_[type_index] != nullptr; return type_index < func_.size() && func_[type_index] != nullptr;
} }
/*! /*!
* \brief invoke the functor , dispatch on type of n * \brief invoke the functor, dispatch on type of n
* \param n The Node argument * \param n The Node argument
* \param args The additional arguments * \param args The additional arguments
* \return The result. * \return The result.
*/ */
inline R operator()(const ObjectRef& n, Args... args) const { R operator()(const ObjectRef& n, Args... args) const {
uint32_t type_index = n->type_index(); CHECK(can_dispatch(n))
CHECK(type_index < func_.size() && << "NodeFunctor calls un-registered function on type "
func_[type_index] != nullptr)
<< "IRFunctor calls un-registered function on type "
<< n->GetTypeKey(); << n->GetTypeKey();
return func_[type_index](n, std::forward<Args>(args)...); return (*func_[n->type_index()])(n, std::forward<Args>(args)...);
} }
/*! /*!
* \brief set the dispacher for type TNode * \brief set the dispacher for type TNode
...@@ -100,7 +102,7 @@ class IRFunctor<R(const ObjectRef& n, Args...)> { ...@@ -100,7 +102,7 @@ class IRFunctor<R(const ObjectRef& n, Args...)> {
* \return reference to self. * \return reference to self.
*/ */
template<typename TNode> template<typename TNode>
inline TSelf& set_dispatch(Function f) { // NOLINT(*) TSelf& set_dispatch(FPointer f) { // NOLINT(*)
uint32_t tindex = TNode::RuntimeTypeIndex(); uint32_t tindex = TNode::RuntimeTypeIndex();
if (func_.size() <= tindex) { if (func_.size() <= tindex) {
func_.resize(tindex + 1, nullptr); func_.resize(tindex + 1, nullptr);
...@@ -112,55 +114,31 @@ class IRFunctor<R(const ObjectRef& n, Args...)> { ...@@ -112,55 +114,31 @@ class IRFunctor<R(const ObjectRef& n, Args...)> {
return *this; return *this;
} }
/*! /*!
* \brief set the dispacher for type TNode
* This allows f to used detailed const Node pointer to replace ObjectRef
*
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
Function fun = [f](const ObjectRef& n, Args... args) {
return f(static_cast<const TNode*>(n.get()),
std::forward<Args>(args)...);
};
return this->set_dispatch<TNode>(fun);
}
/*!
* \brief unset the dispacher for type TNode * \brief unset the dispacher for type TNode
* *
* \tparam TNode the type of Node to be dispatched. * \tparam TNode the type of Node to be dispatched.
* \return reference to self. * \return reference to self.
*/ */
template<typename TNode> template<typename TNode>
inline TSelf& clear_dispatch() { // NOLINT(*) TSelf& clear_dispatch() { // NOLINT(*)
uint32_t tindex = TNode::RuntimeTypeIndex(); uint32_t tindex = TNode::RuntimeTypeIndex();
CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; CHECK_LT(tindex, func_.size())
<< "clear_dispatch: index out of range";
func_[tindex] = nullptr; func_[tindex] = nullptr;
return *this; return *this;
} }
}; };
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif
/*! \brief helper macro to generate string concat */ #define TVM_REG_FUNC_VAR_DEF(ClsName) \
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_REGISTER_VAR_DEF(ClsName) \
static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName
/*! /*!
* \brief Useful macro to set IRFunctor dispatch in a global static field. * \brief Useful macro to set NodeFunctor dispatch in a global static field.
* *
* \code * \code
* // Use IRFunctor to implement IRPrinter similar to Visitor Pattern. * // Use NodeFunctor to implement IRPrinter similar to Visitor Pattern.
* // vtable allows easy patch in of new Node types, without changing * // vtable allows easy patch of new Node types, without changing
* // interface of IRPrinter. * // interface of IRPrinter.
* *
* class IRPrinter { * class IRPrinter {
...@@ -172,7 +150,7 @@ class IRFunctor<R(const ObjectRef& n, Args...)> { ...@@ -172,7 +150,7 @@ class IRFunctor<R(const ObjectRef& n, Args...)> {
* f(e, this); * f(e, this);
* } * }
* *
* using FType = IRFunctor<void (const ObjectRef&, IRPrinter *)>; * using FType = NodeFunctor<void (const ObjectRef&, IRPrinter *)>;
* // function to return global function table * // function to return global function table
* static FType& vtable(); * static FType& vtable();
* }; * };
...@@ -183,7 +161,8 @@ class IRFunctor<R(const ObjectRef& n, Args...)> { ...@@ -183,7 +161,8 @@ class IRFunctor<R(const ObjectRef& n, Args...)> {
* } * }
* *
* TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) * TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
* .set_dispatch<Add>([](const Add* n, IRPrinter* p) { * .set_dispatch<Add>([](const ObjectRef& ref, IRPrinter* p) {
* auto* n = static_cast<const Add*>(ref.get());
* p->print(n->a); * p->print(n->a);
* p->stream << '+' * p->stream << '+'
* p->print(n->b); * p->print(n->b);
...@@ -193,90 +172,10 @@ class IRFunctor<R(const ObjectRef& n, Args...)> { ...@@ -193,90 +172,10 @@ class IRFunctor<R(const ObjectRef& n, Args...)> {
* \endcode * \endcode
* *
* \param ClsName The name of the class * \param ClsName The name of the class
* \param FField The static function that returns a singleton of IRFunctor. * \param FField The static function that returns a singleton of NodeFunctor.
*/ */
#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \ #define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \ TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = \
ClsName::FField() ClsName::FField()
/*!
* \brief A container for a list of callbacks. All callbacks are invoked when
* the object is destructed.
*/
class IRFunctorCleanList {
public:
~IRFunctorCleanList() {
for (auto &f : clean_items) {
f();
}
}
void append(std::function<void()> func) {
clean_items.push_back(func);
}
private:
std::vector< std::function<void()> > clean_items;
};
/*!
* \brief A wrapper around IRFunctor that will record calls to set_dispatch
* and make a corresponding call to clear_dispatch when the last copy of
* the IRFunctorStaticRegistry is destructed. When assigned to a static variable,
* this can be used by NNVM and other libraries to unregister callbacks when
* the library is unloaded. This prevents crashes when the underlying IRFunctor
* is destructed as it will no longer contain std::function instances allocated
* by a library that has been unloaded.
*/
template<typename FType>
class IRFunctorStaticRegistry;
template<typename R, typename ...Args>
class IRFunctorStaticRegistry<R(const ObjectRef& n, Args...)> {
private:
IRFunctor<R(const ObjectRef& n, Args...)> *irf_;
std::shared_ptr<IRFunctorCleanList> free_list;
using TSelf = IRFunctorStaticRegistry<R(const ObjectRef& n, Args...)>;
public:
IRFunctorStaticRegistry(IRFunctor<R(const ObjectRef& n, Args...)> *irf) {
irf_ = irf;
free_list = std::make_shared<IRFunctorCleanList>();
}
template<typename TNode>
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
irf_->template set_dispatch<TNode>(f);
auto irf_copy = irf_;
free_list.get()->append([irf_copy] {
irf_copy->template clear_dispatch<TNode>();
});
return *this;
}
};
/*!
* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows
* the compiler to deduce the template types.
*/
template<typename R, typename ...Args>
IRFunctorStaticRegistry<R(const ObjectRef& n, Args...)> MakeIRFunctorStaticRegistry(
IRFunctor<R(const ObjectRef& n, Args...)> *irf) {
return IRFunctorStaticRegistry<R(const ObjectRef& n, Args...)>(irf);
}
#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \
static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName
/*!
* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry.
* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of
* TVM_STATIC_IR_FUNCTOR.
*/
#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField) \
TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
MakeIRFunctorStaticRegistry(&ClsName::FField())
} // namespace tvm } // namespace tvm
#endif // TVM_NODE_IR_FUNCTOR_H_ #endif // TVM_NODE_FUNCTOR_H_
...@@ -48,20 +48,20 @@ using runtime::ObjectRef; ...@@ -48,20 +48,20 @@ using runtime::ObjectRef;
* Each objects that wants reflection will need to implement * Each objects that wants reflection will need to implement
* a VisitAttrs function and call visitor->Visit on each of its field. * a VisitAttrs function and call visitor->Visit on each of its field.
*/ */
class TVM_DLL AttrVisitor { class AttrVisitor {
public: public:
//! \cond Doxygen_Suppress //! \cond Doxygen_Suppress
virtual ~AttrVisitor() = default; TVM_DLL virtual ~AttrVisitor() = default;
virtual void Visit(const char* key, double* value) = 0; TVM_DLL virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0; TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0; TVM_DLL virtual void Visit(const char* key, uint64_t* value) = 0;
virtual void Visit(const char* key, int* value) = 0; TVM_DLL virtual void Visit(const char* key, int* value) = 0;
virtual void Visit(const char* key, bool* value) = 0; TVM_DLL virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0; TVM_DLL virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, void** value) = 0; TVM_DLL virtual void Visit(const char* key, void** value) = 0;
virtual void Visit(const char* key, DataType* value) = 0; TVM_DLL virtual void Visit(const char* key, DataType* value) = 0;
virtual void Visit(const char* key, runtime::NDArray* value) = 0; TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0;
virtual void Visit(const char* key, runtime::ObjectRef* value) = 0; TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
template<typename ENum, template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type> typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) { void Visit(const char* key, ENum* ptr) {
...@@ -93,13 +93,13 @@ class ReflectionVTable { ...@@ -93,13 +93,13 @@ class ReflectionVTable {
* If this is not empty then FGlobalKey must be defined for the object. * If this is not empty then FGlobalKey must be defined for the object.
* \return The created function. * \return The created function.
*/ */
using FCreate = std::function<ObjectPtr<Object>(const std::string& global_key)>; typedef ObjectPtr<Object> (*FCreate)(const std::string& global_key);
/*! /*!
* \brief Global key function, only needed by global objects. * \brief Global key function, only needed by global objects.
* \param node The node pointer. * \param node The node pointer.
* \return node The global key to the node. * \return node The global key to the node.
*/ */
using FGlobalKey = std::function<std::string(const Object* self)>; typedef std::string (*FGlobalKey)(const Object* self);
/*! /*!
* \brief Dispatch the VisitAttrs function. * \brief Dispatch the VisitAttrs function.
* \param self The pointer to the object. * \param self The pointer to the object.
...@@ -193,7 +193,7 @@ class ReflectionVTable::Registry { ...@@ -193,7 +193,7 @@ class ReflectionVTable::Registry {
static DMLC_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry & \ static DMLC_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry & \
__make_Node ## _ ## TypeName ## __ = \ __make_Node ## _ ## TypeName ## __ = \
::tvm::ReflectionVTable::Global()->Register<TypeName>() \ ::tvm::ReflectionVTable::Global()->Register<TypeName>() \
.set_creator([](const std::string&) { \ .set_creator([](const std::string&) -> ObjectPtr<Object> { \
return ::tvm::runtime::make_object<TypeName>(); \ return ::tvm::runtime::make_object<TypeName>(); \
}) })
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#ifndef TVM_RELAY_EXPR_FUNCTOR_H_ #ifndef TVM_RELAY_EXPR_FUNCTOR_H_
#define TVM_RELAY_EXPR_FUNCTOR_H_ #define TVM_RELAY_EXPR_FUNCTOR_H_
#include <tvm/node/ir_functor.h> #include <tvm/node/functor.h>
#include <string> #include <string>
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
...@@ -66,7 +66,7 @@ template <typename R, typename... Args> ...@@ -66,7 +66,7 @@ template <typename R, typename... Args>
class ExprFunctor<R(const Expr& n, Args...)> { class ExprFunctor<R(const Expr& n, Args...)> {
private: private:
using TSelf = ExprFunctor<R(const Expr& n, Args...)>; using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
using FType = tvm::IRFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
public: public:
/*! \brief the result type of this functor */ /*! \brief the result type of this functor */
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#ifndef TVM_RELAY_PATTERN_FUNCTOR_H_ #ifndef TVM_RELAY_PATTERN_FUNCTOR_H_
#define TVM_RELAY_PATTERN_FUNCTOR_H_ #define TVM_RELAY_PATTERN_FUNCTOR_H_
#include <tvm/node/ir_functor.h> #include <tvm/node/functor.h>
#include <string> #include <string>
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
...@@ -66,7 +66,7 @@ template <typename R, typename... Args> ...@@ -66,7 +66,7 @@ template <typename R, typename... Args>
class PatternFunctor<R(const Pattern& n, Args...)> { class PatternFunctor<R(const Pattern& n, Args...)> {
private: private:
using TSelf = PatternFunctor<R(const Pattern& n, Args...)>; using TSelf = PatternFunctor<R(const Pattern& n, Args...)>;
using FType = tvm::IRFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
public: public:
/*! \brief the result type of this functor */ /*! \brief the result type of this functor */
......
...@@ -391,8 +391,9 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs") ...@@ -391,8 +391,9 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
TVM_REGISTER_NODE_TYPE(GraphFuncNode); TVM_REGISTER_NODE_TYPE(GraphFuncNode);
TVM_REGISTER_NODE_TYPE(GraphCacheEntryNode); TVM_REGISTER_NODE_TYPE(GraphCacheEntryNode);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GraphFuncNode>([](const GraphFuncNode *op, IRPrinter *p) { .set_dispatch<GraphFuncNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* op = static_cast<const GraphFuncNode*>(ref.get());
p->stream << "GraphFunc(name=" << op->func_name p->stream << "GraphFunc(name=" << op->func_name
<< ", addr=" << op << ")"; << ", addr=" << op << ")";
}); });
......
...@@ -101,8 +101,9 @@ GraphKey GraphKeyNode::make(Graph graph, ...@@ -101,8 +101,9 @@ GraphKey GraphKeyNode::make(Graph graph,
return GraphKey(n); return GraphKey(n);
} }
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GraphKeyNode>([](const GraphKeyNode *op, IRPrinter *p) { .set_dispatch<GraphKeyNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* op = static_cast<const GraphKeyNode*>(ref.get());
p->stream << "GraphKeyNode("<< op << ")"; p->stream << "GraphKeyNode("<< op << ")";
}); });
......
...@@ -30,6 +30,8 @@ ...@@ -30,6 +30,8 @@
namespace nnvm { namespace nnvm {
namespace compiler { namespace compiler {
using tvm::Object;
using tvm::ObjectPtr;
using tvm::runtime::TVMArgs; using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue; using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc; using tvm::runtime::PackedFunc;
......
...@@ -53,7 +53,8 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) { ...@@ -53,7 +53,8 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) {
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ConstIntBoundNode* op, IRPrinter* p) { .set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const ConstIntBoundNode*>(node.get());
p->stream << "ConstIntBound["; p->stream << "ConstIntBound[";
PrintBoundValue(p->stream, op->min_value); PrintBoundValue(p->stream, op->min_value);
p->stream << ','; p->stream << ',';
......
...@@ -810,7 +810,8 @@ IntSet EvalSet(Range r, ...@@ -810,7 +810,8 @@ IntSet EvalSet(Range r,
TVM_REGISTER_NODE_TYPE(IntervalSetNode); TVM_REGISTER_NODE_TYPE(IntervalSetNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntervalSetNode>([](const IntervalSetNode *op, IRPrinter *p) { .set_dispatch<IntervalSetNode>([](const ObjectRef& node, IRPrinter *p) {
auto* op = static_cast<const IntervalSetNode*>(node.get());
p->stream << "IntervalSet" p->stream << "IntervalSet"
<< "[" << op->min_value << ", " << "[" << op->min_value << ", "
<< op->max_value << ']'; << op->max_value << ']';
......
...@@ -45,7 +45,8 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) { ...@@ -45,7 +45,8 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) {
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ModularSetNode>([](const ModularSetNode *op, IRPrinter *p) { .set_dispatch<ModularSetNode>([](const ObjectRef& node, IRPrinter *p) {
auto* op = static_cast<const ModularSetNode*>(node.get());
p->stream << "ModularSet(" p->stream << "ModularSet("
<< "coeff=" << op->coeff << ", base=" << "coeff=" << op->coeff << ", base="
<< op->base << ')'; << op->base << ')';
......
...@@ -37,7 +37,8 @@ TVM_REGISTER_NODE_TYPE(TargetNode); ...@@ -37,7 +37,8 @@ TVM_REGISTER_NODE_TYPE(TargetNode);
TVM_REGISTER_NODE_TYPE(GenericFuncNode); TVM_REGISTER_NODE_TYPE(GenericFuncNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TargetNode>([](const TargetNode *op, IRPrinter *p) { .set_dispatch<TargetNode>([](const ObjectRef& node, IRPrinter *p) {
auto* op = static_cast<const TargetNode*>(node.get());
p->stream << op->str(); p->stream << op->str();
}); });
...@@ -654,7 +655,8 @@ tvm::BuildConfig BuildConfig::Current() { ...@@ -654,7 +655,8 @@ tvm::BuildConfig BuildConfig::Current() {
TVM_REGISTER_NODE_TYPE(BuildConfigNode); TVM_REGISTER_NODE_TYPE(BuildConfigNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BuildConfigNode>([](const BuildConfigNode *op, IRPrinter *p) { .set_dispatch<BuildConfigNode>([](const ObjectRef& node, IRPrinter *p) {
auto* op = static_cast<const BuildConfigNode*>(node.get());
p->stream << "build_config("; p->stream << "build_config(";
p->stream << "data_alignment=" << op->data_alignment << ", "; p->stream << "data_alignment=" << op->data_alignment << ", ";
p->stream << "offset_factor=" << op->offset_factor << ", "; p->stream << "offset_factor=" << op->offset_factor << ", ";
......
...@@ -26,11 +26,12 @@ ...@@ -26,11 +26,12 @@
namespace tvm { namespace tvm {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<EnvFuncNode>([](const EnvFuncNode *op, IRPrinter *p) { .set_dispatch<EnvFuncNode>([](const ObjectRef& node, IRPrinter *p) {
auto* op = static_cast<const EnvFuncNode*>(node.get());
p->stream << "EnvFunc(" << op->name << ")"; p->stream << "EnvFunc(" << op->name << ")";
}); });
NodePtr<EnvFuncNode> CreateEnvNode(const std::string& name) { ObjectPtr<Object> CreateEnvNode(const std::string& name) {
auto* f = runtime::Registry::Get(name); auto* f = runtime::Registry::Get(name);
CHECK(f != nullptr) << "Cannot find global function \'" << name << '\''; CHECK(f != nullptr) << "Cannot find global function \'" << name << '\'';
NodePtr<EnvFuncNode> n = make_node<EnvFuncNode>(); NodePtr<EnvFuncNode> n = make_node<EnvFuncNode>();
...@@ -62,7 +63,7 @@ TVM_REGISTER_API("_EnvFuncGetPackedFunc") ...@@ -62,7 +63,7 @@ TVM_REGISTER_API("_EnvFuncGetPackedFunc")
TVM_REGISTER_NODE_TYPE(EnvFuncNode) TVM_REGISTER_NODE_TYPE(EnvFuncNode)
.set_creator(CreateEnvNode) .set_creator(CreateEnvNode)
.set_global_key([](const Object* n) { .set_global_key([](const Object* n) -> std::string {
return static_cast<const EnvFuncNode*>(n)->name; return static_cast<const EnvFuncNode*>(n)->name;
}); });
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
* \file attr_functor.h * \file attr_functor.h
* \brief A way to define arbitrary function signature * \brief A way to define arbitrary function signature
* with dispatch on common attributes. * with dispatch on common attributes.
...@@ -31,6 +30,7 @@ ...@@ -31,6 +30,7 @@
#ifndef TVM_LANG_ATTR_FUNCTOR_H_ #ifndef TVM_LANG_ATTR_FUNCTOR_H_
#define TVM_LANG_ATTR_FUNCTOR_H_ #define TVM_LANG_ATTR_FUNCTOR_H_
#include <tvm/node/functor.h>
#include <utility> #include <utility>
namespace tvm { namespace tvm {
...@@ -54,7 +54,7 @@ template <typename R, typename... Args> ...@@ -54,7 +54,7 @@ template <typename R, typename... Args>
class AttrFunctor<R(const ObjectRef& n, Args...)> { class AttrFunctor<R(const ObjectRef& n, Args...)> {
private: private:
using TSelf = AttrFunctor<R(const ObjectRef& n, Args...)>; using TSelf = AttrFunctor<R(const ObjectRef& n, Args...)>;
using FType = tvm::IRFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
public: public:
/*! \brief the result type of this functor */ /*! \brief the result type of this functor */
......
...@@ -61,7 +61,8 @@ Attrs DictAttrsNode::make(Map<std::string, NodeRef> dict) { ...@@ -61,7 +61,8 @@ Attrs DictAttrsNode::make(Map<std::string, NodeRef> dict) {
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const DictAttrsNode *op, IRPrinter *p) { .set_dispatch<DictAttrsNode>([](const ObjectRef& node, IRPrinter *p) {
auto* op = static_cast<const DictAttrsNode*>(node.get());
p->stream << op->dict; p->stream << op->dict;
}); });
......
...@@ -452,7 +452,8 @@ Buffer BufferNode::make(Var data, ...@@ -452,7 +452,8 @@ Buffer BufferNode::make(Var data,
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BufferNode>([](const BufferNode *op, IRPrinter *p) { .set_dispatch<BufferNode>([](const ObjectRef& node, IRPrinter *p) {
auto* op = static_cast<const BufferNode*>(node.get());
p->stream << "buffer(" << op->name << ", " << op << ")"; p->stream << "buffer(" << op->name << ", " << op << ")";
}); });
......
...@@ -33,7 +33,8 @@ Channel ChannelNode::make(Var handle_var, Type dtype) { ...@@ -33,7 +33,8 @@ Channel ChannelNode::make(Var handle_var, Type dtype) {
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ChannelNode>([](const ChannelNode *op, IRPrinter *p) { .set_dispatch<ChannelNode>([](const ObjectRef& node, IRPrinter *p) {
auto* op = static_cast<const ChannelNode*>(node.get());
p->stream << "channel(" << op->handle_var << ", " << op->dtype << ")"; p->stream << "channel(" << op->handle_var << ", " << op->dtype << ")";
}); });
......
...@@ -196,7 +196,8 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const { ...@@ -196,7 +196,8 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<LayoutNode>([](const LayoutNode* l, IRPrinter* p) { .set_dispatch<LayoutNode>([](const ObjectRef& node, IRPrinter* p) {
auto* l = static_cast<const LayoutNode*>(node.get());
p->stream << "Layout(" << l->name << ")"; p->stream << "Layout(" << l->name << ")";
}); });
...@@ -352,7 +353,8 @@ BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout, ...@@ -352,7 +353,8 @@ BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout,
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BijectiveLayoutNode>([](const BijectiveLayoutNode* b, IRPrinter* p) { .set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, IRPrinter* p) {
auto* b = static_cast<const BijectiveLayoutNode*>(node.get());
p->stream << "BijectiveLayout(" << b->src_layout.name() p->stream << "BijectiveLayout(" << b->src_layout.name()
<< "->" << b->dst_layout.name() << ")"; << "->" << b->dst_layout.name() << ")";
}); });
......
...@@ -182,7 +182,8 @@ IRPrinter::FType& IRPrinter::vtable() { ...@@ -182,7 +182,8 @@ IRPrinter::FType& IRPrinter::vtable() {
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntImm>([](const IntImm *op, IRPrinter *p) { .set_dispatch<IntImm>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const IntImm*>(node.get());
if (op->type == Int(32)) { if (op->type == Int(32)) {
p->stream << op->value; p->stream << op->value;
} else { } else {
...@@ -191,7 +192,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -191,7 +192,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IterVarNode>([](const IterVarNode *op, IRPrinter *p) { .set_dispatch<IterVarNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const IterVarNode*>(node.get());
p->stream << "iter_var("; p->stream << "iter_var(";
if (op->var->name_hint.length() != 0) { if (op->var->name_hint.length() != 0) {
p->stream << op->var->name_hint << ", "; p->stream << op->var->name_hint << ", ";
...@@ -206,7 +208,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -206,7 +208,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RangeNode>([](const RangeNode* op, IRPrinter* p) { .set_dispatch<RangeNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
}); });
......
...@@ -553,12 +553,14 @@ Stmt Evaluate::make(Expr value) { ...@@ -553,12 +553,14 @@ Stmt Evaluate::make(Expr value) {
// Printers // Printers
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<UIntImm>([](const UIntImm* op, IRPrinter* p) { .set_dispatch<UIntImm>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const UIntImm*>(node.get());
p->stream << "(" << op->type << ")" << op->value; p->stream << "(" << op->type << ")" << op->value;
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FloatImm>([](const FloatImm* op, IRPrinter* p) { .set_dispatch<FloatImm>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const FloatImm*>(node.get());
auto& stream = p->stream; auto& stream = p->stream;
switch (op->type.bits()) { switch (op->type.bits()) {
case 64: case 64:
...@@ -576,7 +578,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -576,7 +578,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<StringImm>([](const StringImm* op, IRPrinter* p) { .set_dispatch<StringImm>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const StringImm*>(node.get());
auto& stream = p->stream; auto& stream = p->stream;
stream << '"'; stream << '"';
for (size_t i = 0; i < op->value.size(); ++i) { for (size_t i = 0; i < op->value.size(); ++i) {
...@@ -611,101 +614,116 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -611,101 +614,116 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Cast>([](const Cast* op, IRPrinter* p) { .set_dispatch<Cast>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Cast*>(node.get());
p->stream << op->type << '('; p->stream << op->type << '(';
p->Print(op->value); p->Print(op->value);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<Variable>([](const Variable* op, IRPrinter* p) { .set_dispatch<Variable>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Variable*>(node.get());
// omit the type // omit the type
// stream << op->name << "." << op->type; // stream << op->name << "." << op->type;
p->stream << op->name_hint; p->stream << op->name_hint;
}) })
.set_dispatch<Add>([](const Add* op, IRPrinter* p) { .set_dispatch<Add>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Add*>(node.get());
p->stream << '('; p->stream << '(';
p->Print(op->a); p->Print(op->a);
p->stream << " + "; p->stream << " + ";
p->Print(op->b); p->Print(op->b);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<Sub>([](const Sub* op, IRPrinter* p) { .set_dispatch<Sub>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Sub*>(node.get());
p->stream << '('; p->stream << '(';
p->Print(op->a); p->Print(op->a);
p->stream << " - "; p->stream << " - ";
p->Print(op->b); p->Print(op->b);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<Mul>([](const Mul* op, IRPrinter* p) { .set_dispatch<Mul>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Mul*>(node.get());
p->stream << '('; p->stream << '(';
p->Print(op->a); p->Print(op->a);
p->stream << "*"; p->stream << "*";
p->Print(op->b); p->Print(op->b);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<Div>([](const Div* op, IRPrinter* p) { .set_dispatch<Div>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Div*>(node.get());
p->stream << '('; p->stream << '(';
p->Print(op->a); p->Print(op->a);
p->stream << "/"; p->stream << "/";
p->Print(op->b); p->Print(op->b);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<Mod>([](const Mod* op, IRPrinter* p) { .set_dispatch<Mod>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Mod*>(node.get());
p->stream << '('; p->stream << '(';
p->Print(op->a); p->Print(op->a);
p->stream << " % "; p->stream << " % ";
p->Print(op->b); p->Print(op->b);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<Min>([](const Min* op, IRPrinter* p) { .set_dispatch<Min>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Min*>(node.get());
p->stream << "min("; p->stream << "min(";
p->Print(op->a); p->Print(op->a);
p->stream << ", "; p->stream << ", ";
p->Print(op->b); p->Print(op->b);
p->stream << ")"; p->stream << ")";
}) })
.set_dispatch<Max>([](const Max* op, IRPrinter* p) { .set_dispatch<Max>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Max*>(node.get());
p->stream << "max("; p->stream << "max(";
p->Print(op->a); p->Print(op->a);
p->stream << ", "; p->stream << ", ";
p->Print(op->b); p->Print(op->b);
p->stream << ")"; p->stream << ")";
}) })
.set_dispatch<EQ>([](const EQ* op, IRPrinter* p) { .set_dispatch<EQ>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const EQ*>(node.get());
p->stream << '('; p->stream << '(';
p->Print(op->a); p->Print(op->a);
p->stream << " == "; p->stream << " == ";
p->Print(op->b); p->Print(op->b);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<NE>([](const NE* op, IRPrinter* p) { .set_dispatch<NE>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const NE*>(node.get());
p->stream << '('; p->stream << '(';
p->Print(op->a); p->Print(op->a);
p->stream << " != "; p->stream << " != ";
p->Print(op->b); p->Print(op->b);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<LT>([](const LT* op, IRPrinter* p) { .set_dispatch<LT>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const LT*>(node.get());
p->stream << '('; p->stream << '(';
p->Print(op->a); p->Print(op->a);
p->stream << " < "; p->stream << " < ";
p->Print(op->b); p->Print(op->b);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<LE>([](const LE* op, IRPrinter* p) { .set_dispatch<LE>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const LE*>(node.get());
p->stream << '('; p->stream << '(';
p->Print(op->a); p->Print(op->a);
p->stream << " <= "; p->stream << " <= ";
p->Print(op->b); p->Print(op->b);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<GT>([](const GT* op, IRPrinter* p) { .set_dispatch<GT>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const GT*>(node.get());
p->stream << '('; p->stream << '(';
p->Print(op->a); p->Print(op->a);
p->stream << " > "; p->stream << " > ";
p->Print(op->b); p->Print(op->b);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<GE>([](const GE* op, IRPrinter* p) { .set_dispatch<GE>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const GE*>(node.get());
p->stream << '('; p->stream << '(';
p->Print(op->a); p->Print(op->a);
p->stream << " >= "; p->stream << " >= ";
...@@ -714,17 +732,20 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -714,17 +732,20 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FloorDiv>([](const FloorDiv* op, IRPrinter *p) { .set_dispatch<FloorDiv>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const FloorDiv*>(node.get());
p->stream << "floordiv(" << op->a << ", " << op->b << ")"; p->stream << "floordiv(" << op->a << ", " << op->b << ")";
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FloorMod>([](const FloorMod* op, IRPrinter *p) { .set_dispatch<FloorMod>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const FloorMod*>(node.get());
p->stream << "floormod(" << op->a << ", " << op->b << ")"; p->stream << "floormod(" << op->a << ", " << op->b << ")";
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<And>([](const And* op, IRPrinter* p) { .set_dispatch<And>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const And*>(node.get());
p->stream << '('; p->stream << '(';
p->Print(op->a); p->Print(op->a);
p->stream << " && "; p->stream << " && ";
...@@ -733,7 +754,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -733,7 +754,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Or>([](const Or* op, IRPrinter* p) { .set_dispatch<Or>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Or*>(node.get());
p->stream << '('; p->stream << '(';
p->Print(op->a); p->Print(op->a);
p->stream << " || "; p->stream << " || ";
...@@ -742,13 +764,15 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -742,13 +764,15 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Not>([](const Not* op, IRPrinter* p) { .set_dispatch<Not>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Not*>(node.get());
p->stream << '!'; p->stream << '!';
p->Print(op->a); p->Print(op->a);
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Select>([](const Select* op, IRPrinter* p) { .set_dispatch<Select>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Select*>(node.get());
p->stream << "select("; p->stream << "select(";
p->Print(op->condition); p->Print(op->condition);
p->stream << ", "; p->stream << ", ";
...@@ -759,7 +783,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -759,7 +783,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Load>([](const Load* op, IRPrinter* p) { .set_dispatch<Load>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Load*>(node.get());
p->stream << op->buffer_var << "["; p->stream << op->buffer_var << "[";
p->Print(op->index); p->Print(op->index);
p->stream << "]"; p->stream << "]";
...@@ -770,7 +795,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -770,7 +795,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Ramp>([](const Ramp* op, IRPrinter* p) { .set_dispatch<Ramp>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Ramp*>(node.get());
p->stream << "ramp("; p->stream << "ramp(";
p->Print(op->base); p->Print(op->base);
p->stream << ", "; p->stream << ", ";
...@@ -779,14 +805,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -779,14 +805,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Broadcast>([](const Broadcast* op, IRPrinter* p) { .set_dispatch<Broadcast>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Broadcast*>(node.get());
p->stream << "x" << op->lanes << "("; p->stream << "x" << op->lanes << "(";
p->Print(op->value); p->Print(op->value);
p->stream << ")"; p->stream << ")";
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Call>([](const Call* op, IRPrinter* p) { .set_dispatch<Call>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Call*>(node.get());
p->stream << op->name << "("; p->stream << op->name << "(";
for (size_t i = 0; i < op->args.size(); ++i) { for (size_t i = 0; i < op->args.size(); ++i) {
p->Print(op->args[i]); p->Print(op->args[i]);
...@@ -798,7 +826,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -798,7 +826,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Let>([](const Let* op, IRPrinter* p) { .set_dispatch<Let>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Let*>(node.get());
p->stream << "(let " << op->var << " = "; p->stream << "(let " << op->var << " = ";
p->Print(op->value); p->Print(op->value);
p->stream << " in "; p->stream << " in ";
...@@ -807,7 +836,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -807,7 +836,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<LetStmt>([](const LetStmt* op, IRPrinter* p) { .set_dispatch<LetStmt>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const LetStmt*>(node.get());
p->PrintIndent(); p->PrintIndent();
p->stream << "let " << op->var << " = "; p->stream << "let " << op->var << " = ";
p->Print(op->value); p->Print(op->value);
...@@ -816,7 +846,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -816,7 +846,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt* op, IRPrinter* p) { .set_dispatch<AttrStmt>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const AttrStmt*>(node.get());
p->PrintIndent(); p->PrintIndent();
p->stream << "// attr ["; p->stream << "// attr [";
p->Print(op->node); p->Print(op->node);
...@@ -828,7 +859,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -828,7 +859,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AssertStmt>([](const AssertStmt* op, IRPrinter* p) { .set_dispatch<AssertStmt>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const AssertStmt*>(node.get());
p->PrintIndent(); p->PrintIndent();
p->stream << "assert("; p->stream << "assert(";
p->Print(op->condition); p->Print(op->condition);
...@@ -839,7 +871,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -839,7 +871,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ProducerConsumer>([](const ProducerConsumer* op, IRPrinter* p) { .set_dispatch<ProducerConsumer>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const ProducerConsumer*>(node.get());
if (op->is_producer) { if (op->is_producer) {
p->PrintIndent(); p->PrintIndent();
p->stream << "produce " << op->func->func_name() << " {\n"; p->stream << "produce " << op->func->func_name() << " {\n";
...@@ -872,7 +905,8 @@ std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*) ...@@ -872,7 +905,8 @@ std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*)
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<For>([](const For* op, IRPrinter* p) { .set_dispatch<For>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const For*>(node.get());
p->PrintIndent(); p->PrintIndent();
p->stream << op->for_type << " (" << op->loop_var << ", "; p->stream << op->for_type << " (" << op->loop_var << ", ";
p->Print(op->min); p->Print(op->min);
...@@ -889,7 +923,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -889,7 +923,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Store>([](const Store* op, IRPrinter* p) { .set_dispatch<Store>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Store*>(node.get());
p->PrintIndent(); p->PrintIndent();
p->stream << op->buffer_var << "["; p->stream << op->buffer_var << "[";
p->Print(op->index); p->Print(op->index);
...@@ -903,7 +938,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -903,7 +938,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Provide>([](const Provide* op, IRPrinter* p) { .set_dispatch<Provide>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Provide*>(node.get());
p->PrintIndent(); p->PrintIndent();
p->stream << op->func->func_name() << "("; p->stream << op->func->func_name() << "(";
for (size_t i = 0; i < op->args.size(); ++i) { for (size_t i = 0; i < op->args.size(); ++i) {
...@@ -920,7 +956,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -920,7 +956,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Allocate>([](const Allocate* op, IRPrinter* p) { .set_dispatch<Allocate>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Allocate*>(node.get());
p->PrintIndent(); p->PrintIndent();
p->stream << "allocate " << op->buffer_var << "[" << op->type; p->stream << "allocate " << op->buffer_var << "[" << op->type;
for (size_t i = 0; i < op->extents.size(); ++i) { for (size_t i = 0; i < op->extents.size(); ++i) {
...@@ -937,14 +974,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -937,14 +974,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Free>([](const Free* op, IRPrinter* p) { .set_dispatch<Free>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Free*>(node.get());
p->PrintIndent(); p->PrintIndent();
p->stream << "free " << op->buffer_var; p->stream << "free " << op->buffer_var;
p->stream << '\n'; p->stream << '\n';
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Realize>([](const Realize* op, IRPrinter* p) { .set_dispatch<Realize>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Realize*>(node.get());
p->PrintIndent(); p->PrintIndent();
p->stream << "realize " << op->func->func_name() << "("; p->stream << "realize " << op->func->func_name() << "(";
for (size_t i = 0; i < op->bounds.size(); ++i) { for (size_t i = 0; i < op->bounds.size(); ++i) {
...@@ -974,7 +1013,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -974,7 +1013,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Prefetch>([](const Prefetch* op, IRPrinter* p) { .set_dispatch<Prefetch>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Prefetch*>(node.get());
p->PrintIndent(); p->PrintIndent();
p->stream << "prefetch " << op->func->func_name() << "("; p->stream << "prefetch " << op->func->func_name() << "(";
for (size_t i = 0; i < op->bounds.size(); ++i) { for (size_t i = 0; i < op->bounds.size(); ++i) {
...@@ -992,13 +1032,15 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -992,13 +1032,15 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Block>([](const Block* op, IRPrinter* p) { .set_dispatch<Block>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Block*>(node.get());
p->Print(op->first); p->Print(op->first);
if (op->rest.defined()) p->Print(op->rest); if (op->rest.defined()) p->Print(op->rest);
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IfThenElse>([](const IfThenElse* op, IRPrinter* p) { .set_dispatch<IfThenElse>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const IfThenElse*>(node.get());
p->PrintIndent(); p->PrintIndent();
while (true) { while (true) {
p->stream << "if (" << op->condition << ") {\n"; p->stream << "if (" << op->condition << ") {\n";
...@@ -1028,7 +1070,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -1028,7 +1070,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Evaluate>([](const Evaluate* op, IRPrinter* p) { .set_dispatch<Evaluate>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Evaluate*>(node.get());
p->PrintIndent(); p->PrintIndent();
p->Print(op->value); p->Print(op->value);
p->stream << "\n"; p->stream << "\n";
...@@ -1045,7 +1088,8 @@ void PrintList(const Array<T> &exprs, IRPrinter* p) { ...@@ -1045,7 +1088,8 @@ void PrintList(const Array<T> &exprs, IRPrinter* p) {
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Shuffle>([](const Shuffle* op, IRPrinter* p) { .set_dispatch<Shuffle>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Shuffle*>(node.get());
p->stream << "shuffle("; p->stream << "shuffle(";
PrintList(op->vectors, p); PrintList(op->vectors, p);
p->stream << ", "; p->stream << ", ";
...@@ -1055,7 +1099,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -1055,7 +1099,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
// Container printer // Container printer
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ArrayNode>([](const ArrayNode* op, IRPrinter* p) { .set_dispatch<ArrayNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const ArrayNode*>(node.get());
p->stream << '['; p->stream << '[';
for (size_t i = 0 ; i < op->data.size(); ++i) { for (size_t i = 0 ; i < op->data.size(); ++i) {
if (i != 0) { if (i != 0) {
...@@ -1067,7 +1112,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -1067,7 +1112,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<MapNode>([](const MapNode* op, IRPrinter* p) { .set_dispatch<MapNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const MapNode*>(node.get());
p->stream << '{'; p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) { for (auto it = op->data.begin(); it != op->data.end(); ++it) {
if (it != op->data.begin()) { if (it != op->data.begin()) {
...@@ -1081,7 +1127,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -1081,7 +1127,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<StrMapNode>([](const StrMapNode* op, IRPrinter* p) { .set_dispatch<StrMapNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const StrMapNode*>(node.get());
p->stream << '{'; p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) { for (auto it = op->data.begin(); it != op->data.end(); ++it) {
if (it != op->data.begin()) { if (it != op->data.begin()) {
...@@ -1094,7 +1141,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -1094,7 +1141,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRPrinter* p) { .set_dispatch<Reduce>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Reduce*>(node.get());
p->stream << "reduce(combiner=" p->stream << "reduce(combiner="
<< op->combiner; << op->combiner;
p->stream << ", source=" << op->source; p->stream << ", source=" << op->source;
...@@ -1105,7 +1153,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -1105,7 +1153,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<CommReducerNode>([](const CommReducerNode* op, IRPrinter* p) { .set_dispatch<CommReducerNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const CommReducerNode*>(node.get());
p->stream << "comm_reducer(result=" << op->result p->stream << "comm_reducer(result=" << op->result
<< ", lhs=" << op->lhs << ", lhs=" << op->lhs
<< ", rhs=" << op->rhs << ", rhs=" << op->rhs
...@@ -1114,7 +1163,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -1114,7 +1163,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Any>([](const Any *op, IRPrinter *p) { .set_dispatch<Any>([](const ObjectRef& node, IRPrinter* p) {
p->stream << "?"; p->stream << "?";
}); });
......
...@@ -26,7 +26,8 @@ ...@@ -26,7 +26,8 @@
namespace tvm { namespace tvm {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<LoweredFuncNode>([](const LoweredFuncNode *op, IRPrinter *p) { .set_dispatch<LoweredFuncNode>([](const ObjectRef& node, IRPrinter *p) {
auto* op = static_cast<const LoweredFuncNode*>(node.get());
p->stream << "LoweredFunc(" << op->name << ", " << op << ")"; p->stream << "LoweredFunc(" << op->name << ", " << op << ")";
}); });
......
...@@ -27,7 +27,8 @@ ...@@ -27,7 +27,8 @@
namespace tvm { namespace tvm {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<MemoryInfoNode>([](const MemoryInfoNode *op, IRPrinter *p) { .set_dispatch<MemoryInfoNode>([](const ObjectRef& node, IRPrinter *p) {
auto* op = static_cast<const MemoryInfoNode*>(node.get());
p->stream << "mem-info(" p->stream << "mem-info("
<< "unit_bits=" << op->unit_bits << ", " << "unit_bits=" << op->unit_bits << ", "
<< "max_num_bits=" << op->max_num_bits << ", " << "max_num_bits=" << op->max_num_bits << ", "
......
...@@ -69,7 +69,8 @@ Tensor TensorNode::make(Array<Expr> shape, ...@@ -69,7 +69,8 @@ Tensor TensorNode::make(Array<Expr> shape,
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorNode>([](const TensorNode *t, IRPrinter *p) { .set_dispatch<TensorNode>([](const ObjectRef& node, IRPrinter *p) {
auto* t = static_cast<const TensorNode*>(node.get());
p->stream << "Tensor(shape=" << t->shape p->stream << "Tensor(shape=" << t->shape
<< ", op.name=" << t->op->name << ')'; << ", op.name=" << t->op->name << ')';
}); });
...@@ -100,8 +101,9 @@ TensorIntrin TensorIntrinNode::make(std::string name, ...@@ -100,8 +101,9 @@ TensorIntrin TensorIntrinNode::make(std::string name,
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorIntrinNode>([](const TensorIntrinNode *n, IRPrinter *p) { .set_dispatch<TensorIntrinNode>([](const ObjectRef& node, IRPrinter* p) {
p->stream << "TensorIntrin(name=" << n->name << ", " << n << ")"; auto* op = static_cast<const TensorIntrinNode*>(node.get());
p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")";
}); });
TVM_REGISTER_NODE_TYPE(TensorIntrinNode); TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
...@@ -124,7 +126,8 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, ...@@ -124,7 +126,8 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorIntrinCallNode>([](const TensorIntrinCallNode *n, IRPrinter *p) { .set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, IRPrinter *p) {
auto* n = static_cast<const TensorIntrinCallNode*>(node.get());
p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
}); });
......
...@@ -40,7 +40,8 @@ namespace tvm { ...@@ -40,7 +40,8 @@ namespace tvm {
using namespace ir; using namespace ir;
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) { .set_dispatch<ComputeOpNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const ComputeOpNode*>(node.get());
p->stream << "compute(" << op->name << ", " << op << ")"; p->stream << "compute(" << op->name << ", " << op << ")";
}); });
......
...@@ -31,7 +31,8 @@ namespace tvm { ...@@ -31,7 +31,8 @@ namespace tvm {
using namespace ir; using namespace ir;
// ExternOpNode // ExternOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ExternOpNode>([](const ExternOpNode *op, IRPrinter *p) { .set_dispatch<ExternOpNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const ExternOpNode*>(node.get());
p->stream << "extern(" << op->name << ", " << op << ")"; p->stream << "extern(" << op->name << ", " << op << ")";
}); });
......
...@@ -37,7 +37,8 @@ namespace tvm { ...@@ -37,7 +37,8 @@ namespace tvm {
using namespace ir; using namespace ir;
// HybridOpNode // HybridOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<HybridOpNode>([](const HybridOpNode *op, IRPrinter *p) { .set_dispatch<HybridOpNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const HybridOpNode*>(node.get());
p->stream << "hybrid(" << op->name << ", " << op << ")"; p->stream << "hybrid(" << op->name << ", " << op << ")";
}); });
......
...@@ -28,7 +28,8 @@ namespace tvm { ...@@ -28,7 +28,8 @@ namespace tvm {
// PlaceholderOpNode // PlaceholderOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PlaceholderOpNode>([](const PlaceholderOpNode *op, IRPrinter *p) { .set_dispatch<PlaceholderOpNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const PlaceholderOpNode*>(node.get());
p->stream << "placeholder(" << op->name << ", " << op << ")"; p->stream << "placeholder(" << op->name << ", " << op << ")";
}); });
......
...@@ -32,7 +32,8 @@ namespace tvm { ...@@ -32,7 +32,8 @@ namespace tvm {
using namespace ir; using namespace ir;
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ScanOpNode>([](const ScanOpNode *op, IRPrinter *p) { .set_dispatch<ScanOpNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const ScanOpNode*>(node.get());
p->stream << "scan(" << op->name << ", " << op << ")"; p->stream << "scan(" << op->name << ", " << op << ")";
}); });
TVM_REGISTER_NODE_TYPE(ScanOpNode); TVM_REGISTER_NODE_TYPE(ScanOpNode);
......
...@@ -36,8 +36,8 @@ namespace tvm { ...@@ -36,8 +36,8 @@ namespace tvm {
using namespace ir; using namespace ir;
// TensorComputeOpNode // TensorComputeOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorComputeOpNode>([](const TensorComputeOpNode *op, .set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, IRPrinter* p) {
IRPrinter *p) { auto* op = static_cast<const TensorComputeOpNode*>(node.get());
p->stream << "tensor_compute_op(" << op->name << ", " << op << ")"; p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
}); });
......
...@@ -119,8 +119,8 @@ inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) { ...@@ -119,8 +119,8 @@ inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
// Mutate Stmt // Mutate Stmt
#define DISPATCH_TO_MUTATE_STMT(OP) \ #define DISPATCH_TO_MUTATE_STMT(OP) \
set_dispatch<OP>([](const OP* op, const Stmt& s, IRMutator* m) { \ set_dispatch<OP>([](const ObjectRef& node, const Stmt& s, IRMutator* m) { \
return m->Mutate_(op, s); \ return m->Mutate_(static_cast<const OP*>(node.get()), s); \
}) })
Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) { Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
...@@ -345,8 +345,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -345,8 +345,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
// Mutate Expr // Mutate Expr
#define DISPATCH_TO_MUTATE_EXPR(OP) \ #define DISPATCH_TO_MUTATE_EXPR(OP) \
set_dispatch<OP>([](const OP* op, const Expr& e, IRMutator* m) { \ set_dispatch<OP>([](const ObjectRef& node, const Expr& e, IRMutator* m) { \
return m->Mutate_(op, e); \ return m->Mutate_(static_cast<const OP*>(node.get()), e); \
}) })
Expr IRMutator::Mutate_(const Variable *op, const Expr& e) { Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
......
...@@ -238,8 +238,8 @@ DEFINE_OP_NO_VISIT_(FloatImm) ...@@ -238,8 +238,8 @@ DEFINE_OP_NO_VISIT_(FloatImm)
DEFINE_OP_NO_VISIT_(StringImm) DEFINE_OP_NO_VISIT_(StringImm)
#define DISPATCH_TO_VISIT(OP) \ #define DISPATCH_TO_VISIT(OP) \
set_dispatch<OP>([](const OP* op, IRVisitor* v) { \ set_dispatch<OP>([](const ObjectRef& node, IRVisitor* v) { \
v->Visit_(op); \ v->Visit_(static_cast<const OP*>(node.get())); \
}) })
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <dmlc/any.h> #include <dmlc/any.h>
#include <dmlc/json.h> #include <dmlc/json.h>
#include <tvm/node/ir_functor.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
......
...@@ -53,8 +53,9 @@ Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) { ...@@ -53,8 +53,9 @@ Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) {
TVM_REGISTER_API("relay._make.Closure") TVM_REGISTER_API("relay._make.Closure")
.set_body_typed(ClosureNode::make); .set_body_typed(ClosureNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) { .set_dispatch<ClosureNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const ClosureNode*>(ref.get());
p->stream << "ClosureNode(" << node->func << ", " << node->env << ")"; p->stream << "ClosureNode(" << node->func << ", " << node->env << ")";
}); });
...@@ -71,8 +72,9 @@ RecClosure RecClosureNode::make(Closure clos, Var bind) { ...@@ -71,8 +72,9 @@ RecClosure RecClosureNode::make(Closure clos, Var bind) {
TVM_REGISTER_API("relay._make.RecClosure") TVM_REGISTER_API("relay._make.RecClosure")
.set_body_typed(RecClosureNode::make); .set_body_typed(RecClosureNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RecClosureNode>([](const RecClosureNode* node, tvm::IRPrinter* p) { .set_dispatch<RecClosureNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const RecClosureNode*>(ref.get());
p->stream << "RecClosureNode(" << node->clos << ")"; p->stream << "RecClosureNode(" << node->clos << ")";
}); });
...@@ -85,8 +87,9 @@ TupleValue TupleValueNode::make(tvm::Array<Value> value) { ...@@ -85,8 +87,9 @@ TupleValue TupleValueNode::make(tvm::Array<Value> value) {
TVM_REGISTER_API("relay._make.TupleValue") TVM_REGISTER_API("relay._make.TupleValue")
.set_body_typed(TupleValueNode::make); .set_body_typed(TupleValueNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TupleValueNode>([](const TupleValueNode* node, tvm::IRPrinter* p) { .set_dispatch<TupleValueNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const TupleValueNode*>(ref.get());
p->stream << "TupleValueNode(" << node->fields << ")"; p->stream << "TupleValueNode(" << node->fields << ")";
}); });
...@@ -96,8 +99,9 @@ TensorValue TensorValueNode::make(runtime::NDArray data) { ...@@ -96,8 +99,9 @@ TensorValue TensorValueNode::make(runtime::NDArray data) {
return TensorValue(n); return TensorValue(n);
} }
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorValueNode>([](const TensorValueNode* node, tvm::IRPrinter* p) { .set_dispatch<TensorValueNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const TensorValueNode*>(ref.get());
auto to_str = GetPackedFunc("relay._tensor_value_repr"); auto to_str = GetPackedFunc("relay._tensor_value_repr");
std::string data_str = to_str(GetRef<TensorValue>(node)); std::string data_str = to_str(GetRef<TensorValue>(node));
p->stream << "TensorValueNode(" << data_str << ")"; p->stream << "TensorValueNode(" << data_str << ")";
...@@ -117,9 +121,9 @@ TVM_REGISTER_API("relay._make.RefValue") ...@@ -117,9 +121,9 @@ TVM_REGISTER_API("relay._make.RefValue")
TVM_REGISTER_NODE_TYPE(RefValueNode); TVM_REGISTER_NODE_TYPE(RefValueNode);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RefValueNode>([](const RefValueNode* node, .set_dispatch<RefValueNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const RefValueNode*>(ref.get());
p->stream << "RefValueNode(" << node->value << ")"; p->stream << "RefValueNode(" << node->value << ")";
}); });
...@@ -138,9 +142,9 @@ TVM_REGISTER_API("relay._make.ConstructorValue") ...@@ -138,9 +142,9 @@ TVM_REGISTER_API("relay._make.ConstructorValue")
TVM_REGISTER_NODE_TYPE(ConstructorValueNode); TVM_REGISTER_NODE_TYPE(ConstructorValueNode);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node, .set_dispatch<ConstructorValueNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const ConstructorValueNode*>(ref.get());
p->stream << "ConstructorValueNode(" << node->tag << "," p->stream << "ConstructorValueNode(" << node->tag << ","
<< node->fields << ")"; << node->fields << ")";
}); });
......
...@@ -37,9 +37,8 @@ TVM_REGISTER_NODE_TYPE(PatternWildcardNode); ...@@ -37,9 +37,8 @@ TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
TVM_REGISTER_API("relay._make.PatternWildcard") TVM_REGISTER_API("relay._make.PatternWildcard")
.set_body_typed(PatternWildcardNode::make); .set_body_typed(PatternWildcardNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const PatternWildcardNode* node, .set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) {
p->stream << "PatternWildcardNode()"; p->stream << "PatternWildcardNode()";
}); });
...@@ -54,9 +53,9 @@ TVM_REGISTER_NODE_TYPE(PatternVarNode); ...@@ -54,9 +53,9 @@ TVM_REGISTER_NODE_TYPE(PatternVarNode);
TVM_REGISTER_API("relay._make.PatternVar") TVM_REGISTER_API("relay._make.PatternVar")
.set_body_typed(PatternVarNode::make); .set_body_typed(PatternVarNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PatternVarNode>([](const PatternVarNode* node, .set_dispatch<PatternVarNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const PatternVarNode*>(ref.get());
p->stream << "PatternVarNode(" << node->var << ")"; p->stream << "PatternVarNode(" << node->var << ")";
}); });
...@@ -73,9 +72,9 @@ TVM_REGISTER_NODE_TYPE(PatternConstructorNode); ...@@ -73,9 +72,9 @@ TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
TVM_REGISTER_API("relay._make.PatternConstructor") TVM_REGISTER_API("relay._make.PatternConstructor")
.set_body_typed(PatternConstructorNode::make); .set_body_typed(PatternConstructorNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const PatternConstructorNode* node, .set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const PatternConstructorNode*>(ref.get());
p->stream << "PatternConstructorNode(" << node->constructor p->stream << "PatternConstructorNode(" << node->constructor
<< ", " << node->patterns << ")"; << ", " << node->patterns << ")";
}); });
...@@ -91,9 +90,9 @@ TVM_REGISTER_NODE_TYPE(PatternTupleNode); ...@@ -91,9 +90,9 @@ TVM_REGISTER_NODE_TYPE(PatternTupleNode);
TVM_REGISTER_API("relay._make.PatternTuple") TVM_REGISTER_API("relay._make.PatternTuple")
.set_body_typed(PatternTupleNode::make); .set_body_typed(PatternTupleNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PatternTupleNode>([](const PatternTupleNode* node, .set_dispatch<PatternTupleNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const PatternTupleNode*>(ref.get());
p->stream << "PatternTupleNode(" << node->patterns << ")"; p->stream << "PatternTupleNode(" << node->patterns << ")";
}); });
...@@ -112,9 +111,9 @@ TVM_REGISTER_NODE_TYPE(ConstructorNode); ...@@ -112,9 +111,9 @@ TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_API("relay._make.Constructor") TVM_REGISTER_API("relay._make.Constructor")
.set_body_typed(ConstructorNode::make); .set_body_typed(ConstructorNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ConstructorNode>([](const ConstructorNode* node, .set_dispatch<ConstructorNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const ConstructorNode*>(ref.get());
p->stream << "ConstructorNode(" << node->name_hint << ", " p->stream << "ConstructorNode(" << node->name_hint << ", "
<< node->inputs << ", " << node->belong_to << ")"; << node->inputs << ", " << node->belong_to << ")";
}); });
...@@ -134,9 +133,9 @@ TVM_REGISTER_NODE_TYPE(TypeDataNode); ...@@ -134,9 +133,9 @@ TVM_REGISTER_NODE_TYPE(TypeDataNode);
TVM_REGISTER_API("relay._make.TypeData") TVM_REGISTER_API("relay._make.TypeData")
.set_body_typed(TypeDataNode::make); .set_body_typed(TypeDataNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeDataNode>([](const TypeDataNode* node, .set_dispatch<TypeDataNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const TypeDataNode*>(ref.get());
p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", " p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
<< node->constructors << ")"; << node->constructors << ")";
}); });
...@@ -153,9 +152,9 @@ TVM_REGISTER_NODE_TYPE(ClauseNode); ...@@ -153,9 +152,9 @@ TVM_REGISTER_NODE_TYPE(ClauseNode);
TVM_REGISTER_API("relay._make.Clause") TVM_REGISTER_API("relay._make.Clause")
.set_body_typed(ClauseNode::make); .set_body_typed(ClauseNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ClauseNode>([](const ClauseNode* node, .set_dispatch<ClauseNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const ClauseNode*>(ref.get());
p->stream << "ClauseNode(" << node->lhs << ", " p->stream << "ClauseNode(" << node->lhs << ", "
<< node->rhs << ")"; << node->rhs << ")";
}); });
...@@ -173,9 +172,9 @@ TVM_REGISTER_NODE_TYPE(MatchNode); ...@@ -173,9 +172,9 @@ TVM_REGISTER_NODE_TYPE(MatchNode);
TVM_REGISTER_API("relay._make.Match") TVM_REGISTER_API("relay._make.Match")
.set_body_typed(MatchNode::make); .set_body_typed(MatchNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<MatchNode>([](const MatchNode* node, .set_dispatch<MatchNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const MatchNode*>(ref.get());
p->stream << "MatchNode(" << node->data << ", " p->stream << "MatchNode(" << node->data << ", "
<< node->clauses << ", " << node->complete << ")"; << node->clauses << ", " << node->complete << ")";
}); });
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
* \file base.cc * \file base.cc
* \brief The core base types for Relay. * \brief The core base types for Relay.
*/ */
...@@ -31,7 +30,7 @@ namespace relay { ...@@ -31,7 +30,7 @@ namespace relay {
using tvm::IRPrinter; using tvm::IRPrinter;
using namespace tvm::runtime; using namespace tvm::runtime;
NodePtr<SourceNameNode> GetSourceNameNode(const std::string& name) { ObjectPtr<Object> GetSourceNameNode(const std::string& name) {
// always return pointer as the reference can change as map re-allocate. // always return pointer as the reference can change as map re-allocate.
// or use another level of indirection by creating a unique_ptr // or use another level of indirection by creating a unique_ptr
static std::unordered_map<std::string, NodePtr<SourceNameNode> > source_map; static std::unordered_map<std::string, NodePtr<SourceNameNode> > source_map;
...@@ -54,8 +53,9 @@ SourceName SourceName::Get(const std::string& name) { ...@@ -54,8 +53,9 @@ SourceName SourceName::Get(const std::string& name) {
TVM_REGISTER_API("relay._make.SourceName") TVM_REGISTER_API("relay._make.SourceName")
.set_body_typed(SourceName::Get); .set_body_typed(SourceName::Get);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) { .set_dispatch<SourceNameNode>([](const ObjectRef& ref, tvm::IRPrinter* p) {
auto* node = static_cast<const SourceNameNode*>(ref.get());
p->stream << "SourceName(" << node->name << ", " << node << ")"; p->stream << "SourceName(" << node->name << ", " << node << ")";
}); });
...@@ -78,8 +78,9 @@ TVM_REGISTER_NODE_TYPE(SpanNode); ...@@ -78,8 +78,9 @@ TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_API("relay._make.Span") TVM_REGISTER_API("relay._make.Span")
.set_body_typed(SpanNode::make); .set_body_typed(SpanNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const SpanNode* node, tvm::IRPrinter* p) { .set_dispatch<SpanNode>([](const ObjectRef& ref, tvm::IRPrinter* p) {
auto* node = static_cast<const SpanNode*>(ref.get());
p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", " p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")"; << node->col_offset << ")";
}); });
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/ir/expr.cc * \file src/tvm/ir/expr.cc
* \brief The expression AST nodes of Relay. * \brief The expression AST nodes of Relay.
*/ */
...@@ -41,8 +40,9 @@ TVM_REGISTER_NODE_TYPE(ConstantNode); ...@@ -41,8 +40,9 @@ TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_API("relay._make.Constant") TVM_REGISTER_API("relay._make.Constant")
.set_body_typed(ConstantNode::make); .set_body_typed(ConstantNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) { .set_dispatch<ConstantNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const ConstantNode*>(ref.get());
const PackedFunc* fprint = Registry::Get("relay._constant_repr"); const PackedFunc* fprint = Registry::Get("relay._constant_repr");
CHECK(fprint) << "unable to find printing function for constants"; CHECK(fprint) << "unable to find printing function for constants";
std::string data = (*fprint)(GetRef<Constant>(node)); std::string data = (*fprint)(GetRef<Constant>(node));
...@@ -73,8 +73,9 @@ TVM_REGISTER_NODE_TYPE(TupleNode); ...@@ -73,8 +73,9 @@ TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_API("relay._make.Tuple") TVM_REGISTER_API("relay._make.Tuple")
.set_body_typed(TupleNode::make); .set_body_typed(TupleNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TupleNode>([](const TupleNode* node, tvm::IRPrinter* p) { .set_dispatch<TupleNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const TupleNode*>(ref.get());
p->stream << "Tuple(" << node->fields << ")"; p->stream << "Tuple(" << node->fields << ")";
}); });
...@@ -97,8 +98,9 @@ TVM_REGISTER_NODE_TYPE(VarNode); ...@@ -97,8 +98,9 @@ TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_API("relay._make.Var") TVM_REGISTER_API("relay._make.Var")
.set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make)); .set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) { .set_dispatch<VarNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const VarNode*>(ref.get());
p->stream << "Var(" << node->name_hint(); p->stream << "Var(" << node->name_hint();
if (node->type_annotation.defined()) { if (node->type_annotation.defined()) {
p->stream << ", ty="; p->stream << ", ty=";
...@@ -118,8 +120,9 @@ TVM_REGISTER_NODE_TYPE(GlobalVarNode); ...@@ -118,8 +120,9 @@ TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_API("relay._make.GlobalVar") TVM_REGISTER_API("relay._make.GlobalVar")
.set_body_typed(GlobalVarNode::make); .set_body_typed(GlobalVarNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const GlobalVarNode* node, tvm::IRPrinter* p) { .set_dispatch<GlobalVarNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const GlobalVarNode*>(ref.get());
p->stream << "GlobalVar(" << node->name_hint << ")"; p->stream << "GlobalVar(" << node->name_hint << ")";
}); });
...@@ -217,9 +220,9 @@ TVM_REGISTER_NODE_TYPE(FunctionNode); ...@@ -217,9 +220,9 @@ TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_API("relay._make.Function") TVM_REGISTER_API("relay._make.Function")
.set_body_typed(FunctionNode::make); .set_body_typed(FunctionNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FunctionNode>([](const FunctionNode* node, .set_dispatch<FunctionNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const FunctionNode*>(ref.get());
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ", " << ", " << node->body << ", " << node->type_params << ", "
<< node->attrs << ")"; << node->attrs << ")";
...@@ -240,11 +243,12 @@ TVM_REGISTER_NODE_TYPE(CallNode); ...@@ -240,11 +243,12 @@ TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_API("relay._make.Call") TVM_REGISTER_API("relay._make.Call")
.set_body_typed(CallNode::make); .set_body_typed(CallNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) { .set_dispatch<CallNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const CallNode*>(ref.get());
p->stream << "CallNode(" << node->op << ", " << node->args << ", " p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")"; << node->attrs << ", " << node->type_args << ")";
}); });
Let LetNode::make(Var var, Expr value, Expr body) { Let LetNode::make(Var var, Expr value, Expr body) {
NodePtr<LetNode> n = make_node<LetNode>(); NodePtr<LetNode> n = make_node<LetNode>();
...@@ -259,8 +263,9 @@ TVM_REGISTER_NODE_TYPE(LetNode); ...@@ -259,8 +263,9 @@ TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_API("relay._make.Let") TVM_REGISTER_API("relay._make.Let")
.set_body_typed(LetNode::make); .set_body_typed(LetNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<LetNode>([](const LetNode* node, tvm::IRPrinter* p) { .set_dispatch<LetNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const LetNode*>(ref.get());
p->stream << "LetNode(" << node->var << ", " << node->value p->stream << "LetNode(" << node->var << ", " << node->value
<< ", " << node->body << ")"; << ", " << node->body << ")";
}); });
...@@ -278,8 +283,9 @@ TVM_REGISTER_NODE_TYPE(IfNode); ...@@ -278,8 +283,9 @@ TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_API("relay._make.If") TVM_REGISTER_API("relay._make.If")
.set_body_typed(IfNode::make); .set_body_typed(IfNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IfNode>([](const IfNode* node, tvm::IRPrinter* p) { .set_dispatch<IfNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const IfNode*>(ref.get());
p->stream << "IfNode(" << node->cond << ", " << node->true_branch p->stream << "IfNode(" << node->cond << ", " << node->true_branch
<< ", " << node->false_branch << ")"; << ", " << node->false_branch << ")";
}); });
...@@ -296,8 +302,9 @@ TVM_REGISTER_NODE_TYPE(TupleGetItemNode); ...@@ -296,8 +302,9 @@ TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
TVM_REGISTER_API("relay._make.TupleGetItem") TVM_REGISTER_API("relay._make.TupleGetItem")
.set_body_typed(TupleGetItemNode::make); .set_body_typed(TupleGetItemNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const TupleGetItemNode* node, tvm::IRPrinter* p) { .set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const TupleGetItemNode*>(ref.get());
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
}); });
...@@ -312,8 +319,9 @@ TVM_REGISTER_NODE_TYPE(RefCreateNode); ...@@ -312,8 +319,9 @@ TVM_REGISTER_NODE_TYPE(RefCreateNode);
TVM_REGISTER_API("relay._make.RefCreate") TVM_REGISTER_API("relay._make.RefCreate")
.set_body_typed(RefCreateNode::make); .set_body_typed(RefCreateNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RefCreateNode>([](const RefCreateNode* node, tvm::IRPrinter* p) { .set_dispatch<RefCreateNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const RefCreateNode*>(ref.get());
p->stream << "RefCreateNode(" << node->value << ")"; p->stream << "RefCreateNode(" << node->value << ")";
}); });
...@@ -328,8 +336,9 @@ TVM_REGISTER_NODE_TYPE(RefReadNode); ...@@ -328,8 +336,9 @@ TVM_REGISTER_NODE_TYPE(RefReadNode);
TVM_REGISTER_API("relay._make.RefRead") TVM_REGISTER_API("relay._make.RefRead")
.set_body_typed(RefReadNode::make); .set_body_typed(RefReadNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RefReadNode>([](const RefReadNode* node, tvm::IRPrinter* p) { .set_dispatch<RefReadNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const RefReadNode*>(ref.get());
p->stream << "RefReadNode(" << node->ref << ")"; p->stream << "RefReadNode(" << node->ref << ")";
}); });
...@@ -345,8 +354,9 @@ TVM_REGISTER_NODE_TYPE(RefWriteNode); ...@@ -345,8 +354,9 @@ TVM_REGISTER_NODE_TYPE(RefWriteNode);
TVM_REGISTER_API("relay._make.RefWrite") TVM_REGISTER_API("relay._make.RefWrite")
.set_body_typed(RefWriteNode::make); .set_body_typed(RefWriteNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RefWriteNode>([](const RefWriteNode* node, tvm::IRPrinter* p) { .set_dispatch<RefWriteNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const RefWriteNode*>(ref.get());
p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
}); });
......
...@@ -414,9 +414,9 @@ TVM_REGISTER_API("relay._module.Module_ImportFromStd") ...@@ -414,9 +414,9 @@ TVM_REGISTER_API("relay._module.Module_ImportFromStd")
mod->ImportFromStd(path); mod->ImportFromStd(path);
});; });;
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ModuleNode>( .set_dispatch<ModuleNode>([](const ObjectRef& ref, IRPrinter* p) {
[](const ModuleNode *node, tvm::IRPrinter *p) { auto* node = static_cast<const ModuleNode*>(ref.get());
p->stream << "ModuleNode( " << node->functions << ")"; p->stream << "ModuleNode( " << node->functions << ")";
}); });
......
...@@ -199,8 +199,9 @@ TVM_REGISTER_NODE_TYPE(OpNode) ...@@ -199,8 +199,9 @@ TVM_REGISTER_NODE_TYPE(OpNode)
return static_cast<const OpNode*>(n)->name; return static_cast<const OpNode*>(n)->name;
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<OpNode>([](const OpNode* node, tvm::IRPrinter* p) { .set_dispatch<OpNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const OpNode*>(ref.get());
p->stream << "Op(" << node->name << ")"; p->stream << "Op(" << node->name << ")";
}); });
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/ir/type.cc * \file src/tvm/ir/type.cc
* \brief The type system AST nodes of Relay. * \brief The type system AST nodes of Relay.
*/ */
...@@ -58,9 +57,9 @@ TVM_REGISTER_NODE_TYPE(TensorTypeNode); ...@@ -58,9 +57,9 @@ TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_API("relay._make.TensorType") TVM_REGISTER_API("relay._make.TensorType")
.set_body_typed(TensorTypeNode::make); .set_body_typed(TensorTypeNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode* node, .set_dispatch<TensorTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const TensorTypeNode*>(ref.get());
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
}); });
...@@ -78,9 +77,9 @@ TVM_REGISTER_API("relay._make.TypeVar") ...@@ -78,9 +77,9 @@ TVM_REGISTER_API("relay._make.TypeVar")
return TypeVarNode::make(name, static_cast<Kind>(kind)); return TypeVarNode::make(name, static_cast<Kind>(kind));
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeVarNode>([](const TypeVarNode* node, .set_dispatch<TypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const TypeVarNode*>(ref.get());
p->stream << "TypeVarNode(" << node->var->name_hint << ", " p->stream << "TypeVarNode(" << node->var->name_hint << ", "
<< node->kind << ")"; << node->kind << ")";
}); });
...@@ -99,9 +98,9 @@ TVM_REGISTER_API("relay._make.GlobalTypeVar") ...@@ -99,9 +98,9 @@ TVM_REGISTER_API("relay._make.GlobalTypeVar")
return GlobalTypeVarNode::make(name, static_cast<Kind>(kind)); return GlobalTypeVarNode::make(name, static_cast<Kind>(kind));
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const GlobalTypeVarNode *node, .set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter *p) { auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
p->stream << "GlobalTypeVarNode(" << node->var->name_hint << ", " p->stream << "GlobalTypeVarNode(" << node->var->name_hint << ", "
<< node->kind << ")"; << node->kind << ")";
}); });
...@@ -118,9 +117,9 @@ TVM_REGISTER_NODE_TYPE(TypeCallNode); ...@@ -118,9 +117,9 @@ TVM_REGISTER_NODE_TYPE(TypeCallNode);
TVM_REGISTER_API("relay._make.TypeCall") TVM_REGISTER_API("relay._make.TypeCall")
.set_body_typed(TypeCallNode::make); .set_body_typed(TypeCallNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeCallNode>([](const TypeCallNode* node, .set_dispatch<TypeCallNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const TypeCallNode*>(ref.get());
p->stream << "TypeCallNode(" << node->func << ", " p->stream << "TypeCallNode(" << node->func << ", "
<< node->args << ")"; << node->args << ")";
}); });
...@@ -138,10 +137,9 @@ TVM_REGISTER_API("relay._make.IncompleteType") ...@@ -138,10 +137,9 @@ TVM_REGISTER_API("relay._make.IncompleteType")
return IncompleteTypeNode::make(static_cast<Kind>(kind)); return IncompleteTypeNode::make(static_cast<Kind>(kind));
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IncompleteTypeNode>( .set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
[](const IncompleteTypeNode* node, auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
tvm::IRPrinter* p) {
p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
}); });
...@@ -162,9 +160,9 @@ TVM_REGISTER_NODE_TYPE(FuncTypeNode); ...@@ -162,9 +160,9 @@ TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_API("relay._make.FuncType") TVM_REGISTER_API("relay._make.FuncType")
.set_body_typed(FuncTypeNode::make); .set_body_typed(FuncTypeNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const FuncTypeNode* node, .set_dispatch<FuncTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const FuncTypeNode*>(ref.get());
p->stream << "FuncTypeNode(" << node->type_params << ", " p->stream << "FuncTypeNode(" << node->type_params << ", "
<< node->arg_types << ", " << node->ret_type << ", " << node->arg_types << ", " << node->ret_type << ", "
<< node->type_constraints << ")"; << node->type_constraints << ")";
...@@ -187,8 +185,9 @@ TVM_REGISTER_NODE_TYPE(TypeRelationNode); ...@@ -187,8 +185,9 @@ TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_API("relay._make.TypeRelation") TVM_REGISTER_API("relay._make.TypeRelation")
.set_body_typed(TypeRelationNode::make); .set_body_typed(TypeRelationNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const TypeRelationNode* node, tvm::IRPrinter* p) { .set_dispatch<TypeRelationNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const TypeRelationNode*>(ref.get());
p->stream << "TypeRelationNode(" p->stream << "TypeRelationNode("
<< node->func->name << node->func->name
<< ", " << node->args << ")"; << ", " << node->args << ")";
...@@ -205,9 +204,9 @@ TVM_REGISTER_NODE_TYPE(TupleTypeNode); ...@@ -205,9 +204,9 @@ TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_API("relay._make.TupleType") TVM_REGISTER_API("relay._make.TupleType")
.set_body_typed(TupleTypeNode::make); .set_body_typed(TupleTypeNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const TupleTypeNode* node, .set_dispatch<TupleTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const TupleTypeNode*>(ref.get());
p->stream << "TupleTypeNode(" << node->fields << ")"; p->stream << "TupleTypeNode(" << node->fields << ")";
}); });
...@@ -222,9 +221,9 @@ TVM_REGISTER_API("relay._make.RefType") ...@@ -222,9 +221,9 @@ TVM_REGISTER_API("relay._make.RefType")
TVM_REGISTER_NODE_TYPE(RefTypeNode); TVM_REGISTER_NODE_TYPE(RefTypeNode);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RefTypeNode>([](const RefTypeNode* node, .set_dispatch<RefTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const RefTypeNode*>(ref.get());
p->stream << "RefTypeNode(" << node->value << ")"; p->stream << "RefTypeNode(" << node->value << ")";
}); });
......
...@@ -18,14 +18,13 @@ ...@@ -18,14 +18,13 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
* \file type_functor.h * \file type_functor.h
* \brief A way to defined arbitrary function signature with dispatch on types. * \brief A way to defined arbitrary function signature with dispatch on types.
*/ */
#ifndef TVM_RELAY_IR_TYPE_FUNCTOR_H_ #ifndef TVM_RELAY_IR_TYPE_FUNCTOR_H_
#define TVM_RELAY_IR_TYPE_FUNCTOR_H_ #define TVM_RELAY_IR_TYPE_FUNCTOR_H_
#include <tvm/node/ir_functor.h> #include <tvm/node/functor.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/adt.h> #include <tvm/relay/adt.h>
#include <string> #include <string>
...@@ -54,7 +53,7 @@ template <typename R, typename... Args> ...@@ -54,7 +53,7 @@ template <typename R, typename... Args>
class TypeFunctor<R(const Type& n, Args...)> { class TypeFunctor<R(const Type& n, Args...)> {
private: private:
using TSelf = TypeFunctor<R(const Type& n, Args...)>; using TSelf = TypeFunctor<R(const Type& n, Args...)>;
using FType = tvm::IRFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
public: public:
/*! \brief the result type of this functor */ /*! \brief the result type of this functor */
......
...@@ -449,9 +449,9 @@ TVM_REGISTER_API("relay._transform.Info") ...@@ -449,9 +449,9 @@ TVM_REGISTER_API("relay._transform.Info")
*ret = pass->Info(); *ret = pass->Info();
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PassInfoNode>([](const PassInfoNode* node, .set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const PassInfoNode*>(ref.get());
p->stream << "The meta data of the pass: "; p->stream << "The meta data of the pass: ";
p->stream << "pass name: " << node->name; p->stream << "pass name: " << node->name;
p->stream << "opt_level: " << node->opt_level; p->stream << "opt_level: " << node->opt_level;
...@@ -475,9 +475,9 @@ TVM_REGISTER_API("relay._transform.RunPass") ...@@ -475,9 +475,9 @@ TVM_REGISTER_API("relay._transform.RunPass")
*ret = pass(mod); *ret = pass(mod);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ModulePassNode>([](const ModulePassNode* node, .set_dispatch<ModulePassNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const ModulePassNode*>(ref.get());
const PassInfo info = node->Info(); const PassInfo info = node->Info();
p->stream << "Run Module pass: " << info->name p->stream << "Run Module pass: " << info->name
<< " at the optimization level " << info->opt_level; << " at the optimization level " << info->opt_level;
...@@ -488,9 +488,9 @@ TVM_REGISTER_NODE_TYPE(FunctionPassNode); ...@@ -488,9 +488,9 @@ TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_API("relay._transform.MakeFunctionPass") TVM_REGISTER_API("relay._transform.MakeFunctionPass")
.set_body_typed(FunctionPassNode::make); .set_body_typed(FunctionPassNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FunctionPassNode>([](const FunctionPassNode* node, .set_dispatch<FunctionPassNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const FunctionPassNode*>(ref.get());
const PassInfo info = node->Info(); const PassInfo info = node->Info();
p->stream << "Run Function pass: " << info->name p->stream << "Run Function pass: " << info->name
<< " at the optimization level " << info->opt_level; << " at the optimization level " << info->opt_level;
...@@ -508,9 +508,9 @@ TVM_REGISTER_API("relay._transform.Sequential") ...@@ -508,9 +508,9 @@ TVM_REGISTER_API("relay._transform.Sequential")
*ret = Sequential(passes, pass_info); *ret = Sequential(passes, pass_info);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<SequentialNode>([](const SequentialNode* node, .set_dispatch<SequentialNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const SequentialNode*>(ref.get());
const PassInfo info = node->Info(); const PassInfo info = node->Info();
p->stream << "Run Sequential pass: " << info->name p->stream << "Run Sequential pass: " << info->name
<< " at the optimization level " << info->opt_level << ". "; << " at the optimization level " << info->opt_level << ". ";
...@@ -538,9 +538,9 @@ TVM_REGISTER_API("relay._transform.PassContext") ...@@ -538,9 +538,9 @@ TVM_REGISTER_API("relay._transform.PassContext")
*ret = pctx; *ret = pctx;
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PassContextNode>([](const PassContextNode* node, .set_dispatch<PassContextNode>([](const ObjectRef& ref, IRPrinter* p) {
tvm::IRPrinter* p) { auto* node = static_cast<const PassContextNode*>(ref.get());
p->stream << "Pass context information: " << "\n"; p->stream << "Pass context information: " << "\n";
p->stream << "\topt_level: " << node->opt_level << "\n"; p->stream << "\topt_level: " << node->opt_level << "\n";
p->stream << "\tfallback device: " p->stream << "\tfallback device: "
......
...@@ -117,7 +117,8 @@ QConfig& QConfig::Current() { ...@@ -117,7 +117,8 @@ QConfig& QConfig::Current() {
TVM_REGISTER_NODE_TYPE(QConfigNode); TVM_REGISTER_NODE_TYPE(QConfigNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<QConfigNode>([](const QConfigNode *op, IRPrinter *p) { .set_dispatch<QConfigNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* op = static_cast<const QConfigNode*>(ref.get());
p->stream << "qconfig("; p->stream << "qconfig(";
p->stream << "nbit_input=" << op->nbit_input << ", "; p->stream << "nbit_input=" << op->nbit_input << ", ";
p->stream << "nbit_weight=" << op->nbit_weight << ", "; p->stream << "nbit_weight=" << op->nbit_weight << ", ";
......
...@@ -800,17 +800,20 @@ TVM_REGISTER_NODE_TYPE(ScheduleNode); ...@@ -800,17 +800,20 @@ TVM_REGISTER_NODE_TYPE(ScheduleNode);
// Printer // Printer
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<StageNode>([](const StageNode *op, IRPrinter *p) { .set_dispatch<StageNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const StageNode*>(node.get());
if (op->op.defined()) { if (op->op.defined()) {
p->stream << "stage(" << op->origin_op->name << ", " << op << ")"; p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
} else { } else {
p->stream << "group-stage(" << op << ")"; p->stream << "group-stage(" << op << ")";
} }
}) })
.set_dispatch<IterVarAttrNode>([](const IterVarAttrNode *op, IRPrinter *p) { .set_dispatch<IterVarAttrNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const IterVarAttrNode*>(node.get());
p->stream << IterVarType2String(op->iter_type); p->stream << IterVarType2String(op->iter_type);
}) })
.set_dispatch<SplitNode>([](const SplitNode *op, IRPrinter *p) { .set_dispatch<SplitNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const SplitNode*>(node.get());
p->stream << "split(parent="; p->stream << "split(parent=";
p->Print(op->parent); p->Print(op->parent);
p->stream << ", outer="; p->stream << ", outer=";
...@@ -819,7 +822,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -819,7 +822,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->Print(op->inner); p->Print(op->inner);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<FuseNode>([](const FuseNode *op, IRPrinter *p) { .set_dispatch<FuseNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const FuseNode*>(node.get());
p->stream << "split("; p->stream << "split(";
p->stream << "outer="; p->stream << "outer=";
p->Print(op->outer); p->Print(op->outer);
...@@ -829,7 +833,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -829,7 +833,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->Print(op->fused); p->Print(op->fused);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<RebaseNode>([](const RebaseNode *op, IRPrinter *p) { .set_dispatch<RebaseNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const RebaseNode*>(node.get());
p->stream << "rebase("; p->stream << "rebase(";
p->stream << "parent="; p->stream << "parent=";
p->Print(op->parent); p->Print(op->parent);
...@@ -837,12 +842,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -837,12 +842,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->Print(op->rebased); p->Print(op->rebased);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<SingletonNode>([](const SingletonNode *op, IRPrinter *p) { .set_dispatch<SingletonNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const SingletonNode*>(node.get());
p->stream << "singleton("; p->stream << "singleton(";
p->Print(op->iter); p->Print(op->iter);
p->stream << ')'; p->stream << ')';
}) })
.set_dispatch<ScheduleNode>([](const ScheduleNode *op, IRPrinter *p) { .set_dispatch<ScheduleNode>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const ScheduleNode*>(node.get());
p->stream << "schedule(" << op << ")"; p->stream << "schedule(" << op << ")";
}); });
} // namespace tvm } // namespace tvm
...@@ -53,7 +53,7 @@ struct TestAttrs : public AttrsNode<TestAttrs> { ...@@ -53,7 +53,7 @@ struct TestAttrs : public AttrsNode<TestAttrs> {
TEST(Attrs, Basic) { TEST(Attrs, Basic) {
using namespace tvm; using namespace tvm;
using namespace tvm::test; using namespace tvm::test;
std::shared_ptr<TestAttrs> n = std::make_shared<TestAttrs>(); ObjectPtr<TestAttrs> n = make_object<TestAttrs>();
try { try {
n->InitBySeq("axis", 10); n->InitBySeq("axis", 10);
LOG(FATAL) << "bad"; LOG(FATAL) << "bad";
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/node/ir_functor.h> #include <tvm/node/functor.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
TEST(IRF, Basic) { TEST(IRF, Basic) {
...@@ -30,12 +30,12 @@ TEST(IRF, Basic) { ...@@ -30,12 +30,12 @@ TEST(IRF, Basic) {
Var x("x"); Var x("x");
auto z = x + 1; auto z = x + 1;
IRFunctor<int(const ObjectRef& n, int b)> f; NodeFunctor<int(const ObjectRef& n, int b)> f;
LOG(INFO) << "x"; LOG(INFO) << "x";
f.set_dispatch<Variable>([](const Variable* n, int b) { f.set_dispatch<Variable>([](const ObjectRef& n, int b) {
return b; return b;
}); });
f.set_dispatch<Add>([](const Add* n, int b) { f.set_dispatch<Add>([](const ObjectRef& n, int b) {
return b + 2; return b + 2;
}); });
CHECK_EQ(f(x, 2), 2); CHECK_EQ(f(x, 2), 2);
......
...@@ -45,7 +45,7 @@ IRMutator::FMutateExpr &IRVar2Const::vtable_expr() { // NOLINT(*) ...@@ -45,7 +45,7 @@ IRMutator::FMutateExpr &IRVar2Const::vtable_expr() { // NOLINT(*)
} }
TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr) TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr)
.set_dispatch<Variable>([](const Variable* op, const Expr &e, IRMutator* m) { .set_dispatch<Variable>([](const ObjectRef& ref, const Expr &e, IRMutator* m) {
IRVar2Const* vm = static_cast<IRVar2Const*>(m); IRVar2Const* vm = static_cast<IRVar2Const*>(m);
if (e.same_as(vm->var)) { if (e.same_as(vm->var)) {
return Expr(IntImm::make(Int(32), vm->int_val)); return Expr(IntImm::make(Int(32), vm->int_val));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment