stackvm.h 11.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21
/*!
 *  Copyright (c) 2016 by Contributors
22
 * \file stackvm.h
23 24 25 26
 * \brief A simple stack-based virtual machine.
 *
 *  This can be used to interepret host side code
 *  to setup calls into device functions
27
 *  when only Runtime compilation for device is available(via NVRTC or OpenCL).
28
 */
29 30
#ifndef TVM_RUNTIME_STACKVM_STACKVM_H_
#define TVM_RUNTIME_STACKVM_STACKVM_H_
31 32

#include <tvm/runtime/c_runtime_api.h>
33
#include <tvm/runtime/packed_func.h>
34
#include <tvm/runtime/module.h>
35 36 37 38
#include <string>
#include <vector>

namespace tvm {
39
namespace runtime {
40

41
using runtime::operator<<;
42
/*!
43
 * \brief A simple stack-based virtual machine program.
44 45 46 47
 */
class StackVM {
 public:
  /*!
48
   * \brief Invoke the StackVM program.
49
   * \param args The arguments to the StackVM.
50
   * \param mod_ctx The module context used in running.
51
   */
52
  void Run(const TVMArgs& args, runtime::ModuleNode* mod_ctx) const;
53
  /*!
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
   * \brief The opcode of stack vm
   * \note Notation
   *  - sp Stack pointer
   *  - pc Program pointer
   */
  enum OpCode {
    // integer ops
    ADD_I64,
    SUB_I64,
    MUL_I64,
    DIV_I64,
    MOD_I64,
    EQ_I64,
    LT_I64,
    LE_I64,
    // floating ops
    ADD_F64,
    SUB_F64,
    MUL_F64,
    DIV_F64,
    EQ_F64,
    LT_F64,
    LE_F64,
77 78
    // Pointer comparison
    EQ_HANDLE,
79
    /*!
80
     * \brief Routine to load data from address with const offset.
81
     * \code
82
     *  stack[sp].v_int64 = ((DType*)stack[sp].v_handle)[code[pc + 1].v_int];
83 84 85 86
     *  pc = pc + 2;
     * \endcode
     */
    ARRAY_LOAD_UINT32,
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
    ARRAY_LOAD_INT32,
    ARRAY_LOAD_INT64,
    ARRAY_LOAD_FP64,
    ARRAY_LOAD_HANDLE,
    ARRAY_LOAD_TVMVALUE,
    /*!
     * \brief Routine to store data from constant offset.
     * \code
     *  ((DType*)stack[sp - 1].v_handle)[code[pc + 1].v_int] = stack[sp];
     *  pc = pc + 2;
     *  sp = sp - 2;
     * \endcode
     */
    ARRAY_STORE_UINT32,
    ARRAY_STORE_INT32,
    ARRAY_STORE_INT64,
    ARRAY_STORE_FP64,
    ARRAY_STORE_HANDLE,
    ARRAY_STORE_TVMVALUE,
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
    // logical ops
    NOT,
    /*!
     * \brief Add address by an offset.
     * \code
     *  stack[sp - 1].v_handle = ((char*)stack[sp - 1].v_handle + stack[sp].v_int64);
     *  sp = sp - 1;
     * \endcode
     */
    ADDR_ADD,
    /*!
     * \brief push integer fetched from next pc position into stack
     * \code
     *  stack[sp + 1].v_int64 = code[pc + 1].v_int;
     *  pc = pc + 2;
     *  sp = sp + 1;
     * \endcode
     */
    PUSH_I64,
    /*!
     * \brief push a value given relative index on the stack
     * \code
     *  stack[sp + 1] = stack[sp + code[pc + 1].v_int];
     *  pc = pc + 2;
     *  sp = sp + 1;
     * \endcode
     */
    PUSH_VALUE,
    /*!
     * \brief Load data from heap to top of stack
     * \code
     *  stack[sp + 1] = heap[code[pc + 1].v_int];
     *  pc = pc + 2;
     *  sp = sp + 1;
     * \endcode
     */
    LOAD_HEAP,
    /*!
     * \brief Store data to heap
     * \code
     *  heap[code[pc + 1].v_int] = stack[sp];
     *  sp = sp - 1;
     * \endcode
     */
    STORE_HEAP,
    /*! \brief pop value from top of the stack */
    POP,
    /*!
     * \brief select based on operands.
     * \code
     *  stack[sp - 2] = stack[sp].v_int64 ? stack[sp - 2] : stack[sp - 1]
     *  sp = sp - 2;
     * \endcode
     */
    SELECT,
    /*!
     * \brief Assert condition is true.
     * \code
     *  CHECK(stack[sp]) << str_data[code[pc + 1].v_int];
     *  sp = sp - 1;
     * \endcode
     */
    ASSERT,
    /*!
     * \brief Relative Jump if the condition is true,
     *  Does not change the stack status.
     * \code
     *  if (stack[sp]) {
     *    pc += code[pc + 1].v_int
     *  } else {
     *    pc = pc + 2;
     *  }
     * \endcode
     */
    RJUMP_IF_TRUE,
    /*!
     * \brief Relative Jump if the condition is true,
     *  Does not change the stack status.
     * \code
     *  if (stack[sp]) {
     *    pc += code[pc + 1].v_int
     *  } else {
     *    pc = pc + 2;
     *  }
     * \endcode
     */
    RJUMP_IF_FALSE,
    /*!
     * \brief Relative jump to a location.
     * \code
     *  pc += code[pc + 1].v_int;
     * \endcode
     */
    RJUMP,
    /*!
     * \brief debug instruction.
     * \code
     *  CHECK_EQ(sp, code[pc + 1]).v_int;
     *  pc += 2;
     * \code
     */
    ASSERT_SP,
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
    /*!
     * \brief call an extern packed function
     * \code
     *  value_stack = stack[sp - 1].v_handle;
     *  type_stack = stack[sp - 0].v_handle;
     *  call_fid = code[pc + 1].v_int;
     *  begin = code[pc + 2].v_int;
     *  end = code[pc + 3].v_int;
     *  num_args = end - begin - 1;
     *  f = extern_func[call_fid];
     *  stack[sp - 1] = f(&value_stack[begin:end-1], type_stack[begin:end-1], num_args);
     *  sp = sp - 1;
     *  // The type codes are hidden in the code space.
     *  pc = pc + 4
     * \endcode
     */
    CALL_PACKED_LOWERED,
    // Allocate things on stack
    /*!
     * \brief allocate data from stack.
     * \code
     *  num = code[pc + 1].v_int;
     *  void* addr = &stack[sp];
     *  sp = sp + num;
     *  stack[sp].v_handle = addr;
     *  pc = pc + 1;
     * \endcode
     */
    TVM_STACK_ALLOCA_BY_8BYTE,
    /*!
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
     * \brief allocate data from device.
     * \code
     *  device_type = stack[sp - 2].v_int64;
     *  device_id = stack[sp - 1].v_int64;
     *  nbytes = stack[sp].v_int64;
     *  stack[sp - 2].v_handle = device_alloca(device_type, device_id, nbytes);
     *  sp = sp - 2;
     *  pc = pc + 1;
     * \endcode
     */
    TVM_DEVICE_ALLOCA,
    /*!
     * \brief free data into device.
     * \code
     *  device_type = stack[sp - 2].v_int64;
     *  device_id = stack[sp - 1].v_int64;
     *  ptr = stack[sp].v_handle;
     *  stack[sp - 2].v_int64 = device_free(device_type, device_id, ptr);
     *  sp = sp - 2;
     *  pc = pc + 1;
     * \endcode
     */
    TVM_DEVICE_FREE,
    /*!
     * \brief throw last error
     */
    TVM_THROW_LAST_ERROR,
    /*!
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
     * \brief get data from structure.
     * \code
     *  index = code[pc + 1].v_int;
     *  field = code[pc + 2].v_int;
     *  stack[sp] = ((StructType*)stack[sp].v_handle)[index]->field;
     *  pc = pc + 3
     * \endcode
     */
    TVM_STRUCT_GET,
    /*!
     * \brief set data into structure.
     * \code
     *  index = code[pc + 1].v_int;
     *  field = code[pc + 2].v_int;
     *  ((StructType*)stack[sp - 1].v_handle)[index]->field = stack[sp];
     *  pc = pc + 3
     *  sp = sp - 1
     * \endcode
     */
    TVM_STRUCT_SET
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
  };
  /*! \brief The code structure */
  union Code {
    OpCode op_code;
    int v_int;
  };
  /*! \brief The state object of StackVM */
  struct State {
    /*! \brief The execution stack */
    std::vector<TVMValue> stack;
    /*! \brief The global heap space */
    std::vector<TVMValue> heap;
    /*! \brief stack pointer  */
    int64_t sp{0};
    /*! \brief program counter */
    int64_t pc{0};
302 303
    /*! \brief The current module context of stackvm */
    runtime::ModuleNode* mod_ctx{nullptr};
304
  };
305 306 307 308 309 310 311 312 313 314 315 316
  /*! \brief Initialize local cache*/
  void InitCache();
  /*!
   * \brief Save stackvm program to an output stream
   * \param strm The output stream
   */
  void Save(dmlc::Stream* strm) const;
  /*!
   * \brief Load stackvm program from output stream
   * \param strm The output stream
   */
  bool Load(dmlc::Stream* strm);
317 318 319 320 321 322 323 324 325
  /*!
   * \brief Print instruction at location pc
   * \param os The ostream
   * \param pc The pc
   * \return the pc to next instruction.
   */
  int64_t PrintCode(std::ostream&os, int64_t pc) const;  // NOLINT(*)
  /*! \brief Get thread local state of the stack VM */
  static State* ThreadLocalState();
326
  // The code below are programs
327 328 329 330
  /*! \brief The instructions */
  std::vector<Code> code;
  /*! \brief constant error messages */
  std::vector<std::string> str_data;
331 332 333
  /*! \brief Extern functions */
  std::vector<std::string> extern_func_name;
  /*! \brief name of each heap id */
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
  std::vector<std::string> heap_id_name;
  /*! \brief The memory size needed */
  size_t heap_size{0};
  /*! \brief The stack size required */
  size_t stack_size{1024};
  /*!
   * \brief Convert I64 opcode to F64 Ones
   * \param code The op code.
   * \return the F64 op code.
   */
  static OpCode CodeI64ToF64(OpCode code) {
    switch (code) {
      case ADD_I64: return ADD_F64;
      case SUB_I64: return SUB_F64;
      case MUL_I64: return MUL_F64;
      case DIV_I64: return DIV_F64;
      case EQ_I64: return EQ_F64;
      case LT_I64: return LT_F64;
      case LE_I64: return LE_F64;
353
      case MOD_I64: LOG(FATAL) << "cannot handle mod for float"; return ADD_F64;
354 355 356 357 358 359 360 361
      default: LOG(FATAL) << "cannot handle op " << code; return ADD_F64;
    }
  }
  /*!
   * \brief Get load opcode for type t
   * \param t the type code.
   * \return The load opcode
   */
362 363
  static OpCode GetLoad(TVMType t) {
    CHECK_EQ(t.lanes, 1U);
364
    if (t.code == kHandle) return ARRAY_LOAD_HANDLE;
365
    if (t.code == kDLInt) {
366
      switch (t.bits) {
367 368
        case 32 : return ARRAY_LOAD_INT32;
        case 64 : return ARRAY_LOAD_INT64;
369
      }
370
    } else if (t.code == kDLUInt) {
371
      switch (t.bits) {
372
        case 32 : return ARRAY_LOAD_UINT32;
373
      }
374
    } else if (t.code == kDLFloat) {
375
      switch (t.bits) {
376
        case 64 : return ARRAY_LOAD_FP64;
377 378 379
      }
    }
    LOG(FATAL) << "Cannot load type " << t;
380
    return ARRAY_LOAD_FP64;
381 382 383 384 385 386
  }
  /*!
   * \brief Get store opcode for type t
   * \param t the type code.
   * \return The load opcode
   */
387 388
  static OpCode GetStore(TVMType t) {
    CHECK_EQ(t.lanes, 1U);
389
    if (t.code == kHandle) return ARRAY_STORE_HANDLE;
390
    if (t.code == kDLInt) {
391
      switch (t.bits) {
392 393 394
        case 32 : return ARRAY_STORE_INT32;
        case 64 : return ARRAY_STORE_INT64;
      }
395
    } else if (t.code == kDLUInt) {
396 397 398
      switch (t.bits) {
        case 32 : return ARRAY_STORE_UINT32;
      }
399
    } else if (t.code == kDLFloat) {
400 401
      switch (t.bits) {
        case 64 : return ARRAY_STORE_FP64;
402 403 404
      }
    }
    LOG(FATAL) << "Cannot store type " << t;
405
    return ARRAY_STORE_FP64;
406 407
  }
  friend std::ostream& operator<<(std::ostream& os, const StackVM& vm);  // NOLINT(*)
408 409

 private:
410 411
  //  execute the stack vm with given state
  void Run(State* state) const;
412 413
  // get extern function.
  const PackedFunc& GetExtern(State* s, int fid) const;
414 415
  // cached extern function
  mutable std::vector<PackedFunc> extern_func_cache_;
416 417
};

418
}  // namespace runtime
419
}  // namespace tvm
420 421 422 423 424

namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::StackVM, true);
}
#endif  // TVM_RUNTIME_STACKVM_STACKVM_H_