codegen_stackvm.cc 17.1 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 * \file codegen_stackvm.cc
22
 */
23
#include <tvm/runtime/registry.h>
24
#include <tvm/tir/op.h>
25
#include <limits>
26
#include <utility>
27
#include "codegen_stackvm.h"
28
#include "../../runtime/stackvm/stackvm_module.h"
29 30 31 32

namespace tvm {
namespace codegen {

33
using namespace tir;
34

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
// map struct field kind to runtime variants
// We keep two separate enums to ensure runtime/compiler isolation.
StackVM::StructFieldKind MapFieldKind(int64_t kind) {
  auto val = static_cast<intrinsic::TVMStructFieldKind>(kind);
  switch (val) {
    case intrinsic::kArrData: return StackVM::kArrData;
    case intrinsic::kArrShape: return StackVM::kArrShape;
    case intrinsic::kArrAddr: return StackVM::kArrAddr;
    case intrinsic::kArrStrides: return StackVM::kArrStrides;
    case intrinsic::kArrNDim: return StackVM::kArrNDim;
    case intrinsic::kArrTypeCode: return StackVM::kArrTypeCode;
    case intrinsic::kArrTypeBits: return StackVM::kArrTypeBits;
    case intrinsic::kArrTypeLanes: return StackVM::kArrTypeLanes;
    case intrinsic::kArrByteOffset: return StackVM::kArrByteOffset;
    case intrinsic::kArrDeviceId: return StackVM::kArrDeviceId;
    case intrinsic::kArrDeviceType: return StackVM::kArrDeviceType;
    case intrinsic::kTVMValueContent: return StackVM::kTVMValueContent;
    default: LOG(FATAL) << "Do not know how to map field " << kind;
  }
  return StackVM::kArrData;
}

57
StackVM CodeGenStackVM::Compile(LoweredFunc f) {
58 59 60 61 62 63
  for (size_t i = 0; i < f->args.size(); ++i) {
    Var v = f->args[i];
    int vid = AllocVarID(v.get());
    CHECK_EQ(static_cast<size_t>(vid), i);
  }
  this->Push(f->body);
64
  vm_.InitCache();
65 66 67 68
  return std::move(vm_);
}

void CodeGenStackVM::Push(const Stmt& n) {
69
  VisitStmt(n);
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
  if (debug_) {
    this->PushOp(StackVM::ASSERT_SP, 0);
  }
}

void CodeGenStackVM::PushOp(StackVM::OpCode opcode) {
  StackVM::Code code;
  code.op_code = opcode;
  vm_.code.push_back(code);
}

void CodeGenStackVM::SetOperand(int64_t operand_index, int64_t operand) {
  CHECK(operand >= std::numeric_limits<int>::min() &&
        operand <= std::numeric_limits<int>::max());
  vm_.code.at(operand_index).v_int = static_cast<int>(operand);
}

int64_t CodeGenStackVM::PushOp(StackVM::OpCode opcode, int operand) {
  int64_t pc = static_cast<int64_t>(vm_.code.size());
  StackVM::Code code;
  code.op_code = opcode;
  vm_.code.push_back(code);
  code.v_int = operand;
  vm_.code.push_back(code);
  return pc + 1;
}

int CodeGenStackVM::GetStrID(const std::string& key) {
  auto it = str_idmap_.find(key);
  if (it != str_idmap_.end()) return it->second;
  int sid = static_cast<int>(vm_.str_data.size());
  vm_.str_data.push_back(key);
  str_idmap_[key] = sid;
  return sid;
}

106
int CodeGenStackVM::AllocVarID(const VarNode* v) {
107 108 109 110 111 112 113 114 115
  CHECK(!var_idmap_.count(v));
  int vid = static_cast<int>(vm_.heap_size);
  CHECK_EQ(vm_.heap_size, var_idmap_.size());
  vm_.heap_id_name.push_back(v->name_hint);
  ++vm_.heap_size;
  var_idmap_[v] = vid;
  return vid;
}

116
int CodeGenStackVM::GetVarID(const VarNode* v) const {
117 118 119 120 121 122
  auto it = var_idmap_.find(v);
  CHECK(it != var_idmap_.end())
      << "Find undefined Variable " << v->name_hint;
  return it->second;
}

123
void CodeGenStackVM::VisitExpr_(const LoadNode* op) {
124
  this->Push(op->buffer_var);
125
  StackVM::OpCode code = StackVM::GetLoad(op->dtype);
126
  if (const IntImmNode* index = op->index.as<IntImmNode>()) {
127
    this->PushOp(code, index->value);
128 129
  } else {
    this->Push(op->index);
130
    this->PushOp(StackVM::PUSH_I64, op->dtype.element_of().bytes());
131 132
    this->PushOp(StackVM::MUL_I64);
    this->PushOp(StackVM::ADDR_ADD);
133
    this->PushOp(code, 0);
134 135
  }
}
136

137
void CodeGenStackVM::VisitStmt_(const StoreNode* op) {
138
  this->Push(op->buffer_var);
139
  StackVM::OpCode code = StackVM::GetStore(op->value.dtype());
140
  if (const IntImmNode* index = op->index.as<IntImmNode>()) {
141
    this->Push(op->value);
142
    this->PushOp(code, index->value);
143 144
  } else {
    this->Push(op->index);
145
    this->PushOp(StackVM::PUSH_I64, op->value.dtype().element_of().bytes());
146 147 148 149 150
    this->PushOp(StackVM::MUL_I64);
    this->PushOp(StackVM::ADDR_ADD);
    this->Push(op->value);
    this->PushOp(code, 0);
  }
151 152
}

153
void CodeGenStackVM::VisitStmt_(const AllocateNode* op) {
154 155 156 157 158 159 160 161 162 163 164 165
  CHECK(!is_zero(op->condition));
  int vid = AllocVarID(op->buffer_var.get());
  if (op->new_expr.defined()) {
    // Prefer global static allocation for the program
    CHECK_EQ(op->free_function, "nop");
    this->Push(op->new_expr);
    this->PushOp(StackVM::STORE_HEAP, vid);
  } else {
    LOG(FATAL) << "Dynamic allocation not supported";
  }
}

166
void CodeGenStackVM::VisitExpr_(const CallNode* op) {
167
  if (op->is_intrinsic(intrinsic::tvm_address_of)) {
168
    const LoadNode *l = op->args[0].as<LoadNode>();
169 170 171
    CHECK(op->args.size() == 1 && l);
    this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get()));
    this->Push(l->index);
172
    this->PushOp(StackVM::PUSH_I64, l->dtype.element_of().bytes());
173 174
    this->PushOp(StackVM::MUL_I64);
    this->PushOp(StackVM::ADDR_ADD);
175
  } else if (op->is_intrinsic(CallNode::reinterpret)) {
176
    this->Push(op->args[0]);
177
  } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
178
    CHECK_EQ(op->args.size(), 3U);
179
    int kind = op->args[2].as<IntImmNode>()->value;
180
    this->Push(op->args[0]);
181
    const IntImmNode* index = op->args[1].as<IntImmNode>();
182 183 184 185 186 187
    CHECK(index != nullptr);
    StackVM::Code code;
    code.op_code = StackVM::TVM_STRUCT_GET;
    vm_.code.push_back(code);
    code.v_int = index->value;
    vm_.code.push_back(code);
188
    code.v_int = MapFieldKind(kind);
189 190 191
    vm_.code.push_back(code);
  } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
    CHECK_GE(op->args.size(), 5U);
192
    const StringImmNode* s = op->args[0].as<StringImmNode>();
193
    CHECK(s != nullptr) << "tvm_call_global expect first argument as function name";
194 195
    this->Push(op->args[1]);
    this->Push(op->args[2]);
196 197
    int begin = op->args[3].as<IntImmNode>()->value;
    int end = op->args[4].as<IntImmNode>()->value;
198 199
    // find the fuction id.
    const std::string& func_name = s->value;
200
    auto it = extern_fun_idmap_.find(func_name);
201
    int fid;
202
    if (it != extern_fun_idmap_.end()) {
203 204
      fid = it->second;
    } else {
205 206 207
      fid = static_cast<int>(vm_.extern_func_name.size());
      vm_.extern_func_name.push_back(func_name);
      extern_fun_idmap_[func_name] = fid;
208
    }
209 210 211 212 213 214 215 216 217 218 219 220
    // CALL_PACKED_FUNC
    StackVM::Code code;
    code.op_code = StackVM::CALL_PACKED_LOWERED;
    vm_.code.push_back(code);
    code.v_int = fid;
    vm_.code.push_back(code);
    code.v_int = begin;
    vm_.code.push_back(code);
    code.v_int = end;
    vm_.code.push_back(code);
  } else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
    CHECK_EQ(op->args.size(), 2U);
221 222
    const std::string& type = op->args[0].as<StringImmNode>()->value;
    const IntImmNode* num = op->args[1].as<IntImmNode>();
223
    CHECK(num != nullptr);
224
    static_assert(alignof(TVMValue) % alignof(DLTensor) == 0, "invariant");
225
    // static_assert(alignof(TVMValue) % alignof(tvm_index_t) == 0, "invariant");
226 227 228 229 230 231 232 233 234
    size_t unit = sizeof(TVMValue);
    size_t size = 0;
    if (type == "shape") {
      size = (num->value * sizeof(tvm_index_t) + unit - 1) / unit;
    } else if (type == "arg_value") {
      size = (num->value * sizeof(TVMValue) + unit - 1) / unit;
    } else if (type == "arg_tcode") {
      size = (num->value * sizeof(int) + unit - 1) / unit;
    } else if (type == "array") {
235
      size = (num->value * sizeof(DLTensor) + unit - 1) / unit;
236 237
    } else {
      LOG(FATAL) << "Unknown stack alloca type " << type;
238
    }
239 240 241
    // add stack size to be safe.
    vm_.stack_size += size;
    this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size));
242
  } else if (op->name == "TVMBackendAllocWorkspace") {
243
    CHECK_EQ(op->args.size(), 5U);
244 245 246
    this->Push(op->args[0]);
    this->Push(op->args[1]);
    this->Push(op->args[2]);
247 248
    this->Push(op->args[3]);
    this->Push(op->args[4]);
249 250 251 252 253 254 255 256 257
    this->PushOp(StackVM::TVM_DEVICE_ALLOCA);
  } else if (op->name == "TVMBackendFreeWorkspace") {
    CHECK_EQ(op->args.size(), 3U);
    this->Push(op->args[0]);
    this->Push(op->args[1]);
    this->Push(op->args[2]);
    this->PushOp(StackVM::TVM_DEVICE_FREE);
  } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
    this->PushOp(StackVM::TVM_THROW_LAST_ERROR);
258 259 260 261
  } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
    CHECK_EQ(op->args.size(), 1U);
    this->Push(op->args[0]);
    this->PushOp(StackVM::PUSH_I64, 0);
262
    this->PushOp(StackVM::EQ_HANDLE);
263
  } else {
264
    LOG(FATAL) << "unknown function call " << op->name;
265 266 267
  }
}

268
void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64,
269 270
                                const PrimExpr& a,
                                const PrimExpr& b) {
271 272
  this->Push(a);
  this->Push(b);
273
  DataType t = a.dtype();
274
  if (t.is_int()) {
275
    this->PushOp(op_int64);
276
  } else if (t.is_uint()) {
277
    this->PushOp(op_int64);
278
  } else {
279
    this->PushOp(StackVM::CodeI64ToF64(op_int64));
280 281 282
  }
}

283
void CodeGenStackVM::PushCast(DataType dst, DataType src) {
284
  if (dst.is_int()) {
285 286 287
    if (src.is_int() || src.is_uint()) return;
  } else if (dst.is_uint()) {
    if (src.is_int() || src.is_uint()) return;
288 289 290 291 292
  } else if (dst.is_float()) {
    if (src.is_float()) return;
  }
}

293
void CodeGenStackVM::VisitExpr_(const StringImmNode* op) {
294 295 296 297
  int sid = this->GetStrID(op->value);
  this->PushOp(StackVM::PUSH_I64, sid);
}

298
void CodeGenStackVM::VisitExpr_(const IntImmNode* op) {
299 300 301 302 303 304
  CHECK(op->value >= std::numeric_limits<int>::min() &&
        op->value <= std::numeric_limits<int>::max())
      << "Int constant exceed bound";
    this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
}

305
void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) {
306 307 308
  LOG(FATAL) << "Float Imm is not supported";
}

309
void CodeGenStackVM::VisitExpr_(const VarNode* op) {
310 311 312 313
  int vid = this->GetVarID(op);
  this->PushOp(StackVM::LOAD_HEAP, vid);
}

314
void CodeGenStackVM::VisitExpr_(const CastNode* op) {
315
  this->Push(op->value);
316
  PushCast(op->dtype, op->value.dtype());
317 318
}

319
void CodeGenStackVM::VisitExpr_(const AddNode* op) {
320 321 322
  PushBinary(StackVM::ADD_I64, op->a, op->b);
}

323
void CodeGenStackVM::VisitExpr_(const SubNode* op) {
324 325 326
  PushBinary(StackVM::SUB_I64, op->a, op->b);
}

327
void CodeGenStackVM::VisitExpr_(const MulNode* op) {
328 329 330
  PushBinary(StackVM::MUL_I64, op->a, op->b);
}

331
void CodeGenStackVM::VisitExpr_(const DivNode* op) {
332 333 334
  PushBinary(StackVM::DIV_I64, op->a, op->b);
}

335
void CodeGenStackVM::VisitExpr_(const ModNode* op) {
336 337 338
  PushBinary(StackVM::MOD_I64, op->a, op->b);
}

339
void CodeGenStackVM::VisitExpr_(const MinNode* op) {
340 341 342 343 344 345 346 347
  this->Push(op->a);
  this->Push(op->b);
  this->PushOp(StackVM::PUSH_VALUE, -1);
  this->PushOp(StackVM::PUSH_VALUE, -1);
  this->PushOp(StackVM::LT_I64);
  this->PushOp(StackVM::SELECT);
}

348
void CodeGenStackVM::VisitExpr_(const MaxNode* op) {
349 350 351 352 353 354 355 356
  this->Push(op->a);
  this->Push(op->b);
  this->PushOp(StackVM::PUSH_VALUE, 0);
  this->PushOp(StackVM::PUSH_VALUE, -2);
  this->PushOp(StackVM::LT_I64);
  this->PushOp(StackVM::SELECT);
}

357
void CodeGenStackVM::VisitExpr_(const EQNode* op) {
358 359 360
  PushBinary(StackVM::EQ_I64, op->a, op->b);
}

361
void CodeGenStackVM::VisitExpr_(const LENode* op) {
362 363 364
  PushBinary(StackVM::LE_I64, op->a, op->b);
}

365
void CodeGenStackVM::VisitExpr_(const NENode* op) {
366 367 368 369
  PushBinary(StackVM::EQ_I64, op->a, op->b);
  this->PushOp(StackVM::NOT);
}

370
void CodeGenStackVM::VisitExpr_(const LTNode* op) {
371 372 373
  PushBinary(StackVM::LT_I64, op->a, op->b);
}

374
void CodeGenStackVM::VisitExpr_(const GENode* op) {
375 376 377 378
  PushBinary(StackVM::LT_I64, op->a, op->b);
  this->PushOp(StackVM::NOT);
}

379
void CodeGenStackVM::VisitExpr_(const GTNode* op) {
380 381 382 383
  PushBinary(StackVM::LE_I64, op->a, op->b);
  this->PushOp(StackVM::NOT);
}

384
void CodeGenStackVM::VisitExpr_(const AndNode* op) {
385 386 387 388 389 390 391 392 393
  this->Push(op->a);
  int64_t pc_jump = this->GetPC();
  int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
  this->PushOp(StackVM::POP);
  this->Push(op->b);
  int64_t diff = this->GetPC() - pc_jump;
  this->SetOperand(opr_index, diff);
}

394
void CodeGenStackVM::VisitExpr_(const OrNode* op) {
395 396 397 398 399 400 401 402
  this->Push(op->a);
  int64_t pc_jump = this->GetPC();
  int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_TRUE, 0);
  this->Push(op->b);
  int64_t diff = this->GetPC() - pc_jump;
  this->SetOperand(opr_index, diff);
}

403
void CodeGenStackVM::VisitExpr_(const NotNode* op) {
404
  this->Push(op->a);
405 406 407
  this->PushOp(StackVM::NOT);
}

408
void CodeGenStackVM::VisitStmt_(const ProducerConsumerNode* op) {
409 410 411
  this->Push(op->body);
}

412
void CodeGenStackVM::VisitStmt_(const ForNode* op) {
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
  CHECK(is_zero(op->min));
  int vid = this->AllocVarID(op->loop_var.get());
  this->PushOp(StackVM::PUSH_I64, 0);
  int64_t loop_head = this->GetPC();
  this->PushOp(StackVM::STORE_HEAP, vid);
  this->PushOp(StackVM::LOAD_HEAP, vid);
  this->Push(op->extent);
  this->PushOp(StackVM::LT_I64);
  int64_t label_fjump = this->GetPC();
  int64_t foward_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
  this->PushOp(StackVM::POP);
  this->Push(op->body);
  this->PushOp(StackVM::LOAD_HEAP, vid);
  this->PushOp(StackVM::PUSH_I64, 1);
  this->PushOp(StackVM::ADD_I64);
  int64_t label_bjump = this->GetPC();
  int64_t backward_jump = this->PushOp(StackVM::RJUMP, 0);
  int64_t loop_end = this->GetPC();
  this->PushOp(StackVM::POP);
  this->SetOperand(foward_jump, loop_end - label_fjump);
  this->SetOperand(backward_jump, loop_head - label_bjump);
}

436 437 438 439
void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) {
  for (Stmt stmt : op->seq) {
    this->Push(stmt);
  }
440 441
}

442
void CodeGenStackVM::VisitStmt_(const EvaluateNode *ev) {
443
  if (is_const(ev->value)) return;
444
  const CallNode* op = ev->value.as<CallNode>();
445 446 447 448
  if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) {
    CHECK_EQ(op->args.size(), 4U);
    this->Push(op->args[0]);
    this->Push(op->args[3]);
449
    const IntImmNode* index = op->args[1].as<IntImmNode>();
450 451 452 453 454 455
    CHECK(index != nullptr);
    StackVM::Code code;
    code.op_code = StackVM::TVM_STRUCT_SET;
    vm_.code.push_back(code);
    code.v_int = index->value;
    vm_.code.push_back(code);
456
    code.v_int = MapFieldKind(op->args[2].as<IntImmNode>()->value);
457 458 459 460 461
    vm_.code.push_back(code);
  } else {
    this->Push(ev->value);
    this->PushOp(StackVM::POP);
  }
462 463
}

464
void CodeGenStackVM::VisitStmt_(const IfThenElseNode* op) {
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
  this->Push(op->condition);
  int64_t label_ejump = this->GetPC();
  int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
  this->PushOp(StackVM::POP);
  this->Push(op->then_case);
  if (op->else_case.defined()) {
    int64_t label_then_jump = this->GetPC();
    int64_t then_jump = this->PushOp(StackVM::RJUMP, 0);
    int64_t else_begin = this->GetPC();
    this->SetOperand(else_jump, else_begin - label_ejump);
    this->PushOp(StackVM::POP);
    this->Push(op->else_case);
    int64_t if_end = this->GetPC();
    this->SetOperand(then_jump, if_end - label_then_jump);
  } else {
    int64_t if_end = this->GetPC();
    this->SetOperand(else_jump, if_end - label_ejump);
    this->PushOp(StackVM::POP);
  }
}

486
void CodeGenStackVM::VisitStmt_(const LetStmtNode* op) {
487 488 489 490 491 492
  this->Push(op->value);
  int64_t vid = this->AllocVarID(op->var.get());
  this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
  this->Push(op->body);
}

493
void CodeGenStackVM::VisitExpr_(const RampNode* op) {
494 495 496
  LOG(FATAL) << "Ramp is not supported";
}

497
void CodeGenStackVM::VisitExpr_(const BroadcastNode* op) {
498 499 500
  LOG(FATAL) << "Broadcast is not supported";
}

501
void CodeGenStackVM::VisitExpr_(const SelectNode* op) {
502 503 504 505 506 507
  this->Push(op->true_value);
  this->Push(op->false_value);
  this->Push(op->condition);
  this->PushOp(StackVM::SELECT);
}

508 509
void CodeGenStackVM::VisitStmt_(const AssertStmtNode* op) {
  if (const auto* str = op->message.as<StringImmNode>()) {
510
    int sid = this->GetStrID(str->value);
511 512 513
    this->Push(op->condition);
    this->PushOp(StackVM::ASSERT, sid);
  }
514
  this->Push(op->body);
515 516
}

517
void CodeGenStackVM::VisitStmt_(const AttrStmtNode* op) {
518 519 520
  this->Push(op->body);
}

521
void CodeGenStackVM::VisitExpr_(const LetNode* op) {
522 523 524 525 526
  this->Push(op->value);
  int64_t vid = this->AllocVarID(op->var.get());
  this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
  this->Push(op->body);
}
527 528 529 530 531 532 533 534 535 536 537 538 539

runtime::Module BuildStackVM(const Array<LoweredFunc>& funcs) {
  CHECK_NE(funcs.size(), 0U);
  std::unordered_map<std::string, StackVM> fmap;
  for (LoweredFunc f : funcs) {
    StackVM vm = codegen::CodeGenStackVM().Compile(f);
    CHECK(!fmap.count(f->name))
        << "Function name " << f->name << "already exist in list";
    fmap[f->name] = std::move(vm);
  }
  return runtime::StackVMModuleCreate(fmap, funcs[0]->name);
}

540
TVM_REGISTER_GLOBAL("codegen.build_stackvm")
541
.set_body_typed(BuildStackVM);
542 543
}  // namespace codegen
}  // namespace tvm