Commit b8fa8f62 by Haichen Shen Committed by Jared Roesch

[Relay][VM] Add AllocTensor instruction and better instruction printer (#3306)

* Update vm print & add AllocTensor instruction

* patch

* fix invoke packed

* update cmake

* tweak move

* update invoke_closure

* lint

* add doc

* tweak
parent 59d8ba8f
...@@ -222,6 +222,7 @@ add_library(tvm_runtime SHARED ${RUNTIME_SRCS}) ...@@ -222,6 +222,7 @@ add_library(tvm_runtime SHARED ${RUNTIME_SRCS})
if(USE_RELAY_DEBUG) if(USE_RELAY_DEBUG)
message(STATUS "Building Relay in debug mode...") message(STATUS "Building Relay in debug mode...")
set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_RELAY_DEBUG") set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_RELAY_DEBUG")
else()
set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "NDEBUG") set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "NDEBUG")
endif(USE_RELAY_DEBUG) endif(USE_RELAY_DEBUG)
......
...@@ -56,13 +56,14 @@ enum class Opcode { ...@@ -56,13 +56,14 @@ enum class Opcode {
InvokeClosure = 3U, InvokeClosure = 3U,
InvokePacked = 4U, InvokePacked = 4U,
AllocTensor = 5U, AllocTensor = 5U,
AllocDatatype = 6U, AllocTensorReg = 6U,
AllocClosure = 7U, AllocDatatype = 7U,
GetField = 8U, AllocClosure = 8U,
If = 9U, GetField = 9U,
Select = 10U, If = 10U,
LoadConst = 11U, Select = 11U,
Goto = 12U LoadConst = 12U,
Goto = 13U
}; };
/*! \brief A single virtual machine instruction. /*! \brief A single virtual machine instruction.
...@@ -83,11 +84,19 @@ struct Instruction { ...@@ -83,11 +84,19 @@ struct Instruction {
union { union {
struct /* AllocTensor Operands */ { struct /* AllocTensor Operands */ {
/*! \brief The number of dimensions. */
uint32_t ndim;
/*! \brief The shape of tensor. */
int64_t* shape;
/*! \brief The datatype of tensor to be allocated. */
DLDataType dtype;
} alloc_tensor;
struct /* AllocTensorReg Operands */ {
/*! \brief The register to read the shape out of. */ /*! \brief The register to read the shape out of. */
RegName shape_register; RegName shape_register;
/*! \brief The datatype of tensor to be allocated. */ /*! \brief The datatype of tensor to be allocated. */
DLDataType dtype; DLDataType dtype;
}; } alloc_tensor_reg;
struct /* InvokeClosure Operands */ { struct /* InvokeClosure Operands */ {
/*! \brief The register containing the closure. */ /*! \brief The register containing the closure. */
RegName closure; RegName closure;
...@@ -192,13 +201,20 @@ struct Instruction { ...@@ -192,13 +201,20 @@ struct Instruction {
*/ */
static Instruction InvokePacked(Index packed_index, Index arity, Index output_size, static Instruction InvokePacked(Index packed_index, Index arity, Index output_size,
const std::vector<RegName>& args); const std::vector<RegName>& args);
/*! \brief Construct an allocate tensor instruction. /*! \brief Construct an allocate tensor instruction with constant shape.
* \param shape The shape of the tensor.
* \param dtype The dtype of the tensor.
* \param dst The destination register.
* \return The allocate tensor instruction.
*/
static Instruction AllocTensor(std::vector<int64_t> shape, DLDataType dtype, RegName dst);
/*! \brief Construct an allocate tensor instruction with register.
* \param shape_register The register containing the shape. * \param shape_register The register containing the shape.
* \param dtype The dtype of the tensor. * \param dtype The dtype of the tensor.
* \param dst The destination register. * \param dst The destination register.
* \return The allocate tensor instruction. * \return The allocate tensor instruction.
*/ */
static Instruction AllocTensor(RegName shape_register, DLDataType dtype, RegName dst); static Instruction AllocTensorReg(RegName shape_register, DLDataType dtype, RegName dst);
/*! \brief Construct an allocate datatype instruction. /*! \brief Construct an allocate datatype instruction.
* \param tag The datatype tag. * \param tag The datatype tag.
* \param num_fields The number of fields for the datatype. * \param num_fields The number of fields for the datatype.
......
...@@ -103,13 +103,6 @@ struct ConstantPool : ExprVisitor { ...@@ -103,13 +103,6 @@ struct ConstantPool : ExprVisitor {
} }
} }
void AddConstantTensorShape(TensorType expr, NDArray value) {
auto it = this->const_tensor_shape_map.find(expr);
if (it == this->const_tensor_shape_map.end()) {
this->const_tensor_shape_map.insert({expr, std::make_pair(index++, value)});
}
}
void VisitExpr_(const ConstantNode* const_node) { void VisitExpr_(const ConstantNode* const_node) {
auto konst = GetRef<Constant>(const_node); auto konst = GetRef<Constant>(const_node);
auto it = this->const_map.find(konst); auto it = this->const_map.find(konst);
...@@ -117,48 +110,6 @@ struct ConstantPool : ExprVisitor { ...@@ -117,48 +110,6 @@ struct ConstantPool : ExprVisitor {
this->const_map.insert({konst, index++}); this->const_map.insert({konst, index++});
} }
} }
NDArray GetTensorConstant(const TensorTypeNode* ttype) {
std::vector<int64_t> shapes;
for (auto sh : ttype->shape) {
shapes.push_back(Downcast<tvm::Integer>(sh)->value);
}
int64_t s = shapes.size();
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
auto shape_tensor = NDArray::Empty({s}, Type2TVMType(Int(64)), cpu_ctx);
int64_t* dims = static_cast<int64_t*>(shape_tensor->data);
for (size_t i = 0; i < shapes.size(); ++i) {
dims[i] = shapes[i];
}
return shape_tensor;
}
void VisitExpr_(const CallNode* call_node) {
for (auto arg : call_node->args) {
this->VisitExpr(arg);
}
Expr op = call_node->op;
auto func_node = op.as<FunctionNode>();
if (func_node) {
auto ret_type = call_node->checked_type();
if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
auto shape = GetTensorConstant(ttype);
auto tensor_type = GetRef<TensorType>(ttype);
AddConstantTensorShape(tensor_type, shape);
} else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
for (size_t i = 0; i < ttype->fields.size(); ++i) {
auto f = ttype->fields[i];
auto f_type = f.as<TensorTypeNode>();
auto shape = GetTensorConstant(f_type);
auto tensor_type = GetRef<TensorType>(f_type);
AddConstantTensorShape(tensor_type, shape);
}
}
}
}
}; };
std::tuple<ConstMap, ConstTensorShapeMap> LayoutConstantPool(const Module& module) { std::tuple<ConstMap, ConstTensorShapeMap> LayoutConstantPool(const Module& module) {
...@@ -206,6 +157,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -206,6 +157,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
switch (instr.op) { switch (instr.op) {
case Opcode::AllocDatatype: case Opcode::AllocDatatype:
case Opcode::AllocTensor: case Opcode::AllocTensor:
case Opcode::AllocTensorReg:
case Opcode::GetField: case Opcode::GetField:
case Opcode::LoadConst: case Opcode::LoadConst:
case Opcode::Select: case Opcode::Select:
...@@ -259,14 +211,14 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -259,14 +211,14 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
void VisitExpr_(const MatchNode* match_node) { void VisitExpr_(const MatchNode* match_node) {
auto match = GetRef<Match>(match_node); auto match = GetRef<Match>(match_node);
LOG(FATAL) << "translation of match nodes to the VM is " LOG(FATAL) << "translation of match nodes to the VM is"
<< "currently unsupported" << std::endl; << "currently unsupported";
} }
void VisitExpr_(const LetNode* let_node) { void VisitExpr_(const LetNode* let_node) {
DLOG(INFO) << let_node->value << std::endl; DLOG(INFO) << let_node->value;
this->VisitExpr(let_node->value); this->VisitExpr(let_node->value);
DLOG(INFO) << this->last_register << std::endl; DLOG(INFO) << this->last_register;
var_register_map.insert({let_node->var, this->last_register}); var_register_map.insert({let_node->var, this->last_register});
this->VisitExpr(let_node->body); this->VisitExpr(let_node->body);
} }
...@@ -327,18 +279,13 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -327,18 +279,13 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
} }
Instruction AllocTensorFromType(const TensorTypeNode* ttype) { Instruction AllocTensorFromType(const TensorTypeNode* ttype) {
DataType dtype = ttype->dtype; TVMType dltype = Type2TVMType(ttype->dtype);
TVMType dltype = Type2TVMType(dtype);
auto tensor_type = GetRef<TensorType>(ttype); auto tensor_type = GetRef<TensorType>(ttype);
auto it = this->context->const_tensor_shape_map.find(tensor_type); std::vector<int64_t> shape;
if (it == this->context->const_tensor_shape_map.end()) { for (auto dim : tensor_type->shape) {
DLOG(INFO) << "Can not find constant shape for " << tensor_type; shape.push_back(Downcast<tvm::Integer>(dim)->value);
} else {
Emit(Instruction::LoadConst(it->second.first, NewRegister()));
} }
return Instruction::AllocTensor(shape, dltype, NewRegister());
return Instruction::AllocTensor(last_register, dltype, NewRegister());
} }
void EmitInvokePrimitive(const Function& func, void EmitInvokePrimitive(const Function& func,
...@@ -532,7 +479,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs, ...@@ -532,7 +479,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
} }
VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) { VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) {
DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false) << std::endl; DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false);
size_t params = func->params.size(); size_t params = func->params.size();
VMCompiler compiler(context); VMCompiler compiler(context);
compiler.Compile(func); compiler.Compile(func);
......
...@@ -67,8 +67,14 @@ Instruction::Instruction(const Instruction& instr) { ...@@ -67,8 +67,14 @@ Instruction::Instruction(const Instruction& instr) {
this->result = instr.result; this->result = instr.result;
return; return;
case Opcode::AllocTensor: case Opcode::AllocTensor:
this->shape_register = instr.shape_register; this->alloc_tensor.ndim = instr.alloc_tensor.ndim;
this->dtype = instr.dtype; this->alloc_tensor.shape = Duplicate<int64_t>(instr.alloc_tensor.shape,
instr.alloc_tensor.ndim);
this->alloc_tensor.dtype = instr.alloc_tensor.dtype;
return;
case Opcode::AllocTensorReg:
this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register;
this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype;
return; return;
case Opcode::AllocDatatype: case Opcode::AllocDatatype:
this->constructor_tag = instr.constructor_tag; this->constructor_tag = instr.constructor_tag;
...@@ -142,8 +148,14 @@ Instruction& Instruction::operator=(const Instruction& instr) { ...@@ -142,8 +148,14 @@ Instruction& Instruction::operator=(const Instruction& instr) {
this->result = instr.result; this->result = instr.result;
return *this; return *this;
case Opcode::AllocTensor: case Opcode::AllocTensor:
this->shape_register = instr.shape_register; this->alloc_tensor.ndim = instr.alloc_tensor.ndim;
this->dtype = instr.dtype; this->alloc_tensor.shape = Duplicate<int64_t>(instr.alloc_tensor.shape,
instr.alloc_tensor.ndim);
this->alloc_tensor.dtype = instr.alloc_tensor.dtype;
return *this;
case Opcode::AllocTensorReg:
this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register;
this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype;
return *this; return *this;
case Opcode::AllocDatatype: case Opcode::AllocDatatype:
this->constructor_tag = instr.constructor_tag; this->constructor_tag = instr.constructor_tag;
...@@ -203,12 +215,15 @@ Instruction::~Instruction() { ...@@ -203,12 +215,15 @@ Instruction::~Instruction() {
case Opcode::Move: case Opcode::Move:
case Opcode::Select: case Opcode::Select:
case Opcode::Ret: case Opcode::Ret:
case Opcode::AllocTensor: case Opcode::AllocTensorReg:
case Opcode::If: case Opcode::If:
case Opcode::LoadConst: case Opcode::LoadConst:
case Opcode::GetField: case Opcode::GetField:
case Opcode::Goto: case Opcode::Goto:
return; return;
case Opcode::AllocTensor:
delete this->alloc_tensor.shape;
return;
case Opcode::AllocDatatype: case Opcode::AllocDatatype:
delete this->datatype_fields; delete this->datatype_fields;
return; return;
...@@ -226,8 +241,7 @@ Instruction::~Instruction() { ...@@ -226,8 +241,7 @@ Instruction::~Instruction() {
return; return;
default: default:
std::ostringstream out; std::ostringstream out;
LOG(FATAL) << "Invalid instruction " << static_cast<int>(this->op) LOG(FATAL) << "Invalid instruction " << static_cast<int>(this->op);
<< "\n";
} }
} }
...@@ -252,12 +266,25 @@ Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index out ...@@ -252,12 +266,25 @@ Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index out
return instr; return instr;
} }
Instruction Instruction::AllocTensor(RegName shape_register, DLDataType dtype, Index dst) { Instruction Instruction::AllocTensor(std::vector<int64_t> shape, DLDataType dtype, Index dst) {
Instruction instr; Instruction instr;
instr.op = Opcode::AllocTensor; instr.op = Opcode::AllocTensor;
instr.dst = dst; instr.dst = dst;
instr.shape_register = shape_register; instr.alloc_tensor.ndim = shape.size();
instr.dtype = dtype; instr.alloc_tensor.shape = new int64_t[shape.size()];
for (size_t i = 0; i < shape.size(); ++i) {
instr.alloc_tensor.shape[i] = shape[i];
}
instr.alloc_tensor.dtype = dtype;
return instr;
}
Instruction Instruction::AllocTensorReg(RegName shape_register, DLDataType dtype, Index dst) {
Instruction instr;
instr.op = Opcode::AllocTensorReg;
instr.dst = dst;
instr.alloc_tensor_reg.shape_register = shape_register;
instr.alloc_tensor_reg.dtype = dtype;
return instr; return instr;
} }
...@@ -381,85 +408,92 @@ void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) { ...@@ -381,85 +408,92 @@ void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) {
break; break;
} }
os << dtype.bits; os << int(dtype.bits);
if (dtype.lanes != 0) { if (dtype.lanes != 1) {
os << "[" << dtype.lanes << "]"; os << "x" << dtype.lanes;
} }
} }
template<typename T>
std::string StrJoin(T* items, int offset, int cnt, std::string delim = ",") {
if (cnt == 0) {
return "";
}
std::ostringstream oss;
oss << items[offset];
for (int i = 1; i < cnt; ++i) {
oss << delim << items[offset + i];
}
return oss.str();
}
void InstructionPrint(std::ostream& os, const Instruction& instr) { void InstructionPrint(std::ostream& os, const Instruction& instr) {
switch (instr.op) { switch (instr.op) {
case Opcode::Move: { case Opcode::Move: {
os << "move " << instr.from << " " << instr.dst; os << "move $" << instr.dst << " $" << instr.from;
break; break;
} }
case Opcode::Ret: { case Opcode::Ret: {
os << "ret " << instr.result; os << "ret $" << instr.result;
break; break;
} }
case Opcode::InvokePacked: { case Opcode::InvokePacked: {
os << "invoke_packed "; os << "invoke_packed PackedFunc[" << instr.packed_index << "](in: $"
os << instr.packed_index; << StrJoin<RegName>(instr.packed_args, 0, instr.arity - instr.output_size, ",$")
os << " " << instr.arity; << ", out: $"
os << "("; << StrJoin<RegName>(instr.packed_args, instr.arity - instr.output_size,
for (Index i = 0; i < instr.arity; ++i) { instr.output_size, ",$")
os << instr.packed_args[i] << ","; << ")";
}
os << ")";
os << " " << instr.output_size;
break; break;
} }
case Opcode::AllocTensor: { case Opcode::AllocTensor: {
os << "alloc_tensor "; os << "alloc_tensor $" << instr.dst << " ["
os << instr.dst << " "; << StrJoin<int64_t>(instr.alloc_tensor.shape, 0, instr.alloc_tensor.ndim)
os << instr.shape_register << " "; << "] ";
DLDatatypePrint(os, instr.dtype); DLDatatypePrint(os, instr.alloc_tensor.dtype);
break;
}
case Opcode::AllocTensorReg: {
os << "alloc_tensor_reg $" << instr.dst << " $"
<< instr.alloc_tensor_reg.shape_register << " ";
DLDatatypePrint(os, instr.alloc_tensor_reg.dtype);
break; break;
} }
case Opcode::AllocDatatype: { case Opcode::AllocDatatype: {
os << "alloc_data "; os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$"
os << instr.dst << " "; << StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]";
os << instr.constructor_tag << " ";
os << instr.num_fields;
break; break;
} }
case Opcode::AllocClosure: { case Opcode::AllocClosure: {
os << "alloc_closure "; os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index
os << instr.dst << " "; << "]($" << StrJoin<RegName>(instr.free_vars, 0, instr.num_freevar, ",$")
os << instr.clo_index << " "; << ")";
os << instr.num_freevar << "(";
for (Index i = 0; i < instr.num_freevar; ++i) {
os << instr.free_vars[i] << ",";
}
os << ")";
break; break;
} }
case Opcode::If: { case Opcode::If: {
os << "if " os << "if " << "$" << instr.if_cond << " " << instr.true_offset << " "
<< "$" << instr.if_cond << " " << instr.true_offset << " " << instr.false_offset; << instr.false_offset;
break; break;
} }
case Opcode::Invoke: { case Opcode::Invoke: {
os << "invoke " os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($"
<< "$" << instr.dst << " " << instr.func_index << " " << instr.num_args << "("; << StrJoin<RegName>(instr.invoke_args_registers, 0, instr.num_args, ",$")
for (Index i = 0; i < instr.num_args; ++i) { << ")";
os << instr.invoke_args_registers[i] << ",";
}
os << ")";
break; break;
} }
case Opcode::InvokeClosure: { case Opcode::InvokeClosure: {
os << "invoke_closure " os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($"
<< "$" << instr.dst << " " << instr.closure << " " << instr.closure_args_num << "()"; << StrJoin<RegName>(instr.closure_args, 0, instr.closure_args_num, ",$")
<< ")";
break; break;
} }
case Opcode::LoadConst: { case Opcode::LoadConst: {
os << "load_const " os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]";
<< "$" << instr.dst << " " << instr.const_index;
break; break;
} }
case Opcode::GetField: { case Opcode::GetField: {
os << "get_field " << instr.dst << " " << instr.object << " " << instr.field_index; os << "get_field $" << instr.dst << " $" << instr.object << "["
<< instr.field_index << "]";
break; break;
} }
case Opcode::Goto: { case Opcode::Goto: {
...@@ -467,8 +501,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { ...@@ -467,8 +501,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
break; break;
} }
case Opcode::Select: { case Opcode::Select: {
os << "select " << instr.dst << " " << instr.select_cond << " " << instr.select_op1 << " " os << "select $" << instr.dst << " $" << instr.select_cond << " $"
<< instr.select_op2; << instr.select_op1 << " $" << instr.select_op2;
break; break;
} }
default: default:
...@@ -513,48 +547,64 @@ Index VirtualMachine::PopFrame() { ...@@ -513,48 +547,64 @@ Index VirtualMachine::PopFrame() {
} }
void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Object>& args) { void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Object>& args) {
DLOG(INFO) << "===================\nInvoking global " << func.name << " " << args.size() DLOG(INFO) << "Invoking global " << func.name << " " << args.size();
<< std::endl;
PushFrame(func.params, this->pc + 1, func); PushFrame(func.params, this->pc + 1, func);
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
WriteRegister(i, args[i]); WriteRegister(i, args[i]);
} }
DLOG(INFO) << "func.params= " << func.params << std::endl; DLOG(INFO) << "func.params= " << func.params;
code = func.instructions.data(); code = func.instructions.data();
pc = 0; pc = 0;
} }
Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>& args) { Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>& args) {
DLOG(INFO) << "Executing Function: " << std::endl << func << std::endl; DLOG(INFO) << "Executing Function: " << std::endl << func;
InvokeGlobal(func, args); InvokeGlobal(func, args);
Run(); Run();
auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]); auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]);
DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B\n"; DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B";
return return_register; return return_register;
} }
Object VirtualMachine::Invoke(const std::string& name, const std::vector<Object>& args) { Object VirtualMachine::Invoke(const std::string& name, const std::vector<Object>& args) {
auto func_index = this->global_map_[name]; auto func_index = this->global_map_[name];
DLOG(INFO) << "Invoke Global " << name << " at index " << func_index << std::endl; DLOG(INFO) << "Invoke Global " << name << " at index " << func_index;
return Invoke(this->functions[func_index], args); return Invoke(this->functions[func_index], args);
} }
void InvokePacked(const PackedFunc& func, Index arg_count, Index output_size, void InvokePacked(const PackedFunc& func, Index arg_count, Index output_size,
const std::vector<Object>& args) { const std::vector<Object>& args) {
std::vector<TVMValue> values(arg_count); size_t arity = 0;
std::vector<int> codes(arg_count); for (Index i = 0; i < arg_count; i++) {
runtime::TVMArgsSetter setter(values.data(), codes.data()); if (args[i].ptr_->tag == ObjectTag::kDatatype) {
arity += args[i].AsDatatype()->fields.size();
} else {
++arity;
}
}
std::vector<TVMValue> values(arity);
std::vector<int> codes(arity);
runtime::TVMArgsSetter setter(values.data(), codes.data());
int idx = 0;
for (Index i = 0; i < arg_count; i++) { for (Index i = 0; i < arg_count; i++) {
NDArray data = ToNDArray(args[i]); if (args[i].ptr_->tag == ObjectTag::kDatatype) {
setter(i, data); auto dt_cell = args[i].AsDatatype();
for (auto obj : dt_cell->fields) {
NDArray data = ToNDArray(obj);
setter(idx++, data);
}
} else {
NDArray data = ToNDArray(args[i]);
setter(idx++, data);
}
} }
TVMRetValue rv; TVMRetValue rv;
func.CallPacked(TVMArgs(values.data(), codes.data(), arg_count), &rv); func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
} }
void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) { this->ctxs = ctxs; } void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) { this->ctxs = ctxs; }
...@@ -574,7 +624,7 @@ void VirtualMachine::Run() { ...@@ -574,7 +624,7 @@ void VirtualMachine::Run() {
while (true) { while (true) {
main_loop: main_loop:
auto const& instr = this->code[this->pc]; auto const& instr = this->code[this->pc];
DLOG(INFO) << "\nExecuting(" << pc << "): "; DLOG(INFO) << "Executing(" << pc << "): ";
#if USE_RELAY_DEBUG #if USE_RELAY_DEBUG
InstructionPrint(std::cout, instr); InstructionPrint(std::cout, instr);
#endif // USE_RELAY_DEBUG #endif // USE_RELAY_DEBUG
...@@ -669,11 +719,23 @@ void VirtualMachine::Run() { ...@@ -669,11 +719,23 @@ void VirtualMachine::Run() {
goto main_loop; goto main_loop;
} }
case Opcode::AllocTensor: { case Opcode::AllocTensor: {
auto shape = std::vector<int64_t>(instr.alloc_tensor.ndim);
for (uint i = 0; i < instr.alloc_tensor.ndim; ++i) {
shape[i] = instr.alloc_tensor.shape[i];
}
auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]);
auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]);
auto obj = Object::Tensor(data);
WriteRegister(instr.dst, obj);
pc++;
goto main_loop;
}
case Opcode::AllocTensorReg: {
DLContext cpu_ctx; DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU; cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0; cpu_ctx.device_id = 0;
auto shape_tensor_obj = ReadRegister(instr.shape_register); auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
NDArray shape_tensor = ToNDArray(shape_tensor_obj).CopyTo(cpu_ctx); NDArray shape_tensor = ToNDArray(shape_tensor_obj).CopyTo(cpu_ctx);
int64_t* dims = static_cast<int64_t*>(shape_tensor->data); int64_t* dims = static_cast<int64_t*>(shape_tensor->data);
...@@ -681,7 +743,7 @@ void VirtualMachine::Run() { ...@@ -681,7 +743,7 @@ void VirtualMachine::Run() {
auto shape = std::vector<int64_t>(shape_tensor->shape[0]); auto shape = std::vector<int64_t>(shape_tensor->shape[0]);
shape.assign(dims, dims + num_dims); shape.assign(dims, dims + num_dims);
auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]);
auto data = allocator->Empty(shape, instr.dtype, ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]);
auto obj = Object::Tensor(data); auto obj = Object::Tensor(data);
WriteRegister(instr.dst, obj); WriteRegister(instr.dst, obj);
pc++; pc++;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment