stackvm.h 10.3 KB
Newer Older
1 2
/*!
 *  Copyright (c) 2016 by Contributors
3
 * \file stackvm.h
4 5 6 7
 * \brief A simple stack-based virtual machine.
 *
 *  This can be used to interepret host side code
 *  to setup calls into device functions
8
 *  when only Runtime compilation for device is available(via NVRTC or OpenCL).
9
 */
10 11
#ifndef TVM_RUNTIME_STACKVM_STACKVM_H_
#define TVM_RUNTIME_STACKVM_STACKVM_H_
12 13

#include <tvm/runtime/c_runtime_api.h>
14
#include <tvm/runtime/packed_func.h>
15
#include <tvm/runtime/module.h>
16 17 18 19
#include <string>
#include <vector>

namespace tvm {
20
namespace runtime {
21

22
using runtime::operator<<;
23
/*!
24
 * \brief A simple stack-based virtual machine program.
25 26 27 28
 */
class StackVM {
 public:
  /*!
29
   * \brief Invoke the StackVM program.
30
   * \param args The arguments to the StackVM.
31
   * \param mod_ctx The module context used in running.
32
   */
33
  void Run(const TVMArgs& args, runtime::ModuleNode* mod_ctx) const;
34
  /*!
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
   * \brief The opcode of stack vm
   * \note Notation
   *  - sp Stack pointer
   *  - pc Program pointer
   */
  enum OpCode {
    // integer ops
    ADD_I64,
    SUB_I64,
    MUL_I64,
    DIV_I64,
    MOD_I64,
    EQ_I64,
    LT_I64,
    LE_I64,
    // floating ops
    ADD_F64,
    SUB_F64,
    MUL_F64,
    DIV_F64,
    EQ_F64,
    LT_F64,
    LE_F64,
58 59
    // Pointer comparison
    EQ_HANDLE,
60
    /*!
61
     * \brief Routine to load data from address with const offset.
62
     * \code
63
     *  stack[sp].v_int64 = ((DType*)stack[sp].v_handle)[code[pc + 1].v_int];
64 65 66 67
     *  pc = pc + 2;
     * \endcode
     */
    ARRAY_LOAD_UINT32,
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    ARRAY_LOAD_INT32,
    ARRAY_LOAD_INT64,
    ARRAY_LOAD_FP64,
    ARRAY_LOAD_HANDLE,
    ARRAY_LOAD_TVMVALUE,
    /*!
     * \brief Routine to store data from constant offset.
     * \code
     *  ((DType*)stack[sp - 1].v_handle)[code[pc + 1].v_int] = stack[sp];
     *  pc = pc + 2;
     *  sp = sp - 2;
     * \endcode
     */
    ARRAY_STORE_UINT32,
    ARRAY_STORE_INT32,
    ARRAY_STORE_INT64,
    ARRAY_STORE_FP64,
    ARRAY_STORE_HANDLE,
    ARRAY_STORE_TVMVALUE,
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
    // logical ops
    NOT,
    /*!
     * \brief Add address by an offset.
     * \code
     *  stack[sp - 1].v_handle = ((char*)stack[sp - 1].v_handle + stack[sp].v_int64);
     *  sp = sp - 1;
     * \endcode
     */
    ADDR_ADD,
    /*!
     * \brief push integer fetched from next pc position into stack
     * \code
     *  stack[sp + 1].v_int64 = code[pc + 1].v_int;
     *  pc = pc + 2;
     *  sp = sp + 1;
     * \endcode
     */
    PUSH_I64,
    /*!
     * \brief push a value given relative index on the stack
     * \code
     *  stack[sp + 1] = stack[sp + code[pc + 1].v_int];
     *  pc = pc + 2;
     *  sp = sp + 1;
     * \endcode
     */
    PUSH_VALUE,
    /*!
     * \brief Load data from heap to top of stack
     * \code
     *  stack[sp + 1] = heap[code[pc + 1].v_int];
     *  pc = pc + 2;
     *  sp = sp + 1;
     * \endcode
     */
    LOAD_HEAP,
    /*!
     * \brief Store data to heap
     * \code
     *  heap[code[pc + 1].v_int] = stack[sp];
     *  sp = sp - 1;
     * \endcode
     */
    STORE_HEAP,
    /*! \brief pop value from top of the stack */
    POP,
    /*!
     * \brief select based on operands.
     * \code
     *  stack[sp - 2] = stack[sp].v_int64 ? stack[sp - 2] : stack[sp - 1]
     *  sp = sp - 2;
     * \endcode
     */
    SELECT,
    /*!
     * \brief Assert condition is true.
     * \code
     *  CHECK(stack[sp]) << str_data[code[pc + 1].v_int];
     *  sp = sp - 1;
     * \endcode
     */
    ASSERT,
    /*!
     * \brief Relative Jump if the condition is true,
     *  Does not change the stack status.
     * \code
     *  if (stack[sp]) {
     *    pc += code[pc + 1].v_int
     *  } else {
     *    pc = pc + 2;
     *  }
     * \endcode
     */
    RJUMP_IF_TRUE,
    /*!
     * \brief Relative Jump if the condition is true,
     *  Does not change the stack status.
     * \code
     *  if (stack[sp]) {
     *    pc += code[pc + 1].v_int
     *  } else {
     *    pc = pc + 2;
     *  }
     * \endcode
     */
    RJUMP_IF_FALSE,
    /*!
     * \brief Relative jump to a location.
     * \code
     *  pc += code[pc + 1].v_int;
     * \endcode
     */
    RJUMP,
    /*!
     * \brief debug instruction.
     * \code
     *  CHECK_EQ(sp, code[pc + 1]).v_int;
     *  pc += 2;
     * \code
     */
    ASSERT_SP,
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
    /*!
     * \brief call an extern packed function
     * \code
     *  value_stack = stack[sp - 1].v_handle;
     *  type_stack = stack[sp - 0].v_handle;
     *  call_fid = code[pc + 1].v_int;
     *  begin = code[pc + 2].v_int;
     *  end = code[pc + 3].v_int;
     *  num_args = end - begin - 1;
     *  f = extern_func[call_fid];
     *  stack[sp - 1] = f(&value_stack[begin:end-1], type_stack[begin:end-1], num_args);
     *  sp = sp - 1;
     *  // The type codes are hidden in the code space.
     *  pc = pc + 4
     * \endcode
     */
    CALL_PACKED_LOWERED,
    // Allocate things on stack
    /*!
     * \brief allocate data from stack.
     * \code
     *  num = code[pc + 1].v_int;
     *  void* addr = &stack[sp];
     *  sp = sp + num;
     *  stack[sp].v_handle = addr;
     *  pc = pc + 1;
     * \endcode
     */
    TVM_STACK_ALLOCA_BY_8BYTE,
    /*!
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
     * \brief allocate data from device.
     * \code
     *  device_type = stack[sp - 2].v_int64;
     *  device_id = stack[sp - 1].v_int64;
     *  nbytes = stack[sp].v_int64;
     *  stack[sp - 2].v_handle = device_alloca(device_type, device_id, nbytes);
     *  sp = sp - 2;
     *  pc = pc + 1;
     * \endcode
     */
    TVM_DEVICE_ALLOCA,
    /*!
     * \brief free data into device.
     * \code
     *  device_type = stack[sp - 2].v_int64;
     *  device_id = stack[sp - 1].v_int64;
     *  ptr = stack[sp].v_handle;
     *  stack[sp - 2].v_int64 = device_free(device_type, device_id, ptr);
     *  sp = sp - 2;
     *  pc = pc + 1;
     * \endcode
     */
    TVM_DEVICE_FREE,
    /*!
     * \brief throw last error
     */
    TVM_THROW_LAST_ERROR,
    /*!
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
     * \brief get data from structure.
     * \code
     *  index = code[pc + 1].v_int;
     *  field = code[pc + 2].v_int;
     *  stack[sp] = ((StructType*)stack[sp].v_handle)[index]->field;
     *  pc = pc + 3
     * \endcode
     */
    TVM_STRUCT_GET,
    /*!
     * \brief set data into structure.
     * \code
     *  index = code[pc + 1].v_int;
     *  field = code[pc + 2].v_int;
     *  ((StructType*)stack[sp - 1].v_handle)[index]->field = stack[sp];
     *  pc = pc + 3
     *  sp = sp - 1
     * \endcode
     */
    TVM_STRUCT_SET
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
  };
  /*! \brief The code structure */
  union Code {
    OpCode op_code;
    int v_int;
  };
  /*! \brief The state object of StackVM */
  struct State {
    /*! \brief The execution stack */
    std::vector<TVMValue> stack;
    /*! \brief The global heap space */
    std::vector<TVMValue> heap;
    /*! \brief stack pointer  */
    int64_t sp{0};
    /*! \brief program counter */
    int64_t pc{0};
283 284
    /*! \brief The current module context of stackvm */
    runtime::ModuleNode* mod_ctx{nullptr};
285
  };
286 287 288 289 290 291 292 293 294 295 296 297
  /*! \brief Initialize local cache*/
  void InitCache();
  /*!
   * \brief Save stackvm program to an output stream
   * \param strm The output stream
   */
  void Save(dmlc::Stream* strm) const;
  /*!
   * \brief Load stackvm program from output stream
   * \param strm The output stream
   */
  bool Load(dmlc::Stream* strm);
298 299 300 301 302 303 304 305 306
  /*!
   * \brief Print instruction at location pc
   * \param os The ostream
   * \param pc The pc
   * \return the pc to next instruction.
   */
  int64_t PrintCode(std::ostream&os, int64_t pc) const;  // NOLINT(*)
  /*! \brief Get thread local state of the stack VM */
  static State* ThreadLocalState();
307
  // The code below are programs
308 309 310 311
  /*! \brief The instructions */
  std::vector<Code> code;
  /*! \brief constant error messages */
  std::vector<std::string> str_data;
312 313 314
  /*! \brief Extern functions */
  std::vector<std::string> extern_func_name;
  /*! \brief name of each heap id */
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
  std::vector<std::string> heap_id_name;
  /*! \brief The memory size needed */
  size_t heap_size{0};
  /*! \brief The stack size required */
  size_t stack_size{1024};
  /*!
   * \brief Convert I64 opcode to F64 Ones
   * \param code The op code.
   * \return the F64 op code.
   */
  static OpCode CodeI64ToF64(OpCode code) {
    switch (code) {
      case ADD_I64: return ADD_F64;
      case SUB_I64: return SUB_F64;
      case MUL_I64: return MUL_F64;
      case DIV_I64: return DIV_F64;
      case EQ_I64: return EQ_F64;
      case LT_I64: return LT_F64;
      case LE_I64: return LE_F64;
334
      case MOD_I64: LOG(FATAL) << "cannot handle mod for float"; return ADD_F64;
335 336 337 338 339 340 341 342
      default: LOG(FATAL) << "cannot handle op " << code; return ADD_F64;
    }
  }
  /*!
   * \brief Get load opcode for type t
   * \param t the type code.
   * \return The load opcode
   */
343 344
  static OpCode GetLoad(TVMType t) {
    CHECK_EQ(t.lanes, 1U);
345
    if (t.code == kHandle) return ARRAY_LOAD_HANDLE;
346
    if (t.code == kDLInt) {
347
      switch (t.bits) {
348 349
        case 32 : return ARRAY_LOAD_INT32;
        case 64 : return ARRAY_LOAD_INT64;
350
      }
351
    } else if (t.code == kDLUInt) {
352
      switch (t.bits) {
353
        case 32 : return ARRAY_LOAD_UINT32;
354
      }
355
    } else if (t.code == kDLFloat) {
356
      switch (t.bits) {
357
        case 64 : return ARRAY_LOAD_FP64;
358 359 360
      }
    }
    LOG(FATAL) << "Cannot load type " << t;
361
    return ARRAY_LOAD_FP64;
362 363 364 365 366 367
  }
  /*!
   * \brief Get store opcode for type t
   * \param t the type code.
   * \return The load opcode
   */
368 369
  static OpCode GetStore(TVMType t) {
    CHECK_EQ(t.lanes, 1U);
370
    if (t.code == kHandle) return ARRAY_STORE_HANDLE;
371
    if (t.code == kDLInt) {
372
      switch (t.bits) {
373 374 375
        case 32 : return ARRAY_STORE_INT32;
        case 64 : return ARRAY_STORE_INT64;
      }
376
    } else if (t.code == kDLUInt) {
377 378 379
      switch (t.bits) {
        case 32 : return ARRAY_STORE_UINT32;
      }
380
    } else if (t.code == kDLFloat) {
381 382
      switch (t.bits) {
        case 64 : return ARRAY_STORE_FP64;
383 384 385
      }
    }
    LOG(FATAL) << "Cannot store type " << t;
386
    return ARRAY_STORE_FP64;
387 388
  }
  friend std::ostream& operator<<(std::ostream& os, const StackVM& vm);  // NOLINT(*)
389 390

 private:
391 392
  //  execute the stack vm with given state
  void Run(State* state) const;
393 394
  // get extern function.
  const PackedFunc& GetExtern(State* s, int fid) const;
395 396
  // cached extern function
  mutable std::vector<PackedFunc> extern_func_cache_;
397 398
};

399
}  // namespace runtime
400
}  // namespace tvm
401 402 403 404 405

namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::StackVM, true);
}
#endif  // TVM_RUNTIME_STACKVM_STACKVM_H_