/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2017 by Contributors
 * Implementation stack VM.
 * \file stackvm.cc
 */
#include <dmlc/thread_local.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/c_backend_api.h>
#include <algorithm>
#include "stackvm.h"

namespace tvm {
namespace runtime {

typedef dmlc::ThreadLocalStore<StackVM::State> StackVMStateStore;

StackVM::State* StackVM::ThreadLocalState() {
  return StackVMStateStore::Get();
}

#define STACK_VM_BINOP(OP, FIELD)                                 \
  {                                                               \
    stack[sp - 1].FIELD = stack[sp - 1].FIELD OP stack[sp].FIELD; \
    sp -= 1; pc += 1;                                             \
  }

#define STACK_VM_CMPOP(OP, FIELD)                                   \
  {                                                                 \
    stack[sp - 1].v_int64 = stack[sp - 1].FIELD OP stack[sp].FIELD; \
    sp -= 1; pc += 1;                                               \
  }

#define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE)                        \
  {                                                                     \
    int index = code[pc + 1].v_int;                                     \
    stack[sp]FIELD = static_cast<DST_TYPE>(                             \
        static_cast<SRC_TYPE*>(stack[sp].v_handle)[index]);             \
    pc += 2;                                                            \
  }

#define STACK_VM_STORE(FIELD, DST_TYPE)                                 \
  {                                                                     \
    int index = code[pc + 1].v_int;                                     \
    static_cast<DST_TYPE*>(stack[sp - 1].v_handle)[index] =             \
        static_cast<DST_TYPE>(stack[sp]FIELD);                          \
    sp -= 2; pc += 2;                                                   \
  }

#define STACK_VM_PRINT_CODE0(CODE)                            \
  case CODE:  {                                                     \
    os << "[" << pc << "]\t" << #CODE << std::endl; return pc + 1;  \
  }

#define STACK_VM_PRINT_CODE1(CODE)                                      \
  case CODE:  {                                                         \
    os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << "\n" \
       <<  "[" << pc + 1 << "]" << std::endl;                           \
        return pc + 2;                                                  \
  }

#define STACK_VM_PRINT_CODE2(CODE)                                      \
  case CODE:  {                                                         \
    os << "[" << pc << "]\t" << #CODE                                   \
        << " " << code[pc + 1].v_int                                    \
        << " " << code[pc + 2].v_int << "\n"                            \
       <<  "[" << pc + 1 << "]" << std::endl                            \
       <<  "[" << pc + 2 << "]" << std::endl;                           \
        return pc + 3;                                                  \
  }

#define STACK_VM_PRINT_HEAP_ACCESS(CODE)                                \
  case CODE:  {                                                         \
    os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int      \
       << " " << heap_id_name[code[pc + 1].v_int] << "\n"               \
       <<  "[" << pc + 1 << "]" << std::endl;                           \
        return pc + 2;                                                  \
  }

#define STACK_VM_PRINT_JUMP(CODE)                                     \
  case CODE:  {                                                        \
    os << "[" << pc << "]\t" << #CODE << " rel=" << code[pc + 1].v_int \
       << " to " << pc + code[pc + 1].v_int << '\n'                    \
       << "[" << pc + 1 << "]" << std::endl;                         \
    return pc + 2;                                                     \
  }


int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const {
  switch (code[pc].op_code) {
    // int
    STACK_VM_PRINT_CODE0(ADD_I64);
    STACK_VM_PRINT_CODE0(SUB_I64);
    STACK_VM_PRINT_CODE0(MUL_I64);
    STACK_VM_PRINT_CODE0(MOD_I64);
    STACK_VM_PRINT_CODE0(DIV_I64);
    STACK_VM_PRINT_CODE0(EQ_I64);
    STACK_VM_PRINT_CODE0(LT_I64);
    STACK_VM_PRINT_CODE0(LE_I64);
    // floats
    STACK_VM_PRINT_CODE0(ADD_F64);
    STACK_VM_PRINT_CODE0(SUB_F64);
    STACK_VM_PRINT_CODE0(MUL_F64);
    STACK_VM_PRINT_CODE0(DIV_F64);
    STACK_VM_PRINT_CODE0(EQ_F64);
    STACK_VM_PRINT_CODE0(LT_F64);
    STACK_VM_PRINT_CODE0(LE_F64);
    // handle.
    STACK_VM_PRINT_CODE0(EQ_HANDLE);
    // addressing load
    STACK_VM_PRINT_CODE1(ARRAY_LOAD_UINT32);
    STACK_VM_PRINT_CODE1(ARRAY_LOAD_INT32);
    STACK_VM_PRINT_CODE1(ARRAY_LOAD_INT64);
    STACK_VM_PRINT_CODE1(ARRAY_LOAD_FP64);
    STACK_VM_PRINT_CODE1(ARRAY_LOAD_HANDLE);
    STACK_VM_PRINT_CODE1(ARRAY_LOAD_TVMVALUE);
    STACK_VM_PRINT_CODE1(ARRAY_STORE_UINT32);
    STACK_VM_PRINT_CODE1(ARRAY_STORE_INT32);
    STACK_VM_PRINT_CODE1(ARRAY_STORE_INT64);
    STACK_VM_PRINT_CODE1(ARRAY_STORE_FP64);
    STACK_VM_PRINT_CODE1(ARRAY_STORE_HANDLE);
    STACK_VM_PRINT_CODE1(ARRAY_STORE_TVMVALUE);
    STACK_VM_PRINT_CODE0(NOT);
    STACK_VM_PRINT_CODE0(ADDR_ADD);
    // stack ops
    STACK_VM_PRINT_CODE1(PUSH_I64);
    STACK_VM_PRINT_CODE1(PUSH_VALUE);
    STACK_VM_PRINT_CODE0(POP);
    STACK_VM_PRINT_CODE0(SELECT);
    STACK_VM_PRINT_HEAP_ACCESS(STORE_HEAP);
    STACK_VM_PRINT_HEAP_ACCESS(LOAD_HEAP);
    STACK_VM_PRINT_CODE1(ASSERT);
    STACK_VM_PRINT_JUMP(RJUMP_IF_TRUE);
    STACK_VM_PRINT_JUMP(RJUMP_IF_FALSE);
    STACK_VM_PRINT_JUMP(RJUMP);
    STACK_VM_PRINT_CODE1(ASSERT_SP);
    // Intrinsics
    STACK_VM_PRINT_CODE2(TVM_STRUCT_GET);
    STACK_VM_PRINT_CODE2(TVM_STRUCT_SET);
    // Allocate data by 8 bytes.
    STACK_VM_PRINT_CODE1(TVM_STACK_ALLOCA_BY_8BYTE);
    STACK_VM_PRINT_CODE0(TVM_DEVICE_ALLOCA);
    STACK_VM_PRINT_CODE0(TVM_DEVICE_FREE);
    STACK_VM_PRINT_CODE0(TVM_THROW_LAST_ERROR);
    // packed function.
    case CALL_PACKED_LOWERED: {
      int call_fid = code[pc + 1].v_int;
      int begin = code[pc + 2].v_int;
      int end = code[pc + 3].v_int;
      os << "[" << pc << "]\tCALL_PACKED_FUNC "
         << " fid=" << call_fid
         << " begin=" << begin
         << " end=" << end;
      os << '\n';
      for (int i = 0; i < 3; ++i) {
        os << "[" << pc + 1 + i << "]" << std::endl;
      }
      return pc + 4;
    }
  }
  LOG(FATAL) << "unknown op code " << code[pc].op_code;
  return 0;
}

std::ostream& operator<<(std::ostream& os, const StackVM& vm) {  // NOLINT(*)
  int64_t pc = 0;
  const int64_t code_size = static_cast<int64_t>(vm.code.size());
  os << "Program dump: code-size=" << code_size << '\n'
     << "----------begin-----------------\n";
  while (pc < code_size) {
    pc = vm.PrintCode(os, pc);
  }
  os << "----------end--------------------\n";
  return os;
}

void StackVM::Run(const runtime::TVMArgs& args,
                  runtime::ModuleNode* mod_ctx) const {
  StackVM::State* s = StackVM::ThreadLocalState();
  if (s->heap.size() < heap_size) {
    s->heap.resize(heap_size);
  }
  s->sp = 0;
  s->pc = 0;
  s->mod_ctx = mod_ctx;
  s->heap[0].v_handle = (void*)args.values;  // NOLINT(*)
  s->heap[1].v_handle = (void*)args.type_codes;  // NOLINT(*)
  s->heap[2].v_int64 = args.num_args;
  this->Run(s);
}

void StackVM::InitCache() {
  extern_func_cache_.clear();
  extern_func_cache_.resize(
      extern_func_name.size(), PackedFunc(nullptr));
}

void StackVM::Save(dmlc::Stream* strm) const {
  // to be endian invariant.
  std::vector<int32_t> code_copy(code.size());
  std::transform(code.begin(), code.end(), code_copy.begin(), [](Code c) {
      return c.v_int;
    });
  strm->Write(code_copy);
  strm->Write(str_data);
  strm->Write(extern_func_name);
  strm->Write(heap_id_name);
  strm->Write(heap_size);
  strm->Write(stack_size);
}

bool StackVM::Load(dmlc::Stream* strm)  {
  // to be endian invariant.
  std::vector<int32_t> code_copy;
  if (!strm->Read(&code_copy)) return false;
  code.resize(code_copy.size());
  std::transform(code_copy.begin(), code_copy.end(), code.begin(), [](int v) {
      Code code; code.v_int = v; return code;
    });
  if (!strm->Read(&str_data)) return false;
  if (!strm->Read(&extern_func_name)) return false;
  if (!strm->Read(&heap_id_name)) return false;
  if (!strm->Read(&heap_size)) return false;
  if (!strm->Read(&stack_size)) return false;
  this->InitCache();
  return true;
}

void StackVM::Run(State* s) const {
  int64_t sp = s->sp;
  int64_t pc = s->pc;
  int64_t alloca_sp = s->sp;
  std::vector<TVMValue>& stack = s->stack;
  std::vector<TVMValue>& heap = s->heap;
  if (stack.size() < stack_size) {
    stack.resize(stack_size);
  }
  int64_t stack_cap = static_cast<int64_t>(stack_size - 4);
  if (heap.size() < heap_size) {
    heap.resize(heap_size);
  }
  const int64_t code_size = static_cast<int64_t>(code.size());
  while (pc < code_size) {
    switch (code[pc].op_code) {
      case ADD_I64: STACK_VM_BINOP(+, v_int64); break;
      case SUB_I64: STACK_VM_BINOP(-, v_int64); break;
      case MUL_I64: STACK_VM_BINOP(*, v_int64); break;
      case DIV_I64: STACK_VM_BINOP(/, v_int64); break;
      case MOD_I64: STACK_VM_BINOP(%, v_int64); break;
      case EQ_I64: STACK_VM_CMPOP(==, v_int64); break;
      case LT_I64: STACK_VM_CMPOP(<, v_int64); break;
      case LE_I64: STACK_VM_CMPOP(<=, v_int64); break;
      case ADD_F64: STACK_VM_BINOP(+, v_float64); break;
      case SUB_F64: STACK_VM_BINOP(-, v_float64); break;
      case MUL_F64: STACK_VM_BINOP(*, v_float64); break;
      case DIV_F64: STACK_VM_BINOP(/, v_float64); break;
      case EQ_F64: STACK_VM_CMPOP(==, v_float64); break;
      case LT_F64: STACK_VM_CMPOP(<, v_float64); break;
      case LE_F64: STACK_VM_CMPOP(<=, v_float64); break;
      case EQ_HANDLE: STACK_VM_CMPOP(==, v_handle); break;
      // addressing
      case ARRAY_LOAD_UINT32: STACK_VM_LOAD(.v_int64, int64_t, uint32_t); break;
      case ARRAY_LOAD_INT32: STACK_VM_LOAD(.v_int64, int64_t, int32_t); break;
      case ARRAY_LOAD_INT64: STACK_VM_LOAD(.v_int64, int64_t, int64_t); break;
      case ARRAY_LOAD_FP64: STACK_VM_LOAD(.v_float64, double, double); break;
      case ARRAY_LOAD_HANDLE: STACK_VM_LOAD(.v_handle, void*, void*); break;
      case ARRAY_LOAD_TVMVALUE: STACK_VM_LOAD(, TVMValue, TVMValue); break;
      // store
      case ARRAY_STORE_UINT32: STACK_VM_STORE(.v_int64, uint32_t); break;
      case ARRAY_STORE_INT32: STACK_VM_STORE(.v_int64, int32_t); break;
      case ARRAY_STORE_INT64: STACK_VM_STORE(.v_int64, int64_t); break;
      case ARRAY_STORE_FP64: STACK_VM_STORE(.v_float64, double); break;
      case ARRAY_STORE_HANDLE: STACK_VM_STORE(.v_handle, void*); break;
      case ARRAY_STORE_TVMVALUE: STACK_VM_STORE(, TVMValue); break;
      // add
      case ADDR_ADD: {
        stack[sp - 1].v_handle = (char*)(stack[sp - 1].v_handle) + stack[sp].v_int64;  // NOLINT(*)
        sp = sp - 1;
        pc = pc + 1;
        break;
      }
      case NOT: {
        stack[sp].v_int64 = !stack[sp].v_int64;
        pc += 1;
        break;
      }
      case PUSH_I64: {
        stack[sp + 1].v_int64 = code[pc + 1].v_int;
        sp += 1;
        pc += 2;
        break;
      }
      case PUSH_VALUE: {
        int relpos = code[pc + 1].v_int;
        CHECK_LE(relpos, 0);
        stack[sp + 1] = stack[sp + relpos];
        sp += 1;
        pc += 2;
        break;
      }
      case POP: {
        sp -= 1;
        pc += 1;
        break;
      }
      case SELECT: {
        stack[sp - 2] = (stack[sp].v_int64 ? stack[sp - 2] : stack[sp - 1]);
        sp -= 2;
        pc += 1;
        break;
      }
      case LOAD_HEAP: {
        stack[sp + 1] = heap[code[pc + 1].v_int];
        sp += 1;
        pc += 2;
        break;
      }
      case STORE_HEAP: {
        heap[code[pc + 1].v_int] = stack[sp];
        sp -= 1;
        pc += 2;
        break;
      }
      case ASSERT: {
        CHECK(stack[sp].v_int64) << str_data[code[pc + 1].v_int];
        sp -= 1;
        pc += 2;
        break;
      }
      case RJUMP_IF_TRUE: {
        if (stack[sp].v_int64) {
          pc += code[pc + 1].v_int;
        } else {
          pc += 2;
        }
        break;
      }
      case RJUMP_IF_FALSE: {
        if (!stack[sp].v_int64) {
          pc += code[pc + 1].v_int;
        } else {
          pc += 2;
        }
        break;
      }
      case RJUMP: {
        pc += code[pc + 1].v_int;
        break;
      }
      case ASSERT_SP: {
        int64_t expected = code[pc + 1].v_int;
        CHECK_EQ(sp, expected)
            << "sp assertion failed, expected="
            << expected << " now=" << sp << ", pc=" << pc;
        pc += 2;
        break;
      }
      case CALL_PACKED_LOWERED: {
        // call packed function.
        TVMValue* value_stack = static_cast<TVMValue*>(stack[sp - 1].v_handle);
        int* type_stack = static_cast<int*>(stack[sp].v_handle);
        int call_fid = code[pc + 1].v_int;
        int begin = code[pc + 2].v_int;
        int end = code[pc + 3].v_int;
        int num_args = end - begin;
        static_assert(sizeof(Code) == sizeof(int) &&
                      alignof(Code) == alignof(int), "asusmption");
        runtime::TVMRetValue rv;
        GetExtern(s, call_fid).CallPacked(
            runtime::TVMArgs(value_stack + begin, type_stack + begin, num_args), &rv);
        sp = sp - 1;
        stack[sp] = rv.value();
        pc += 4;
        break;
      }
      // intrinsics
      case TVM_STRUCT_GET: {
        using namespace ir;
        int index = code[pc + 1].v_int;
        int kind = code[pc + 2].v_int;
        TVMArray* arr = static_cast<TVMArray*>(stack[sp].v_handle);
        switch (kind) {
          case intrinsic::kArrData: {
            stack[sp].v_handle = arr[index].data; break;
          }
          case intrinsic::kArrShape: {
            stack[sp].v_handle = arr[index].shape; break;
          }
          case intrinsic::kArrStrides: {
            stack[sp].v_handle = arr[index].strides; break;
          }
          case intrinsic::kArrNDim: {
            stack[sp].v_int64 = arr[index].ndim; break;
          }
          case intrinsic::kArrTypeCode: {
            stack[sp].v_int64 = static_cast<int64_t>(
                arr[index].dtype.code); break;
          }
          case intrinsic::kArrTypeBits: {
            stack[sp].v_int64 = static_cast<int64_t>(
                arr[index].dtype.bits); break;
          }
          case intrinsic::kArrTypeLanes: {
            stack[sp].v_int64 = static_cast<int64_t>(
                arr[index].dtype.lanes); break;
          }
          case intrinsic::kArrByteOffset: {
            stack[sp].v_int64 = static_cast<int64_t>(
                arr[index].byte_offset); break;
          }
          case intrinsic::kArrDeviceId: {
            stack[sp].v_int64 = arr[index].ctx.device_id; break;
          }
          case intrinsic::kArrDeviceType: {
            stack[sp].v_int64 = static_cast<int64_t>(
                arr[index].ctx.device_type); break;
          }
          case intrinsic::kArrAddr: {
            stack[sp].v_handle = arr + index; break;
          }
          case intrinsic::kTVMValueContent: {
            stack[sp] = static_cast<TVMValue*>(stack[sp].v_handle)[index]; break;
          }
          default: LOG(FATAL) << "unhandled get " << kind;
        }
        pc = pc + 3;
        break;
      }
      case TVM_STRUCT_SET: {
        using namespace ir;
        int index = code[pc + 1].v_int;
        int kind = code[pc + 2].v_int;
        TVMArray* arr = static_cast<TVMArray*>(stack[sp - 1].v_handle);
        switch (kind) {
          case intrinsic::kArrData: {
            arr[index].data = stack[sp].v_handle; break;
          }
          case intrinsic::kArrShape: {
            arr[index].shape = static_cast<int64_t*>(stack[sp].v_handle);
            break;
          }
          case intrinsic::kArrStrides: {
            arr[index].strides = static_cast<int64_t*>(stack[sp].v_handle);
            break;
          }
          case intrinsic::kArrNDim: {
            arr[index].ndim = static_cast<int>(stack[sp].v_int64);
            break;
          }
          case intrinsic::kArrTypeCode: {
            arr[index].dtype.code = static_cast<uint8_t>(stack[sp].v_int64);
            break;
          }
          case intrinsic::kArrTypeBits: {
            arr[index].dtype.bits = static_cast<uint8_t>(stack[sp].v_int64);
            break;
          }
          case intrinsic::kArrTypeLanes: {
            arr[index].dtype.lanes = static_cast<uint16_t>(stack[sp].v_int64);
            break;
          }
          case intrinsic::kArrByteOffset: {
            arr[index].byte_offset = static_cast<uint64_t>(stack[sp].v_int64);
            break;
          }
          case intrinsic::kArrDeviceId: {
            arr[index].ctx.device_id = static_cast<int>(stack[sp].v_int64);
            break;
          }
          case intrinsic::kArrDeviceType: {
            arr[index].ctx.device_type = static_cast<DLDeviceType>(stack[sp].v_int64);
            break;
          }
          case intrinsic::kTVMValueContent: {
            static_cast<TVMValue*>(stack[sp - 1].v_handle)[index] = stack[sp]; break;
          }
          default: LOG(FATAL) << "unhandled tvm_struct_set " << kind;
        }
        sp -= 2;
        pc += 3;
        break;
      }
      // alloca
      case TVM_STACK_ALLOCA_BY_8BYTE: {
        static_assert(sizeof(TVMValue) == 8, "invariance");
        int num = code[pc + 1].v_int;
        void* addr = &stack[sp] + 1;
        sp = sp + num + 1;
        alloca_sp = sp - 1;
        stack[sp].v_handle = addr;
        pc = pc + 2;
        break;
      }
      case TVM_DEVICE_ALLOCA: {
        int device_type = static_cast<int>(stack[sp - 4].v_int64);
        int device_id = static_cast<int>(stack[sp - 3].v_int64);
        size_t nbytes = static_cast<size_t>(stack[sp - 2].v_int64);
        int dtype_code_hint = static_cast<int>(stack[sp - 1].v_int64);
        int dtype_bits_hint = static_cast<int>(stack[sp].v_int64);
        void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes,
                                             dtype_code_hint, dtype_bits_hint);
        stack[sp - 4].v_handle = ptr;
        sp = sp - 4;
        pc = pc + 1;
        break;
      }
      case TVM_DEVICE_FREE: {
        int device_type = static_cast<int>(stack[sp - 2].v_int64);
        int device_id = static_cast<int>(stack[sp - 1].v_int64);
        void* ptr = stack[sp].v_handle;
        int ret = TVMBackendFreeWorkspace(device_type, device_id, ptr);
        stack[sp - 2].v_int64 = ret;
        sp = sp - 2;
        pc = pc + 1;
        break;
      }
      case TVM_THROW_LAST_ERROR: {
        LOG(FATAL) << TVMGetLastError();
        break;
      }
    }
    CHECK_GE(sp, alloca_sp) << "touch allocated space";
    CHECK_LT(sp, stack_cap) << "Stack overflow";
  }
}

const PackedFunc& StackVM::GetExtern(State* s, int fid) const {
  CHECK_LT(static_cast<size_t>(fid), extern_func_cache_.size());
  // allow race write in this, since write is idempotent
  PackedFunc& f = extern_func_cache_[fid];
  if (f == nullptr) {
    CHECK(s->mod_ctx != nullptr)
        << "No local context is set in stackvm";
    const PackedFunc* pf = s->mod_ctx->GetFuncFromEnv(extern_func_name[fid]);
    CHECK(pf != nullptr);
    f = *pf;
  }
  return f;
}

}  // namespace runtime
}  // namespace tvm