vm.h 14.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2018 by Contributors
 * \file tvm/runtime/vm.h
 * \brief A virtual machine for executing Relay programs.
 */
#ifndef TVM_RUNTIME_VM_H_
#define TVM_RUNTIME_VM_H_

#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

namespace tvm {
namespace runtime {
namespace vm {

/*! \brief A register name. */
using RegName = int64_t;

/*! \brief An alias for the integer type used ubiquitously
 * in the VM.
 */
using Index = int64_t;

/*! \brief An enumeration of Relay's opcodes.
 *
 * The opcode is used to implement instruction
 * as a tagged union.
 */
enum class Opcode {
  Move = 0U,
  Ret = 1U,
  Invoke = 2U,
  InvokeClosure = 3U,
  InvokePacked = 4U,
  AllocTensor = 5U,
59 60 61 62 63 64 65 66
  AllocTensorReg = 6U,
  AllocDatatype = 7U,
  AllocClosure = 8U,
  GetField = 9U,
  If = 10U,
  Select = 11U,
  LoadConst = 12U,
  Goto = 13U
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
};

/*! \brief A single virtual machine instruction.
 *
 * The representation of the instruction is as
 * a tagged union.
 *
 * The first field represents which instruction,
 * and by extension which field of the union
 * is active.
 */
struct Instruction {
  /*! \brief The instruction opcode. */
  Opcode op;

  /*! \brief The destination register. */
  RegName dst;

  union {
    struct /* AllocTensor Operands */ {
87 88 89 90 91 92 93 94
      /*! \brief The number of dimensions. */
      uint32_t ndim;
      /*! \brief The shape of tensor. */
      int64_t* shape;
      /*! \brief The datatype of tensor to be allocated. */
      DLDataType dtype;
    } alloc_tensor;
    struct /* AllocTensorReg Operands */ {
95 96 97 98
      /*! \brief The register to read the shape out of. */
      RegName shape_register;
      /*! \brief The datatype of tensor to be allocated. */
      DLDataType dtype;
99
    } alloc_tensor_reg;
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
    struct /* InvokeClosure Operands */ {
      /*! \brief The register containing the closure. */
      RegName closure;
      /*! \brief The number of arguments to the closure. */
      Index closure_args_num;
      /*! \brief The closure arguments as an array. */
      RegName* closure_args;
    };
    struct /* Return Operands */ {
      /*! \brief The register to return. */
      RegName result;
    };
    struct /* Move Operands */ {
      /*! \brief The source register for a move operation. */
      RegName from;
    };
    struct /* Packed Operands */ {
      /*! \brief The index into the packed function table. */
      Index packed_index;
      /*! \brief The arity of the packed function. */
      Index arity;
      /*! \brief The number of outputs produced by the packed function. */
      Index output_size;
      /*! \brief The arguments to pass to the packed function. */
      RegName* packed_args;
    };
    struct /* Select Operands */ {
      /*! \brief The condition of select. */
      RegName select_cond;
      /*! \brief The true branch. */
      RegName select_op1;
      /*! \brief The false branch. */
      RegName select_op2;
    };
    struct /* If Operands */ {
      /*! \brief The register containing the condition value. */
      RegName if_cond;
      /*! \brief The program counter offset for the true branch. */
      Index true_offset;
      /*! \brief The program counter offset for the false branch. */
      Index false_offset;
    };
    struct /* Invoke Operands */ {
      /*! \brief The function to call. */
      Index func_index;
      /*! \brief The number of arguments to the function. */
      Index num_args;
      /*! \brief The registers containing the arguments. */
      RegName* invoke_args_registers;
    };
    struct /* Const Operands */ {
      /* \brief The index into the constant pool. */
      Index const_index;
    };
    struct /* Jump Operands */ {
      /*! \brief The jump offset. */
      Index pc_offset;
    };
    struct /* Proj Operands */ {
      /*! \brief The register to project from. */
      RegName object;
      /*! \brief The field to read out. */
      Index field_index;
    };
    struct /* AllocDatatype Operands */ {
      /*! \brief The datatype's constructor tag. */
      Index constructor_tag;
      /*! \brief The number of fields to store in the datatype. */
      Index num_fields;
      /*! \brief The fields as an array. */
      RegName* datatype_fields;
    };
    struct /* AllocClosure Operands */ {
      /*! \brief The index into the function table. */
      Index clo_index;
      /*! \brief The number of free variables to capture. */
      Index num_freevar;
      /*! \brief The free variables as an array. */
      RegName* free_vars;
    };
  };

  /*! \brief Construct a select instruction.
   *  \param cond The condition register.
   *  \param op1 The true register.
   *  \param op2 The false register.
   *  \param dst The destination register.
   *  \return The select instruction.
   */
  static Instruction Select(RegName cond, RegName op1, RegName op2, RegName dst);
  /*! \brief Construct a return instruction.
   *  \param return_reg The register containing the return value.
   *  \return The return instruction.
   * */
  static Instruction Ret(RegName return_reg);
  /*! \brief Construct a invoke packed instruction.
   *  \param packed_index The index of the packed function.
   *  \param arity The arity of the function.
   *  \param output_size The number of outputs of the packed function.
   *  \param args The argument registers.
   *  \return The invoke packed instruction.
   */
  static Instruction InvokePacked(Index packed_index, Index arity, Index output_size,
                                  const std::vector<RegName>& args);
204 205 206 207 208 209 210 211
  /*! \brief Construct an allocate tensor instruction with constant shape.
   *  \param shape The shape of the tensor.
   *  \param dtype The dtype of the tensor.
   *  \param dst The destination register.
   *  \return The allocate tensor instruction.
   */
  static Instruction AllocTensor(std::vector<int64_t> shape, DLDataType dtype, RegName dst);
  /*! \brief Construct an allocate tensor instruction with register.
212 213 214 215 216
   *  \param shape_register The register containing the shape.
   *  \param dtype The dtype of the tensor.
   *  \param dst The destination register.
   *  \return The allocate tensor instruction.
   */
217
  static Instruction AllocTensorReg(RegName shape_register, DLDataType dtype, RegName dst);
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
  /*! \brief Construct an allocate datatype instruction.
   *  \param tag The datatype tag.
   *  \param num_fields The number of fields for the datatype.
   *  \param fields The registers containing the fields.
   *  \param dst The register name of the destination.
   *  \return The allocate instruction tensor.
   */
  static Instruction AllocDatatype(Index tag, Index num_fields, const std::vector<RegName>& fields,
                                   RegName dst);
  /*! \brief Construct an allocate closure instruction.
   *  \param func_index The index of the function table.
   *  \param num_freevar The number of free variables.
   *  \param free_vars The registers of the free variables.
   *  \param dst The destination register.
   *  \return The allocate closure instruction.
   */
  static Instruction AllocClosure(Index func_index, Index num_freevar,
                                  const std::vector<RegName>& free_vars, RegName dst);
  /*! \brief Construct a get field instruction.
   *  \param object_reg The register containing the object to project from.
   *  \param field_index The field to read out of the object.
   *  \param dst The destination register.
   *  \return The get field instruction.
   */
  static Instruction GetField(RegName object_reg, Index field_index, RegName dst);
  /*! \brief Construct an if instruction.
   *  \param cond_reg The register containing the condition.
   *  \param true_branch The offset to the true branch.
   *  \param false_branch The offset to the false branch.
   *  \return The if instruction.
   */
  static Instruction If(RegName cond_reg, Index true_branch, Index false_branch);
  /*! \brief Construct a goto instruction.
   *  \param pc_offset The offset from the current pc.
   *  \return The goto instruction.
   */
  static Instruction Goto(Index pc_offset);
  /*! \brief Construct an invoke instruction.
   *  \param func_index The index of the function to invoke.
   *  \param args The registers containing the arguments.
   *  \param dst The destination register.
   *  \return The invoke instruction.
   */
  static Instruction Invoke(Index func_index, const std::vector<RegName>& args, RegName dst);
  /*! \brief Construct an invoke closure instruction.
   *  \param closure The register of the closure to invoke.
   *  \param args The registers containing the arguments.
   *  \param dst The destination register.
   *  \return The invoke closure instruction.
   */
  static Instruction InvokeClosure(RegName closure, const std::vector<RegName>& args, RegName dst);
  /*! \brief Construct a load constant instruction.
   *  \param const_index The index of the constant.
   *  \param dst The destination register.
   *  \return The load constant instruction.
   */
  static Instruction LoadConst(Index const_index, RegName dst);
  /*! \brief Construct a move instruction.
   *  \param src The source register.
   *  \param dst The destination register.
   *  \return The move instruction.
   */
  static Instruction Move(RegName src, RegName dst);

  Instruction();
  Instruction(const Instruction& instr);
284
  Instruction& operator=(const Instruction& instr);
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
  ~Instruction();

  friend std::ostream& operator<<(std::ostream& os, const Instruction&);
};

/*! \brief A representation of a Relay function in the VM.
 *
 * Contains metadata about the compiled function, as
 * well as the compiled VM instructions.
 */
struct VMFunction {
  /*! \brief The function's name. */
  std::string name;
  /*! \brief The number of function parameters. */
  Index params;
  /*! \brief The instructions representing the function. */
  std::vector<Instruction> instructions;
  /*! \brief The size of the frame for this function */
  Index register_file_size;

  VMFunction(const std::string& name, Index params,
             const std::vector<Instruction>& instructions,
             Index register_file_size)
      : name(name),
        params(params),
        instructions(instructions),
        register_file_size(register_file_size) {}

  VMFunction() {}

  friend std::ostream& operator<<(std::ostream& os, const VMFunction&);
};

/*! \brief A representation of a stack frame.
 *
 * A stack frame is a record containing the information needed
 * to restore the caller's virtual machine state after returning
 * from a function call.
 */
struct VMFrame {
  /*! \brief The return program counter. */
  Index pc;
  /*! \brief The index into the function table, points to the caller. */
  Index func_index;
  /*! \brief The number of arguments. */
  Index args;
  /*! \brief A pointer into the caller function's instructions. */
  const Instruction* code;

  /*! \brief Statically allocated space for objects */
  std::vector<Object> register_file;

  /*! \brief Register in caller's frame to put return value */
  RegName caller_return_register;

  VMFrame(Index pc, Index func_index, Index args, const Instruction* code, Index register_file_size)
      : pc(pc),
        func_index(func_index),
        args(args),
        code(code),
        register_file(register_file_size),
        caller_return_register(0) {}
};

/*! \brief The virtual machine.
 *
 * The virtual machine contains all the current execution state,
 * as well as the global view of functions, the global constant
 * table, the compiled operators.
 *
 * The goal is to have a single self-contained object,
 * enabling one to easily pass around VMs, execute them on
 * multiple threads, or serialized them to disk or over the
 * wire.
 */
struct VirtualMachine {
  /*! \brief The virtual machine's packed function table. */
  std::vector<PackedFunc> packed_funcs;
  /*! \brief The virtual machine's function table. */
  std::vector<VMFunction> functions;
  /*! \brief The current stack of call frames. */
  std::vector<VMFrame> frames;
  /*! \brief The global constant pool. */
  std::vector<Object> constants;
  /*! \brief The fuction table index of the current function. */
  Index func_index;
  /*! \brief The current pointer to the code section. */
  const Instruction* code;
  /*! \brief The virtual machine PC. */
  Index pc;

  /*! \brief The special return register. */
  Object return_register;

  /*! \brief The set of TVM contexts the VM is currently executing on. */
  std::vector<TVMContext> ctxs;

  /*! \brief Push a call frame on to the call stack. */
  void PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func);
  /*! \brief Pop a frame off the call stack.
   *  \return The number of frames left.
   */
  Index PopFrame();

  /*! \brief Write to a VM register.
   *  \param reg The register to write to.
   *  \param obj The object to write to.
   */
  inline void WriteRegister(RegName reg, const Object& obj);

  /*! \brief Read a VM register.
   *  \param reg The register to read from.
   *  \return The read object.
   */
  inline Object ReadRegister(RegName reg) const;

  /*! \brief Invoke a VM function.
   * \param func The function.
   * \param args The arguments to the function.
   * \return The object representing the result.
   */
  Object Invoke(const VMFunction& func, const std::vector<Object>& args);

  // TODO(@jroesch): I really would like this to be a global variable.
  /*! \brief Invoke a VM function by name.
   * \param name The function's name.
   * \param args The arguments to the function.
   * \return The object representing the result.
   */
  Object Invoke(const std::string& name, const std::vector<Object>& args);

  VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {}

  /*! \brief Initialize the virtual machine for a set of contexts.
   *  \param contexts The set of TVM contexts.
   */
  void Init(const std::vector<TVMContext>& contexts);
  void Run();

  /*! \brief A map from globals (as strings) to their index in the function map.
   */
  std::unordered_map<std::string, Index> global_map_;

 private:
  /*! \brief Invoke a global setting up the VM state to execute.
   *
   * This does not begin execution of the VM.
   */
  void InvokeGlobal(const VMFunction& func, const std::vector<Object>& args);
};

}  // namespace vm
}  // namespace runtime
}  // namespace tvm

#endif  // TVM_RUNTIME_VM_H_