Unverified Commit 78ca6fc8 by Tianqi Chen Committed by GitHub

[NODE][REFACTOR] Refactor reflection system in node. (#4189)

* [NODE][REFACTOR] Refactor reflection system in node.

- Removed the old Node, Node is now just an alias of runtime::Object
- Introduce ReflectionVTable, a new columnar dispatcher to support reflection
  - This allows us to remove vtable from most node objects
  - The VisitAttrs are registered via TVM_RESGITER_NODE_TYPE,
    they are no longer virtual.
- Consolidated serialization and reflection features into node.

* Explicit type qualification when calling destructor.

* Fix SPIRV, more comments
parent 324a9607
......@@ -58,7 +58,7 @@ class EnvFuncNode : public Node {
/*! \brief constructor */
EnvFuncNode() {}
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
}
......
......@@ -60,7 +60,7 @@ class ConstIntBoundNode : public Node {
int64_t min_value;
int64_t max_value;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("min_value", &min_value);
v->Visit("max_value", &max_value);
}
......@@ -162,7 +162,7 @@ class ModularSetNode : public Node {
/*! \brief The base */
int64_t base;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("coeff", &coeff);
v->Visit("base", &base);
}
......@@ -351,7 +351,7 @@ enum SignType {
*/
struct IntSetNode : public Node {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Object);
};
/*!
......
......@@ -115,7 +115,7 @@ class AttrFieldInfoNode : public Node {
/*! \brief detailed description of the type */
std::string description;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("type_info", &type_info);
v->Visit("description", &description);
......@@ -197,7 +197,7 @@ class AttrsHash {
size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value);
}
size_t operator()(const Type& value) const {
size_t operator()(const DataType& value) const {
return std::hash<int>()(
static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) |
......@@ -221,6 +221,8 @@ class BaseAttrsNode : public Node {
public:
using TVMArgs = runtime::TVMArgs;
using TVMRetValue = runtime::TVMRetValue;
// visit function
virtual void VisitAttrs(AttrVisitor* v) {}
/*!
* \brief Initialize the attributes by sequence of arguments
* \param args The postional arguments in the form
......@@ -753,12 +755,12 @@ class AttrNonDefaultVisitor {
template<typename DerivedType>
class AttrsNode : public BaseAttrsNode {
public:
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
::tvm::detail::AttrNormalVisitor vis(v);
self()->__VisitAttrs__(vis);
}
void VisitNonDefaultAttrs(AttrVisitor* v) final {
void VisitNonDefaultAttrs(AttrVisitor* v) {
::tvm::detail::AttrNonDefaultVisitor vis(v);
self()->__VisitAttrs__(vis);
}
......
......@@ -19,89 +19,16 @@
/*!
* \file tvm/base.h
* \brief Defines the base data structure
* \brief Base utilities
*/
#ifndef TVM_BASE_H_
#define TVM_BASE_H_
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <tvm/node/node.h>
#include <string>
#include <memory>
#include <functional>
#include <utility>
#include "runtime/registry.h"
namespace tvm {
using ::tvm::Node;
using ::tvm::NodeRef;
using ::tvm::AttrVisitor;
/*!
* \brief Macro to define common node ref methods.
* \param TypeName The name of the NodeRef.
* \param BaseTypeName The Base type.
* \param NodeName The node container type.
*/
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
TypeName() {} \
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
: BaseTypeName(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(data_.get()); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = NodeName;
/*!
* \brief Macro to define CopyOnWrite function in a NodeRef.
* \param NodeName The Type of the Node.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWNodeRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
NodeName* CopyOnWrite() { \
CHECK(data_ != nullptr); \
if (!data_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<NodeName*>(data_.get()); \
}
/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public ::tvm::NodeRef { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
}; \
/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
TVM_DEFINE_NODE_REF_COW(NodeName); \
};
/*!
* \brief RAII wrapper function to enter and exit a context object
* similar to python's with syntax.
......@@ -146,100 +73,6 @@ class With {
ContextType ctx_;
};
/*!
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
*
* \return the string representation of the node.
*/
std::string SaveJSON(const NodeRef& node);
/*!
* \brief Internal implementation of LoadJSON
* Load tvm Node object from json and return a shared_ptr of Node.
* \param json_str The json string to load from.
*
* \return The shared_ptr of the Node.
*/
ObjectPtr<Object> LoadJSON_(std::string json_str);
/*!
* \brief Load the node from json string.
* This can be used to deserialize any TVM object.
*
* \param json_str The json string to load from.
*
* \tparam NodeType the nodetype
*
* \code
* Expr e = LoadJSON<Expr>(json_str);
* \endcode
*/
template<typename NodeType,
typename = typename std::enable_if<std::is_base_of<NodeRef, NodeType>::value>::type >
inline NodeType LoadJSON(const std::string& json_str) {
return NodeType(LoadJSON_(json_str));
}
/*!
* \brief Registry entry for NodeFactory.
*
* There are two types of Nodes that can be serialized.
* The normal node requires a registration a creator function that
* constructs an empty Node of the corresponding type.
*
* The global singleton(e.g. global operator) where only global_key need to be serialized,
* in this case, FGlobalKey need to be defined.
*/
struct NodeFactoryReg {
/*!
* \brief creator function.
* \param global_key Key that identifies a global single object.
* If this is not empty then FGlobalKey
* \return The created function.
*/
using FCreate = std::function<NodePtr<Node>(const std::string& global_key)>;
/*!
* \brief Global key function, only needed by global objects.
* \param node The node pointer.
* \return node The global key to the node.
*/
using FGlobalKey = std::function<std::string(const Node* node)>;
/*! \brief registered name */
std::string name;
/*!
* \brief The creator function
*/
FCreate fcreator = nullptr;
/*!
* \brief The global key function.
*/
FGlobalKey fglobal_key = nullptr;
// setter of creator
NodeFactoryReg& set_creator(FCreate f) { // NOLINT(*)
this->fcreator = f;
return *this;
}
// setter of creator
NodeFactoryReg& set_global_key(FGlobalKey f) { // NOLINT(*)
this->fglobal_key = f;
return *this;
}
// global registry singleton
TVM_DLL static ::dmlc::Registry<::tvm::NodeFactoryReg> *Registry();
};
/*!
* \brief Register a Node type
* \note This is necessary to enable serialization of the Node.
*/
#define TVM_REGISTER_NODE_TYPE(TypeName) \
TVM_REGISTER_OBJECT_TYPE(TypeName); \
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \
.set_creator([](const std::string&) { return ::tvm::make_node<TypeName>(); })
#define TVM_STRINGIZE_DETAIL(x) #x
#define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x)
#define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__))
......
......@@ -135,7 +135,7 @@ class BufferNode : public Node {
/*! \brief constructor */
BufferNode() {}
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("dtype", &dtype);
v->Visit("shape", &shape);
......
......@@ -61,7 +61,7 @@ class TargetNode : public Node {
/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("target_name", &target_name);
v->Visit("device_name", &device_name);
v->Visit("device_type", &device_type);
......@@ -229,7 +229,7 @@ class BuildConfigNode : public Node {
/*! \brief Whether to disable loop vectorization. */
bool disable_vectorize = false;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
......@@ -473,6 +473,8 @@ class GenericFuncNode : public Node {
/* \brief map from keys to registered functions */
std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
void VisitAttrs(AttrVisitor* v) {}
static constexpr const char* _type_key = "GenericFunc";
TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node);
};
......
......@@ -54,7 +54,7 @@ struct ChannelNode : public Node {
/*! \brief default data type in read/write */
Type dtype;
// visit all attributes
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("handle_var", &handle_var);
v->Visit("dtype", &dtype);
}
......
......@@ -104,7 +104,7 @@ class LayoutNode : public Node {
*/
Array<IterVar> axes;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("axes", &axes);
}
......@@ -325,7 +325,7 @@ class BijectiveLayoutNode : public Node {
/*! \brief The destination layout */
Layout dst_layout;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("src_layout", &src_layout);
v->Visit("dst_layout", &dst_layout);
v->Visit("forward_rule", &forward_rule);
......
......@@ -27,8 +27,10 @@
#include <string>
#include <algorithm>
#include <unordered_map>
#include <iostream>
#include "base.h"
#include "dtype.h"
#include "node/node.h"
#include "node/container.h"
#include "node/ir_functor.h"
#include "runtime/c_runtime_api.h"
......@@ -110,7 +112,7 @@ class Variable : public ExprNode {
static Var make(DataType dtype, std::string name_hint);
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("name", &name_hint);
}
......@@ -164,7 +166,7 @@ class IntImm : public ExprNode {
/*! \brief the Internal value. */
int64_t value;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("value", &value);
}
......@@ -230,7 +232,7 @@ class RangeNode : public Node {
RangeNode() {}
RangeNode(Expr min, Expr extent) : min(min), extent(extent) {}
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("min", &min);
v->Visit("extent", &extent);
}
......@@ -406,7 +408,7 @@ class IterVarNode : public Node {
*/
std::string thread_tag;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dom", &dom);
v->Visit("var", &var);
v->Visit("iter_type", &iter_type);
......@@ -490,7 +492,7 @@ class IRPrinter {
};
// default print function for all nodes
inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
IRPrinter(os).Print(n);
return os;
}
......
......@@ -45,7 +45,7 @@ class UIntImm : public ExprNode {
/*! \brief The constant value content. */
uint64_t value;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("value", &value);
}
......@@ -62,7 +62,7 @@ class FloatImm : public ExprNode {
/*! \brief The constant value content. */
double value;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("value", &value);
}
......@@ -79,7 +79,7 @@ class StringImm : public ExprNode {
/*! \brief The constant value content. */
std::string value;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("value", &value);
}
......@@ -99,7 +99,7 @@ class Cast : public ExprNode {
/*! \brief Original data type. */
Expr value;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("value", &value);
}
......@@ -122,7 +122,7 @@ class BinaryOpNode : public ExprNode {
/*! \brief The right operand. */
Expr b;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->type));
v->Visit("a", &a);
v->Visit("b", &b);
......@@ -214,7 +214,7 @@ class CmpOpNode : public ExprNode {
/*! \brief The right operand. */
Expr b;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->type));
v->Visit("a", &a);
v->Visit("b", &b);
......@@ -278,7 +278,7 @@ class And : public ExprNode {
/*! \brief The right operand. */
Expr b;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->type));
v->Visit("a", &a);
v->Visit("b", &b);
......@@ -298,7 +298,7 @@ class Or : public ExprNode {
/*! \brief The right operand. */
Expr b;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("a", &a);
v->Visit("b", &b);
......@@ -316,7 +316,7 @@ class Not : public ExprNode {
/*! \brief The input operand. */
Expr a;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("a", &a);
}
......@@ -343,7 +343,7 @@ class Select : public ExprNode {
/*! \brief value to be returned when condition is false. */
Expr false_value;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("condition", &condition);
v->Visit("true_value", &true_value);
......@@ -380,7 +380,7 @@ class Load : public ExprNode {
/*! \brief The predicate to mask which lanes would be loaded. */
Expr predicate;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("buffer_var", &buffer_var);
v->Visit("index", &index);
......@@ -411,7 +411,7 @@ class Ramp : public ExprNode {
/*! \brief Total number of lanes. */
int lanes;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("base", &base);
v->Visit("stride", &stride);
......@@ -432,7 +432,7 @@ class Broadcast : public ExprNode {
/*! \brief The numerb of lanes. */
int lanes;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("value", &value);
v->Visit("lanes", &lanes);
......@@ -456,7 +456,7 @@ class Let : public ExprNode {
/*! \brief The result expression. */
Expr body;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("var", &var);
v->Visit("value", &value);
......@@ -522,7 +522,7 @@ class Call : public ExprNode {
/*! \brief The output value index if func's value is a tuple. */
int value_index{0};
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("name", &name);
v->Visit("args", &args);
......@@ -592,7 +592,7 @@ class Shuffle : public ExprNode {
/*! \brief The indices of each element. */
Array<Expr> indices;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("vectors", &vectors);
v->Visit("indices", &indices);
}
......@@ -652,7 +652,7 @@ class CommReducerNode : public Node {
Array<Expr> result,
Array<Expr> identity_element);
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs);
v->Visit("result", &result);
......@@ -694,7 +694,7 @@ class Reduce : public ExprNode {
Expr condition,
int value_index);
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("combiner", &combiner);
v->Visit("source", &source);
......@@ -710,7 +710,7 @@ class Reduce : public ExprNode {
/*! \brief Any shape. */
class Any : public ExprNode {
public:
void VisitAttrs(AttrVisitor* v) final {}
void VisitAttrs(AttrVisitor* v) {}
/*! \brief Convert to var. */
Var ToVar() const {
return Variable::make(Int(32), "any_dim");
......@@ -735,7 +735,7 @@ class LetStmt : public StmtNode {
/*! \brief The body block. */
Stmt body;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
......@@ -768,7 +768,7 @@ class AttrStmt : public StmtNode {
/*! \brief The body statement to be executed */
Stmt body;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("node", &node);
v->Visit("attr_key", &attr_key);
v->Visit("value", &value);
......@@ -799,7 +799,7 @@ class AssertStmt : public StmtNode {
*/
Stmt body;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("condition", &condition);
v->Visit("message", &message);
v->Visit("body", &body);
......@@ -822,7 +822,7 @@ class ProducerConsumer : public StmtNode {
/*! \brief Body to be executed. */
Stmt body;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("is_producer", &is_producer);
v->Visit("body", &body);
......@@ -863,7 +863,7 @@ class Store : public StmtNode {
/*! \brief The predicate to mask which lanes would be stored. */
Expr predicate;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
v->Visit("value", &value);
v->Visit("index", &index);
......@@ -893,7 +893,7 @@ class Provide : public StmtNode {
/*! \brief The index arguments of the function. */
Array<Expr> args;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("value_index", &value_index);
v->Visit("value", &value);
......@@ -929,7 +929,7 @@ class Allocate : public StmtNode {
Expr new_expr;
std::string free_function;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
v->Visit("dtype", &type);
v->Visit("extents", &extents);
......@@ -972,7 +972,7 @@ class Free : public StmtNode {
/*! \brief The buffer variable. */
Var buffer_var;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
}
......@@ -1001,7 +1001,7 @@ class Realize : public StmtNode {
/*! \brief The body of realization. */
Stmt body;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("value_index", &value_index);
v->Visit("dtype", &type);
......@@ -1031,7 +1031,7 @@ class Block : public StmtNode {
/*! \brief The restof statments. */
Stmt rest;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("first", &first);
v->Visit("rest", &rest);
}
......@@ -1055,7 +1055,7 @@ class IfThenElse : public StmtNode {
/*! \brief The branch to be executed when condition is false, can be null. */
Stmt else_case;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("condition", &condition);
v->Visit("then_case", &then_case);
v->Visit("else_case", &else_case);
......@@ -1078,7 +1078,7 @@ class Evaluate : public StmtNode {
/*! \brief The expression to be evaluated. */
Expr value;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("value", &value);
}
......@@ -1142,7 +1142,7 @@ class For : public StmtNode {
DeviceAPI device_api,
Stmt body);
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_var", &loop_var);
v->Visit("min", &min);
v->Visit("extent", &extent);
......@@ -1169,7 +1169,7 @@ class Prefetch : public StmtNode {
/*! \brief Bounds to be prefetched. */
Region bounds;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("value_index", &value_index);
v->Visit("type", &type);
......
......@@ -119,7 +119,7 @@ class LoweredFuncNode : public ir::FunctionBaseNode {
int num_outputs() const final {
return 1;
}
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis);
......
......@@ -40,8 +40,7 @@ class ArrayNode : public Node {
/*! \brief the data content */
std::vector<ObjectRef> data;
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to array have no effect.
void VisitAttrs(AttrVisitor* visitor) {
}
static constexpr const char* _type_key = "Array";
......@@ -51,9 +50,9 @@ class ArrayNode : public Node {
/*! \brief map node content */
class MapNode : public Node {
public:
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect.
void VisitAttrs(AttrVisitor* visitor) {
}
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<
ObjectRef,
......@@ -71,12 +70,12 @@ class MapNode : public Node {
/*! \brief specialized map node with string as key */
class StrMapNode : public Node {
public:
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect.
}
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<std::string, ObjectRef>;
void VisitAttrs(AttrVisitor* visitor) {
}
/*! \brief the data content */
ContainerType data;
......
......@@ -18,113 +18,68 @@
*/
/*!
* \file tvm/node/node.h
* \brief Node system data structure.
* \brief Definitions and helper macros for IR/AST nodes.
*
* The node folder contains base utilities for IR/AST nodes,
* invariant of which specific language dialect.
*
* We implement AST/IR nodes as sub-classes of runtime::Object.
* The base class Node is just an alias of runtime::Object.
*
* Besides the runtime type checking provided by Object,
* node folder contains additional functionalities such as
* reflection and serialization, which are important features
* for building a compiler infra.
*/
#ifndef TVM_NODE_NODE_H_
#define TVM_NODE_NODE_H_
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/node/reflection.h>
#include <string>
#include <vector>
#include <utility>
#include <type_traits>
namespace tvm {
// forward declaration
class DataType;
class Node;
class NodeRef;
/*!
* \brief Visitor class to each node content.
* The content is going to be called for each field.
*/
class TVM_DLL AttrVisitor {
public:
//! \cond Doxygen_Suppress
virtual ~AttrVisitor() = default;
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0;
virtual void Visit(const char* key, int* value) = 0;
virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, void** value) = 0;
virtual void Visit(const char* key, DataType* value) = 0;
virtual void Visit(const char* key, NodeRef* value) = 0;
virtual void Visit(const char* key, runtime::NDArray* value) = 0;
virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr));
}
//! \endcond
};
using runtime::TypeIndex;
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
using runtime::GetRef;
using runtime::Downcast;
using runtime::ObjectHash;
using runtime::ObjectEqual;
using runtime::make_object;
/*! \brief Reuse the type index in he runtime. */
using TypeIndex = runtime::TypeIndex;
using NodeHash = ObjectHash;
using NodeEqual = ObjectEqual;
using Node = Object;
/*!
* \brief base class of node container in DSL AST.
* \brief Base class of all references to AST/IR nodes.
*/
class Node : public runtime::Object {
class NodeRef : public ObjectRef {
public:
/*! \brief virtual destructor */
virtual ~Node() {}
/*!
* \brief Apply visitor to each field of the Node
* Visitor could mutate the content of the node.
* override if Node contains attribute fields.
* \param visitor The visitor
*/
virtual void VisitAttrs(AttrVisitor* visitor) {}
static constexpr const char* _type_key = "Node";
static constexpr uint32_t _type_index = TypeIndex::kDynamic;
TVM_DECLARE_BASE_OBJECT_INFO(Node, runtime::Object);
NodeRef() {}
explicit NodeRef(ObjectPtr<Object> n) : ObjectRef(n) {}
};
/*!
* \brief Base class of all node reference object
* NodeRef is just a alias of ObjectRef.
*/
class NodeRef : public runtime::ObjectRef {
public:
/*! \brief type indicate the container type */
using ContainerType = Node;
/*! \return the internal node pointer */
const Node* get() const {
return static_cast<const Node*>(ObjectRef::get());
}
/*! \return the internal node pointer */
const Node* operator->() const {
return get();
}
/*!
* \brief A more powerful version of as that also works with
* intermediate base types.
* \tparam T the target type, must be subtype of IRNode
* \brief Allocate a node object.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
* \note This function is an alias of make_object.
*/
template<typename T>
const T *as_derived() const {
return as<T>();
}
/*! \brief default constructor */
NodeRef() = default;
explicit NodeRef(runtime::ObjectPtr<runtime::Object> ptr) : ObjectRef(ptr) {}
};
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args) {
return runtime::make_object<T>(std::forward<Args>(args)...);
}
/*!
* \brief helper macro to declare type information in a base node.
......@@ -139,27 +94,67 @@ class NodeRef : public runtime::ObjectRef {
TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, Parent);
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
using runtime::GetRef;
using runtime::Downcast;
using runtime::make_object;
using runtime::ObjectHash;
using runtime::ObjectEqual;
/*!
* \brief Macro to define common node ref methods.
* \param TypeName The name of the NodeRef.
* \param BaseTypeName The Base type.
* \param NodeName The node container type.
*/
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
TypeName() {} \
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
: BaseTypeName(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(data_.get()); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = NodeName;
using NodeHash = ObjectHash;
using NodeEqual = ObjectEqual;
/*!
* \brief Macro to define CopyOnWrite function in a NodeRef.
* \param NodeName The Type of the Node.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWNodeRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
NodeName* CopyOnWrite() { \
CHECK(data_ != nullptr); \
if (!data_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<NodeName*>(data_.get()); \
}
/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public ::tvm::NodeRef { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
}; \
/*!
* \brief Allocate a node object.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*/
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args) {
return runtime::make_object<T>(std::forward<Args>(args)...);
}
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
TVM_DEFINE_NODE_REF_COW(NodeName); \
};
} // namespace tvm
#endif // TVM_NODE_NODE_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/reflection.h
* \brief Reflection and serialization of compiler IR/AST nodes.
*/
#ifndef TVM_NODE_REFLECTION_H_
#define TVM_NODE_REFLECTION_H_
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/ndarray.h>
#include <vector>
#include <string>
namespace tvm {
// forward declaration
class DataType;
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
/*!
* \brief Visitor class for to get the attributesof a AST/IR node.
* The content is going to be called for each field.
*
* Each objects that wants reflection will need to implement
* a VisitAttrs function and call visitor->Visit on each of its field.
*/
class TVM_DLL AttrVisitor {
public:
//! \cond Doxygen_Suppress
virtual ~AttrVisitor() = default;
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0;
virtual void Visit(const char* key, int* value) = 0;
virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, void** value) = 0;
virtual void Visit(const char* key, DataType* value) = 0;
virtual void Visit(const char* key, runtime::NDArray* value) = 0;
virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr));
}
//! \endcond
};
/*!
* \brief Virtual function table to support IR/AST node reflection.
*
* Functions are stored in columar manner.
* Each column is a vector indexed by Object's type_index.
*/
class ReflectionVTable {
public:
/*!
* \brief Visitor function.
* \note We use function pointer, instead of std::function
* to reduce the dispatch overhead as field visit
* does not need as much customization.
*/
typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor);
/*!
* \brief creator function.
* \param global_key Key that identifies a global single object.
* If this is not empty then FGlobalKey must be defined for the object.
* \return The created function.
*/
using FCreate = std::function<ObjectPtr<Object>(const std::string& global_key)>;
/*!
* \brief Global key function, only needed by global objects.
* \param node The node pointer.
* \return node The global key to the node.
*/
using FGlobalKey = std::function<std::string(const Object* self)>;
/*!
* \brief Dispatch the VisitAttrs function.
* \param self The pointer to the object.
* \param visitor The attribute visitor.
*/
inline void VisitAttrs(Object* self, AttrVisitor* visitor) const;
/*!
* \brief Get global key of the object, if any.
* \param self The pointer to the object.
* \return the global key if object has one, otherwise return empty string.
*/
inline std::string GetGlobalKey(Object* self) const;
/*!
* \brief Create an initial object using default constructor
* by type_key and global key.
*
* \param type_key The type key of the object.
* \param global_key A global key that can be used to uniquely identify the object if any.
*/
TVM_DLL ObjectPtr<Object> CreateInitObject(const std::string& type_key,
const std::string& global_key = "") const;
/*!
* \brief Get an field object by the attr name.
* \param self The pointer to the object.
* \param attr_name The name of the field.
* \return The corresponding attribute value.
* \note This function will throw an exception if the object does not contain the field.
*/
TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const std::string& attr_name) const;
/*!
* \brief List all the fields in the object.
* \return All the fields.
*/
TVM_DLL std::vector<std::string> ListAttrNames(Object* self) const;
/*! \return The global singleton. */
TVM_DLL static ReflectionVTable* Global();
class Registry;
template<typename T>
inline Registry Register();
private:
/*! \brief Attribute visitor. */
std::vector<FVisitAttrs> fvisit_attrs_;
/*! \brief Creation function. */
std::vector<FCreate> fcreate_;
/*! \brief Global key function. */
std::vector<FGlobalKey> fglobal_key_;
};
/*! \brief Registry of a reflection table. */
class ReflectionVTable::Registry {
public:
Registry(ReflectionVTable* parent, uint32_t type_index)
: parent_(parent), type_index_(type_index) { }
/*!
* \brief Set fcreate function.
* \param f The creator function.
* \return rference to self.
*/
Registry& set_creator(FCreate f) { // NOLINT(*)
CHECK_LT(type_index_, parent_->fcreate_.size());
parent_->fcreate_[type_index_] = f;
return *this;
}
/*!
* \brief Set global_key function.
* \param f The creator function.
* \return rference to self.
*/
Registry& set_global_key(FGlobalKey f) { // NOLINT(*)
CHECK_LT(type_index_, parent_->fglobal_key_.size());
parent_->fglobal_key_[type_index_] = f;
return *this;
}
private:
ReflectionVTable* parent_;
uint32_t type_index_;
};
/*!
* \brief Register a node type to object registry and reflection registry.
* \param TypeName The name of the type.
* \note This macro will call TVM_REGISTER_OBJECT_TYPE for the type as well.
*/
#define TVM_REGISTER_NODE_TYPE(TypeName) \
TVM_REGISTER_OBJECT_TYPE(TypeName); \
static DMLC_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry & \
__make_Node ## _ ## TypeName ## __ = \
::tvm::ReflectionVTable::Global()->Register<TypeName>() \
.set_creator([](const std::string&) { \
return ::tvm::runtime::make_object<TypeName>(); \
})
// Implementation details
template<typename T>
inline ReflectionVTable::Registry
ReflectionVTable::Register() {
uint32_t tindex = T::RuntimeTypeIndex();
if (tindex >= fvisit_attrs_.size()) {
fvisit_attrs_.resize(tindex + 1, nullptr);
fcreate_.resize(tindex + 1, nullptr);
fglobal_key_.resize(tindex + 1, nullptr);
}
// functor that implemnts the redirection.
struct Functor {
static void VisitAttrs(Object* self, AttrVisitor* v) {
static_cast<T*>(self)->VisitAttrs(v);
}
};
fvisit_attrs_[tindex] = Functor::VisitAttrs;
return Registry(this, tindex);
}
inline void ReflectionVTable::
VisitAttrs(Object* self, AttrVisitor* visitor) const {
uint32_t tindex = self->type_index();
if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: " << self->GetTypeKey()
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
fvisit_attrs_[tindex](self, visitor);
}
inline std::string ReflectionVTable::GetGlobalKey(Object* self) const {
uint32_t tindex = self->type_index();
if (tindex < fglobal_key_.size() && fglobal_key_[tindex] != nullptr) {
return fglobal_key_[tindex](self);
} else {
return std::string();
}
}
} // namespace tvm
#endif // TVM_NODE_REFLECTION_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Utility functions for serialization.
* \file tvm/node/serialization.h
*/
#ifndef TVM_NODE_SERIALIZATION_H_
#define TVM_NODE_SERIALIZATION_H_
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/object.h>
#include <string>
namespace tvm {
/*!
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
*
* \return the string representation of the node.
*/
TVM_DLL std::string SaveJSON(const runtime::ObjectRef& node);
/*!
* \brief Internal implementation of LoadJSON
* Load tvm Node object from json and return a shared_ptr of Node.
* \param json_str The json string to load from.
*
* \return The shared_ptr of the Node.
*/
TVM_DLL runtime::ObjectRef LoadJSON(std::string json_str);
} // namespace tvm
#endif // TVM_NODE_SERIALIZATION_H_
......@@ -188,7 +188,7 @@ class PlaceholderOpNode : public OperationNode {
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
......@@ -259,7 +259,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
bool debug_keep_trivial_loop) const final;
size_t num_schedulable_dims() const final;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
......@@ -312,7 +312,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
bool debug_keep_trivial_loop) const final;
size_t num_schedulable_dims() const final;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("axis", &axis);
......@@ -394,7 +394,7 @@ class ScanOpNode : public OperationNode {
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
......@@ -461,7 +461,7 @@ class ExternOpNode : public OperationNode {
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
......@@ -529,7 +529,7 @@ class HybridOpNode : public OperationNode {
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
......
......@@ -20,7 +20,7 @@
/*!
* \file tvm/packed_func_ext.h
* \brief Extension package to PackedFunc
* This enales pass NodeRef types into/from PackedFunc.
* This enales pass ObjectRef types into/from PackedFunc.
*/
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_
......@@ -129,18 +129,18 @@ inline std::string ObjectTypeName() {
// extensions for tvm arg value
template<typename TNodeRef>
inline TNodeRef TVMArgValue::AsNodeRef() const {
template<typename TObjectRef>
inline TObjectRef TVMArgValue::AsObjectRef() const {
static_assert(
std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef(NodePtr<Node>(nullptr));
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef(NodePtr<Node>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TNodeRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TNodeRef>()
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey();
return TNodeRef(ObjectPtr<Node>(ptr));
return TObjectRef(ObjectPtr<Node>(ptr));
}
inline TVMArgValue::operator tvm::Expr() const {
......@@ -184,28 +184,28 @@ inline TVMArgValue::operator tvm::Integer() const {
return Integer(ObjectPtr<Node>(ptr));
}
template<typename TNodeRef, typename>
template<typename TObjectRef, typename>
inline bool TVMPODValue_::IsObjectRef() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
return ObjectTypeChecker<TNodeRef>::Check(ptr);
return ObjectTypeChecker<TObjectRef>::Check(ptr);
}
// extensions for TVMRetValue
template<typename TNodeRef>
inline TNodeRef TVMRetValue::AsNodeRef() const {
template<typename TObjectRef>
inline TObjectRef TVMRetValue::AsObjectRef() const {
static_assert(
std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef();
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef();
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TNodeRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TNodeRef>()
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey();
return TNodeRef(ObjectPtr<Object>(ptr));
return TObjectRef(ObjectPtr<Object>(ptr));
}
// type related stuffs
......
......@@ -66,7 +66,7 @@ class PatternWildcardNode : public PatternNode {
TVM_DLL static PatternWildcard make();
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("span", &span);
}
......@@ -88,7 +88,7 @@ class PatternVarNode : public PatternNode {
TVM_DLL static PatternVar make(tvm::relay::Var var);
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("span", &span);
}
......@@ -122,7 +122,7 @@ class ConstructorNode : public ExprNode {
tvm::Array<Type> inputs,
GlobalTypeVar belong_to);
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("inputs", &inputs);
v->Visit("belong_to", &belong_to);
......@@ -151,7 +151,7 @@ class PatternConstructorNode : public PatternNode {
TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var);
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("constructor", &constructor);
v->Visit("patterns", &patterns);
v->Visit("span", &span);
......@@ -175,7 +175,7 @@ class PatternTupleNode : public PatternNode {
TVM_DLL static PatternTuple make(tvm::Array<Pattern> var);
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("patterns", &patterns);
v->Visit("span", &span);
}
......@@ -213,7 +213,7 @@ class TypeDataNode : public TypeNode {
/*! \brief The constructors. */
tvm::Array<Constructor> constructors;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("header", &header);
v->Visit("type_vars", &type_vars);
v->Visit("constructors", &constructors);
......@@ -240,7 +240,7 @@ class ClauseNode : public Node {
/*! \brief The resulting value. */
Expr rhs;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs);
}
......@@ -269,7 +269,7 @@ class MatchNode : public ExprNode {
*/
bool complete;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("clauses", &clauses);
v->Visit("complete", &complete);
......
......@@ -107,7 +107,7 @@ class SourceNameNode : public Node {
/*! \brief The source name. */
std::string name;
// override attr visitor
void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); }
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
static constexpr const char* _type_key = "relay.SourceName";
TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node);
......@@ -160,7 +160,7 @@ class SpanNode : public Node {
/*! \brief column offset */
int col_offset;
// override attr visitor
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("source", &source);
v->Visit("lineno", &lineno);
v->Visit("col_offset", &col_offset);
......@@ -204,7 +204,7 @@ class IdNode : public Node {
*/
std::string name_hint;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
}
......
......@@ -95,7 +95,7 @@ class ConstantNode : public ExprNode {
return data->ndim == 0;
}
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
......@@ -117,7 +117,7 @@ class TupleNode : public ExprNode {
/*! \brief the fields of the tuple */
tvm::Array<relay::Expr> fields;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
......@@ -165,7 +165,7 @@ class VarNode : public ExprNode {
return vid->name_hint;
}
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("vid", &vid);
v->Visit("type_annotation", &type_annotation);
v->Visit("span", &span);
......@@ -197,7 +197,7 @@ class GlobalVarNode : public ExprNode {
/*! \brief The name of the variable, this only acts as a hint. */
std::string name_hint;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
......@@ -243,7 +243,7 @@ class FunctionNode : public ExprNode {
*/
tvm::Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
......@@ -327,7 +327,7 @@ class CallNode : public ExprNode {
*/
tvm::Array<Type> type_args;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("op", &op);
v->Visit("args", &args);
v->Visit("attrs", &attrs);
......@@ -369,7 +369,7 @@ class LetNode : public ExprNode {
/*! \brief The body of the let binding */
Expr body;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
......@@ -407,7 +407,7 @@ class IfNode : public ExprNode {
/*! \brief The expression evaluated when condition is false */
Expr false_branch;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
v->Visit("false_branch", &false_branch);
......@@ -432,7 +432,7 @@ class TupleGetItemNode : public ExprNode {
/*! \brief which value to get */
int index;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("tuple_value", &tuple);
v->Visit("index", &index);
v->Visit("span", &span);
......@@ -454,7 +454,7 @@ class RefCreateNode : public ExprNode {
/*! \brief The initial value of the Reference. */
Expr value;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
......@@ -475,7 +475,7 @@ class RefReadNode : public ExprNode {
/*! \brief The Reference Expression. */
Expr ref;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ref", &ref);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
......@@ -498,7 +498,7 @@ class RefWriteNode : public ExprNode {
/*! \brief The value to write into. */
Expr value;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ref", &ref);
v->Visit("value", &value);
v->Visit("span", &span);
......
......@@ -106,7 +106,7 @@ class ClosureNode : public ValueNode {
ClosureNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("env", &env);
v->Visit("func", &func);
}
......@@ -154,7 +154,7 @@ struct TupleValueNode : ValueNode {
TupleValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }
TVM_DLL static TupleValue make(tvm::Array<Value> value);
......@@ -173,7 +173,7 @@ struct TensorValueNode : ValueNode {
TensorValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); }
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); }
/*! \brief Build a value from an NDArray. */
TVM_DLL static TensorValue make(runtime::NDArray data);
......@@ -192,7 +192,7 @@ struct RefValueNode : ValueNode {
RefValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
}
......@@ -215,7 +215,7 @@ struct ConstructorValueNode : ValueNode {
/*! \brief Optional field tracking ADT constructor. */
Constructor constructor;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("tag", &tag);
v->Visit("fields", &fields);
v->Visit("constructor", &constructor);
......
......@@ -68,7 +68,7 @@ class ModuleNode : public RelayNode {
ModuleNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("functions", &functions);
v->Visit("type_definitions", &type_definitions);
v->Visit("global_var_map_", &global_var_map_);
......
......@@ -24,6 +24,8 @@
#ifndef TVM_RELAY_OP_H_
#define TVM_RELAY_OP_H_
#include <dmlc/registry.h>
#include <functional>
#include <limits>
#include <string>
......@@ -82,7 +84,7 @@ class OpNode : public relay::ExprNode {
*/
int32_t support_level = 10;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("op_type", &op_type);
v->Visit("description", &description);
......
......@@ -101,7 +101,7 @@ class PassContextNode : public RelayNode {
PassContextNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass);
......@@ -196,7 +196,7 @@ class PassInfoNode : public RelayNode {
PassInfoNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("name", &name);
v->Visit("required", &required);
......@@ -221,6 +221,7 @@ class Pass;
*/
class PassNode : public RelayNode {
public:
virtual ~PassNode() {}
/*!
* \brief Get the pass information/meta data. */
virtual PassInfo Info() const = 0;
......@@ -247,7 +248,7 @@ class PassNode : public RelayNode {
virtual Module operator()(const Module& mod,
const PassContext& pass_ctx) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) override {}
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
......
......@@ -96,7 +96,7 @@ class TensorTypeNode : public BaseTensorTypeNode {
/*! \brief The content data type */
DataType dtype;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("span", &span);
......@@ -159,7 +159,7 @@ class TypeVarNode : public TypeNode {
/*! \brief The kind of type parameter */
Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("kind", &kind);
v->Visit("span", &span);
......@@ -188,7 +188,7 @@ class GlobalTypeVarNode : public TypeNode {
/*! \brief The kind of type parameter */
Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("kind", &kind);
v->Visit("span", &span);
......@@ -216,7 +216,7 @@ class TypeCallNode : public TypeNode {
/*! \brief The arguments. */
tvm::Array<Type> args;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("span", &span);
......@@ -245,7 +245,7 @@ class IncompleteTypeNode : public TypeNode {
public:
Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("kind", &kind);
v->Visit("span", &span);
}
......@@ -297,7 +297,7 @@ class FuncTypeNode : public TypeNode {
*/
tvm::Array<TypeConstraint> type_constraints;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("arg_types", &arg_types);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
......@@ -330,7 +330,7 @@ class TupleTypeNode : public TypeNode {
TupleTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("span", &span);
}
......@@ -357,7 +357,7 @@ class RefTypeNode : public TypeNode {
RefTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("span", &span);
}
......@@ -417,7 +417,7 @@ class TypeReporterNode : public Node {
TVM_DLL virtual Module GetModule() = 0;
// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.TypeReporter";
TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node);
......@@ -488,7 +488,7 @@ class TypeRelationNode : public TypeConstraintNode {
/*! \brief Attributes to the relation function */
Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("num_inputs", &num_inputs);
......
......@@ -230,6 +230,7 @@ inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*)
os << runtime::DeviceName(device_type) << "(" << ctx.device_id << ")";
return os;
}
#endif
} // namespace runtime
} // namespace tvm
......
......@@ -82,6 +82,8 @@ class SimpleObjAllocator :
template<typename T>
class Handler {
public:
using StorageType = typename std::aligned_storage<sizeof(T), alignof(T)>::type;
template<typename... Args>
static T* New(SimpleObjAllocator*, Args&&... args) {
// NOTE: the first argument is not needed for SimpleObjAllocator
......@@ -91,7 +93,15 @@ class SimpleObjAllocator :
// In the case of an object pool, an allocator needs to create
// a special chunk memory that hides reference to the allocator
// and call allocator's release function in the deleter.
return new T(std::forward<Args>(args)...);
// NOTE2: Use inplace new to allocate
// This is used to get rid of warning when deleting a virtual
// class with non-virtual destructor.
// We are fine here as we captured the right deleter during construction.
// This is also the right way to get storage type for an object pool.
StorageType* data = new StorageType();
new (data) T(std::forward<Args>(args)...);
return reinterpret_cast<T*>(data);
}
static Object::FDeleter Deleter() {
......@@ -99,8 +109,17 @@ class SimpleObjAllocator :
}
private:
static void Deleter_(Object* ptr) {
delete static_cast<T*>(ptr);
static void Deleter_(Object* objptr) {
// NOTE: this is important to cast back to T*
// because objptr and tptr may not be the same
// depending on how sub-class allocates the space.
T* tptr = static_cast<T*>(objptr);
// It is important to do tptr->T::~T(),
// so that we explicitly call the specific destructor
// instead of tptr->~T(), which could mean the intention
// call a virtual destructor(which may not be available and is not required).
tptr->T::~T();
delete reinterpret_cast<StorageType*>(tptr);
}
};
};
......
......@@ -23,6 +23,7 @@
#ifndef TVM_RUNTIME_OBJECT_H_
#define TVM_RUNTIME_OBJECT_H_
#include <dmlc/logging.h>
#include <type_traits>
#include <string>
#include <utility>
......@@ -189,7 +190,7 @@ class Object {
* \param key The type key.
* \return the result.
*/
TVM_DLL static uint32_t TypeKey2Index(const char* key);
TVM_DLL static uint32_t TypeKey2Index(const std::string& key);
#if TVM_OBJECT_ATOMIC_REF_COUNTER
using RefCounterType = std::atomic<int32_t>;
......@@ -197,18 +198,24 @@ class Object {
using RefCounterType = int32_t;
#endif
// Object type properties
static constexpr const char* _type_key = "Object";
static constexpr bool _type_final = false;
static constexpr uint32_t _type_child_slots = 0;
static constexpr bool _type_child_slots_can_overflow = true;
static uint32_t _GetOrAllocRuntimeTypeIndex() {
return 0;
return TypeIndex::kRoot;
}
static uint32_t RuntimeTypeIndex() {
return 0;
return TypeIndex::kRoot;
}
// Default object type properties for sub-classes
static constexpr bool _type_final = false;
static constexpr uint32_t _type_child_slots = 0;
static constexpr bool _type_child_slots_can_overflow = true;
// NOTE: the following field is not type index of Object
// but was intended to be used by sub-classes as default value.
// The type index of Object is TypeIndex::kRoot
static constexpr uint32_t _type_index = TypeIndex::kDynamic;
// Default constructor and copy constructor
Object() {}
// Override the copy and assign constructors to do nothing.
......@@ -262,13 +269,12 @@ class Object {
* \return The allocated type index.
*/
TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex(
const char* key,
const std::string& key,
uint32_t static_tindex,
uint32_t parent_tindex,
uint32_t type_child_slots,
bool type_child_slots_can_overflow);
private:
// reference counter related operations
/*! \brief developer function, increases reference counter. */
inline void IncRef();
......@@ -621,8 +627,8 @@ struct ObjectEqual {
*/
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static const uint32_t RuntimeTypeIndex() { \
if (_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
return _type_index; \
if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
return TypeName::_type_index; \
} \
return _GetOrAllocRuntimeTypeIndex(); \
} \
......
......@@ -51,8 +51,6 @@ namespace tvm {
class Integer;
class DataType;
class Expr;
class Node;
class NodeRef;
namespace runtime {
......@@ -516,9 +514,9 @@ class TVMPODValue_ {
CHECK_LT(type_code_, kExtEnd);
return static_cast<TExtension*>(value_.v_handle)[0];
}
template<typename TNodeRef,
template<typename TObjectRef,
typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type>
std::is_class<TObjectRef>::value>::type>
inline bool IsObjectRef() const;
int type_code() const {
return type_code_;
......@@ -620,8 +618,8 @@ class TVMArgValue : public TVMPODValue_ {
return value_;
}
// Deferred extension handler.
template<typename TNodeRef>
inline TNodeRef AsNodeRef() const;
template<typename TObjectRef>
inline TObjectRef AsObjectRef() const;
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
......@@ -834,13 +832,13 @@ class TVMRetValue : public TVMPODValue_ {
type_code_ != kStr) << "TVMRetValue.value can only be used for POD data";
return value_;
}
// NodeRef related extenstions: in tvm/packed_func_ext.h
// ObjectRef related extenstions: in tvm/packed_func_ext.h
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
inline operator T() const;
template<typename TNodeRef>
inline TNodeRef AsNodeRef() const;
template<typename TObjectRef>
inline TObjectRef AsObjectRef() const;
// type related
inline operator tvm::DataType() const;
inline TVMRetValue& operator=(const tvm::DataType& other);
......@@ -1306,7 +1304,7 @@ template<typename T, typename TSrc, bool is_ext, bool is_nd>
struct TVMValueCast {
static T Apply(const TSrc* self) {
static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions");
return self->template AsNodeRef<T>();
return self->template AsObjectRef<T>();
}
};
......
......@@ -168,7 +168,7 @@ class Registry {
/*!
* \brief set the body of the function to be the passed method pointer.
* Used when calling a method on a Node subclass through a NodeRef subclass.
* Used when calling a method on a Node subclass through a ObjectRef subclass.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
......@@ -191,15 +191,15 @@ class Registry {
* \endcode
*
* \param f the method pointer to forward to.
* \tparam TNodeRef the node reference type to call the method on
* \tparam TObjectRef the node reference type to call the method on
* \tparam TNode the node type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename TNodeRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
template<typename TObjectRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...)) {
return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
......@@ -208,7 +208,7 @@ class Registry {
/*!
* \brief set the body of the function to be the passed method pointer.
* Used when calling a method on a Node subclass through a NodeRef subclass.
* Used when calling a method on a Node subclass through a ObjectRef subclass.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
......@@ -231,15 +231,15 @@ class Registry {
* \endcode
*
* \param f the method pointer to forward to.
* \tparam TNodeRef the node reference type to call the method on
* \tparam TObjectRef the node reference type to call the method on
* \tparam TNode the node type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename TNodeRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
template<typename TObjectRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...) const) {
return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
const TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
......
......@@ -495,7 +495,7 @@ class StageNode : public Node {
/*! \brief Number of direct child stages, only used for group stage.*/
int num_child_stages{0};
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("op", &op);
v->Visit("origin_op", &origin_op);
v->Visit("all_iter_vars", &all_iter_vars);
......@@ -540,7 +540,7 @@ class ScheduleNode : public Node {
*/
std::unordered_map<const Node*, Stage> op2stage_cache_;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("outputs", &outputs);
v->Visit("stages", &stages);
v->Visit("groups", &groups);
......@@ -617,7 +617,7 @@ class IterVarAttrNode : public Node {
*/
Array<Expr> pragma_values;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("iter_type", &iter_type);
v->Visit("bind_thread", &bind_thread);
v->Visit("prefetch_data", &prefetch_data);
......@@ -657,7 +657,7 @@ class SplitNode : public IterVarRelationNode {
/*! \brief Number of parts, only factor or nparts can be given */
Expr nparts;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("parent", &parent);
v->Visit("outer", &outer);
v->Visit("inner", &inner);
......@@ -687,7 +687,7 @@ class FuseNode : public IterVarRelationNode {
/*! \brief The target domain */
IterVar fused;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("outer", &outer);
v->Visit("inner", &inner);
v->Visit("fused", &fused);
......@@ -712,7 +712,7 @@ class RebaseNode : public IterVarRelationNode {
/*! \brief The inner domain */
IterVar rebased;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("parent", &parent);
v->Visit("rebased", &rebased);
}
......@@ -732,7 +732,7 @@ class SingletonNode : public IterVarRelationNode {
/*! \brief The singleton iterator */
IterVar iter;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("iter", &iter);
}
......
......@@ -47,7 +47,7 @@ struct MemoryInfoNode : public Node {
*/
Expr head_address;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("unit_bits", &unit_bits);
v->Visit("max_num_bits", &max_num_bits);
v->Visit("max_simd_bits", &max_simd_bits);
......
......@@ -171,7 +171,7 @@ class TensorNode : public Node {
/*! \brief constructor */
TensorNode() {}
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("op", &op);
......
......@@ -87,7 +87,7 @@ class TensorIntrinNode : public Node {
/*! \brief constructor */
TensorIntrinNode() {}
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("op", &op);
v->Visit("inputs", &inputs);
......@@ -152,7 +152,7 @@ class TensorIntrinCallNode : public Node {
/*! \brief scalar expression inputs */
Array<Expr> scalar_inputs;
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("intrin", &intrin);
v->Visit("tensors", &tensors);
v->Visit("regions", &regions);
......
......@@ -55,7 +55,7 @@ struct GraphFuncNode : public tvm::Node {
/*! \brief The lowered functions */
tvm::Array<tvm::LoweredFunc> funcs;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("target", &target);
v->Visit("func_name", &func_name);
v->Visit("inputs", &inputs);
......@@ -78,7 +78,7 @@ struct GraphCacheEntryNode : public tvm::Node {
/*! \brief Index of the master node for calling schedule*/
int master_idx;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("graph_func", &graph_func);
v->Visit("use_count", &use_count);
v->Visit("master_idx", &master_idx);
......
......@@ -48,7 +48,7 @@ struct GraphKeyNode : public tvm::Node {
// The graph hash key is ensured always not to be 0
mutable size_t cache_hash_key_{0};
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("inputs", &inputs);
v->Visit("target", &target);
}
......
......@@ -18,11 +18,12 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file graph_runtime.cc
* \brief Interface code with TVM graph runtime.
*/
#include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h>
#include <utility>
#include "graph_runtime.h"
......
......@@ -61,13 +61,13 @@ struct NDArrayWrapperNode : public ::tvm::Node {
std::string name;
tvm::runtime::NDArray array;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("array", &array);
}
static constexpr const char* _type_key = "NDArrayWrapper";
TVM_DECLARE_NODE_TYPE_INFO(NDArrayWrapperNode, Node);
TVM_DECLARE_NODE_TYPE_INFO(NDArrayWrapperNode, tvm::Node);
};
TVM_DEFINE_NODE_REF(NDArrayWrapper, NDArrayWrapperNode);
......
......@@ -22,6 +22,8 @@ There can be internal header files within each module that sit in src.
## Modules
- common: Internal common utilities.
- runtime: Minimum runtime related codes.
- node: base infra for IR/AST nodes that is dialect independent.
- api: API function registration.
- lang: The definition of DSL related data structure.
- arithmetic: Arithmetic expression and set simplification.
......@@ -29,7 +31,6 @@ There can be internal header files within each module that sit in src.
- schedule: The operations on the schedule graph before converting to IR.
- pass: The optimization pass on the IR structure.
- codegen: The code generator.
- runtime: Minimum runtime related codes.
- autotvm: The auto-tuning module.
- relay: Implementation of Relay. The second generation of NNVM, a new IR for deep learning frameworks.
- contrib: Contrib extension libraries.
......@@ -26,6 +26,7 @@
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/api_registry.h>
#include <tvm/node/serialization.h>
namespace tvm {
TVM_REGISTER_API("_format_str")
......@@ -43,10 +44,10 @@ TVM_REGISTER_API("_raw_ptr")
});
TVM_REGISTER_API("_save_json")
.set_body_typed<std::string(NodeRef)>(SaveJSON);
.set_body_typed<std::string(ObjectRef)>(SaveJSON);
TVM_REGISTER_API("_load_json")
.set_body_typed<NodeRef(std::string)>(LoadJSON<NodeRef>);
.set_body_typed<ObjectRef(std::string)>(LoadJSON);
TVM_REGISTER_API("_TVMSetStream")
.set_body_typed(TVMSetStream);
......
......@@ -53,17 +53,17 @@ class VariablePathFinder: public IRVisitor {
if (!found_) path_.pop_back();
}
std::vector<const Node*> path_;
std::vector<const Object*> path_;
private:
bool found_{false};
Expr target_;
std::unordered_set<const Node*> visited_;
std::unordered_set<const Object*> visited_;
};
// get the path to the variable,
// return empty vector to represent failure
std::vector<const Node*> GetPath(Expr target, Expr expr) {
std::vector<const Object*> GetPath(Expr target, Expr expr) {
VariablePathFinder v(target);
v.Visit(expr);
return v.path_;
......@@ -189,7 +189,7 @@ class BoundDeducer: public IRVisitor {
const std::unordered_map<const Variable*, IntSet>& hint_map_;
const std::unordered_map<const Variable*, IntSet>& relax_map_;
ExprIntSetMap expr_map_;
std::vector<const Node*> path_;
std::vector<const Object*> path_;
size_t iter_{0};
// internal analzyer
Analyzer analyzer_;
......
......@@ -43,6 +43,7 @@ class SplitExpr;
*/
class CanonicalExprNode : public BaseExprNode {
public:
virtual ~CanonicalExprNode() {}
/*!
* \brief Return the normal Expr that is equivalent to self.
* \note Can mutate the internal data structure.
......@@ -51,7 +52,7 @@ class CanonicalExprNode : public BaseExprNode {
virtual Expr Normalize() const = 0;
// overrides
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
}
static constexpr const char* _type_key = "arith.CanonicalExpr";
......@@ -485,7 +486,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
* \return Normalized expr.
*/
Expr Normalize(Expr expr) {
if (const auto* op = expr.as_derived<CanonicalExprNode>()) {
if (const auto* op = expr.as<CanonicalExprNode>()) {
return op->Normalize();
} else {
return expr;
......@@ -503,7 +504,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
if (const auto* op = expr.as<SumExprNode>()) {
if (op->base == 0 && op->args.size() == 1) return op->args[0];
}
if (const auto* op = expr.as_derived<CanonicalExprNode>()) {
if (const auto* op = expr.as<CanonicalExprNode>()) {
expr = op->Normalize();
}
NodePtr<SplitExprNode> n = make_node<SplitExprNode>();
......
......@@ -807,6 +807,8 @@ IntSet EvalSet(Range r,
return EvalSet(r, ConvertDomMap(dom_map));
}
TVM_REGISTER_NODE_TYPE(IntervalSetNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntervalSetNode>([](const IntervalSetNode *op, IRPrinter *p) {
p->stream << "IntervalSet"
......
......@@ -47,7 +47,7 @@ class IntervalSetNode : public IntSetNode {
Expr max_value;
// visitor overload.
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("min_value", &min_value);
v->Visit("max_value", &max_value);
}
......
......@@ -18,9 +18,9 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule_spirv.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
#include <GLSL.std.450.h>
......
......@@ -62,7 +62,7 @@ TVM_REGISTER_API("_EnvFuncGetPackedFunc")
TVM_REGISTER_NODE_TYPE(EnvFuncNode)
.set_creator(CreateEnvNode)
.set_global_key([](const Node* n) {
.set_global_key([](const Object* n) {
return static_cast<const EnvFuncNode*>(n)->name;
});
......
......@@ -1150,6 +1150,8 @@ TVM_REGISTER_NODE_TYPE(Select);
TVM_REGISTER_NODE_TYPE(Load);
TVM_REGISTER_NODE_TYPE(Ramp);
TVM_REGISTER_NODE_TYPE(Broadcast);
TVM_REGISTER_NODE_TYPE(Shuffle);
TVM_REGISTER_NODE_TYPE(Prefetch);
TVM_REGISTER_NODE_TYPE(Call);
TVM_REGISTER_NODE_TYPE(Let);
TVM_REGISTER_NODE_TYPE(LetStmt);
......
......@@ -18,9 +18,9 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file target_info.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/target_info.h>
#include <tvm/packed_func_ext.h>
......
......@@ -18,22 +18,27 @@
*/
/*!
* Implementation of DSL API
* \file dsl_api.cc
* Reflection utilities.
* \file node/reflection.cc
*/
#include <dmlc/logging.h>
#include <tvm/api_registry.h>
#include <tvm/runtime/registry.h>
#include <tvm/node/node.h>
#include <tvm/node/container.h>
#include <tvm/node/reflection.h>
#include <tvm/attrs.h>
#include <tvm/expr.h>
#include <vector>
#include <string>
namespace tvm {
namespace runtime {
struct APIAttrGetter : public AttrVisitor {
std::string skey;
// Attr getter.
class AttrGetter : public AttrVisitor {
public:
const std::string& skey;
TVMRetValue* ret;
AttrGetter(const std::string &skey,
TVMRetValue* ret)
: skey(skey), ret(ret) {}
bool found_ref_object{false};
void Visit(const char* key, double* value) final {
......@@ -62,12 +67,7 @@ struct APIAttrGetter : public AttrVisitor {
void Visit(const char* key, std::string* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, NodeRef* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
}
}
void Visit(const char* key, runtime::NDArray* value) final {
if (skey == key) {
*ret = value[0];
......@@ -82,7 +82,39 @@ struct APIAttrGetter : public AttrVisitor {
}
};
struct APIAttrDir : public AttrVisitor {
runtime::TVMRetValue ReflectionVTable::GetAttr(
Object* self, const std::string& field_name) const {
runtime::TVMRetValue ret;
AttrGetter getter(field_name, &ret);
bool success;
if (getter.skey == "type_key") {
ret = self->GetTypeKey();
success = true;
} else if (!self->IsInstance<DictAttrsNode>()) {
VisitAttrs(self, &getter);
success = getter.found_ref_object || ret.type_code() != kNull;
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
auto it = dnode->dict.find(getter.skey);
if (it != dnode->dict.end()) {
success = true;
ret = (*it).second;
} else {
success = false;
}
}
if (!success) {
LOG(FATAL) << "AttributeError: " << self->GetTypeKey()
<< " object has no attributed " << getter.skey;
}
return ret;
}
// List names;
class AttrDir : public AttrVisitor {
public:
std::vector<std::string>* names;
void Visit(const char* key, double* value) final {
......@@ -109,9 +141,6 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, std::string* value) final {
names->push_back(key);
}
void Visit(const char* key, NodeRef* value) final {
names->push_back(key);
}
void Visit(const char* key, runtime::NDArray* value) final {
names->push_back(key);
}
......@@ -120,54 +149,122 @@ struct APIAttrDir : public AttrVisitor {
}
};
struct NodeAPI {
static void GetAttr(TVMArgs args, TVMRetValue* ret) {
NodeRef ref = args[0];
Node* tnode = const_cast<Node*>(ref.get());
APIAttrGetter getter;
getter.skey = args[1].operator std::string();
getter.ret = ret;
std::vector<std::string>
ReflectionVTable::ListAttrNames(Object* self) const {
std::vector<std::string> names;
AttrDir dir;
dir.names = &names;
bool success;
if (getter.skey == "type_key") {
*ret = tnode->GetTypeKey();
success = true;
} else if (!tnode->IsInstance<DictAttrsNode>()) {
tnode->VisitAttrs(&getter);
success = getter.found_ref_object || ret->type_code() != kNull;
if (!self->IsInstance<DictAttrsNode>()) {
VisitAttrs(self, &dir);
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode);
auto it = dnode->dict.find(getter.skey);
if (it != dnode->dict.end()) {
success = true;
*ret = (*it).second;
} else {
success = false;
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
for (const auto& kv : dnode->dict) {
names.push_back(kv.first);
}
}
if (!success) {
LOG(FATAL) << "AttributeError: " << tnode->GetTypeKey()
<< " object has no attributed " << getter.skey;
return names;
}
ReflectionVTable* ReflectionVTable::Global() {
static ReflectionVTable inst;
return &inst;
}
ObjectPtr<Object>
ReflectionVTable::CreateInitObject(const std::string& type_key,
const std::string& global_key) const {
uint32_t tindex = Object::TypeKey2Index(type_key);
if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: " << type_key
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
return fcreate_[tindex](global_key);
}
class NodeAttrSetter : public AttrVisitor {
public:
std::string type_key;
std::unordered_map<std::string, runtime::TVMArgValue> attrs;
void Visit(const char* key, double* value) final {
*value = GetAttr(key).operator double();
}
void Visit(const char* key, int64_t* value) final {
*value = GetAttr(key).operator int64_t();
}
void Visit(const char* key, uint64_t* value) final {
*value = GetAttr(key).operator uint64_t();
}
void Visit(const char* key, int* value) final {
*value = GetAttr(key).operator int();
}
void Visit(const char* key, bool* value) final {
*value = GetAttr(key).operator bool();
}
void Visit(const char* key, std::string* value) final {
*value = GetAttr(key).operator std::string();
}
void Visit(const char* key, void** value) final {
*value = GetAttr(key).operator void*();
}
void Visit(const char* key, DataType* value) final {
*value = GetAttr(key).operator DataType();
}
void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray();
}
void Visit(const char* key, ObjectRef* value) final {
*value = GetAttr(key).operator ObjectRef();
}
static void ListAttrNames(TVMArgs args, TVMRetValue* ret) {
NodeRef ref = args[0];
Node* tnode = const_cast<Node*>(ref.get());
auto names = std::make_shared<std::vector<std::string> >();
APIAttrDir dir;
dir.names = names.get();
private:
runtime::TVMArgValue GetAttr(const char* key) {
auto it = attrs.find(key);
if (it == attrs.end()) {
LOG(FATAL) << type_key << ": require field " << key;
}
runtime::TVMArgValue v = it->second;
attrs.erase(it);
return v;
}
};
if (!tnode->IsInstance<DictAttrsNode>()) {
tnode->VisitAttrs(&dir);
} else {
// specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode);
for (const auto& kv : dnode->dict) {
names->push_back(kv.first);
void InitNodeByPackedArgs(Object* n, const TVMArgs& args) {
NodeAttrSetter setter;
setter.type_key = n->GetTypeKey();
CHECK_EQ(args.size() % 2, 0);
for (int i = 0; i < args.size(); i += 2) {
setter.attrs.emplace(args[i].operator std::string(),
args[i + 1]);
}
auto* reflection = ReflectionVTable::Global();
reflection->VisitAttrs(n, &setter);
if (setter.attrs.size() != 0) {
std::ostringstream os;
os << setter.type_key << " does not contain field ";
for (const auto &kv : setter.attrs) {
os << " " << kv.first;
}
LOG(FATAL) << os.str();
}
}
// Expose to FFI APIs.
void NodeGetAttr(TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kObjectHandle);
Object* self = static_cast<Object*>(args[0].value().v_handle);
*ret = ReflectionVTable::Global()->GetAttr(self, args[1]);
}
void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kObjectHandle);
Object* self = static_cast<Object*>(args[0].value().v_handle);
auto names = std::make_shared<std::vector<std::string> >(
ReflectionVTable::Global()->ListAttrNames(self));
*ret = PackedFunc([names](TVMArgs args, TVMRetValue *rv) {
int64_t i = args[0];
......@@ -177,14 +274,33 @@ struct NodeAPI {
*rv = (*names)[i];
}
});
}
// API function to make node.
// args format:
// key1, value1, ..., key_n, value_n
void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
std::string type_key = args[0];
std::string empty_str;
TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1);
auto* reflection = ReflectionVTable::Global();
ObjectPtr<Object> n = reflection->CreateInitObject(type_key);
if (n->IsInstance<BaseAttrsNode>()) {
static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
} else {
InitNodeByPackedArgs(n.get(), kwargs);
}
};
*rv = ObjectRef(n);
}
TVM_REGISTER_GLOBAL("_NodeGetAttr")
.set_body(NodeAPI::GetAttr);
.set_body(NodeGetAttr);
TVM_REGISTER_GLOBAL("_NodeListAttrNames")
.set_body(NodeAPI::ListAttrNames);
.set_body(NodeListAttrNames);
TVM_REGISTER_GLOBAL("make._Node")
.set_body(MakeNode);
} // namespace runtime
} // namespace tvm
......@@ -59,7 +59,7 @@ struct CachedFuncNode : public Node {
/*! \brief Parameter usage states in the shape function. */
tvm::Array<Integer> shape_func_param_states;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("target", &target);
v->Visit("func_name", &func_name);
v->Visit("inputs", &inputs);
......@@ -84,7 +84,7 @@ class CCacheKeyNode : public Node {
/*! \brief The hardware target.*/
Target target;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("source_func", &source_func);
v->Visit("target", &target);
}
......@@ -141,7 +141,7 @@ class CCacheValueNode : public Node {
/*! \brief usage statistics */
int use_count{0};
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("cached_func", &cached_func);
v->Visit("use_count", &use_count);
}
......@@ -191,7 +191,7 @@ class CompileEngineNode : public Node {
virtual void Clear() = 0;
// VisitAttrs
void VisitAttrs(AttrVisitor*) final {}
void VisitAttrs(AttrVisitor*) {}
static constexpr const char* _type_key = "relay.CompileEngine";
TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node);
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/tvm/relay/interpreter.cc
* \brief An interpreter for the Relay IR.
*/
......@@ -116,6 +115,8 @@ RefValue RefValueNode::make(Value value) {
TVM_REGISTER_API("relay._make.RefValue")
.set_body_typed(RefValueNode::make);
TVM_REGISTER_NODE_TYPE(RefValueNode);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefValueNode>([](const RefValueNode* node,
tvm::IRPrinter* p) {
......@@ -135,6 +136,8 @@ ConstructorValue ConstructorValueNode::make(int32_t tag,
TVM_REGISTER_API("relay._make.ConstructorValue")
.set_body_typed(ConstructorValueNode::make);
TVM_REGISTER_NODE_TYPE(ConstructorValueNode);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
tvm::IRPrinter* p) {
......@@ -207,7 +210,7 @@ class InterpreterStateNode : public Node {
/*! \brief The call stack of the interpreter. */
Stack stack;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("current_expr", &current_expr);
v->Visit("stack", &stack);
}
......
......@@ -18,19 +18,21 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file param_dict.cc
* \brief Implementation and registration of parameter dictionary
* serializing/deserializing functions.
*/
#include "param_dict.h"
#include <tvm/runtime/registry.h>
#include <dmlc/memory_io.h>
#include <string>
#include <vector>
#include <utility>
#include "param_dict.h"
namespace tvm {
namespace relay {
......
......@@ -45,7 +45,7 @@ struct NamedNDArrayNode : public ::tvm::Node {
std::string name;
tvm::runtime::NDArray array;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("array", &array);
}
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/tvm/ir/adt.cc
* \brief AST nodes for Relay algebraic data types (ADTs).
*/
......
......@@ -61,7 +61,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(GetSourceNameNode)
.set_global_key([](const Node* n) {
.set_global_key([](const Object* n) {
return static_cast<const SourceNameNode*>(n)->name;
});
......@@ -88,7 +88,7 @@ TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_API("relay._base.set_span")
.set_body_typed<void(NodeRef, Span)>([](NodeRef node_ref, Span sp) {
auto rn = node_ref.as_derived<RelayNode>();
auto rn = node_ref.as<RelayNode>();
CHECK(rn);
rn->span = sp;
});
......
......@@ -195,7 +195,7 @@ NodePtr<Node> CreateOp(const std::string& name) {
TVM_REGISTER_NODE_TYPE(OpNode)
.set_creator(CreateOp)
.set_global_key([](const Node* n) {
.set_global_key([](const Object* n) {
return static_cast<const OpNode*>(n)->name;
});
......
......@@ -32,7 +32,7 @@
* - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/
#include <dmlc/json.h>
#include <tvm/node/serialization.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h>
#include <tvm/relay/pattern_functor.h>
......@@ -214,7 +214,7 @@ class PrettyPrinter :
}
Doc PrintFinal(const NodeRef& node) {
if (node.as_derived<ExprNode>()) {
if (node.as<ExprNode>()) {
Expr expr = Downcast<Expr>(node);
dg_ = DependencyGraph::Create(&arena_, expr);
}
......@@ -237,13 +237,13 @@ class PrettyPrinter :
std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) {
if (node.as_derived<ExprNode>()) {
if (node.as<ExprNode>()) {
return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as_derived<TypeNode>()) {
} else if (node.as<TypeNode>()) {
return PrintType(Downcast<Type>(node), meta);
} else if (node.as_derived<PatternNode>()) {
} else if (node.as<PatternNode>()) {
return PrintPattern(Downcast<Pattern>(node), meta);
} else if (node.as_derived<ModuleNode>()) {
} else if (node.as<ModuleNode>()) {
return PrintMod(Downcast<Module>(node));
} else {
Doc doc;
......@@ -924,14 +924,11 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
void Visit(const char* key, DataType* value) final {
PrintKV(key, PrintString(runtime::TVMType2String(Type2TVMType(*value))));
}
void Visit(const char* key, NodeRef* value) final {
PrintKV(key, parent_->PrintAttr(*value));
}
void Visit(const char* key, runtime::NDArray* value) final {
LOG(FATAL) << "do not allow NDarray as argument";
}
void Visit(const char* key, runtime::ObjectRef* obj) final {
LOG(FATAL) << "do not allow Object as argument";
PrintKV(key, parent_->PrintAttr(*obj));
}
private:
......
......@@ -132,7 +132,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) {
if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeVar>(tin));
} else {
LOG(FATAL) << new_type_param << std::endl;
LOG(FATAL) << new_type_param;
}
}
......@@ -141,10 +141,10 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) {
auto new_type_cs = VisitType(type_cs);
changed = changed || !new_type_cs.same_as(type_cs);
if (const TypeConstraintNode* tin =
new_type_cs.as_derived<TypeConstraintNode>()) {
new_type_cs.as<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else {
LOG(FATAL) << new_type_cs << std::endl;
LOG(FATAL) << new_type_cs;
}
}
......
......@@ -140,7 +140,7 @@ class LayoutAlternatedExprNode : public TempExprNode {
return tmp_memorizer.Transform(value, new_layout, old_layout);
}
void VisitAttrs(AttrVisitor *v) final {
void VisitAttrs(AttrVisitor *v) {
v->Visit("value", &value);
v->Visit("old_layout", &old_layout);
v->Visit("new_layout", &new_layout);
......
......@@ -18,8 +18,6 @@
*/
/*!
* Copyright (c) 2018 by Contributors
*
* \file deivce_annotation.cc
* \brief Passes to rewrite annotated program and retrieve the device allocation
* of expression.
......@@ -46,13 +44,15 @@ namespace relay {
namespace {
bool IsOnDeviceNode(const ExprNode* node) {
const auto* call_node = dynamic_cast<const CallNode*>(node);
return call_node != nullptr && call_node->attrs.as<OnDeviceAttrs>();
if (!node->IsInstance<CallNode>()) return false;
const auto* call_node = static_cast<const CallNode*>(node);
return call_node->attrs.as<OnDeviceAttrs>();
}
bool IsDeviceCopyNode(const ExprNode* node) {
const auto* call_node = dynamic_cast<const CallNode*>(node);
return call_node != nullptr && call_node->attrs.as<DeviceCopyAttrs>();
if (!node->IsInstance<CallNode>()) return false;
const auto* call_node = static_cast<const CallNode*>(node);
return call_node->attrs.as<DeviceCopyAttrs>();
}
} // namespace
......@@ -447,7 +447,8 @@ class DeviceInfo {
static const ExprNode* GetDeviceCopyNode(const ExprNode* node) {
if (IsDeviceCopyNode(node)) {
return node;
} else if (const auto* call_node = dynamic_cast<const CallNode*>(node)) {
} else if (node->IsInstance<CallNode>()) {
const auto* call_node = static_cast<const CallNode*>(node);
if (const auto* fn = call_node->op.as<FunctionNode>()) {
const ExprNode* body = fn->body.operator->();
if (IsDeviceCopyNode(body)) {
......@@ -472,7 +473,8 @@ class DeviceInfo {
for (auto it = post_visitor_.post_dfs_order_.crbegin();
it != post_visitor_.post_dfs_order_.crend(); ++it) {
if (const auto* node = GetDeviceCopyNode(it->first)) {
last_copy_node = dynamic_cast<const CallNode*>(node);
CHECK(node->IsInstance<CallNode>());
last_copy_node = static_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->src_dev_type;
if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type;
......
......@@ -37,14 +37,14 @@ Expr EtaExpand(const Expr& e, const Module& mod) {
Type ret_type;
if (e->IsInstance<GlobalVarNode>()) {
auto gvar_node = e.as_derived<GlobalVarNode>();
auto gvar_node = e.as<GlobalVarNode>();
auto func = mod->Lookup(GetRef<GlobalVar>(gvar_node));
original_params = func->params;
original_type_params = func->type_params;
ret_type = func->ret_type;
} else {
CHECK(e->IsInstance<FunctionNode>());
auto func = GetRef<Function>(e.as_derived<FunctionNode>());
auto func = GetRef<Function>(e.as<FunctionNode>());
original_params = func->params;
original_type_params = func->type_params;
ret_type = func->ret_type;
......
......@@ -176,7 +176,7 @@ class ScaledExprNode : public TempExprNode {
return value;
}
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("axes", &axes);
v->Visit("scale", &scale);
......@@ -664,7 +664,7 @@ class BackwardTransformerNode :
}
// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer";
TVM_DECLARE_NODE_TYPE_INFO(BackwardTransformerNode, Node);
......
......@@ -47,7 +47,7 @@ class TempRealizer : private ExprMutator {
return it->second;
} else {
Expr res;
if (const auto* temp = expr.as_derived<TempExprNode>()) {
if (const auto* temp = expr.as<TempExprNode>()) {
res = temp->Realize();
} else {
......
......@@ -102,7 +102,7 @@ class ModulePassNode : public PassNode {
ModulePassNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pass_info", &pass_info);
}
......@@ -156,7 +156,7 @@ class FunctionPassNode : public PassNode {
FunctionPassNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pass_info", &pass_info);
}
......@@ -211,7 +211,7 @@ class SequentialNode : public PassNode {
/*! \brief A list of passes that used to compose a sequential pass. */
tvm::Array<Pass> passes;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pass_info", &pass_info);
v->Visit("passes", &passes);
}
......
......@@ -41,7 +41,7 @@ class QAnnotateExprNode : public TempExprNode {
Expr expr;
QAnnotateKind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("expr", &expr);
v->Visit("kind", &kind);
}
......
......@@ -42,7 +42,7 @@ class QPartitionExprNode : public TempExprNode {
/*! \brief The original expression */
Expr expr;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("expr", &expr);
}
......
......@@ -18,8 +18,6 @@
*/
/*!
* Copyright (c) 2018 by Contributors
*
* \file quantize.cc
*
* \brief transform a graph to a low-bit graph
......
......@@ -76,7 +76,7 @@ class QConfigNode : public Node {
bool round_for_shift = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("nbit_input", &nbit_input);
v->Visit("nbit_weight", &nbit_weight);
v->Visit("nbit_activation", &nbit_activation);
......
......@@ -56,7 +56,7 @@ class QRealizeIntExprNode : public QRealizeExprNode {
Expr dom_scale;
DataType dtype;
void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("dom_scale", &dom_scale);
v->Visit("dtype", &dtype);
......
......@@ -153,7 +153,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
// default: unify only if alpha-equal
Type VisitTypeDefault_(const Node* op, const Type& tn) final {
NodeRef nr = GetRef<NodeRef>(op);
Type t1 = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>());
Type t1 = GetRef<Type>(nr.as<tvm::relay::TypeNode>());
if (!AlphaEqual(t1, tn)) {
return Type(nullptr);
}
......@@ -411,7 +411,7 @@ class TypeSolver::Propagator : public TypeFunctor<void(const Type&)> {
void VisitTypeDefault_(const Node* op) override {
NodeRef nr = GetRef<NodeRef>(op);
Type t = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>());
Type t = GetRef<Type>(nr.as<tvm::relay::TypeNode>());
UpdateRelSet(t);
}
......@@ -495,7 +495,7 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
void VisitTypeDefault_(const Node* op) override {
NodeRef nr = GetRef<NodeRef>(op);
Type t = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>());
Type t = GetRef<Type>(nr.as<tvm::relay::TypeNode>());
TransferLinks(t);
}
......
......@@ -280,7 +280,7 @@ TVM_REGISTER_API("relay._analysis.free_vars")
TVM_REGISTER_API("relay._analysis.bound_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0];
if (x.as_derived<ExprNode>()) {
if (x.as<ExprNode>()) {
*ret = BoundVars(Downcast<Expr>(x));
} else {
*ret = BoundVars(Downcast<Pattern>(x));
......@@ -294,7 +294,7 @@ TVM_REGISTER_API("relay._analysis.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0];
Module mod = args[1];
if (x.as_derived<TypeNode>()) {
if (x.as<TypeNode>()) {
*ret = FreeTypeVars(Downcast<Type>(x), mod);
} else {
*ret = FreeTypeVars(Downcast<Expr>(x), mod);
......@@ -305,7 +305,7 @@ TVM_REGISTER_API("relay._analysis.bound_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0];
Module mod = args[1];
if (x.as_derived<TypeNode>()) {
if (x.as<TypeNode>()) {
*ret = BoundTypeVars(Downcast<Type>(x), mod);
} else {
*ret = BoundTypeVars(Downcast<Expr>(x), mod);
......@@ -316,7 +316,7 @@ TVM_REGISTER_API("relay._analysis.all_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0];
Module mod = args[1];
if (x.as_derived<TypeNode>()) {
if (x.as<TypeNode>()) {
*ret = AllTypeVars(Downcast<Type>(x), mod);
} else {
*ret = AllTypeVars(Downcast<Expr>(x), mod);
......
......@@ -73,13 +73,12 @@ class TypeContext {
return child_tindex == parent_tindex;
}
uint32_t GetOrAllocRuntimeTypeIndex(const char* key,
uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey,
uint32_t static_tindex,
uint32_t parent_tindex,
uint32_t num_child_slots,
bool child_slots_can_overflow) {
std::lock_guard<std::mutex> lock(mutex_);
std::string skey = key;
auto it = type_key2index_.find(skey);
if (it != type_key2index_.end()) {
return it->second;
......@@ -106,7 +105,7 @@ class TypeContext {
<< "Conflicting static index " << static_tindex
<< " between " << type_table_[allocated_tindex].name
<< " and "
<< key;
<< skey;
} else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) {
// allocate the slot from parent's reserved pool
allocated_tindex = parent_tindex + pinfo.allocated_slots;
......@@ -152,11 +151,10 @@ class TypeContext {
return type_table_[tindex].name_hash;
}
uint32_t TypeKey2Index(const char* key) {
std::string skey = key;
uint32_t TypeKey2Index(const std::string& skey) {
auto it = type_key2index_.find(skey);
CHECK(it != type_key2index_.end())
<< "Cannot find type " << key;
<< "Cannot find type " << skey;
return it->second;
}
......@@ -176,7 +174,7 @@ class TypeContext {
std::unordered_map<std::string, uint32_t> type_key2index_;
};
uint32_t Object::GetOrAllocRuntimeTypeIndex(const char* key,
uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key,
uint32_t static_tindex,
uint32_t parent_tindex,
uint32_t num_child_slots,
......@@ -198,7 +196,7 @@ size_t Object::TypeIndex2KeyHash(uint32_t tindex) {
return TypeContext::Global()->TypeIndex2KeyHash(tindex);
}
uint32_t Object::TypeKey2Index(const char* key) {
uint32_t Object::TypeKey2Index(const std::string& key) {
return TypeContext::Global()->TypeKey2Index(key);
}
......@@ -210,7 +208,7 @@ class TVMObjectCAPI {
}
}
static uint32_t TypeKey2Index(const char* type_key) {
static uint32_t TypeKey2Index(const std::string& type_key) {
return Object::TypeKey2Index(type_key);
}
};
......
......@@ -21,6 +21,7 @@
#include <gtest/gtest.h>
#include <topi/cuda/injective.h>
#include <tvm/operation.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/build_module.h>
......
......@@ -20,6 +20,7 @@
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment