vm.h 25 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file tvm/runtime/vm.h
 * \brief A virtual machine for executing Relay programs.
 */
#ifndef TVM_RUNTIME_VM_H_
#define TVM_RUNTIME_VM_H_

#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
29
#include <tvm/runtime/registry.h>
30 31 32 33 34 35 36 37 38
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

namespace tvm {
namespace runtime {
namespace vm {

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
/*! \brief An object representing a closure. */
class ClosureObj : public Object {
 public:
  /*! \brief The index into the VM function table. */
  size_t func_index;
  /*! \brief The free variables of the closure. */
  std::vector<ObjectRef> free_vars;

  static constexpr const uint32_t _type_index = TypeIndex::kVMClosure;
  static constexpr const char* _type_key = "vm.Closure";
  TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object);
};

/*! \brief reference to closure. */
class Closure : public ObjectRef {
 public:
  Closure(size_t func_index, std::vector<ObjectRef> free_vars);

  TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
};

60 61 62
/*! \brief Magic number for NDArray list file  */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;

63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
/*! \brief A register name. */
using RegName = int64_t;

/*! \brief An alias for the integer type used ubiquitously
 * in the VM.
 */
using Index = int64_t;

/*! \brief An enumeration of Relay's opcodes.
 *
 * The opcode is used to implement instruction
 * as a tagged union.
 */
enum class Opcode {
  Move = 0U,
  Ret = 1U,
  Invoke = 2U,
  InvokeClosure = 3U,
  InvokePacked = 4U,
  AllocTensor = 5U,
83
  AllocTensorReg = 6U,
84
  AllocADT = 7U,
85 86 87
  AllocClosure = 8U,
  GetField = 9U,
  If = 10U,
88 89 90 91 92
  LoadConst = 11U,
  Goto = 12U,
  GetTag = 13U,
  LoadConsti = 14U,
  Fatal = 15U,
93
  AllocStorage = 16U,
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
};

/*! \brief A single virtual machine instruction.
 *
 * The representation of the instruction is as
 * a tagged union.
 *
 * The first field represents which instruction,
 * and by extension which field of the union
 * is active.
 */
struct Instruction {
  /*! \brief The instruction opcode. */
  Opcode op;

  /*! \brief The destination register. */
  RegName dst;

  union {
    struct /* AllocTensor Operands */ {
114 115
      /*! \brief The storage to allocate from. */
      RegName storage;
116 117 118 119 120 121 122 123
      /*! \brief The number of dimensions. */
      uint32_t ndim;
      /*! \brief The shape of tensor. */
      int64_t* shape;
      /*! \brief The datatype of tensor to be allocated. */
      DLDataType dtype;
    } alloc_tensor;
    struct /* AllocTensorReg Operands */ {
124 125
      /*! \brief The storage to allocate from. */
      RegName storage;
126 127 128 129
      /*! \brief The register to read the shape out of. */
      RegName shape_register;
      /*! \brief The datatype of tensor to be allocated. */
      DLDataType dtype;
130
    } alloc_tensor_reg;
131 132 133 134
    struct /* InvokeClosure Operands */ {
      /*! \brief The register containing the closure. */
      RegName closure;
      /*! \brief The number of arguments to the closure. */
135
      Index num_closure_args;
136 137 138 139 140 141 142 143 144 145 146
      /*! \brief The closure arguments as an array. */
      RegName* closure_args;
    };
    struct /* Return Operands */ {
      /*! \brief The register to return. */
      RegName result;
    };
    struct /* Move Operands */ {
      /*! \brief The source register for a move operation. */
      RegName from;
    };
147
    struct /* InvokePacked Operands */ {
148 149 150 151 152 153 154 155 156 157
      /*! \brief The index into the packed function table. */
      Index packed_index;
      /*! \brief The arity of the packed function. */
      Index arity;
      /*! \brief The number of outputs produced by the packed function. */
      Index output_size;
      /*! \brief The arguments to pass to the packed function. */
      RegName* packed_args;
    };
    struct /* If Operands */ {
158 159 160 161
      /*! \brief The register containing the test value. */
      RegName test;
      /*! \brief The register containing the target value. */
      RegName target;
162 163 164 165
      /*! \brief The program counter offset for the true branch. */
      Index true_offset;
      /*! \brief The program counter offset for the false branch. */
      Index false_offset;
166
    } if_op;
167 168 169 170 171 172 173 174
    struct /* Invoke Operands */ {
      /*! \brief The function to call. */
      Index func_index;
      /*! \brief The number of arguments to the function. */
      Index num_args;
      /*! \brief The registers containing the arguments. */
      RegName* invoke_args_registers;
    };
175
    struct /* LoadConst Operands */ {
176 177 178
      /* \brief The index into the constant pool. */
      Index const_index;
    };
179 180
    struct /* LoadConsti Operands */ {
      /* \brief The index into the constant pool. */
181
      Index val;
182
    } load_consti;
183 184 185 186 187 188 189 190 191 192
    struct /* Jump Operands */ {
      /*! \brief The jump offset. */
      Index pc_offset;
    };
    struct /* Proj Operands */ {
      /*! \brief The register to project from. */
      RegName object;
      /*! \brief The field to read out. */
      Index field_index;
    };
193 194 195 196
    struct /* GetTag Operands */ {
      /*! \brief The register to project from. */
      RegName object;
    } get_tag;
197
    struct /* AllocADT Operands */ {
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
      /*! \brief The datatype's constructor tag. */
      Index constructor_tag;
      /*! \brief The number of fields to store in the datatype. */
      Index num_fields;
      /*! \brief The fields as an array. */
      RegName* datatype_fields;
    };
    struct /* AllocClosure Operands */ {
      /*! \brief The index into the function table. */
      Index clo_index;
      /*! \brief The number of free variables to capture. */
      Index num_freevar;
      /*! \brief The free variables as an array. */
      RegName* free_vars;
    };
213 214 215 216 217 218 219 220
    struct /* AllocStorage Operands */ {
      /*! \brief The size of the allocation. */
      RegName allocation_size;
      /*! \brief The alignment of the allocation. */
      RegName alignment;
      /*! \brief The hint of the dtype. */
      DLDataType dtype_hint;
    } alloc_storage;
221 222
  };

223 224 225 226 227
  /*!
   * \brief Construct a return instruction.
   * \param return_reg The register containing the return value.
   * \return The return instruction.
   */
228
  static Instruction Ret(RegName return_reg);
229 230 231 232
  /*!
   * \brief Construct a fatal instruction.
   * \return The fatal instruction.
   */
233
  static Instruction Fatal();
234 235 236 237 238 239 240
  /*!
   * \brief Construct a invoke packed instruction.
   * \param packed_index The index of the packed function.
   * \param arity The arity of the function.
   * \param output_size The number of outputs of the packed function.
   * \param args The argument registers.
   * \return The invoke packed instruction.
241 242 243
   */
  static Instruction InvokePacked(Index packed_index, Index arity, Index output_size,
                                  const std::vector<RegName>& args);
244 245 246 247 248 249 250
  /*!
   * \brief Construct an allocate tensor instruction with constant shape.
   * \param storage The storage to allocate out of.
   * \param shape The shape of the tensor.
   * \param dtype The dtype of the tensor.
   * \param dst The destination register.
   * \return The allocate tensor instruction.
251
   */
252 253
  static Instruction AllocTensor(RegName storage,
                                 const std::vector<int64_t>& shape, DLDataType dtype, RegName dst);
254 255 256 257 258 259 260
  /*!
   * \brief Construct an allocate tensor instruction with register.
   * \param storage The storage to allocate out of.
   * \param shape_register The register containing the shape.
   * \param dtype The dtype of the tensor.
   * \param dst The destination register.
   * \return The allocate tensor instruction.
261
   */
262 263
  static Instruction AllocTensorReg(RegName storage,
                                    RegName shape_register, DLDataType dtype, RegName dst);
264 265 266 267 268 269 270
  /*!
   * \brief Construct an allocate datatype instruction.
   * \param tag The datatype tag.
   * \param num_fields The number of fields for the datatype.
   * \param fields The registers containing the fields.
   * \param dst The register name of the destination.
   * \return The allocate instruction tensor.
271
   */
272
  static Instruction AllocADT(Index tag, Index num_fields, const std::vector<RegName>& fields,
273
                              RegName dst);
274 275 276 277 278 279 280
  /*!
   * \brief Construct an allocate closure instruction.
   * \param func_index The index of the function table.
   * \param num_freevar The number of free variables.
   * \param free_vars The registers of the free variables.
   * \param dst The destination register.
   * \return The allocate closure instruction.
281 282 283
   */
  static Instruction AllocClosure(Index func_index, Index num_freevar,
                                  const std::vector<RegName>& free_vars, RegName dst);
284 285 286 287 288 289
  /*!
   * \brief Construct a get field instruction.
   * \param object_reg The register containing the object to project from.
   * \param field_index The field to read out of the object.
   * \param dst The destination register.
   * \return The get field instruction.
290 291
   */
  static Instruction GetField(RegName object_reg, Index field_index, RegName dst);
292 293 294 295 296
  /*!
   * \brief Construct a get_tag instruction.
   * \param object_reg The register containing the object to project from.
   * \param dst The destination register.
   * \return The get_tag instruction.
297 298
   */
  static Instruction GetTag(RegName object_reg, RegName dst);
299 300 301 302 303 304 305
  /*!
   * \brief Construct an if instruction.
   * \param test The register containing the test value.
   * \param target The register containing the target value.
   * \param true_branch The offset to the true branch.
   * \param false_branch The offset to the false branch.
   * \return The if instruction.
306
   */
307
  static Instruction If(RegName test, RegName target, Index true_branch, Index false_branch);
308 309 310 311
  /*!
   * \brief Construct a goto instruction.
   * \param pc_offset The offset from the current pc.
   * \return The goto instruction.
312 313
   */
  static Instruction Goto(Index pc_offset);
314 315 316 317 318 319
  /*!
   * \brief Construct an invoke instruction.
   * \param func_index The index of the function to invoke.
   * \param args The registers containing the arguments.
   * \param dst The destination register.
   * \return The invoke instruction.
320 321
   */
  static Instruction Invoke(Index func_index, const std::vector<RegName>& args, RegName dst);
322 323 324 325 326 327
  /*!
   * \brief Construct an invoke closure instruction.
   * \param closure The register of the closure to invoke.
   * \param args The registers containing the arguments.
   * \param dst The destination register.
   * \return The invoke closure instruction.
328 329
   */
  static Instruction InvokeClosure(RegName closure, const std::vector<RegName>& args, RegName dst);
330 331 332 333 334
  /*!
   * \brief Construct a load constant instruction.
   * \param const_index The index of the constant.
   * \param dst The destination register.
   * \return The load constant instruction.
335 336
   */
  static Instruction LoadConst(Index const_index, RegName dst);
337 338 339 340 341
  /*!
   * \brief Construct a load_constanti instruction.
   * \param val The interger constant value.
   * \param dst The destination register.
   * \return The load_constanti instruction.
342
   */
343
  static Instruction LoadConsti(Index val, RegName dst);
344 345 346 347 348
  /*!
   * \brief Construct a move instruction.
   * \param src The source register.
   * \param dst The destination register.
   * \return The move instruction.
349 350 351
   */
  static Instruction Move(RegName src, RegName dst);

352 353 354 355 356 357 358
  /*!
   * \brief Allocate a storage block.
   * \param size The size of the allocation.
   * \param alignment The allocation's alignment.
   * \param dtype_hint The data type hint for the allocator.
   * \param dst The destination to place the storage.
   * \return The alloc storage instruction.
359 360 361 362
   */
  static Instruction AllocStorage(RegName size, RegName alignment,
                                  DLDataType dtype_hint, RegName dst);

363 364
  Instruction();
  Instruction(const Instruction& instr);
365
  Instruction& operator=(const Instruction& instr);
366 367 368 369 370
  ~Instruction();

  friend std::ostream& operator<<(std::ostream& os, const Instruction&);
};

371 372
/*!
 * \brief A representation of a Relay function in the VM.
373 374 375 376 377 378 379
 *
 * Contains metadata about the compiled function, as
 * well as the compiled VM instructions.
 */
struct VMFunction {
  /*! \brief The function's name. */
  std::string name;
380 381
  /*! \brief The function parameter names. */
  std::vector<std::string> params;
382 383 384 385 386
  /*! \brief The instructions representing the function. */
  std::vector<Instruction> instructions;
  /*! \brief The size of the frame for this function */
  Index register_file_size;

387
  VMFunction(const std::string& name, std::vector<std::string> params,
388 389 390 391 392 393 394 395 396 397 398 399
             const std::vector<Instruction>& instructions,
             Index register_file_size)
      : name(name),
        params(params),
        instructions(instructions),
        register_file_size(register_file_size) {}

  VMFunction() {}

  friend std::ostream& operator<<(std::ostream& os, const VMFunction&);
};

400 401
/*!
 * \brief A representation of a stack frame.
402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
 *
 * A stack frame is a record containing the information needed
 * to restore the caller's virtual machine state after returning
 * from a function call.
 */
struct VMFrame {
  /*! \brief The return program counter. */
  Index pc;
  /*! \brief The index into the function table, points to the caller. */
  Index func_index;
  /*! \brief The number of arguments. */
  Index args;
  /*! \brief A pointer into the caller function's instructions. */
  const Instruction* code;

  /*! \brief Statically allocated space for objects */
418
  std::vector<ObjectRef> register_file;
419 420 421 422 423 424 425 426 427 428 429 430 431

  /*! \brief Register in caller's frame to put return value */
  RegName caller_return_register;

  VMFrame(Index pc, Index func_index, Index args, const Instruction* code, Index register_file_size)
      : pc(pc),
        func_index(func_index),
        args(args),
        code(code),
        register_file(register_file_size),
        caller_return_register(0) {}
};

432 433
/*!
 * \brief The executable emitted by the VM compiler.
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454
 *
 * The executable contains information (e.g. data in different memory regions)
 * to run in a virtual machine.
 *
 *  - Global section, containing all globals.
 *  - Constant section, storing the constant pool.
 *  - Primitive name section, containing the function name of the primitive ops
 *  used by the virtual machine.
 *  - Code section, handling the VM functions and bytecode.
 */
class Executable : public ModuleNode {
 public:
  /*!
   * \brief Get a PackedFunc from an executable module.
   *
   * \param name the name of the function.
   * \param sptr_to_self The shared_ptr that points to this module node.
   *
   * \return PackedFunc or nullptr when it is not available.
   */
  PackedFunc GetFunction(const std::string& name,
455
                         const ObjectPtr<Object>& sptr_to_self) final;
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508

  /*!
   * \brief Serialize the executable into global section, constant section, and
   * code section.
   *
   * \return The binary representation of the VM.
   */
  TVMByteArray Save();

  /*!
   * \brief Load the saved VM executable.
   *
   * \param code The bytecode in string.
   * \param lib The compiled runtime library.
   *
   * \return exe The constructed executable.
   */
  static runtime::Module Load(const std::string& code, const runtime::Module lib);

  /*!
   * \brief Get the serialized form of the `functions`. This is
   * essentially bytecode serialization.
   *
   * \return The serialized vm bytecode.
   *
   * \note The bytecode is in the following format:
   *   func_name reg_file_size num_instructions
   *   param1 param2 ... paramM
   *   instruction1
   *   instruction2
   *   ...
   *   instructionN
   *
   * Each instruction is printed in the following format:
   *   opcode num_fields field1 ... fieldX # The text format.
   *
   * Serializing an `Instruction` requires us to deal with the bytecode. Each line
   * of the instructions could be serialized as the following format:
   *   hash, opcode, f1, f2, ..., fX, field with variable length
   *   1. hash: the hash of the instruction. This number will be used to help us
   * validate if an instruction is well-formed during deserialization.
   *   2. opcode: the opcode code of the instruction.
   *   3. f1, f2, ..., fX. These fields together represent the fixed fields in
   * an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For
   * example, `DLDataType` will be unpacked into three fields (code, bits, lanes).
   *   4. The rest of the line indicates the field with variable length, e.g.,
   * the shape of a tensor, the args used by an `InvokPacked` instruction, etc.

   * The field starting from # is only used for debugging. The serialized code
   * doesn't contain it, therefore the deserializer doens't need to handle it.
   */
  std::string GetBytecode() const;

509
  /*!
510 511 512 513 514
   * \brief Print the detailed statistics of the given code, i.e. number of
   * globls and constants, etc.
   */
  std::string Stats() const;

515 516
  /*!
   * \brief Get the `lib` module in an executable. Users have the flexibility to call
517 518 519 520 521 522
   * `export_library` from the frontend to save the library to disk.
   *
   * \return The runtime module that contains the hardwre dependent code.
   */
  runtime::Module GetLib() const { return lib; }

523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
  /*!
   * \brief Get the arity of the VM Fucntion.
   * \param func Function name.
   * \return The number of parameters.
   */
  int GetFunctionArity(std::string func) const;

  /*!
   * \brief Get the parameter name given the function name and parameter index.
   * \param func Function name.
   * \param index Parameter index.
   * \return The parameter name.
   */
  std::string GetFunctionParameterName(std::string func, uint32_t index) const;

538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618
  virtual ~Executable() {}

  const char* type_key() const final {
    return "VMExecutable";
  }

  /*! \brief The runtime module/library that contains both the host and also the device
   * code when executing on non-CPU devices. */
  runtime::Module lib;
  /*! \brief The global constant pool. */
  std::vector<ObjectRef> constants;
  /*! \brief A map from globals (as strings) to their index in the function map. */
  std::unordered_map<std::string, Index> global_map;
  /*! \brief A mapping from the packed function (as string) to the index that
   * corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object.
   */
  std::unordered_map<std::string, Index> primitive_map;
  /*! \brief The virtual machine's function table. */
  std::vector<VMFunction> functions;

 private:
  /*!
   * \brief Save the globals.
   *
   * \param strm The input stream.
   */
  void SaveGlobalSection(dmlc::Stream* strm);

  /*!
   * \brief Save the constant pool.
   *
   * \param strm The input stream.
   */
  void SaveConstantSection(dmlc::Stream* strm);

  /*!
   * \brief Save primitive op names.
   *
   *  \param strm The input stream.
   */
  void SavePrimitiveOpNames(dmlc::Stream* strm);

  /*!
   * \brief Save the vm functions.
   *
   * \param strm The input stream.
   */
  void SaveCodeSection(dmlc::Stream* strm);

  /*!
   * \brief Load the globals.
   *
   * \param strm The input stream.
   */
  void LoadGlobalSection(dmlc::Stream* strm);

  /*!
   * \brief Load the constant pool.
   *
   * \param strm The input stream.
   */
  void LoadConstantSection(dmlc::Stream* strm);

  /*!
   * \brief Load primitive op names.
   *
   * \param strm The input stream.
   */
  void LoadPrimitiveOpNames(dmlc::Stream* strm);

  /*!
   * \brief Load the vm functions.
   *
   * \param strm The input stream.
   */
  void LoadCodeSection(dmlc::Stream* strm);

  /*! \brief The serialized bytecode. */
  std::string code_;
};

619 620
/*!
 * \brief The virtual machine.
621 622
 *
 * The virtual machine contains all the current execution state,
623
 * as well as the executable.
624 625 626
 *
 * The goal is to have a single self-contained object,
 * enabling one to easily pass around VMs, execute them on
627
 * multiple threads, or serialize them to disk or over the
628 629
 * wire.
 */
630 631
class VirtualMachine : public runtime::ModuleNode {
 public:
632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649
  /*!
   * \brief Get a PackedFunc from module.
   *
   *  The PackedFunc may not be fully initialized,
   *  there might still be first time running overhead when
   *  executing the function on certain devices.
   *  For benchmarking, use prepare to eliminate
   *
   * \param name the name of the function.
   * \param sptr_to_self The shared_ptr that points to this module node.
   *
   * \return PackedFunc(nullptr) when it is not available.
   *
   * \note The function will always remain valid.
   *   If the function needs resource from the module(e.g. late linking),
   *   it should capture sptr_to_self.
   */
  virtual PackedFunc GetFunction(const std::string& name,
650
                                 const ObjectPtr<Object>& sptr_to_self);
651 652

  virtual ~VirtualMachine() {}
653 654 655 656 657

  const char* type_key() const final {
    return "VirtualMachine";
  }

658
  VirtualMachine() : frames_(), func_index_(0), code_(nullptr), pc_(0), exec_(nullptr) {}
659

660 661 662
  /*!
   * \brief load the executable for the virtual machine.
   * \param exec The executable.
663
   */
664
  virtual void LoadExecutable(const Executable* exec);
665 666

 protected:
667
  /*! \brief The virtual machine's packed function table. */
668
  std::vector<PackedFunc> packed_funcs_;
669
  /*! \brief The current stack of call frames. */
670
  std::vector<VMFrame> frames_;
671
  /*! \brief The fuction table index of the current function. */
672
  Index func_index_;
673
  /*! \brief The current pointer to the code section. */
674
  const Instruction* code_;
675
  /*! \brief The virtual machine PC. */
676
  Index pc_;
677
  /*! \brief The special return register. */
678
  ObjectRef return_register_;
679
  /*! \brief The executable the VM will operate on. */
680 681 682
  const Executable* exec_;
  /*! \brief The function name to inputs mapping. */
  std::unordered_map<std::string, std::vector<ObjectRef>> inputs_;
683
  /*! \brief The set of TVM contexts the VM is currently executing on. */
684
  std::vector<TVMContext> ctxs_;
685 686 687

  /*! \brief Push a call frame on to the call stack. */
  void PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func);
688 689 690 691

  /*!
   * \brief Pop a frame off the call stack.
   * \return The number of frames left.
692 693 694
   */
  Index PopFrame();

695 696 697 698
  /*!
   * \brief Write to a VM register.
   * \param reg The register to write to.
   * \param obj The object to write to.
699
   */
700
  inline void WriteRegister(RegName reg, const ObjectRef& obj);
701

702 703 704 705
  /*!
   * \brief Read a VM register.
   * \param reg The register to read from.
   * \return The read object.
706
   */
707
  inline ObjectRef ReadRegister(RegName reg) const;
708

709 710 711 712
  /*!
   * \brief Read a VM register and cast it to int32_t
   * \param reg The register to read from.
   * \return The read scalar.
713 714 715
   */
  int32_t LoadScalarInt(RegName reg) const;

716 717
  /*!
   * \brief Invoke a VM function.
718 719 720 721
   * \param func The function.
   * \param args The arguments to the function.
   * \return The object representing the result.
   */
722
  ObjectRef Invoke(const VMFunction& func, const std::vector<ObjectRef>& args);
723 724

  // TODO(@jroesch): I really would like this to be a global variable.
725 726
  /*!
   * \brief Invoke a VM function by name.
727 728 729 730
   * \param name The function's name.
   * \param args The arguments to the function.
   * \return The object representing the result.
   */
731
  ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args);
732

733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752
  /*!
   * \brief Invoke a PackedFunction
   *
   * \param packed_index The offset of the PackedFunction in all functions.
   * \param func The PackedFunction to be invoked.
   * \param arg_count The number of arguments to the PackedFunction.
   * \param output_size The number of outputs of the PackedFunction.
   * \param args Arguments to the PackedFunction.
   *
   * \note The return value will be stored in the last output_size slots of args.
   */
  virtual void InvokePacked(Index packed_index,
                            const PackedFunc& func,
                            Index arg_count,
                            Index output_size,
                            const std::vector<ObjectRef>& args);

  /*!
   * \brief Initialize the virtual machine for a set of contexts.
   * \param contexts The set of TVM contexts.
753 754
   */
  void Init(const std::vector<TVMContext>& contexts);
755

756
  /*! \brief Run VM dispatch loop. */
757
  void RunLoop();
758

759
  /*! \brief Get device context for params. */
760 761
  TVMContext GetParamsContext() const;

762
 private:
763 764
  /*!
   * \brief Invoke a global setting up the VM state to execute.
765 766 767
   *
   * This does not begin execution of the VM.
   */
768
  void InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args);
769

770 771 772 773 774
  /*!
   * \brief The constant pool for runtime. It caches the device dependent
   * object to avoid rellocation of constants during inference.
   */
  std::vector<ObjectRef> const_pool_;
775 776 777 778 779 780 781
};

}  // namespace vm
}  // namespace runtime
}  // namespace tvm

#endif  // TVM_RUNTIME_VM_H_