/*! * Copyright (c) 2017 by Contributors * \file codegen_stack_vm.cc */ #include <tvm/runtime/registry.h> #include <tvm/packed_func_ext.h> #include <limits> #include "./codegen_stack_vm.h" namespace tvm { namespace codegen { using namespace ir; StackVM CodeGenStackVM::Compile(LoweredFunc f) { for (size_t i = 0; i < f->args.size(); ++i) { Var v = f->args[i]; int vid = AllocVarID(v.get()); CHECK_EQ(static_cast<size_t>(vid), i); } this->Push(f->body); return std::move(vm_); } void CodeGenStackVM::Push(const Stmt& n) { VisitStmt(n); if (debug_) { this->PushOp(StackVM::ASSERT_SP, 0); } } void CodeGenStackVM::PushOp(StackVM::OpCode opcode) { StackVM::Code code; code.op_code = opcode; vm_.code.push_back(code); } void CodeGenStackVM::SetOperand(int64_t operand_index, int64_t operand) { CHECK(operand >= std::numeric_limits<int>::min() && operand <= std::numeric_limits<int>::max()); vm_.code.at(operand_index).v_int = static_cast<int>(operand); } int64_t CodeGenStackVM::PushOp(StackVM::OpCode opcode, int operand) { int64_t pc = static_cast<int64_t>(vm_.code.size()); StackVM::Code code; code.op_code = opcode; vm_.code.push_back(code); code.v_int = operand; vm_.code.push_back(code); return pc + 1; } int CodeGenStackVM::GetStrID(const std::string& key) { auto it = str_idmap_.find(key); if (it != str_idmap_.end()) return it->second; int sid = static_cast<int>(vm_.str_data.size()); vm_.str_data.push_back(key); str_idmap_[key] = sid; return sid; } int CodeGenStackVM::AllocVarID(const Variable* v) { CHECK(!var_idmap_.count(v)); int vid = static_cast<int>(vm_.heap_size); CHECK_EQ(vm_.heap_size, var_idmap_.size()); vm_.heap_id_name.push_back(v->name_hint); ++vm_.heap_size; var_idmap_[v] = vid; return vid; } int CodeGenStackVM::GetVarID(const Variable* v) const { auto it = var_idmap_.find(v); CHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; return it->second; } void CodeGenStackVM::VisitExpr_(const Load* op) { this->Push(op->buffer_var); StackVM::OpCode code = StackVM::GetLoad(Type2TVMType(op->type)); if (const IntImm* index = op->index.as<IntImm>()) { this->PushOp(code, index->value); } else { this->Push(op->index); this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); this->PushOp(code, 0); } } void CodeGenStackVM::VisitStmt_(const Store* op) { this->Push(op->buffer_var); StackVM::OpCode code = StackVM::GetStore(Type2TVMType(op->value.type())); if (const IntImm* index = op->index.as<IntImm>()) { this->Push(op->value); this->PushOp(code, index->value); } else { this->Push(op->index); this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); this->Push(op->value); this->PushOp(code, 0); } } void CodeGenStackVM::VisitStmt_(const Allocate* op) { CHECK(!is_zero(op->condition)); int vid = AllocVarID(op->buffer_var.get()); if (op->new_expr.defined()) { // Prefer global static allocation for the program CHECK_EQ(op->free_function, "nop"); this->Push(op->new_expr); this->PushOp(StackVM::STORE_HEAP, vid); } else { LOG(FATAL) << "Dynamic allocation not supported"; } } void CodeGenStackVM::VisitExpr_(const Call* op) { if (op->is_intrinsic(intrinsic::tvm_address_of)) { const Load *l = op->args[0].as<Load>(); CHECK(op->args.size() == 1 && l); this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get())); this->Push(l->index); this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); } else if (op->is_intrinsic(Call::reinterpret)) { this->Push(op->args[0]); } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { CHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as<IntImm>()->value; this->Push(op->args[0]); const IntImm* index = op->args[1].as<IntImm>(); CHECK(index != nullptr); StackVM::Code code; code.op_code = StackVM::TVM_STRUCT_GET; vm_.code.push_back(code); code.v_int = index->value; vm_.code.push_back(code); code.v_int = kind; vm_.code.push_back(code); } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { CHECK_GE(op->args.size(), 5U); const StringImm* s = op->args[0].as<StringImm>(); CHECK(s != nullptr) << "tvm_call_global expect first argument as function name"; this->Push(op->args[1]); this->Push(op->args[2]); int begin = op->args[3].as<IntImm>()->value; int end = op->args[4].as<IntImm>()->value; // find the fuction id. const std::string& func_name = s->value; auto it = extern_fun_idmap_.find(func_name); int fid; if (it != extern_fun_idmap_.end()) { fid = it->second; } else { fid = static_cast<int>(vm_.extern_func_name.size()); vm_.extern_func_name.push_back(func_name); extern_fun_idmap_[func_name] = fid; } // CALL_PACKED_FUNC StackVM::Code code; code.op_code = StackVM::CALL_PACKED_LOWERED; vm_.code.push_back(code); code.v_int = fid; vm_.code.push_back(code); code.v_int = begin; vm_.code.push_back(code); code.v_int = end; vm_.code.push_back(code); } else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { CHECK_EQ(op->args.size(), 2U); const std::string& type = op->args[0].as<StringImm>()->value; const IntImm* num = op->args[1].as<IntImm>(); CHECK(num != nullptr); static_assert(alignof(TVMValue) % alignof(TVMArray) == 0, "invariant"); // static_assert(alignof(TVMValue) % alignof(tvm_index_t) == 0, "invariant"); size_t unit = sizeof(TVMValue); size_t size = 0; if (type == "shape") { size = (num->value * sizeof(tvm_index_t) + unit - 1) / unit; } else if (type == "arg_value") { size = (num->value * sizeof(TVMValue) + unit - 1) / unit; } else if (type == "arg_tcode") { size = (num->value * sizeof(int) + unit - 1) / unit; } else if (type == "array") { size = (num->value * sizeof(TVMArray) + unit - 1) / unit; } else { LOG(FATAL) << "Unknown stack alloca type " << type; } // add stack size to be safe. vm_.stack_size += size; this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size)); } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { CHECK_EQ(op->args.size(), 1U); this->Push(op->args[0]); this->PushOp(StackVM::PUSH_I64, 0); this->PushOp(StackVM::EQ_HANDLE); } else { LOG(FATAL) << "unknown function call " << op->name; } } void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, const Expr& a, const Expr& b) { this->Push(a); this->Push(b); Type t = a.type(); if (t.is_int()) { this->PushOp(op_int64); } else if (t.is_uint()) { if (t.bits() <= 32) { this->PushOp(op_int64); } else { LOG(FATAL) << "Cannot handle uint64_t in StackVM"; } } else { this->PushOp(StackVM::CodeI64ToF64(op_int64)); } } void CodeGenStackVM::PushCast(Type dst, Type src) { if (dst.is_int()) { if (src.is_int() || src.is_uint()) return; } else if (dst.is_uint()) { if (src.is_int() || src.is_uint()) return; } else if (dst.is_float()) { if (src.is_float()) return; } } void CodeGenStackVM::VisitExpr_(const StringImm *op) { int sid = this->GetStrID(op->value); this->PushOp(StackVM::PUSH_I64, sid); } void CodeGenStackVM::VisitExpr_(const IntImm *op) { CHECK(op->value >= std::numeric_limits<int>::min() && op->value <= std::numeric_limits<int>::max()) << "Int constant exceed bound"; this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value)); } void CodeGenStackVM::VisitExpr_(const UIntImm *op) { CHECK(op->value <= std::numeric_limits<int>::max()) << "Int constant exceed bound"; this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value)); } void CodeGenStackVM::VisitExpr_(const FloatImm *op) { LOG(FATAL) << "Float Imm is not supported"; } void CodeGenStackVM::VisitExpr_(const Variable *op) { int vid = this->GetVarID(op); this->PushOp(StackVM::LOAD_HEAP, vid); } void CodeGenStackVM::VisitExpr_(const Cast *op) { this->Push(op->value); PushCast(op->type, op->value.type()); } void CodeGenStackVM::VisitExpr_(const Add *op) { PushBinary(StackVM::ADD_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const Sub *op) { PushBinary(StackVM::SUB_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const Mul *op) { PushBinary(StackVM::MUL_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const Div *op) { PushBinary(StackVM::DIV_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const Mod *op) { PushBinary(StackVM::MOD_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const Min *op) { this->Push(op->a); this->Push(op->b); this->PushOp(StackVM::PUSH_VALUE, -1); this->PushOp(StackVM::PUSH_VALUE, -1); this->PushOp(StackVM::LT_I64); this->PushOp(StackVM::SELECT); } void CodeGenStackVM::VisitExpr_(const Max *op) { this->Push(op->a); this->Push(op->b); this->PushOp(StackVM::PUSH_VALUE, 0); this->PushOp(StackVM::PUSH_VALUE, -2); this->PushOp(StackVM::LT_I64); this->PushOp(StackVM::SELECT); } void CodeGenStackVM::VisitExpr_(const EQ *op) { PushBinary(StackVM::EQ_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const LE *op) { PushBinary(StackVM::LE_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const NE *op) { PushBinary(StackVM::EQ_I64, op->a, op->b); this->PushOp(StackVM::NOT); } void CodeGenStackVM::VisitExpr_(const LT *op) { PushBinary(StackVM::LT_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const GE *op) { PushBinary(StackVM::LT_I64, op->a, op->b); this->PushOp(StackVM::NOT); } void CodeGenStackVM::VisitExpr_(const GT *op) { PushBinary(StackVM::LE_I64, op->a, op->b); this->PushOp(StackVM::NOT); } void CodeGenStackVM::VisitExpr_(const And *op) { this->Push(op->a); int64_t pc_jump = this->GetPC(); int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_FALSE, 0); this->PushOp(StackVM::POP); this->Push(op->b); int64_t diff = this->GetPC() - pc_jump; this->SetOperand(opr_index, diff); } void CodeGenStackVM::VisitExpr_(const Or *op) { this->Push(op->a); int64_t pc_jump = this->GetPC(); int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_TRUE, 0); this->Push(op->b); int64_t diff = this->GetPC() - pc_jump; this->SetOperand(opr_index, diff); } void CodeGenStackVM::VisitExpr_(const Not* op) { this->PushOp(StackVM::NOT); } void CodeGenStackVM::VisitStmt_(const ProducerConsumer *op) { this->Push(op->body); } void CodeGenStackVM::VisitStmt_(const For *op) { CHECK(is_zero(op->min)); int vid = this->AllocVarID(op->loop_var.get()); this->PushOp(StackVM::PUSH_I64, 0); int64_t loop_head = this->GetPC(); this->PushOp(StackVM::STORE_HEAP, vid); this->PushOp(StackVM::LOAD_HEAP, vid); this->Push(op->extent); this->PushOp(StackVM::LT_I64); int64_t label_fjump = this->GetPC(); int64_t foward_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0); this->PushOp(StackVM::POP); this->Push(op->body); this->PushOp(StackVM::LOAD_HEAP, vid); this->PushOp(StackVM::PUSH_I64, 1); this->PushOp(StackVM::ADD_I64); int64_t label_bjump = this->GetPC(); int64_t backward_jump = this->PushOp(StackVM::RJUMP, 0); int64_t loop_end = this->GetPC(); this->PushOp(StackVM::POP); this->SetOperand(foward_jump, loop_end - label_fjump); this->SetOperand(backward_jump, loop_head - label_bjump); } void CodeGenStackVM::VisitStmt_(const Block *op) { this->Push(op->first); if (op->rest.defined()) this->Push(op->rest); } void CodeGenStackVM::VisitStmt_(const Evaluate *ev) { if (is_const(ev->value)) return; const Call* op = ev->value.as<Call>(); if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) { CHECK_EQ(op->args.size(), 4U); this->Push(op->args[0]); this->Push(op->args[3]); const IntImm* index = op->args[1].as<IntImm>(); CHECK(index != nullptr); StackVM::Code code; code.op_code = StackVM::TVM_STRUCT_SET; vm_.code.push_back(code); code.v_int = index->value; vm_.code.push_back(code); code.v_int = op->args[2].as<IntImm>()->value; vm_.code.push_back(code); } else { this->Push(ev->value); this->PushOp(StackVM::POP); } } void CodeGenStackVM::VisitStmt_(const IfThenElse *op) { this->Push(op->condition); int64_t label_ejump = this->GetPC(); int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0); this->PushOp(StackVM::POP); this->Push(op->then_case); if (op->else_case.defined()) { int64_t label_then_jump = this->GetPC(); int64_t then_jump = this->PushOp(StackVM::RJUMP, 0); int64_t else_begin = this->GetPC(); this->SetOperand(else_jump, else_begin - label_ejump); this->PushOp(StackVM::POP); this->Push(op->else_case); int64_t if_end = this->GetPC(); this->SetOperand(then_jump, if_end - label_then_jump); } else { int64_t if_end = this->GetPC(); this->SetOperand(else_jump, if_end - label_ejump); this->PushOp(StackVM::POP); } } void CodeGenStackVM::VisitStmt_(const LetStmt *op) { this->Push(op->value); int64_t vid = this->AllocVarID(op->var.get()); this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid)); this->Push(op->body); } void CodeGenStackVM::VisitExpr_(const Ramp *op) { LOG(FATAL) << "Ramp is not supported"; } void CodeGenStackVM::VisitExpr_(const Broadcast *op) { LOG(FATAL) << "Broadcast is not supported"; } void CodeGenStackVM::VisitExpr_(const Select *op) { this->Push(op->true_value); this->Push(op->false_value); this->Push(op->condition); this->PushOp(StackVM::SELECT); } void CodeGenStackVM::VisitStmt_(const AssertStmt *op) { if (op->message.as<StringImm>()) { int sid = this->GetStrID(op->message.as<StringImm>()->value); this->Push(op->condition); this->PushOp(StackVM::ASSERT, sid); } } void CodeGenStackVM::VisitStmt_(const AttrStmt *op) { this->Push(op->body); } void CodeGenStackVM::VisitExpr_(const Let *op) { this->Push(op->value); int64_t vid = this->AllocVarID(op->var.get()); this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid)); this->Push(op->body); } } // namespace codegen } // namespace tvm