Unverified Commit 3cc49719 by Tianqi Chen Committed by GitHub

[RUNTIME][OBJECT] Introduce static slots for common objects. (#5423)

The _type_child_slots can be used to enable quick type checking optimization
by checking the whether the type index is within the bound.

This PR enables these static slots:

- Introduce a static assert to avoid the scenario when a developer forget to
  _type_child_slots when the field is set for the type's parent.
- Revamp and assign static type index to common runtime objects
- Add a DumpTypeTable call to allow developer monitor the current situation
  of type table and offers suggestions for the slots(ideally the slots equals
  the number of children so there is no overflow.
parent 96873076
...@@ -42,9 +42,10 @@ namespace tvm { ...@@ -42,9 +42,10 @@ namespace tvm {
*/ */
class BaseExprNode : public Object { class BaseExprNode : public Object {
public: public:
static constexpr const char* _type_key = "Expr"; static constexpr const char* _type_key = "BaseExpr";
static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 58;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
}; };
...@@ -88,6 +89,7 @@ class PrimExprNode : public BaseExprNode { ...@@ -88,6 +89,7 @@ class PrimExprNode : public BaseExprNode {
DataType dtype; DataType dtype;
static constexpr const char* _type_key = "PrimExpr"; static constexpr const char* _type_key = "PrimExpr";
static constexpr const uint32_t _type_child_slots = 34;
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode); TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
}; };
...@@ -161,7 +163,8 @@ class RelayExprNode : public BaseExprNode { ...@@ -161,7 +163,8 @@ class RelayExprNode : public BaseExprNode {
template<typename TTypeNode> template<typename TTypeNode>
inline const TTypeNode* type_as() const; inline const TTypeNode* type_as() const;
static constexpr const char* _type_key = "relay.Expr"; static constexpr const char* _type_key = "RelayExpr";
static constexpr const uint32_t _type_child_slots = 22;
TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode); TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode);
}; };
......
...@@ -140,6 +140,7 @@ class BaseFuncNode : public RelayExprNode { ...@@ -140,6 +140,7 @@ class BaseFuncNode : public RelayExprNode {
} }
static constexpr const char* _type_key = "BaseFunc"; static constexpr const char* _type_key = "BaseFunc";
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode); TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
}; };
......
...@@ -36,6 +36,7 @@ namespace tvm { ...@@ -36,6 +36,7 @@ namespace tvm {
class BaseTensorTypeNode : public TypeNode { class BaseTensorTypeNode : public TypeNode {
public: public:
static constexpr const char* _type_key = "relay.BaseTensorType"; static constexpr const char* _type_key = "relay.BaseTensorType";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode); TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
}; };
......
...@@ -81,6 +81,7 @@ class TypeNode : public Object { ...@@ -81,6 +81,7 @@ class TypeNode : public Object {
static constexpr const char* _type_key = "Type"; static constexpr const char* _type_key = "Type";
static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 14;
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
}; };
...@@ -391,6 +392,7 @@ inline bool IsVoidType(const Type& type) { ...@@ -391,6 +392,7 @@ inline bool IsVoidType(const Type& type) {
class TypeConstraintNode : public TypeNode { class TypeConstraintNode : public TypeNode {
public: public:
static constexpr const char* _type_key = "TypeConstraint"; static constexpr const char* _type_key = "TypeConstraint";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode); TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
}; };
......
...@@ -630,6 +630,7 @@ class TempExprNode : public ExprNode { ...@@ -630,6 +630,7 @@ class TempExprNode : public ExprNode {
static constexpr const char* _type_key = "relay.TempExpr"; static constexpr const char* _type_key = "relay.TempExpr";
static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false;
static constexpr const uint32_t _type_child_slots = 0;
TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode); TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
}; };
......
...@@ -200,8 +200,8 @@ class ADTObj : public Object, public InplaceArrayBase<ADTObj, ObjectRef> { ...@@ -200,8 +200,8 @@ class ADTObj : public Object, public InplaceArrayBase<ADTObj, ObjectRef> {
uint32_t size; uint32_t size;
// The fields of the structure follows directly in memory. // The fields of the structure follows directly in memory.
static constexpr const uint32_t _type_index = TypeIndex::kVMADT; static constexpr const uint32_t _type_index = TypeIndex::kRuntimeADT;
static constexpr const char* _type_key = "vm.ADT"; static constexpr const char* _type_key = "runtime.ADT";
TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object);
private: private:
...@@ -314,7 +314,7 @@ class StringObj : public Object { ...@@ -314,7 +314,7 @@ class StringObj : public Object {
/*! \brief The length of the string object. */ /*! \brief The length of the string object. */
uint64_t size; uint64_t size;
static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString;
static constexpr const char* _type_key = "runtime.String"; static constexpr const char* _type_key = "runtime.String";
TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object); TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object);
......
...@@ -288,10 +288,10 @@ class NDArray::Container : ...@@ -288,10 +288,10 @@ class NDArray::Container :
using Object::IncRef; using Object::IncRef;
// Information for object protocol. // Information for object protocol.
static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const uint32_t _type_index = TypeIndex::kRuntimeNDArray;
static constexpr const uint32_t _type_child_slots = 0; static constexpr const uint32_t _type_child_slots = 0;
static constexpr const uint32_t _type_child_slots_can_overflow = true; static constexpr const uint32_t _type_child_slots_can_overflow = true;
static constexpr const char* _type_key = "NDArray"; static constexpr const char* _type_key = "runtime.NDArray";
TVM_DECLARE_BASE_OBJECT_INFO(NDArray::Container, Object); TVM_DECLARE_BASE_OBJECT_INFO(NDArray::Container, Object);
protected: protected:
......
...@@ -46,17 +46,31 @@ ...@@ -46,17 +46,31 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
/*! \brief list of the type index. */ /*!
enum TypeIndex { * \brief Namespace for the list of type index.
/*! \brief Root object type. */ * \note Use struct so that we have to use TypeIndex::ENumName to refer to
kRoot = 0, * the constant, but still able to use enum.
kClosure = 1, */
kVMADT = 2, struct TypeIndex {
kRuntimeModule = 3, enum {
kStaticIndexEnd, /*! \brief Root object type. */
/*! \brief Type index is allocated during runtime. */ kRoot = 0,
kDynamic = kStaticIndexEnd // Standard static index assignments,
}; // Frontends can take benefit of these constants.
/*! \brief runtime::Module. */
kRuntimeModule = 1,
/*! \brief runtime::NDArray. */
kRuntimeNDArray = 2,
/*! \brief runtime::String. */
kRuntimeString = 3,
// static assignments that may subject to change.
kRuntimeClosure,
kRuntimeADT,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
};
}; // namespace TypeIndex
/*! /*!
* \brief base class of all object containers. * \brief base class of all object containers.
...@@ -198,7 +212,7 @@ class Object { ...@@ -198,7 +212,7 @@ class Object {
using RefCounterType = int32_t; using RefCounterType = int32_t;
#endif #endif
static constexpr const char* _type_key = "Object"; static constexpr const char* _type_key = "runtime.Object";
static uint32_t _GetOrAllocRuntimeTypeIndex() { static uint32_t _GetOrAllocRuntimeTypeIndex() {
return TypeIndex::kRoot; return TypeIndex::kRoot;
...@@ -675,6 +689,10 @@ struct ObjectEqual { ...@@ -675,6 +689,10 @@ struct ObjectEqual {
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ #define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static_assert(!ParentType::_type_final, "ParentObj maked as final"); \ static_assert(!ParentType::_type_final, "ParentObj maked as final"); \
static uint32_t RuntimeTypeIndex() { \ static uint32_t RuntimeTypeIndex() { \
static_assert(TypeName::_type_child_slots == 0 || \
ParentType::_type_child_slots == 0 || \
TypeName::_type_child_slots < ParentType::_type_child_slots, \
"Need to set _type_child_slots when parent specifies it."); \
if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
return TypeName::_type_index; \ return TypeName::_type_index; \
} \ } \
...@@ -690,6 +708,7 @@ struct ObjectEqual { ...@@ -690,6 +708,7 @@ struct ObjectEqual {
return tidx; \ return tidx; \
} \ } \
/*! /*!
* \brief helper macro to declare type information in a final class. * \brief helper macro to declare type information in a final class.
* \param TypeName The name of the current type. * \param TypeName The name of the current type.
......
...@@ -1268,6 +1268,8 @@ struct unpack_call_dispatcher<void, 0, index, F> { ...@@ -1268,6 +1268,8 @@ struct unpack_call_dispatcher<void, 0, index, F> {
template<typename R, int nargs, typename F> template<typename R, int nargs, typename F>
inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) { inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) {
CHECK_EQ(nargs, args.size())
<< "Expect " << nargs << " arguments but get " << args.size();
unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv); unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
} }
......
...@@ -44,8 +44,8 @@ namespace vm { ...@@ -44,8 +44,8 @@ namespace vm {
*/ */
class ClosureObj : public Object { class ClosureObj : public Object {
public: public:
static constexpr const uint32_t _type_index = TypeIndex::kClosure; static constexpr const uint32_t _type_index = TypeIndex::kRuntimeClosure;
static constexpr const char* _type_key = "Closure"; static constexpr const char* _type_key = "runtime.Closure";
TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object); TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object);
}; };
......
...@@ -40,6 +40,7 @@ class StmtNode : public Object { ...@@ -40,6 +40,7 @@ class StmtNode : public Object {
static constexpr const char* _type_key = "Stmt"; static constexpr const char* _type_key = "Stmt";
static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 15;
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
}; };
......
...@@ -78,6 +78,7 @@ class VarNode : public PrimExprNode { ...@@ -78,6 +78,7 @@ class VarNode : public PrimExprNode {
} }
static constexpr const char* _type_key = "tir.Var"; static constexpr const char* _type_key = "tir.Var";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode); TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
}; };
......
...@@ -121,7 +121,7 @@ class QConfig(Object): ...@@ -121,7 +121,7 @@ class QConfig(Object):
return self return self
def __exit__(self, ptype, value, trace): def __exit__(self, ptype, value, trace):
_quantize._ExitQConfigScope(self) _quantize._ExitQConfigScope()
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name in QConfig._node_defaults: if name in QConfig._node_defaults:
......
...@@ -61,7 +61,7 @@ def getitem_helper(obj, elem_getter, length, idx): ...@@ -61,7 +61,7 @@ def getitem_helper(obj, elem_getter, length, idx):
return elem_getter(obj, idx) return elem_getter(obj, idx)
@tvm._ffi.register_object("vm.ADT") @tvm._ffi.register_object("runtime.ADT")
class ADT(Object): class ADT(Object):
"""Algebatic data type(ADT) object. """Algebatic data type(ADT) object.
......
...@@ -36,7 +36,7 @@ except (RuntimeError, ImportError): ...@@ -36,7 +36,7 @@ except (RuntimeError, ImportError):
from tvm._ffi._ctypes.ndarray import NDArrayBase from tvm._ffi._ctypes.ndarray import NDArrayBase
@tvm._ffi.register_object @tvm._ffi.register_object("runtime.NDArray")
class NDArray(NDArrayBase): class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime. """Lightweight NDArray class of TVM runtime.
......
...@@ -57,6 +57,7 @@ class CanonicalExprNode : public PrimExprNode { ...@@ -57,6 +57,7 @@ class CanonicalExprNode : public PrimExprNode {
} }
static constexpr const char* _type_key = "arith.CanonicalExpr"; static constexpr const char* _type_key = "arith.CanonicalExpr";
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode); TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode);
}; };
......
...@@ -188,5 +188,7 @@ TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile") ...@@ -188,5 +188,7 @@ TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
.set_body_typed([](Module mod, std::string name, std::string fmt) { .set_body_typed([](Module mod, std::string name, std::string fmt) {
mod->SaveToFile(name, fmt); mod->SaveToFile(name, fmt);
}); });
TVM_REGISTER_OBJECT_TYPE(ModuleNode);
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -86,7 +86,8 @@ class TypeContext { ...@@ -86,7 +86,8 @@ class TypeContext {
return it->second; return it->second;
} }
// try to allocate from parent's type table. // try to allocate from parent's type table.
CHECK_LT(parent_tindex, type_table_.size()); CHECK_LT(parent_tindex, type_table_.size())
<< " skey= " << skey << "static_index=" << static_tindex;
TypeInfo& pinfo = type_table_[parent_tindex]; TypeInfo& pinfo = type_table_[parent_tindex];
CHECK_EQ(pinfo.index, parent_tindex); CHECK_EQ(pinfo.index, parent_tindex);
...@@ -108,7 +109,7 @@ class TypeContext { ...@@ -108,7 +109,7 @@ class TypeContext {
<< " between " << type_table_[allocated_tindex].name << " between " << type_table_[allocated_tindex].name
<< " and " << " and "
<< skey; << skey;
} else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) { } else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) {
// allocate the slot from parent's reserved pool // allocate the slot from parent's reserved pool
allocated_tindex = parent_tindex + pinfo.allocated_slots; allocated_tindex = parent_tindex + pinfo.allocated_slots;
// update parent's state // update parent's state
...@@ -119,8 +120,8 @@ class TypeContext { ...@@ -119,8 +120,8 @@ class TypeContext {
// allocate new entries. // allocate new entries.
allocated_tindex = type_counter_; allocated_tindex = type_counter_;
type_counter_ += num_slots; type_counter_ += num_slots;
CHECK_LE(type_table_.size(), allocated_tindex); CHECK_LE(type_table_.size(), type_counter_);
type_table_.resize(allocated_tindex + 1, TypeInfo()); type_table_.resize(type_counter_, TypeInfo());
} }
CHECK_GT(allocated_tindex, parent_tindex); CHECK_GT(allocated_tindex, parent_tindex);
// initialize the slot. // initialize the slot.
...@@ -161,6 +162,25 @@ class TypeContext { ...@@ -161,6 +162,25 @@ class TypeContext {
return it->second; return it->second;
} }
void Dump(int min_children_count) {
std::vector<int> num_children(type_table_.size(), 0);
// reverse accumulation so we can get total counts in a bottom-up manner.
for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
if (it->index != 0) {
num_children[it->parent_index] += num_children[it->index] + 1;
}
}
for (const auto& info : type_table_) {
if (info.index != 0 && num_children[info.index] >= min_children_count) {
std::cerr <<'[' << info.index << "] "<< info.name
<< "\tparent=" << type_table_[info.parent_index].name
<< "\tnum_child_slots=" << info.num_slots - 1
<< "\tnum_children=" << num_children[info.index] << std::endl;
}
}
}
static TypeContext* Global() { static TypeContext* Global() {
static TypeContext inst; static TypeContext inst;
return &inst; return &inst;
...@@ -169,6 +189,7 @@ class TypeContext { ...@@ -169,6 +189,7 @@ class TypeContext {
private: private:
TypeContext() { TypeContext() {
type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo()); type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo());
type_table_[0].name = "runtime.Object";
} }
// mutex to avoid registration from multiple threads. // mutex to avoid registration from multiple threads.
std::mutex mutex_; std::mutex mutex_;
...@@ -208,6 +229,11 @@ TVM_REGISTER_GLOBAL("runtime.ObjectHash") ...@@ -208,6 +229,11 @@ TVM_REGISTER_GLOBAL("runtime.ObjectHash")
.set_body_typed([](ObjectRef obj) { .set_body_typed([](ObjectRef obj) {
return static_cast<int64_t>(ObjectHash()(obj)); return static_cast<int64_t>(ObjectHash()(obj));
}); });
TVM_REGISTER_GLOBAL("runtime.DumpTypeTable")
.set_body_typed([](int min_child_count) {
TypeContext::Global()->Dump(min_child_count);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -46,6 +46,7 @@ runtime::Module Build(IRModule mod, const Target& target) { ...@@ -46,6 +46,7 @@ runtime::Module Build(IRModule mod, const Target& target) {
if (BuildConfig::Current()->disable_assert) { if (BuildConfig::Current()->disable_assert) {
mod = tir::transform::SkipAssert()(mod); mod = tir::transform::SkipAssert()(mod);
} }
std::string build_f_name = "target.build." + target->target_name; std::string build_f_name = "target.build." + target->target_name;
// the build function. // the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name); const PackedFunc* bf = runtime::Registry::Get(build_f_name);
......
...@@ -127,7 +127,7 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) { ...@@ -127,7 +127,7 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) {
return ptx; return ptx;
} }
runtime::Module BuildCUDA(IRModule mod) { runtime::Module BuildCUDA(IRModule mod, std::string target) {
using tvm::runtime::Registry; using tvm::runtime::Registry;
bool output_ssa = false; bool output_ssa = false;
CodeGenCUDA cg; CodeGenCUDA cg;
......
...@@ -238,7 +238,7 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NO ...@@ -238,7 +238,7 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NO
} }
} }
runtime::Module BuildOpenCL(IRModule mod) { runtime::Module BuildOpenCL(IRModule mod, std::string target) {
using tvm::runtime::Registry; using tvm::runtime::Registry;
bool output_ssa = false; bool output_ssa = false;
CodeGenOpenCL cg; CodeGenOpenCL cg;
......
...@@ -289,7 +289,7 @@ void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) { ...@@ -289,7 +289,7 @@ void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) {
this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n"; this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n";
} }
runtime::Module BuildOpenGL(IRModule mod) { runtime::Module BuildOpenGL(IRModule mod, std::string target) {
bool output_ssa = false; bool output_ssa = false;
CodeGenOpenGL cg; CodeGenOpenGL cg;
cg.Init(output_ssa); cg.Init(output_ssa);
......
...@@ -70,7 +70,7 @@ class SPIRVTools { ...@@ -70,7 +70,7 @@ class SPIRVTools {
spv_context ctx_; spv_context ctx_;
}; };
runtime::Module BuildSPIRV(IRModule mod) { runtime::Module BuildSPIRV(IRModule mod, std::string target) {
using tvm::runtime::Registry; using tvm::runtime::Registry;
using tvm::runtime::VulkanShader; using tvm::runtime::VulkanShader;
......
...@@ -517,7 +517,7 @@ void CodeGenStackVM::VisitExpr_(const LetNode* op) { ...@@ -517,7 +517,7 @@ void CodeGenStackVM::VisitExpr_(const LetNode* op) {
this->Push(op->body); this->Push(op->body);
} }
runtime::Module BuildStackVM(const IRModule& mod) { runtime::Module BuildStackVM(const IRModule& mod, const std::string& target) {
std::unordered_map<std::string, StackVM> fmap; std::unordered_map<std::string, StackVM> fmap;
std::string entry_func; std::string entry_func;
......
...@@ -155,9 +155,9 @@ TEST(BuildModule, Heterogeneous) { ...@@ -155,9 +155,9 @@ TEST(BuildModule, Heterogeneous) {
auto c_val = auto c_val =
runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto pa = (float*)a_val.ToDLPack()->dl_tensor.data; auto pa = (float*)(a_val->data);
auto pb = (float*)b_val.ToDLPack()->dl_tensor.data; auto pb = (float*)(b_val->data);
auto pc = (float*)c_val.ToDLPack()->dl_tensor.data; auto pc = (float*)(c_val->data);
// Assign values. // Assign values.
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
...@@ -186,7 +186,7 @@ TEST(BuildModule, Heterogeneous) { ...@@ -186,7 +186,7 @@ TEST(BuildModule, Heterogeneous) {
run(); run();
tvm::runtime::NDArray out = get_output(0); tvm::runtime::NDArray out = get_output(0);
float* p_out = (float*)out.ToDLPack()->dl_tensor.data; float* p_out = (float*)out->data;
// Check correctness. // Check correctness.
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
......
...@@ -47,6 +47,7 @@ class ObjA : public ObjBase { ...@@ -47,6 +47,7 @@ class ObjA : public ObjBase {
class ObjB : public ObjBase { class ObjB : public ObjBase {
public: public:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_child_slots = 0;
static constexpr const char* _type_key = "test.ObjB"; static constexpr const char* _type_key = "test.ObjB";
TVM_DECLARE_BASE_OBJECT_INFO(ObjB, ObjBase); TVM_DECLARE_BASE_OBJECT_INFO(ObjB, ObjBase);
}; };
......
...@@ -87,9 +87,9 @@ TEST(Relay, BuildModule) { ...@@ -87,9 +87,9 @@ TEST(Relay, BuildModule) {
auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto pA = (float*)A.ToDLPack()->dl_tensor.data; auto pA = (float*)A->data;
auto pB = (float*)B.ToDLPack()->dl_tensor.data; auto pB = (float*)B->data;
auto pC = (float*)C.ToDLPack()->dl_tensor.data; auto pC = (float*)C->data;
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
pA[i] = i; pA[i] = i;
...@@ -132,7 +132,7 @@ TEST(Relay, BuildModule) { ...@@ -132,7 +132,7 @@ TEST(Relay, BuildModule) {
set_input_f("c", &C.ToDLPack()->dl_tensor); set_input_f("c", &C.ToDLPack()->dl_tensor);
run_f(); run_f();
tvm::runtime::NDArray Y = get_output_f(0); tvm::runtime::NDArray Y = get_output_f(0);
auto pY = (float*)Y.ToDLPack()->dl_tensor.data; auto pY = (float*)Y->data;
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4); CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4);
} }
...@@ -142,20 +142,20 @@ TEST(Relay, BuildModule) { ...@@ -142,20 +142,20 @@ TEST(Relay, BuildModule) {
} }
run_f(); run_f();
tvm::runtime::NDArray Y2 = get_output_f(0); tvm::runtime::NDArray Y2 = get_output_f(0);
auto pY2 = (float*)Y2.ToDLPack()->dl_tensor.data; auto pY2 = (float*)Y2->data;
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY2[i] - (i + (i + 3) + (i + 2))), 1e-4); CHECK_LT(fabs(pY2[i] - (i + (i + 3) + (i + 2))), 1e-4);
} }
// attach a different input and run it again // attach a different input and run it again
auto C2 = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto C2 = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto pC2 = (float*)C2.ToDLPack()->dl_tensor.data; auto pC2 = (float*)C2->data;
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
pC2[i] = i + 4; pC2[i] = i + 4;
} }
set_input_f("c", &C2.ToDLPack()->dl_tensor); set_input_f("c", &C2.ToDLPack()->dl_tensor);
run_f(); run_f();
tvm::runtime::NDArray Y3 = get_output_f(0); tvm::runtime::NDArray Y3 = get_output_f(0);
auto pY3 = (float*)Y3.ToDLPack()->dl_tensor.data; auto pY3 = (float*)Y3->data;
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY3[i] - (i + (i + 3) + (i + 4))), 1e-4); CHECK_LT(fabs(pY3[i] - (i + (i + 3) + (i + 4))), 1e-4);
} }
......
...@@ -63,9 +63,9 @@ TEST(MicroStandaloneRuntime, BuildModule) { ...@@ -63,9 +63,9 @@ TEST(MicroStandaloneRuntime, BuildModule) {
auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto pA = (float*)A.ToDLPack()->dl_tensor.data; auto pA = (float*)A->data;
auto pB = (float*)B.ToDLPack()->dl_tensor.data; auto pB = (float*)B->data;
auto pC = (float*)C.ToDLPack()->dl_tensor.data; auto pC = (float*)C->data;
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
pA[i] = i; pA[i] = i;
...@@ -118,7 +118,7 @@ TEST(MicroStandaloneRuntime, BuildModule) { ...@@ -118,7 +118,7 @@ TEST(MicroStandaloneRuntime, BuildModule) {
UTVMRuntimeRun(handle); UTVMRuntimeRun(handle);
auto Y = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto Y = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
UTVMRuntimeGetOutput(handle, 0, &Y.ToDLPack()->dl_tensor); UTVMRuntimeGetOutput(handle, 0, &Y.ToDLPack()->dl_tensor);
auto* pY = (float*)Y.ToDLPack()->dl_tensor.data; auto* pY = (float*)Y->data;
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4); CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4);
} }
......
...@@ -41,7 +41,7 @@ def test_scan(): ...@@ -41,7 +41,7 @@ def test_scan():
def test_fix_pt(): def test_fix_pt():
body = tvm.te.schedule.ScanGetBody(s_scan.op) body = tvm.te.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op, body) fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.spatial_axis_[0]].value != 0) assert(fxpt[s_scan.spatial_axis_[0]].value != 0)
def test_scan_fix_point(): def test_scan_fix_point():
...@@ -57,7 +57,7 @@ def test_scan_fix_point(): ...@@ -57,7 +57,7 @@ def test_scan_fix_point():
lambda t, i, j: x[t, j, i] + s_state[t-1, i, j], name="update") lambda t, i, j: x[t, j, i] + s_state[t-1, i, j], name="update")
s_scan = tvm.te.scan(s_init, s_update, s_state) s_scan = tvm.te.scan(s_init, s_update, s_state)
body = tvm.te.schedule.ScanGetBody(s_scan.op) body = tvm.te.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op, body) fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 1) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 1)
...@@ -66,7 +66,7 @@ def test_scan_fix_point(): ...@@ -66,7 +66,7 @@ def test_scan_fix_point():
lambda t, i, j: x[t, j, i] + s_state[t-1, j, i], name="update") lambda t, i, j: x[t, j, i] + s_state[t-1, j, i], name="update")
s_scan = tvm.te.scan(s_init, s_update, s_state) s_scan = tvm.te.scan(s_init, s_update, s_state)
body = tvm.te.schedule.ScanGetBody(s_scan.op) body = tvm.te.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op, body) fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment