Commit 32aad56c by Wei Chen Committed by Haichen Shen

[Refactor] Rename Datatype to ADT (#4156)

We think it will reduce the confusion with the meaning.

https://discuss.tvm.ai/t/discuss-consider-rename-vm-datatype/4339
parent 3c4b7cce
...@@ -121,7 +121,7 @@ AllocTensor ...@@ -121,7 +121,7 @@ AllocTensor
Allocate a tensor value of the appropriate shape (stored in `shape_register`) and `dtype`. The result Allocate a tensor value of the appropriate shape (stored in `shape_register`) and `dtype`. The result
is saved to register `dst`. is saved to register `dst`.
AllocDatatype AllocADT
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
**Arguments**: **Arguments**:
:: ::
...@@ -176,7 +176,7 @@ GetTagi ...@@ -176,7 +176,7 @@ GetTagi
RegName object RegName object
RegName dst RegName dst
Get the object tag for Datatype object in register `object`. And saves the reult to register `dst`. Get the object tag for ADT object in register `object`. And saves the reult to register `dst`.
Fatal Fatal
^^^^^ ^^^^^
...@@ -251,9 +251,9 @@ Currently, we support 3 types of objects: tensors, data types, and closures. ...@@ -251,9 +251,9 @@ Currently, we support 3 types of objects: tensors, data types, and closures.
:: ::
VMObject VMTensor(const tvm::runtime::NDArray& data); Object Tensor(const tvm::runtime::NDArray& data);
VMObject VMDatatype(size_t tag, const std::vector<VMObject>& fields); Object ADT(size_t tag, const std::vector<Object>& fields);
VMObject VMClosure(size_t func_index, std::vector<VMObject> free_vars); Object Closure(size_t func_index, std::vector<Object> free_vars);
Stack and State Stack and State
......
...@@ -51,7 +51,7 @@ enum TypeIndex { ...@@ -51,7 +51,7 @@ enum TypeIndex {
kRoot = 0, kRoot = 0,
kVMTensor = 1, kVMTensor = 1,
kVMClosure = 2, kVMClosure = 2,
kVMDatatype = 3, kVMADT = 3,
kStaticIndexEnd, kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */ /*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd kDynamic = kStaticIndexEnd
......
...@@ -57,31 +57,31 @@ class Tensor : public ObjectRef { ...@@ -57,31 +57,31 @@ class Tensor : public ObjectRef {
/*! \brief An object representing a structure or enumeration. */ /*! \brief An object representing a structure or enumeration. */
class DatatypeObj : public Object { class ADTObj : public Object {
public: public:
/*! \brief The tag representing the constructor used. */ /*! \brief The tag representing the constructor used. */
size_t tag; size_t tag;
/*! \brief The fields of the structure. */ /*! \brief The fields of the structure. */
std::vector<ObjectRef> fields; std::vector<ObjectRef> fields;
static constexpr const uint32_t _type_index = TypeIndex::kVMDatatype; static constexpr const uint32_t _type_index = TypeIndex::kVMADT;
static constexpr const char* _type_key = "vm.Datatype"; static constexpr const char* _type_key = "vm.ADT";
TVM_DECLARE_FINAL_OBJECT_INFO(DatatypeObj, Object); TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object);
}; };
/*! \brief reference to data type. */ /*! \brief reference to algebraic data type objects. */
class Datatype : public ObjectRef { class ADT : public ObjectRef {
public: public:
Datatype(size_t tag, std::vector<ObjectRef> fields); ADT(size_t tag, std::vector<ObjectRef> fields);
/*! /*!
* \brief construct a tuple object. * \brief construct a tuple object.
* \param fields The fields of the tuple. * \param fields The fields of the tuple.
* \return The constructed tuple type. * \return The constructed tuple type.
*/ */
static Datatype Tuple(std::vector<ObjectRef> fields); static ADT Tuple(std::vector<ObjectRef> fields);
TVM_DEFINE_OBJECT_REF_METHODS(Datatype, ObjectRef, DatatypeObj); TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj);
}; };
/*! \brief An object representing a closure. */ /*! \brief An object representing a closure. */
...@@ -129,7 +129,7 @@ enum class Opcode { ...@@ -129,7 +129,7 @@ enum class Opcode {
InvokePacked = 4U, InvokePacked = 4U,
AllocTensor = 5U, AllocTensor = 5U,
AllocTensorReg = 6U, AllocTensorReg = 6U,
AllocDatatype = 7U, AllocADT = 7U,
AllocClosure = 8U, AllocClosure = 8U,
GetField = 9U, GetField = 9U,
If = 10U, If = 10U,
...@@ -237,7 +237,7 @@ struct Instruction { ...@@ -237,7 +237,7 @@ struct Instruction {
/*! \brief The register to project from. */ /*! \brief The register to project from. */
RegName object; RegName object;
} get_tag; } get_tag;
struct /* AllocDatatype Operands */ { struct /* AllocADT Operands */ {
/*! \brief The datatype's constructor tag. */ /*! \brief The datatype's constructor tag. */
Index constructor_tag; Index constructor_tag;
/*! \brief The number of fields to store in the datatype. */ /*! \brief The number of fields to store in the datatype. */
...@@ -294,7 +294,7 @@ struct Instruction { ...@@ -294,7 +294,7 @@ struct Instruction {
* \param dst The register name of the destination. * \param dst The register name of the destination.
* \return The allocate instruction tensor. * \return The allocate instruction tensor.
*/ */
static Instruction AllocDatatype(Index tag, Index num_fields, const std::vector<RegName>& fields, static Instruction AllocADT(Index tag, Index num_fields, const std::vector<RegName>& fields,
RegName dst); RegName dst);
/*! \brief Construct an allocate closure instruction. /*! \brief Construct an allocate closure instruction.
* \param func_index The index of the function table. * \param func_index The index of the function table.
......
...@@ -31,7 +31,7 @@ from . import vmobj as _obj ...@@ -31,7 +31,7 @@ from . import vmobj as _obj
from .interpreter import Executor from .interpreter import Executor
Tensor = _obj.Tensor Tensor = _obj.Tensor
Datatype = _obj.Datatype ADT = _obj.ADT
def _convert(arg, cargs): def _convert(arg, cargs):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)): if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
......
...@@ -61,14 +61,14 @@ class Tensor(Object): ...@@ -61,14 +61,14 @@ class Tensor(Object):
return self.data.asnumpy() return self.data.asnumpy()
@register_object("vm.Datatype") @register_object("vm.ADT")
class Datatype(Object): class ADT(Object):
"""Datatype object. """Algebatic data type(ADT) object.
Parameters Parameters
---------- ----------
tag : int tag : int
The tag of datatype. The tag of ADT.
fields : list[Object] or tuple[Object] fields : list[Object] or tuple[Object]
The source tuple. The source tuple.
...@@ -77,22 +77,22 @@ class Datatype(Object): ...@@ -77,22 +77,22 @@ class Datatype(Object):
for f in fields: for f in fields:
assert isinstance(f, Object) assert isinstance(f, Object)
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_vmobj.Datatype, tag, *fields) _vmobj.ADT, tag, *fields)
@property @property
def tag(self): def tag(self):
return _vmobj.GetDatatypeTag(self) return _vmobj.GetADTTag(self)
def __getitem__(self, idx): def __getitem__(self, idx):
return getitem_helper( return getitem_helper(
self, _vmobj.GetDatatypeFields, len(self), idx) self, _vmobj.GetADTFields, len(self), idx)
def __len__(self): def __len__(self):
return _vmobj.GetDatatypeNumberOfFields(self) return _vmobj.GetADTNumberOfFields(self)
def tuple_object(fields): def tuple_object(fields):
"""Create a datatype object from source tuple. """Create a ADT object from source tuple.
Parameters Parameters
---------- ----------
...@@ -101,7 +101,7 @@ def tuple_object(fields): ...@@ -101,7 +101,7 @@ def tuple_object(fields):
Returns Returns
------- -------
ret : Datatype ret : ADT
The created object. The created object.
""" """
for f in fields: for f in fields:
......
...@@ -239,7 +239,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -239,7 +239,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
DLOG(INFO) << "VMCompiler::Emit: instr=" << instr; DLOG(INFO) << "VMCompiler::Emit: instr=" << instr;
CHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op; CHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op;
switch (instr.op) { switch (instr.op) {
case Opcode::AllocDatatype: case Opcode::AllocADT:
case Opcode::AllocTensor: case Opcode::AllocTensor:
case Opcode::AllocTensorReg: case Opcode::AllocTensorReg:
case Opcode::GetField: case Opcode::GetField:
...@@ -287,7 +287,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -287,7 +287,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
} }
// TODO(@jroesch): use correct tag // TODO(@jroesch): use correct tag
Emit(Instruction::AllocDatatype( Emit(Instruction::AllocADT(
0, 0,
tuple->fields.size(), tuple->fields.size(),
fields_registers, fields_registers,
...@@ -626,7 +626,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -626,7 +626,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
for (size_t i = arity - return_count; i < arity; ++i) { for (size_t i = arity - return_count; i < arity; ++i) {
fields_registers.push_back(unpacked_arg_regs[i]); fields_registers.push_back(unpacked_arg_regs[i]);
} }
Emit(Instruction::AllocDatatype(0, return_count, fields_registers, NewRegister())); Emit(Instruction::AllocADT(0, return_count, fields_registers, NewRegister()));
} }
} }
...@@ -659,7 +659,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -659,7 +659,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
} }
} else if (auto constructor_node = op.as<ConstructorNode>()) { } else if (auto constructor_node = op.as<ConstructorNode>()) {
auto constructor = GetRef<Constructor>(constructor_node); auto constructor = GetRef<Constructor>(constructor_node);
Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers, Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers,
NewRegister())); NewRegister()));
} else if (auto var_node = op.as<VarNode>()) { } else if (auto var_node = op.as<VarNode>()) {
VisitExpr(GetRef<Var>(var_node)); VisitExpr(GetRef<Var>(var_node));
......
...@@ -315,7 +315,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { ...@@ -315,7 +315,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
fields.push_back(instr.dst); fields.push_back(instr.dst);
break; break;
} }
case Opcode::AllocDatatype: { case Opcode::AllocADT: {
// Number of fields = 3 + instr.num_fields // Number of fields = 3 + instr.num_fields
fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); fields.assign({instr.constructor_tag, instr.num_fields, instr.dst});
...@@ -551,7 +551,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { ...@@ -551,7 +551,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
return Instruction::AllocTensorReg(shape_register, dtype, dst); return Instruction::AllocTensorReg(shape_register, dtype, dst);
} }
case Opcode::AllocDatatype: { case Opcode::AllocADT: {
// Number of fields = 3 + instr.num_fields // Number of fields = 3 + instr.num_fields
DCHECK_GE(instr.fields.size(), 3U); DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1])); DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
...@@ -561,7 +561,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { ...@@ -561,7 +561,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
RegName dst = instr.fields[2]; RegName dst = instr.fields[2];
std::vector<Index> fields = ExtractFields(instr.fields, 3, num_fields); std::vector<Index> fields = ExtractFields(instr.fields, 3, num_fields);
return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst); return Instruction::AllocADT(constructor_tag, num_fields, fields, dst);
} }
case Opcode::AllocClosure: { case Opcode::AllocClosure: {
// Number of fields = 3 + instr.num_freevar // Number of fields = 3 + instr.num_freevar
......
...@@ -39,15 +39,15 @@ Tensor::Tensor(NDArray data) { ...@@ -39,15 +39,15 @@ Tensor::Tensor(NDArray data) {
data_ = std::move(ptr); data_ = std::move(ptr);
} }
Datatype::Datatype(size_t tag, std::vector<ObjectRef> fields) { ADT::ADT(size_t tag, std::vector<ObjectRef> fields) {
auto ptr = make_object<DatatypeObj>(); auto ptr = make_object<ADTObj>();
ptr->tag = tag; ptr->tag = tag;
ptr->fields = std::move(fields); ptr->fields = std::move(fields);
data_ = std::move(ptr); data_ = std::move(ptr);
} }
Datatype Datatype::Tuple(std::vector<ObjectRef> fields) { ADT ADT::Tuple(std::vector<ObjectRef> fields) {
return Datatype(0, fields); return ADT(0, fields);
} }
Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) { Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
...@@ -66,28 +66,28 @@ TVM_REGISTER_GLOBAL("_vmobj.GetTensorData") ...@@ -66,28 +66,28 @@ TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
*rv = cell->data; *rv = cell->data;
}); });
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeTag") TVM_REGISTER_GLOBAL("_vmobj.GetADTTag")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0]; ObjectRef obj = args[0];
const auto* cell = obj.as<DatatypeObj>(); const auto* cell = obj.as<ADTObj>();
CHECK(cell != nullptr); CHECK(cell != nullptr);
*rv = static_cast<int64_t>(cell->tag); *rv = static_cast<int64_t>(cell->tag);
}); });
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeNumberOfFields") TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0]; ObjectRef obj = args[0];
const auto* cell = obj.as<DatatypeObj>(); const auto* cell = obj.as<ADTObj>();
CHECK(cell != nullptr); CHECK(cell != nullptr);
*rv = static_cast<int64_t>(cell->fields.size()); *rv = static_cast<int64_t>(cell->fields.size());
}); });
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeFields") TVM_REGISTER_GLOBAL("_vmobj.GetADTFields")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0]; ObjectRef obj = args[0];
int idx = args[1]; int idx = args[1];
const auto* cell = obj.as<DatatypeObj>(); const auto* cell = obj.as<ADTObj>();
CHECK(cell != nullptr); CHECK(cell != nullptr);
CHECK_LT(idx, cell->fields.size()); CHECK_LT(idx, cell->fields.size());
*rv = cell->fields[idx]; *rv = cell->fields[idx];
...@@ -104,10 +104,10 @@ TVM_REGISTER_GLOBAL("_vmobj.Tuple") ...@@ -104,10 +104,10 @@ TVM_REGISTER_GLOBAL("_vmobj.Tuple")
for (auto i = 0; i < args.size(); ++i) { for (auto i = 0; i < args.size(); ++i) {
fields.push_back(args[i]); fields.push_back(args[i]);
} }
*rv = Datatype::Tuple(fields); *rv = ADT::Tuple(fields);
}); });
TVM_REGISTER_GLOBAL("_vmobj.Datatype") TVM_REGISTER_GLOBAL("_vmobj.ADT")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0]; int itag = args[0];
size_t tag = static_cast<size_t>(itag); size_t tag = static_cast<size_t>(itag);
...@@ -115,11 +115,11 @@ TVM_REGISTER_GLOBAL("_vmobj.Datatype") ...@@ -115,11 +115,11 @@ TVM_REGISTER_GLOBAL("_vmobj.Datatype")
for (int i = 1; i < args.size(); i++) { for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]); fields.push_back(args[i]);
} }
*rv = Datatype(tag, fields); *rv = ADT(tag, fields);
}); });
TVM_REGISTER_OBJECT_TYPE(TensorObj); TVM_REGISTER_OBJECT_TYPE(TensorObj);
TVM_REGISTER_OBJECT_TYPE(DatatypeObj); TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj);
} // namespace vm } // namespace vm
} // namespace runtime } // namespace runtime
......
...@@ -74,7 +74,7 @@ Instruction::Instruction(const Instruction& instr) { ...@@ -74,7 +74,7 @@ Instruction::Instruction(const Instruction& instr) {
this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register;
this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype;
return; return;
case Opcode::AllocDatatype: case Opcode::AllocADT:
this->constructor_tag = instr.constructor_tag; this->constructor_tag = instr.constructor_tag;
this->num_fields = instr.num_fields; this->num_fields = instr.num_fields;
this->datatype_fields = Duplicate<RegName>(instr.datatype_fields, instr.num_fields); this->datatype_fields = Duplicate<RegName>(instr.datatype_fields, instr.num_fields);
...@@ -159,7 +159,7 @@ Instruction& Instruction::operator=(const Instruction& instr) { ...@@ -159,7 +159,7 @@ Instruction& Instruction::operator=(const Instruction& instr) {
this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register;
this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype;
return *this; return *this;
case Opcode::AllocDatatype: case Opcode::AllocADT:
this->constructor_tag = instr.constructor_tag; this->constructor_tag = instr.constructor_tag;
this->num_fields = instr.num_fields; this->num_fields = instr.num_fields;
FreeIf(this->datatype_fields); FreeIf(this->datatype_fields);
...@@ -229,7 +229,7 @@ Instruction::~Instruction() { ...@@ -229,7 +229,7 @@ Instruction::~Instruction() {
case Opcode::AllocTensor: case Opcode::AllocTensor:
delete this->alloc_tensor.shape; delete this->alloc_tensor.shape;
return; return;
case Opcode::AllocDatatype: case Opcode::AllocADT:
delete this->datatype_fields; delete this->datatype_fields;
return; return;
case Opcode::AllocClosure: case Opcode::AllocClosure:
...@@ -301,10 +301,10 @@ Instruction Instruction::AllocTensorReg(RegName shape_register, DLDataType dtype ...@@ -301,10 +301,10 @@ Instruction Instruction::AllocTensorReg(RegName shape_register, DLDataType dtype
return instr; return instr;
} }
Instruction Instruction::AllocDatatype(Index tag, Index num_fields, Instruction Instruction::AllocADT(Index tag, Index num_fields,
const std::vector<RegName>& datatype_fields, Index dst) { const std::vector<RegName>& datatype_fields, Index dst) {
Instruction instr; Instruction instr;
instr.op = Opcode::AllocDatatype; instr.op = Opcode::AllocADT;
instr.dst = dst; instr.dst = dst;
instr.constructor_tag = tag; instr.constructor_tag = tag;
instr.num_fields = num_fields; instr.num_fields = num_fields;
...@@ -485,7 +485,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { ...@@ -485,7 +485,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); DLDatatypePrint(os, instr.alloc_tensor_reg.dtype);
break; break;
} }
case Opcode::AllocDatatype: { case Opcode::AllocADT: {
os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$" os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$"
<< StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]"; << StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]";
break; break;
...@@ -691,7 +691,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, ...@@ -691,7 +691,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
const std::vector<ObjectRef>& args) { const std::vector<ObjectRef>& args) {
size_t arity = 0; size_t arity = 0;
for (Index i = 0; i < arg_count; i++) { for (Index i = 0; i < arg_count; i++) {
if (const auto* obj = args[i].as<DatatypeObj>()) { if (const auto* obj = args[i].as<ADTObj>()) {
arity += obj->fields.size(); arity += obj->fields.size();
} else { } else {
++arity; ++arity;
...@@ -703,7 +703,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, ...@@ -703,7 +703,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
runtime::TVMArgsSetter setter(values.data(), codes.data()); runtime::TVMArgsSetter setter(values.data(), codes.data());
int idx = 0; int idx = 0;
for (Index i = 0; i < arg_count; i++) { for (Index i = 0; i < arg_count; i++) {
if (const auto* dt_cell = args[i].as<DatatypeObj>()) { if (const auto* dt_cell = args[i].as<ADTObj>()) {
for (auto obj : dt_cell->fields) { for (auto obj : dt_cell->fields) {
const auto* tensor = obj.as<TensorObj>(); const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr); CHECK(tensor != nullptr);
...@@ -849,7 +849,7 @@ void VirtualMachine::RunLoop() { ...@@ -849,7 +849,7 @@ void VirtualMachine::RunLoop() {
} }
case Opcode::GetField: { case Opcode::GetField: {
auto object = ReadRegister(instr.object); auto object = ReadRegister(instr.object);
const auto* tuple = object.as<DatatypeObj>(); const auto* tuple = object.as<ADTObj>();
CHECK(tuple != nullptr) CHECK(tuple != nullptr)
<< "Object is not data type object, register " << instr.object << ", Object tag " << "Object is not data type object, register " << instr.object << ", Object tag "
<< object->type_index(); << object->type_index();
...@@ -860,7 +860,7 @@ void VirtualMachine::RunLoop() { ...@@ -860,7 +860,7 @@ void VirtualMachine::RunLoop() {
} }
case Opcode::GetTag: { case Opcode::GetTag: {
auto object = ReadRegister(instr.get_tag.object); auto object = ReadRegister(instr.get_tag.object);
const auto* data = object.as<DatatypeObj>(); const auto* data = object.as<ADTObj>();
CHECK(data != nullptr) CHECK(data != nullptr)
<< "Object is not data type object, register " << "Object is not data type object, register "
<< instr.get_tag.object << ", Object tag " << instr.get_tag.object << ", Object tag "
...@@ -925,12 +925,12 @@ void VirtualMachine::RunLoop() { ...@@ -925,12 +925,12 @@ void VirtualMachine::RunLoop() {
pc++; pc++;
goto main_loop; goto main_loop;
} }
case Opcode::AllocDatatype: { case Opcode::AllocADT: {
std::vector<ObjectRef> fields; std::vector<ObjectRef> fields;
for (Index i = 0; i < instr.num_fields; ++i) { for (Index i = 0; i < instr.num_fields; ++i) {
fields.push_back(ReadRegister(instr.datatype_fields[i])); fields.push_back(ReadRegister(instr.datatype_fields[i]));
} }
ObjectRef obj = Datatype(instr.constructor_tag, fields); ObjectRef obj = ADT(instr.constructor_tag, fields);
WriteRegister(instr.dst, obj); WriteRegister(instr.dst, obj);
pc++; pc++;
goto main_loop; goto main_loop;
......
...@@ -49,7 +49,7 @@ def convert_to_list(x): ...@@ -49,7 +49,7 @@ def convert_to_list(x):
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.Tensor): if isinstance(o, tvm.relay.backend.vmobj.Tensor):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.Datatype): elif isinstance(o, tvm.relay.backend.vmobj.ADT):
result = [] result = []
for f in o: for f in o:
result.extend(vmobj_to_list(f)) result.extend(vmobj_to_list(f))
......
...@@ -742,7 +742,7 @@ def vmobj_to_list(o): ...@@ -742,7 +742,7 @@ def vmobj_to_list(o):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.asnumpy()] return [o.asnumpy()]
elif isinstance(o, tvm.relay.backend.vmobj.Datatype): elif isinstance(o, tvm.relay.backend.vmobj.ADT):
result = [] result = []
for f in o: for f in o:
result.extend(vmobj_to_list(f)) result.extend(vmobj_to_list(f))
......
...@@ -63,7 +63,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): ...@@ -63,7 +63,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vm.Tensor): if isinstance(o, tvm.relay.backend.vm.Tensor):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vm.Datatype): elif isinstance(o, tvm.relay.backend.vm.ADT):
result = [] result = []
for f in o: for f in o:
result.extend(vmobj_to_list(f)) result.extend(vmobj_to_list(f))
......
...@@ -28,13 +28,13 @@ def test_tensor(): ...@@ -28,13 +28,13 @@ def test_tensor():
assert isinstance(x.data, tvm.nd.NDArray) assert isinstance(x.data, tvm.nd.NDArray)
def test_datatype(): def test_adt():
arr = tvm.nd.array([1,2,3]) arr = tvm.nd.array([1,2,3])
x = vm.Tensor(arr) x = vm.Tensor(arr)
y = vm.Datatype(0, [x, x]) y = vm.ADT(0, [x, x])
assert len(y) == 2 assert len(y) == 2
assert isinstance(y, vm.Datatype) assert isinstance(y, vm.ADT)
y[0:1][-1].data == x.data y[0:1][-1].data == x.data
assert y.tag == 0 assert y.tag == 0
assert isinstance(x.data, tvm.nd.NDArray) assert isinstance(x.data, tvm.nd.NDArray)
...@@ -43,4 +43,4 @@ def test_datatype(): ...@@ -43,4 +43,4 @@ def test_datatype():
if __name__ == "__main__": if __name__ == "__main__":
test_tensor() test_tensor()
test_datatype() test_adt()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment