Unverified Commit 3cc49719 by Tianqi Chen Committed by GitHub

[RUNTIME][OBJECT] Introduce static slots for common objects. (#5423)

The _type_child_slots can be used to enable quick type checking optimization
by checking the whether the type index is within the bound.

This PR enables these static slots:

- Introduce a static assert to avoid the scenario when a developer forget to
  _type_child_slots when the field is set for the type's parent.
- Revamp and assign static type index to common runtime objects
- Add a DumpTypeTable call to allow developer monitor the current situation
  of type table and offers suggestions for the slots(ideally the slots equals
  the number of children so there is no overflow.
parent 96873076
......@@ -42,9 +42,10 @@ namespace tvm {
*/
class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
static constexpr const char* _type_key = "BaseExpr";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 58;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};
......@@ -88,6 +89,7 @@ class PrimExprNode : public BaseExprNode {
DataType dtype;
static constexpr const char* _type_key = "PrimExpr";
static constexpr const uint32_t _type_child_slots = 34;
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
};
......@@ -161,7 +163,8 @@ class RelayExprNode : public BaseExprNode {
template<typename TTypeNode>
inline const TTypeNode* type_as() const;
static constexpr const char* _type_key = "relay.Expr";
static constexpr const char* _type_key = "RelayExpr";
static constexpr const uint32_t _type_child_slots = 22;
TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode);
};
......
......@@ -140,6 +140,7 @@ class BaseFuncNode : public RelayExprNode {
}
static constexpr const char* _type_key = "BaseFunc";
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};
......
......@@ -36,6 +36,7 @@ namespace tvm {
class BaseTensorTypeNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.BaseTensorType";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
};
......
......@@ -81,6 +81,7 @@ class TypeNode : public Object {
static constexpr const char* _type_key = "Type";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 14;
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};
......@@ -391,6 +392,7 @@ inline bool IsVoidType(const Type& type) {
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "TypeConstraint";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
};
......
......@@ -630,6 +630,7 @@ class TempExprNode : public ExprNode {
static constexpr const char* _type_key = "relay.TempExpr";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
static constexpr const uint32_t _type_child_slots = 0;
TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
};
......
......@@ -200,8 +200,8 @@ class ADTObj : public Object, public InplaceArrayBase<ADTObj, ObjectRef> {
uint32_t size;
// The fields of the structure follows directly in memory.
static constexpr const uint32_t _type_index = TypeIndex::kVMADT;
static constexpr const char* _type_key = "vm.ADT";
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeADT;
static constexpr const char* _type_key = "runtime.ADT";
TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object);
private:
......@@ -314,7 +314,7 @@ class StringObj : public Object {
/*! \brief The length of the string object. */
uint64_t size;
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString;
static constexpr const char* _type_key = "runtime.String";
TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object);
......
......@@ -288,10 +288,10 @@ class NDArray::Container :
using Object::IncRef;
// Information for object protocol.
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeNDArray;
static constexpr const uint32_t _type_child_slots = 0;
static constexpr const uint32_t _type_child_slots_can_overflow = true;
static constexpr const char* _type_key = "NDArray";
static constexpr const char* _type_key = "runtime.NDArray";
TVM_DECLARE_BASE_OBJECT_INFO(NDArray::Container, Object);
protected:
......
......@@ -46,17 +46,31 @@
namespace tvm {
namespace runtime {
/*! \brief list of the type index. */
enum TypeIndex {
/*!
* \brief Namespace for the list of type index.
* \note Use struct so that we have to use TypeIndex::ENumName to refer to
* the constant, but still able to use enum.
*/
struct TypeIndex {
enum {
/*! \brief Root object type. */
kRoot = 0,
kClosure = 1,
kVMADT = 2,
kRuntimeModule = 3,
// Standard static index assignments,
// Frontends can take benefit of these constants.
/*! \brief runtime::Module. */
kRuntimeModule = 1,
/*! \brief runtime::NDArray. */
kRuntimeNDArray = 2,
/*! \brief runtime::String. */
kRuntimeString = 3,
// static assignments that may subject to change.
kRuntimeClosure,
kRuntimeADT,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
};
};
}; // namespace TypeIndex
/*!
* \brief base class of all object containers.
......@@ -198,7 +212,7 @@ class Object {
using RefCounterType = int32_t;
#endif
static constexpr const char* _type_key = "Object";
static constexpr const char* _type_key = "runtime.Object";
static uint32_t _GetOrAllocRuntimeTypeIndex() {
return TypeIndex::kRoot;
......@@ -675,6 +689,10 @@ struct ObjectEqual {
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static_assert(!ParentType::_type_final, "ParentObj maked as final"); \
static uint32_t RuntimeTypeIndex() { \
static_assert(TypeName::_type_child_slots == 0 || \
ParentType::_type_child_slots == 0 || \
TypeName::_type_child_slots < ParentType::_type_child_slots, \
"Need to set _type_child_slots when parent specifies it."); \
if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
return TypeName::_type_index; \
} \
......@@ -690,6 +708,7 @@ struct ObjectEqual {
return tidx; \
} \
/*!
* \brief helper macro to declare type information in a final class.
* \param TypeName The name of the current type.
......
......@@ -1268,6 +1268,8 @@ struct unpack_call_dispatcher<void, 0, index, F> {
template<typename R, int nargs, typename F>
inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) {
CHECK_EQ(nargs, args.size())
<< "Expect " << nargs << " arguments but get " << args.size();
unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
}
......
......@@ -44,8 +44,8 @@ namespace vm {
*/
class ClosureObj : public Object {
public:
static constexpr const uint32_t _type_index = TypeIndex::kClosure;
static constexpr const char* _type_key = "Closure";
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeClosure;
static constexpr const char* _type_key = "runtime.Closure";
TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object);
};
......
......@@ -40,6 +40,7 @@ class StmtNode : public Object {
static constexpr const char* _type_key = "Stmt";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 15;
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
};
......
......@@ -78,6 +78,7 @@ class VarNode : public PrimExprNode {
}
static constexpr const char* _type_key = "tir.Var";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
};
......
......@@ -121,7 +121,7 @@ class QConfig(Object):
return self
def __exit__(self, ptype, value, trace):
_quantize._ExitQConfigScope(self)
_quantize._ExitQConfigScope()
def __setattr__(self, name, value):
if name in QConfig._node_defaults:
......
......@@ -61,7 +61,7 @@ def getitem_helper(obj, elem_getter, length, idx):
return elem_getter(obj, idx)
@tvm._ffi.register_object("vm.ADT")
@tvm._ffi.register_object("runtime.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
......
......@@ -36,7 +36,7 @@ except (RuntimeError, ImportError):
from tvm._ffi._ctypes.ndarray import NDArrayBase
@tvm._ffi.register_object
@tvm._ffi.register_object("runtime.NDArray")
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
......
......@@ -57,6 +57,7 @@ class CanonicalExprNode : public PrimExprNode {
}
static constexpr const char* _type_key = "arith.CanonicalExpr";
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode);
};
......
......@@ -188,5 +188,7 @@ TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
.set_body_typed([](Module mod, std::string name, std::string fmt) {
mod->SaveToFile(name, fmt);
});
TVM_REGISTER_OBJECT_TYPE(ModuleNode);
} // namespace runtime
} // namespace tvm
......@@ -86,7 +86,8 @@ class TypeContext {
return it->second;
}
// try to allocate from parent's type table.
CHECK_LT(parent_tindex, type_table_.size());
CHECK_LT(parent_tindex, type_table_.size())
<< " skey= " << skey << "static_index=" << static_tindex;
TypeInfo& pinfo = type_table_[parent_tindex];
CHECK_EQ(pinfo.index, parent_tindex);
......@@ -108,7 +109,7 @@ class TypeContext {
<< " between " << type_table_[allocated_tindex].name
<< " and "
<< skey;
} else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) {
} else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) {
// allocate the slot from parent's reserved pool
allocated_tindex = parent_tindex + pinfo.allocated_slots;
// update parent's state
......@@ -119,8 +120,8 @@ class TypeContext {
// allocate new entries.
allocated_tindex = type_counter_;
type_counter_ += num_slots;
CHECK_LE(type_table_.size(), allocated_tindex);
type_table_.resize(allocated_tindex + 1, TypeInfo());
CHECK_LE(type_table_.size(), type_counter_);
type_table_.resize(type_counter_, TypeInfo());
}
CHECK_GT(allocated_tindex, parent_tindex);
// initialize the slot.
......@@ -161,6 +162,25 @@ class TypeContext {
return it->second;
}
void Dump(int min_children_count) {
std::vector<int> num_children(type_table_.size(), 0);
// reverse accumulation so we can get total counts in a bottom-up manner.
for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
if (it->index != 0) {
num_children[it->parent_index] += num_children[it->index] + 1;
}
}
for (const auto& info : type_table_) {
if (info.index != 0 && num_children[info.index] >= min_children_count) {
std::cerr <<'[' << info.index << "] "<< info.name
<< "\tparent=" << type_table_[info.parent_index].name
<< "\tnum_child_slots=" << info.num_slots - 1
<< "\tnum_children=" << num_children[info.index] << std::endl;
}
}
}
static TypeContext* Global() {
static TypeContext inst;
return &inst;
......@@ -169,6 +189,7 @@ class TypeContext {
private:
TypeContext() {
type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo());
type_table_[0].name = "runtime.Object";
}
// mutex to avoid registration from multiple threads.
std::mutex mutex_;
......@@ -208,6 +229,11 @@ TVM_REGISTER_GLOBAL("runtime.ObjectHash")
.set_body_typed([](ObjectRef obj) {
return static_cast<int64_t>(ObjectHash()(obj));
});
TVM_REGISTER_GLOBAL("runtime.DumpTypeTable")
.set_body_typed([](int min_child_count) {
TypeContext::Global()->Dump(min_child_count);
});
} // namespace runtime
} // namespace tvm
......
......@@ -46,6 +46,7 @@ runtime::Module Build(IRModule mod, const Target& target) {
if (BuildConfig::Current()->disable_assert) {
mod = tir::transform::SkipAssert()(mod);
}
std::string build_f_name = "target.build." + target->target_name;
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
......
......@@ -127,7 +127,7 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) {
return ptx;
}
runtime::Module BuildCUDA(IRModule mod) {
runtime::Module BuildCUDA(IRModule mod, std::string target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
......
......@@ -238,7 +238,7 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NO
}
}
runtime::Module BuildOpenCL(IRModule mod) {
runtime::Module BuildOpenCL(IRModule mod, std::string target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenOpenCL cg;
......
......@@ -289,7 +289,7 @@ void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) {
this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n";
}
runtime::Module BuildOpenGL(IRModule mod) {
runtime::Module BuildOpenGL(IRModule mod, std::string target) {
bool output_ssa = false;
CodeGenOpenGL cg;
cg.Init(output_ssa);
......
......@@ -70,7 +70,7 @@ class SPIRVTools {
spv_context ctx_;
};
runtime::Module BuildSPIRV(IRModule mod) {
runtime::Module BuildSPIRV(IRModule mod, std::string target) {
using tvm::runtime::Registry;
using tvm::runtime::VulkanShader;
......
......@@ -517,7 +517,7 @@ void CodeGenStackVM::VisitExpr_(const LetNode* op) {
this->Push(op->body);
}
runtime::Module BuildStackVM(const IRModule& mod) {
runtime::Module BuildStackVM(const IRModule& mod, const std::string& target) {
std::unordered_map<std::string, StackVM> fmap;
std::string entry_func;
......
......@@ -155,9 +155,9 @@ TEST(BuildModule, Heterogeneous) {
auto c_val =
runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto pa = (float*)a_val.ToDLPack()->dl_tensor.data;
auto pb = (float*)b_val.ToDLPack()->dl_tensor.data;
auto pc = (float*)c_val.ToDLPack()->dl_tensor.data;
auto pa = (float*)(a_val->data);
auto pb = (float*)(b_val->data);
auto pc = (float*)(c_val->data);
// Assign values.
for (int i = 0; i < n; i++) {
......@@ -186,7 +186,7 @@ TEST(BuildModule, Heterogeneous) {
run();
tvm::runtime::NDArray out = get_output(0);
float* p_out = (float*)out.ToDLPack()->dl_tensor.data;
float* p_out = (float*)out->data;
// Check correctness.
for (int i = 0; i < n; ++i) {
......
......@@ -47,6 +47,7 @@ class ObjA : public ObjBase {
class ObjB : public ObjBase {
public:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_child_slots = 0;
static constexpr const char* _type_key = "test.ObjB";
TVM_DECLARE_BASE_OBJECT_INFO(ObjB, ObjBase);
};
......
......@@ -87,9 +87,9 @@ TEST(Relay, BuildModule) {
auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto pA = (float*)A.ToDLPack()->dl_tensor.data;
auto pB = (float*)B.ToDLPack()->dl_tensor.data;
auto pC = (float*)C.ToDLPack()->dl_tensor.data;
auto pA = (float*)A->data;
auto pB = (float*)B->data;
auto pC = (float*)C->data;
for (int i = 0; i < 6; ++i) {
pA[i] = i;
......@@ -132,7 +132,7 @@ TEST(Relay, BuildModule) {
set_input_f("c", &C.ToDLPack()->dl_tensor);
run_f();
tvm::runtime::NDArray Y = get_output_f(0);
auto pY = (float*)Y.ToDLPack()->dl_tensor.data;
auto pY = (float*)Y->data;
for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4);
}
......@@ -142,20 +142,20 @@ TEST(Relay, BuildModule) {
}
run_f();
tvm::runtime::NDArray Y2 = get_output_f(0);
auto pY2 = (float*)Y2.ToDLPack()->dl_tensor.data;
auto pY2 = (float*)Y2->data;
for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY2[i] - (i + (i + 3) + (i + 2))), 1e-4);
}
// attach a different input and run it again
auto C2 = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto pC2 = (float*)C2.ToDLPack()->dl_tensor.data;
auto pC2 = (float*)C2->data;
for (int i = 0; i < 6; ++i) {
pC2[i] = i + 4;
}
set_input_f("c", &C2.ToDLPack()->dl_tensor);
run_f();
tvm::runtime::NDArray Y3 = get_output_f(0);
auto pY3 = (float*)Y3.ToDLPack()->dl_tensor.data;
auto pY3 = (float*)Y3->data;
for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY3[i] - (i + (i + 3) + (i + 4))), 1e-4);
}
......
......@@ -63,9 +63,9 @@ TEST(MicroStandaloneRuntime, BuildModule) {
auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto pA = (float*)A.ToDLPack()->dl_tensor.data;
auto pB = (float*)B.ToDLPack()->dl_tensor.data;
auto pC = (float*)C.ToDLPack()->dl_tensor.data;
auto pA = (float*)A->data;
auto pB = (float*)B->data;
auto pC = (float*)C->data;
for (int i = 0; i < 6; ++i) {
pA[i] = i;
......@@ -118,7 +118,7 @@ TEST(MicroStandaloneRuntime, BuildModule) {
UTVMRuntimeRun(handle);
auto Y = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
UTVMRuntimeGetOutput(handle, 0, &Y.ToDLPack()->dl_tensor);
auto* pY = (float*)Y.ToDLPack()->dl_tensor.data;
auto* pY = (float*)Y->data;
for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4);
}
......
......@@ -41,7 +41,7 @@ def test_scan():
def test_fix_pt():
body = tvm.te.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op, body)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.spatial_axis_[0]].value != 0)
def test_scan_fix_point():
......@@ -57,7 +57,7 @@ def test_scan_fix_point():
lambda t, i, j: x[t, j, i] + s_state[t-1, i, j], name="update")
s_scan = tvm.te.scan(s_init, s_update, s_state)
body = tvm.te.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op, body)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 1)
......@@ -66,7 +66,7 @@ def test_scan_fix_point():
lambda t, i, j: x[t, j, i] + s_state[t-1, j, i], name="update")
s_scan = tvm.te.scan(s_init, s_update, s_state)
body = tvm.te.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op, body)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment