vm.cc 34.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file src/runtime/vm/vm.cc
 * \brief The Relay virtual machine.
 */

25
#include <dmlc/memory_io.h>
26
#include <tvm/support/logging.h>
27
#include <tvm/runtime/container.h>
28
#include <tvm/runtime/vm.h>
29 30
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
31

32
#include <algorithm>
33 34 35 36 37 38
#include <chrono>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <vector>

39 40
#include "memory_manager.h"
#include "naive_allocator.h"
41 42 43 44 45 46 47

using namespace tvm::runtime;

namespace tvm {
namespace runtime {
namespace vm {

48 49 50 51 52 53
VMClosure::VMClosure(size_t func_index, std::vector<ObjectRef> free_vars) {
  auto ptr = make_object<VMClosureObj>();
  ptr->func_index = func_index;
  ptr->free_vars = std::move(free_vars);
  data_ = std::move(ptr);
}
54

55
inline Storage make_storage(size_t size, size_t alignment, DLDataType dtype_hint, TVMContext ctx) {
56
  // We could put cache in here, from ctx to storage allocator.
57
  auto storage_obj = SimpleObjAllocator().make_object<StorageObj>();
58 59 60 61 62 63 64
  auto alloc = MemoryManager::Global()->GetAllocator(ctx);
  DCHECK(alloc != nullptr)
    << "allocator must not null";
  storage_obj->buffer = alloc->Alloc(size, alignment, dtype_hint);
  return Storage(storage_obj);
}

65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
Instruction::Instruction() {}

template <typename T>
static T* Duplicate(T* src, Index size) {
  auto dst = new T[size];
  std::copy(src, src + size, dst);
  return dst;
}

Instruction::Instruction(const Instruction& instr) {
  this->op = instr.op;
  this->dst = instr.dst;

  switch (instr.op) {
    case Opcode::Move:
      this->from = instr.from;
      return;
82
    case Opcode::Fatal:
83 84 85 86 87
      return;
    case Opcode::Ret:
      this->result = instr.result;
      return;
    case Opcode::AllocTensor:
88
      this->alloc_tensor.storage = instr.alloc_tensor.storage;
89 90 91 92 93 94
      this->alloc_tensor.ndim = instr.alloc_tensor.ndim;
      this->alloc_tensor.shape = Duplicate<int64_t>(instr.alloc_tensor.shape,
                                                    instr.alloc_tensor.ndim);
      this->alloc_tensor.dtype = instr.alloc_tensor.dtype;
      return;
    case Opcode::AllocTensorReg:
95
      this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage;
96 97
      this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register;
      this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype;
98
      return;
99
    case Opcode::AllocADT:
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
      this->constructor_tag = instr.constructor_tag;
      this->num_fields = instr.num_fields;
      this->datatype_fields = Duplicate<RegName>(instr.datatype_fields, instr.num_fields);
      return;
    case Opcode::AllocClosure:
      this->clo_index = instr.clo_index;
      this->num_freevar = instr.num_freevar;
      this->free_vars = Duplicate<RegName>(instr.free_vars, instr.num_freevar);
      return;
    case Opcode::InvokePacked:
      this->packed_index = instr.packed_index;
      this->arity = instr.arity;
      this->output_size = instr.output_size;
      this->packed_args = Duplicate<RegName>(instr.packed_args, instr.arity);
      return;
    case Opcode::InvokeClosure:
      this->closure = instr.closure;
117 118
      this->num_closure_args = instr.num_closure_args;
      this->closure_args = Duplicate<RegName>(instr.closure_args, instr.num_closure_args);
119 120 121 122 123 124 125
      return;
    case Opcode::Invoke:
      this->func_index = instr.func_index;
      this->num_args = instr.num_args;
      this->invoke_args_registers = Duplicate<RegName>(instr.invoke_args_registers, instr.num_args);
      return;
    case Opcode::If:
126
      this->if_op = instr.if_op;
127 128 129 130
      return;
    case Opcode::LoadConst:
      this->const_index = instr.const_index;
      return;
131 132 133
    case Opcode::LoadConsti:
      this->load_consti = instr.load_consti;
      return;
134 135 136 137
    case Opcode::GetField:
      this->object = instr.object;
      this->field_index = instr.field_index;
      return;
138 139 140
    case Opcode::GetTag:
      this->get_tag = instr.get_tag;
      return;
141 142 143
    case Opcode::Goto:
      this->pc_offset = instr.pc_offset;
      return;
144 145 146
    case Opcode::AllocStorage:
      this->alloc_storage = instr.alloc_storage;
      return;
147 148 149 150 151 152 153
    default:
      std::ostringstream out;
      out << "Invalid instruction " << static_cast<int>(instr.op);
      throw std::runtime_error(out.str());
  }
}

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
template<typename T>
static inline void FreeIf(T* t) {
  if (t != nullptr) {
    delete t;
  }
}

Instruction& Instruction::operator=(const Instruction& instr) {
  this->op = instr.op;
  this->dst = instr.dst;

  switch (instr.op) {
    case Opcode::Move:
      this->from = instr.from;
      return *this;
169 170 171 172
    case Opcode::Fatal:
      return *this;
    case Opcode::LoadConsti:
      this->load_consti = instr.load_consti;
173 174 175 176 177
      return *this;
    case Opcode::Ret:
      this->result = instr.result;
      return *this;
    case Opcode::AllocTensor:
178
      this->alloc_tensor.storage = this->alloc_tensor.storage;
179 180 181 182 183 184
      this->alloc_tensor.ndim = instr.alloc_tensor.ndim;
      this->alloc_tensor.shape = Duplicate<int64_t>(instr.alloc_tensor.shape,
                                                    instr.alloc_tensor.ndim);
      this->alloc_tensor.dtype = instr.alloc_tensor.dtype;
      return *this;
    case Opcode::AllocTensorReg:
185
      this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage;
186 187
      this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register;
      this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype;
188
      return *this;
189
    case Opcode::AllocADT:
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
      this->constructor_tag = instr.constructor_tag;
      this->num_fields = instr.num_fields;
      FreeIf(this->datatype_fields);
      this->datatype_fields = Duplicate<RegName>(instr.datatype_fields, instr.num_fields);
      return *this;
    case Opcode::AllocClosure:
      this->clo_index = instr.clo_index;
      this->num_freevar = instr.num_freevar;
      FreeIf(this->free_vars);
      this->free_vars = Duplicate<RegName>(instr.free_vars, instr.num_freevar);
      return *this;
    case Opcode::InvokePacked:
      this->packed_index = instr.packed_index;
      this->arity = instr.arity;
      this->output_size = instr.output_size;
      FreeIf(this->packed_args);
      this->packed_args = Duplicate<RegName>(instr.packed_args, instr.arity);
      return *this;
    case Opcode::InvokeClosure:
      this->closure = instr.closure;
210
      this->num_closure_args = instr.num_closure_args;
211
      FreeIf(this->closure_args);
212
      this->closure_args = Duplicate<RegName>(instr.closure_args, instr.num_closure_args);
213 214 215 216 217 218 219 220
      return *this;
    case Opcode::Invoke:
      this->func_index = instr.func_index;
      this->num_args = instr.num_args;
      FreeIf(this->invoke_args_registers);
      this->invoke_args_registers = Duplicate<RegName>(instr.invoke_args_registers, instr.num_args);
      return *this;
    case Opcode::If:
221
      this->if_op = instr.if_op;
222 223 224 225 226 227 228 229
      return *this;
    case Opcode::LoadConst:
      this->const_index = instr.const_index;
      return *this;
    case Opcode::GetField:
      this->object = instr.object;
      this->field_index = instr.field_index;
      return *this;
230 231 232
    case Opcode::GetTag:
      this->get_tag = instr.get_tag;
      return *this;
233 234 235
    case Opcode::Goto:
      this->pc_offset = instr.pc_offset;
      return *this;
236 237 238
    case Opcode::AllocStorage:
      this->alloc_storage = instr.alloc_storage;
      return *this;
239 240 241 242 243 244 245
    default:
      std::ostringstream out;
      out << "Invalid instruction " << static_cast<int>(instr.op);
      throw std::runtime_error(out.str());
  }
}

246 247 248 249
Instruction::~Instruction() {
  switch (this->op) {
    case Opcode::Move:
    case Opcode::Ret:
250
    case Opcode::AllocTensorReg:
251 252 253
    case Opcode::If:
    case Opcode::LoadConst:
    case Opcode::GetField:
254
    case Opcode::GetTag:
255
    case Opcode::Goto:
256
    case Opcode::LoadConsti:
257
    case Opcode::AllocStorage:
258
    case Opcode::Fatal:
259
      return;
260 261 262
    case Opcode::AllocTensor:
      delete this->alloc_tensor.shape;
      return;
263
    case Opcode::AllocADT:
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
      delete this->datatype_fields;
      return;
    case Opcode::AllocClosure:
      delete this->free_vars;
      return;
    case Opcode::InvokePacked:
      delete this->packed_args;
      return;
    case Opcode::InvokeClosure:
      delete this->closure_args;
      return;
    case Opcode::Invoke:
      delete this->invoke_args_registers;
      return;
    default:
      std::ostringstream out;
280
      LOG(FATAL) << "Invalid instruction " << static_cast<int>(this->op);
281 282 283 284 285 286 287 288 289 290
  }
}

Instruction Instruction::Ret(RegName result) {
  Instruction instr;
  instr.op = Opcode::Ret;
  instr.result = result;
  return instr;
}

291 292 293 294 295 296
Instruction Instruction::Fatal() {
  Instruction instr;
  instr.op = Opcode::Fatal;
  return instr;
}

297 298 299
Instruction Instruction::InvokePacked(Index packed_index,
                                      Index arity,
                                      Index output_size,
300 301 302 303 304 305 306 307 308 309 310 311 312
                                      const std::vector<RegName>& args) {
  Instruction instr;
  instr.op = Opcode::InvokePacked;
  instr.packed_index = packed_index;
  instr.arity = arity;
  instr.output_size = output_size;
  instr.packed_args = new RegName[arity];
  for (Index i = 0; i < arity; ++i) {
    instr.packed_args[i] = args[i];
  }
  return instr;
}

313 314 315 316
Instruction Instruction::AllocTensor(
  RegName storage,
  const std::vector<int64_t>& shape,
  DLDataType dtype, Index dst) {
317 318 319
  Instruction instr;
  instr.op = Opcode::AllocTensor;
  instr.dst = dst;
320
  instr.alloc_tensor.storage = storage;
321 322 323 324 325 326 327 328 329
  instr.alloc_tensor.ndim = shape.size();
  instr.alloc_tensor.shape = new int64_t[shape.size()];
  for (size_t i = 0; i < shape.size(); ++i) {
    instr.alloc_tensor.shape[i] = shape[i];
  }
  instr.alloc_tensor.dtype = dtype;
  return instr;
}

330 331 332 333
Instruction Instruction::AllocTensorReg(
  RegName storage,
  RegName shape_register,
  DLDataType dtype, Index dst) {
334 335 336
  Instruction instr;
  instr.op = Opcode::AllocTensorReg;
  instr.dst = dst;
337
  instr.alloc_tensor_reg.storage = storage;
338 339
  instr.alloc_tensor_reg.shape_register = shape_register;
  instr.alloc_tensor_reg.dtype = dtype;
340 341 342
  return instr;
}

343 344
Instruction Instruction::AllocStorage(RegName size,
                                      Index alignment,
345
                                      DLDataType dtype_hint,
346 347 348 349 350 351 352 353 354 355
                                      Index dst) {
  Instruction instr;
  instr.op = Opcode::AllocStorage;
  instr.dst = dst;
  instr.alloc_storage.allocation_size = size;
  instr.alloc_storage.alignment = alignment;
  instr.alloc_storage.dtype_hint = dtype_hint;
  return instr;
}

356
Instruction Instruction::AllocADT(Index tag, Index num_fields,
357 358
                                       const std::vector<RegName>& datatype_fields, Index dst) {
  Instruction instr;
359
  instr.op = Opcode::AllocADT;
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
  instr.dst = dst;
  instr.constructor_tag = tag;
  instr.num_fields = num_fields;
  instr.datatype_fields = new RegName[num_fields];
  for (Index i = 0; i < num_fields; ++i) {
    instr.datatype_fields[i] = datatype_fields[i];
  }
  return instr;
}

Instruction Instruction::AllocClosure(Index func_index, Index free_vars,
                                      const std::vector<RegName>& free_var_register, Index dst) {
  Instruction instr;
  instr.op = Opcode::AllocClosure;
  instr.dst = dst;
  instr.clo_index = func_index;
  instr.num_freevar = free_vars;
  instr.free_vars = new RegName[instr.num_freevar];
  for (Index i = 0; i < instr.num_freevar; ++i) {
    instr.free_vars[i] = free_var_register[i];
  }
  return instr;
}

Instruction Instruction::GetField(RegName object, Index field_index, RegName dst) {
  Instruction instr;
  instr.op = Opcode::GetField;
  instr.dst = dst;
  instr.object = object;
  instr.field_index = field_index;
  return instr;
}

393
Instruction Instruction::GetTag(RegName object, RegName dst) {
394
  Instruction instr;
395 396 397
  instr.op = Opcode::GetTag;
  instr.dst = dst;
  instr.get_tag.object = object;
398 399 400
  return instr;
}

401
Instruction Instruction::If(RegName test, RegName target, Index true_branch, Index false_branch) {
402
  Instruction instr;
403 404 405 406 407
  instr.op = Opcode::If;
  instr.if_op.test = test;
  instr.if_op.target = target;
  instr.if_op.true_offset = true_branch;
  instr.if_op.false_offset = false_branch;
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437
  return instr;
}

Instruction Instruction::Goto(Index pc_offset) {
  Instruction instr;
  instr.op = Opcode::Goto;
  instr.pc_offset = pc_offset;
  return instr;
}

Instruction Instruction::Invoke(Index func_index, const std::vector<RegName>& args_registers,
                                RegName dst) {
  Instruction instr;
  instr.op = Opcode::Invoke;
  instr.dst = dst;
  instr.func_index = func_index;
  instr.num_args = args_registers.size();
  instr.invoke_args_registers = new RegName[instr.num_args];
  for (Index i = 0; i < instr.num_args; ++i) {
    instr.invoke_args_registers[i] = args_registers[i];
  }
  return instr;
}

Instruction Instruction::InvokeClosure(RegName closure, const std::vector<RegName>& args,
                                       RegName dst) {
  Instruction instr;
  instr.op = Opcode::InvokeClosure;
  instr.dst = dst;
  instr.closure = closure;
438
  instr.num_closure_args = args.size();
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
  instr.closure_args = new RegName[args.size()];
  for (size_t i = 0; i < args.size(); ++i) {
    instr.closure_args[i] = args[i];
  }
  return instr;
}

Instruction Instruction::LoadConst(Index const_index, RegName dst) {
  Instruction instr;
  instr.op = Opcode::LoadConst;
  instr.dst = dst;
  instr.const_index = const_index;
  return instr;
}

454
Instruction Instruction::LoadConsti(Index val, RegName dst) {
455 456 457 458 459 460 461
  Instruction instr;
  instr.op = Opcode::LoadConsti;
  instr.dst = dst;
  instr.load_consti.val = val;
  return instr;
}

462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482
Instruction Instruction::Move(RegName src, RegName dst) {
  Instruction instr;
  instr.op = Opcode::Move;
  instr.dst = dst;
  instr.from = src;
  return instr;
}

void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) {
  switch (dtype.code) {
    case kDLInt:
      os << "int";
      break;
    case kDLUInt:
      os << "uint";
      break;
    case kDLFloat:
      os << "float";
      break;
  }

483 484 485
  os << int(dtype.bits);
  if (dtype.lanes != 1) {
    os << "x" << dtype.lanes;
486 487 488
  }
}

489
template<typename T>
490
std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ") {
491 492 493 494 495 496 497 498 499 500 501
  if (cnt == 0) {
    return "";
  }
  std::ostringstream oss;
  oss << items[offset];
  for (int i = 1; i < cnt; ++i) {
    oss << delim << items[offset + i];
  }
  return oss.str();
}

502 503 504
void InstructionPrint(std::ostream& os, const Instruction& instr) {
  switch (instr.op) {
    case Opcode::Move: {
505
      os << "move $" << instr.dst << " $" << instr.from;
506 507 508
      break;
    }
    case Opcode::Ret: {
509
      os << "ret $" << instr.result;
510 511
      break;
    }
512 513 514 515
    case Opcode::Fatal: {
      os << "fatal";
      break;
    }
516
    case Opcode::InvokePacked: {
517 518 519
      os << "invoke_packed PackedFunc[" << instr.packed_index << "] (in: $"
         << StrJoin<RegName>(instr.packed_args, 0,
                             instr.arity - instr.output_size, ", $")
520 521
         << ", out: $"
         << StrJoin<RegName>(instr.packed_args, instr.arity - instr.output_size,
522
                             instr.output_size, ", $")
523
         << ")";
524 525 526
      break;
    }
    case Opcode::AllocTensor: {
527 528
      os << "alloc_tensor $" << instr.dst << " $"
         << instr.alloc_tensor.storage << " ["
529 530
         << StrJoin<int64_t>(instr.alloc_tensor.shape, 0,
                             instr.alloc_tensor.ndim)
531 532 533 534 535 536
         << "] ";
      DLDatatypePrint(os, instr.alloc_tensor.dtype);
      break;
    }
    case Opcode::AllocTensorReg: {
      os << "alloc_tensor_reg $" << instr.dst << " $"
537
         << instr.alloc_tensor_reg.storage << " $"
538 539
         << instr.alloc_tensor_reg.shape_register << " ";
      DLDatatypePrint(os, instr.alloc_tensor_reg.dtype);
540 541
      break;
    }
542
    case Opcode::AllocADT: {
543
      os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$"
544
         << StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]";
545 546 547
      break;
    }
    case Opcode::AllocClosure: {
548 549
      os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index
         << "]($" << StrJoin<RegName>(instr.free_vars, 0, instr.num_freevar, ",$")
550
         << ")";
551 552 553
      break;
    }
    case Opcode::If: {
554
      os << "if " << "$" << instr.if_op.test << " $" << instr.if_op.target << " "
555
         << instr.if_op.true_offset << " " << instr.if_op.false_offset;
556 557 558
      break;
    }
    case Opcode::Invoke: {
559 560
      os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($"
         << StrJoin<RegName>(instr.invoke_args_registers, 0, instr.num_args, ",$")
561
         << ")";
562 563 564
      break;
    }
    case Opcode::InvokeClosure: {
565
      os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($"
566
         << StrJoin<RegName>(instr.closure_args, 0, instr.num_closure_args, ",$")
567
         << ")";
568 569 570
      break;
    }
    case Opcode::LoadConst: {
571
      os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]";
572 573
      break;
    }
574
    case Opcode::LoadConsti: {
575
      os << "load_consti $" << instr.dst << " " << instr.load_consti.val;
576 577
      break;
    }
578
    case Opcode::GetField: {
579
      os << "get_field $" << instr.dst << " $" << instr.object << "["
580
         << instr.field_index << "]";
581 582
      break;
    }
583
    case Opcode::GetTag: {
584
      os << "get_tag $" << instr.dst << " $" << instr.get_tag.object;
585 586
      break;
    }
587
    case Opcode::Goto: {
588
      os << "goto " << instr.pc_offset;
589 590
      break;
    }
591
    case Opcode::AllocStorage: {
592 593 594
      os << "alloc_storage $" <<
        instr.dst << " $" <<
        instr.alloc_storage.allocation_size << " $" <<
595
        instr.alloc_storage.alignment << " " <<
596
        DLDataType2String(instr.alloc_storage.dtype_hint);
597 598
      break;
    }
599 600 601 602 603 604 605 606 607 608 609 610 611 612
    default:
      LOG(FATAL) << "should never hit this case" << static_cast<int>(instr.op);
      break;
  }
}

std::ostream& operator<<(std::ostream& os, const Instruction& instr) {
  InstructionPrint(os, instr);
  return os;
}

void VMFunctionPrint(std::ostream& os, const VMFunction& vm_func) {
  os << vm_func.name << ": " << std::endl;
  for (size_t i = 0; i < vm_func.instructions.size(); ++i) {
613
    os << i << ": " << vm_func.instructions[i] << ";" << std::endl;
614 615 616 617 618 619 620 621
  }
}

std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) {
  VMFunctionPrint(os, vm_func);
  return os;
}

622 623 624 625 626
inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) {
  if (src->IsInstance<NDArray::ContainerType>()) {
    auto nd_array = Downcast<NDArray>(src);
    if (nd_array->ctx.device_type != ctx.device_type) {
      return nd_array.CopyTo(ctx);
627 628
    }
  }
629
  return src;
630 631
}

632
PackedFunc VirtualMachine::GetFunction(const std::string& name,
633
                                       const ObjectPtr<Object>& sptr_to_self) {
634 635
  if (name == "invoke") {
    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
636
      CHECK(exec_) << "The executable is not created yet.";
637
      std::string func_name = args[0];
638 639 640 641 642 643 644 645 646 647 648
      auto git = exec_->global_map.find(func_name);
      CHECK(git != exec_->global_map.end())
        << "Cannot find function " << func_name << " in the executable";
      auto func = exec_->functions[git->second];
      if (func.params.empty()) {
        *rv = Invoke(func, {});
      } else {
        auto it = inputs_.find(func_name);
        CHECK(it != inputs_.end()) << "Input has not been set for function " << func_name;
        const std::vector<ObjectRef> &func_args = it->second;
        *rv = Invoke(func, func_args);
649 650 651 652 653 654 655 656 657 658 659 660 661 662 663
      }
    });
  } else if (name == "init") {
    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
      CHECK_EQ(args.size() % 2, 0);
      std::vector<TVMContext> contexts;
      for (int i = 0; i < args.size() / 2; ++i) {
        TVMContext ctx;
        int device_type = args[i * 2];
        ctx.device_type = DLDeviceType(device_type);
        ctx.device_id = args[i * 2 + 1];
        contexts.push_back(ctx);
      }
      this->Init(contexts);
    });
664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684
  } else if (name == "set_input") {
    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
      CHECK(exec_) << "The executable is not created yet.";
      std::string func_name = args[0];
      auto gvit = exec_->global_map.find(func_name);
      CHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
      auto func_index = gvit->second;
      const auto& vm_func = exec_->functions[func_index];
      const auto& param_names = vm_func.params;
      // TODO(icemelon9): For heterogeneous execution, get input device information
      TVMContext ctx = ctxs_[0];
      CHECK_EQ(args.size() - 1, param_names.size()) <<
          "The number of provided parameters doesn't match the number of arguments";
      std::vector<ObjectRef> func_args(param_names.size());
      for (int i = 1; i < args.size(); ++i) {
        ObjectRef obj = CopyTo(args[i], ctx);
        func_args[i - 1] = obj;
      }
      inputs_.erase(func_name);
      inputs_.emplace(func_name, func_args);
    });
685 686 687 688 689 690
  } else {
    LOG(FATAL) << "Unknown packed function: " << name;
    return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
  }
}

691
TVMContext VirtualMachine::GetParamsContext() const {
692
  CHECK(!ctxs_.empty()) << "Context has not been initialized yet.";
693

694
  // Use the fallback device if no device index is available.
695
  int fallback_device_type = static_cast<int>(ctxs_[0].device_type);
696 697 698
  // TODO(wweic): For heterogeneous execution, get device information from byte

  const auto& cit =
699
      std::find_if(ctxs_.begin(), ctxs_.end(), [&fallback_device_type](const TVMContext& c) {
700 701
        return fallback_device_type == static_cast<int>(c.device_type);
      });
702
  return (cit == ctxs_.end() ? ctxs_[0] : *cit);
703 704
}

705
void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) {
706 707
  auto frame = VMFrame(ret_pc, func_index_, arg_count, code_, vm_func.register_file_size);
  frames_.push_back(frame);
708 709 710
}

Index VirtualMachine::PopFrame() {
711 712 713 714 715 716 717
  CHECK_GT(frames_.size(), 0);
  const VMFrame& fr = frames_.back();
  func_index_ = fr.func_index;
  code_ = fr.code;
  pc_ = fr.pc;
  auto call_stack_size = frames_.size();
  frames_.pop_back();
718 719 720
  return call_stack_size;
}

721
void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args) {
722
  DLOG(INFO) << "Invoking global " << func.name << " " << args.size();
723

724
  PushFrame(func.params.size(), this->pc_ + 1, func);
725 726 727
  for (size_t i = 0; i < args.size(); ++i) {
    WriteRegister(i, args[i]);
  }
728
  DLOG(INFO) << "func.params= " << func.params.size();
729

730 731
  code_ = func.instructions.data();
  pc_ = 0;
732 733
}

734
ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<ObjectRef>& args) {
735
  DLOG(INFO) << "Executing Function: " << std::endl << func;
736 737

  InvokeGlobal(func, args);
738
  RunLoop();
739
  // TODO(wweic) ctx could be obtained from the ctxs list.
740
  auto alloc = MemoryManager::Global()->GetAllocator(ctxs_[0]);
741
  DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B";
742
  return return_register_;
743 744
}

745
ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector<ObjectRef>& args) {
746 747 748 749 750 751 752
  CHECK(exec_) << "The executable has not been created yet.";
  auto it = exec_->global_map.find(name);
  CHECK(it != exec_->global_map.end())
    << "Cannot find function " << name << " in the executable";
  auto func_index_ = it->second;
  DLOG(INFO) << "Invoke Global " << name << " at index " << func_index_;
  return Invoke(exec_->functions[func_index_], args);
753 754
}

755 756
void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
                                  Index arg_count, Index output_size,
757
                                  const std::vector<ObjectRef>& args) {
758 759
  size_t arity = 0;
  for (Index i = 0; i < arg_count; i++) {
760
    if (const auto* obj = args[i].as<ADTObj>()) {
761
      arity += obj->size;
762 763 764 765
    } else {
      ++arity;
    }
  }
766

767 768 769 770
  std::vector<TVMValue> values(arity);
  std::vector<int> codes(arity);
  runtime::TVMArgsSetter setter(values.data(), codes.data());
  int idx = 0;
771
  for (Index i = 0; i < arg_count; i++) {
772
    if (const auto* dt_cell = args[i].as<ADTObj>()) {
773 774
      for (size_t fi = 0; fi < dt_cell->size; ++fi) {
        auto obj = (*dt_cell)[fi];
775 776
        auto nd_array = Downcast<NDArray>(obj);
        setter(idx++, nd_array);
777 778
      }
    } else {
779 780
      auto nd_array = Downcast<NDArray>(args[i]);
      setter(idx++, nd_array);
781
    }
782 783 784
  }

  TVMRetValue rv;
785
  func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
786 787
}

788 789
void VirtualMachine::LoadExecutable(const Executable* exec) {
  CHECK(exec) << "The executable is not created yet.";
790
  exec_ = exec;
791

792
  runtime::Module lib = exec_->lib;
793
  // Get the list of packed functions.
794
  CHECK(exec->primitive_map.empty() || lib.operator->())
795 796
      << "runtime module should have been built for primitive functions"
      << "\n";
797
  for (const auto& it : exec_->primitive_map) {
798 799
    const auto& packed_name = it.first;
    auto packed_index = static_cast<size_t>(it.second);
800 801
    if (packed_funcs_.size() <= packed_index) {
      packed_funcs_.resize(packed_index + 1);
802
    }
Zhi committed
803 804 805
    tvm::runtime::PackedFunc pf = lib.GetFunction(packed_name, true);
    CHECK(pf != nullptr) << "Cannot find function in module: " << packed_name;
    packed_funcs_[packed_index] = pf;
806 807
  }
}
808

809 810

void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) {
811
  ctxs_ = ctxs;
812 813
}

814
inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) {
815
  frames_.back().register_file[r] = val;
816 817
}

818
inline ObjectRef VirtualMachine::ReadRegister(Index r) const {
819
  return frames_.back().register_file[r];
820 821
}

822 823 824
inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
  int32_t result;
  const auto& obj = ReadRegister(r);
825 826
  auto nd_array = Downcast<NDArray>(obj);
  NDArray array = nd_array.CopyTo({kDLCPU, 0});
827 828 829 830 831 832 833 834 835 836 837

  if (array->dtype.bits <= 8) {
    result = reinterpret_cast<int8_t*>(array->data)[0];
  } else if (array->dtype.bits <= 16) {
    result = reinterpret_cast<int16_t*>(array->data)[0];
  } else {
    result = reinterpret_cast<int32_t*>(array->data)[0];
  }
  return result;
}

838
void VirtualMachine::RunLoop() {
839 840 841 842
  CHECK(this->exec_);
  CHECK(this->code_);
  pc_ = 0;
  Index frame_start = frames_.size();
843 844
  while (true) {
  main_loop:
845 846
    auto const& instr = code_[this->pc_];
    DLOG(INFO) << "Executing(" << pc_ << "): " << instr;
847 848 849 850 851 852
#if USE_RELAY_DEBUG
    InstructionPrint(std::cout, instr);
#endif  // USE_RELAY_DEBUG

    switch (instr.op) {
      case Opcode::Move: {
853
        ObjectRef from_obj;
854
        from_obj = ReadRegister(instr.from);
855
        WriteRegister(instr.dst, from_obj);
856
        pc_++;
857 858
        goto main_loop;
      }
859 860 861
      case Opcode::Fatal: {
        throw std::runtime_error("VM encountered fatal error");
      }
862
      case Opcode::LoadConst: {
863
        auto constant_obj = exec_->constants[instr.const_index];
864 865 866 867 868 869 870 871 872
        // We cache the allocated object in the constant pool. To measure, the
        // first iteration will set the pool up. The other iterations will
        // directly reuse the allocated objects.
        if (const_pool_.size() <= static_cast<size_t>(instr.const_index)) {
          const_pool_.resize(instr.const_index + 1);
        }

        if (!const_pool_[instr.const_index].defined()) {
          // TODO(wweic) ctx could be obtained from the ctxs list.
873
          const_pool_[instr.const_index] = CopyTo(constant_obj, ctxs_[0]);
874 875
        }
        WriteRegister(instr.dst, const_pool_[instr.const_index]);
876
        pc_++;
877 878
        goto main_loop;
      }
879
      case Opcode::LoadConsti: {
880 881
        auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0});
        reinterpret_cast<int64_t*>(tensor->data)[0] = instr.load_consti.val;
882
        WriteRegister(instr.dst, tensor);
883
        pc_++;
884 885
        goto main_loop;
      }
886
      case Opcode::Invoke: {
887
        std::vector<ObjectRef> args;
888 889 890
        for (Index i = 0; i < instr.num_args; ++i) {
          args.push_back(ReadRegister(instr.invoke_args_registers[i]));
        }
891 892
        InvokeGlobal(exec_->functions[instr.func_index], args);
        frames_.back().caller_return_register = instr.dst;
893 894 895
        goto main_loop;
      }
      case Opcode::InvokePacked: {
896 897
        DLOG(INFO) << "InvokedPacked " << "arity=" << instr.arity;
        const auto& func = packed_funcs_[instr.packed_index];
898
        const auto& arity = instr.arity;
899
        std::vector<ObjectRef> args;
900
        for (Index i = 0; i < arity; ++i) {
901 902 903 904
          DLOG(INFO) <<
            "arg" << i << " $" << instr.packed_args[i];
          auto arg = ReadRegister(instr.packed_args[i]);
          args.push_back(arg);
905
        }
906 907 908

        // We no longer need to write the registers back, we write directly
        // through the registers mutably.
909
        InvokePacked(instr.packed_index, func, arity, instr.output_size, args);
910
        pc_++;
911 912 913 914
        goto main_loop;
      }
      case Opcode::InvokeClosure: {
        auto object = ReadRegister(instr.closure);
915
        const auto* closure = object.as<VMClosureObj>();
916 917

        std::vector<ObjectRef> args;
918 919 920
        for (auto free_var : closure->free_vars) {
          args.push_back(free_var);
        }
921
        for (Index i = 0; i < instr.num_closure_args; ++i) {
922 923
          args.push_back(ReadRegister(instr.closure_args[i]));
        }
924 925
        InvokeGlobal(exec_->functions[closure->func_index], args);
        frames_.back().caller_return_register = instr.dst;
926 927 928 929
        goto main_loop;
      }
      case Opcode::GetField: {
        auto object = ReadRegister(instr.object);
930 931
        const auto& tuple = Downcast<ADT>(object);
        auto field = tuple[instr.field_index];
932
        WriteRegister(instr.dst, field);
933
        pc_++;
934 935
        goto main_loop;
      }
936 937
      case Opcode::GetTag: {
        auto object = ReadRegister(instr.get_tag.object);
938 939
        const auto& adt = Downcast<ADT>(object);
        auto tag = adt.tag();
940 941
        auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0});
        reinterpret_cast<int32_t*>(tag_tensor->data)[0] = tag;
942
        WriteRegister(instr.dst, tag_tensor);
943
        pc_++;
944 945
        goto main_loop;
      }
946
      case Opcode::Goto: {
947
        pc_ += instr.pc_offset;
948 949 950
        goto main_loop;
      }
      case Opcode::If: {
951 952
        int32_t test_val = LoadScalarInt(instr.if_op.test);
        int32_t target_val = LoadScalarInt(instr.if_op.target);
953

954 955
        if (test_val == target_val) {
          CHECK_NE(instr.if_op.true_offset, 0);
956
          pc_ += instr.if_op.true_offset;
957
        } else {
958
          CHECK_NE(instr.if_op.false_offset, 0);
959
          pc_ += instr.if_op.false_offset;
960 961 962 963 964
        }

        goto main_loop;
      }
      case Opcode::AllocTensor: {
965
        auto shape = std::vector<int64_t>(instr.alloc_tensor.ndim);
966

Li committed
967
        for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) {
968 969
          shape[i] = instr.alloc_tensor.shape[i];
        }
970 971 972

        auto storage_obj = ReadRegister(instr.alloc_tensor.storage);
        auto storage = Downcast<Storage>(storage_obj);
973
        auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype);
974

975
        WriteRegister(instr.dst, obj);
976
        pc_++;
977 978 979
        goto main_loop;
      }
      case Opcode::AllocTensorReg: {
980 981 982
        DLContext cpu_ctx;
        cpu_ctx.device_type = kDLCPU;
        cpu_ctx.device_id = 0;
983
        auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
984 985
        const auto shape_arr = Downcast<NDArray>(shape_tensor_obj);
        NDArray shape_tensor = shape_arr.CopyTo(cpu_ctx);
986 987 988 989
        const DLTensor* dl_tensor = shape_tensor.operator->();
        CHECK_EQ(dl_tensor->dtype.code, 0u);
        CHECK_LE(dl_tensor->dtype.bits, 64);
        int64_t* dims = reinterpret_cast<int64_t*>(dl_tensor->data);
990
        auto num_dims = shape_tensor->shape[0];
991
        auto shape = std::vector<int64_t>(num_dims);
992
        shape.assign(dims, dims + num_dims);
993 994 995

        auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage);
        auto storage = Downcast<Storage>(storage_obj);
996
        auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype);
997

998
        WriteRegister(instr.dst, obj);
999
        pc_++;
1000 1001
        goto main_loop;
      }
1002
      case Opcode::AllocADT: {
1003
        std::vector<ObjectRef> fields;
1004 1005 1006
        for (Index i = 0; i < instr.num_fields; ++i) {
          fields.push_back(ReadRegister(instr.datatype_fields[i]));
        }
1007
        ObjectRef obj = ADT(instr.constructor_tag, fields);
1008
        WriteRegister(instr.dst, obj);
1009
        pc_++;
1010 1011 1012
        goto main_loop;
      }
      case Opcode::AllocClosure: {
1013
        std::vector<ObjectRef> free_vars;
1014 1015 1016
        for (Index i = 0; i < instr.num_freevar; i++) {
          free_vars.push_back(ReadRegister(instr.free_vars[i]));
        }
1017
        WriteRegister(instr.dst, VMClosure(instr.func_index, free_vars));
1018
        pc_++;
1019 1020
        goto main_loop;
      }
1021 1022 1023 1024 1025 1026 1027
      case Opcode::AllocStorage: {
        auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
        auto alignment = LoadScalarInt(instr.alloc_storage.alignment);

        DLOG(INFO) <<
          "AllocStorage: allocation_size=" << size <<
          "alignment=" << alignment <<
1028
          "dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint);
1029

1030
        auto storage = make_storage(size, alignment, instr.alloc_storage.dtype_hint, ctxs_[0]);
1031
        WriteRegister(instr.dst, storage);
1032
        pc_++;
1033 1034
        goto main_loop;
      }
1035 1036 1037 1038
      case Opcode::Ret: {
        // If we have hit the point from which we started
        // running, we should return to the caller breaking
        // the dispatch loop.
1039 1040
        return_register_ = ReadRegister(instr.result);
        auto caller_return_register = frames_.back().caller_return_register;
1041 1042 1043 1044 1045

        if (PopFrame() == frame_start) {
          return;
          // Otherwise we are just returning from a local call.
        } else {
1046
          WriteRegister(caller_return_register, return_register_);
1047 1048 1049 1050 1051 1052 1053
          goto main_loop;
        }
      }
    }
  }
}

1054
runtime::Module CreateVirtualMachine(const Executable* exec) {
1055
  auto vm = make_object<VirtualMachine>();
1056 1057 1058 1059
  vm->LoadExecutable(exec);
  return runtime::Module(vm);
}

1060
TVM_REGISTER_GLOBAL("runtime._VirtualMachine")
1061 1062 1063
.set_body([](TVMArgs args, TVMRetValue* rv) {
  runtime::Module mod = args[0];
  const auto* exec = dynamic_cast<Executable*>(mod.operator->());
1064
  CHECK(exec) << "The virtual machine executable has not been defined yet.";
1065 1066 1067
  *rv = CreateVirtualMachine(exec);
});

1068 1069 1070
}  // namespace vm
}  // namespace runtime
}  // namespace tvm