/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_stack_vm.cc
 */
#include <tvm/packed_func_ext.h>
#include <limits>
#include "./codegen_stack_vm.h"

namespace tvm {
namespace codegen {

using namespace ir;

PackedFunc BuildStackVM(
    LoweredFunc func,
    const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs) {
  StackVM vm = codegen::CodeGenStackVM().Compile(func, device_funcs);
  auto f = [vm](TVMArgs args, TVMRetValue* rv) {
    vm(args);
  };
  return PackedFunc(f);
}

CodeGenStackVM::FType& CodeGenStackVM::vtable() {  // NOLINT(*)
  static FType inst; return inst;
}

StackVM CodeGenStackVM::Compile(
    LoweredFunc f,
    const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs) {
  for (size_t i = 0; i < f->args.size(); ++i) {
    Var v = f->args[i];
    int vid = AllocVarID(v.get());
    CHECK_EQ(static_cast<size_t>(vid), i);
  }
  // setup device function map
  for (const auto& kv : device_funcs) {
    int fid = static_cast<int>(vm_.packed_func.size());
    vm_.packed_func.push_back(kv.second);
    device_fun_idmap_[kv.first->name] = fid;
  }
  this->Push(f->body);
  return std::move(vm_);
}

void CodeGenStackVM::Push(const Stmt& n) {
  static const FType& f = vtable();
  f(n, this);
  if (debug_) {
    this->PushOp(StackVM::ASSERT_SP, 0);
  }
}

void CodeGenStackVM::Push(const Expr& n) {
  static const FType& f = vtable();
  f(n, this);
}

void CodeGenStackVM::PushOp(StackVM::OpCode opcode) {
  StackVM::Code code;
  code.op_code = opcode;
  vm_.code.push_back(code);
}

void CodeGenStackVM::SetOperand(int64_t operand_index, int64_t operand) {
  CHECK(operand >= std::numeric_limits<int>::min() &&
        operand <= std::numeric_limits<int>::max());
  vm_.code.at(operand_index).v_int = static_cast<int>(operand);
}

int64_t CodeGenStackVM::PushOp(StackVM::OpCode opcode, int operand) {
  int64_t pc = static_cast<int64_t>(vm_.code.size());
  StackVM::Code code;
  code.op_code = opcode;
  vm_.code.push_back(code);
  code.v_int = operand;
  vm_.code.push_back(code);
  return pc + 1;
}

int CodeGenStackVM::GetStrID(const std::string& key) {
  auto it = str_idmap_.find(key);
  if (it != str_idmap_.end()) return it->second;
  int sid = static_cast<int>(vm_.str_data.size());
  vm_.str_data.push_back(key);
  str_idmap_[key] = sid;
  return sid;
}

int CodeGenStackVM::AllocVarID(const Variable* v) {
  CHECK(!var_idmap_.count(v));
  int vid = static_cast<int>(vm_.heap_size);
  CHECK_EQ(vm_.heap_size, var_idmap_.size());
  vm_.heap_id_name.push_back(v->name_hint);
  ++vm_.heap_size;
  var_idmap_[v] = vid;
  return vid;
}

void CodeGenStackVM::PushCallPacked(
    int fid, const std::vector<int>& arg_type_codes) {
  StackVM::Code code;
  // CALL_PACKED_FUNC
  code.op_code = StackVM::CALL_PACKED_FUNC;
  vm_.code.push_back(code);
  // num_args
  code.v_int = static_cast<int>(arg_type_codes.size());
  vm_.code.push_back(code);
  // fid
  code.v_int = fid;
  vm_.code.push_back(code);
  // type codes.
  for (int tcode : arg_type_codes) {
    code.v_int = tcode;
    vm_.code.push_back(code);
  }
}

int CodeGenStackVM::GetVarID(const Variable* v) const {
  auto it = var_idmap_.find(v);
  CHECK(it != var_idmap_.end())
      << "Find undefined Variable " << v->name_hint;
  return it->second;
}

void CodeGenStackVM::Push_(const ir::Load* op) {
  this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get()));
  if (op->type == UInt(32) && op->index.as<IntImm>()) {
    this->PushOp(StackVM::ARRAY_LOAD_UINT32, op->index.as<IntImm>()->value);
  } else {
    this->Push(op->index);
    this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes());
    this->PushOp(StackVM::MUL_I64);
    this->PushOp(StackVM::ADDR_ADD);
    this->PushOp(StackVM::GetLoad(Type2TVMType(op->type)));
  }
}
void CodeGenStackVM::Push_(const ir::Store* op) {
  this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get()));
  this->Push(op->index);
  this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes());
  this->PushOp(StackVM::MUL_I64);
  this->PushOp(StackVM::ADDR_ADD);
  this->Push(op->value);
  this->PushOp(StackVM::GetStore(Type2TVMType(op->value.type())));
}

void CodeGenStackVM::Push_(const ir::Allocate* op) {
  CHECK(!is_zero(op->condition));
  int vid = AllocVarID(op->buffer_var.get());
  if (op->new_expr.defined()) {
    // Prefer global static allocation for the program
    CHECK_EQ(op->free_function, "nop");
    this->Push(op->new_expr);
    this->PushOp(StackVM::STORE_HEAP, vid);
  } else {
    LOG(FATAL) << "Dynamic allocation not supported";
  }
}

void CodeGenStackVM::Push_(const ir::Call* op) {
  if (op->is_intrinsic(Call::address_of)) {
    const Load *l = op->args[0].as<Load>();
    CHECK(op->args.size() == 1 && l);
    this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get()));
    this->Push(l->index);
    this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes());
    this->PushOp(StackVM::MUL_I64);
    this->PushOp(StackVM::ADDR_ADD);
  } else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) {
    CHECK_EQ(op->args.size(), 3U);
    this->Push(op->args[0]);
    this->Push(op->args[1]);
    this->Push(op->args[2]);
    if (op->type.is_handle()) {
      this->PushOp(StackVM::TVM_LOAD_ARG_HANDLE);
    } else if (op->type.is_float()) {
      this->PushOp(StackVM::TVM_LOAD_ARG_FP64);
    } else if (op->type.is_int() || op->type.is_uint()) {
      this->PushOp(StackVM::TVM_LOAD_ARG_INT64);
    } else {
      LOG(FATAL) << "donot know how to handle type" << op->type;
    }
  } else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
    CHECK_EQ(op->args.size(), 2U);
    this->Push(op->args[0]);
    switch (op->args[1].as<IntImm>()->value) {
      case intrinsic::kData: PushOp(StackVM::TVM_ARRAY_GET_DATA); break;
      case intrinsic::kShape: PushOp(StackVM::TVM_ARRAY_GET_SHAPE); break;
      case intrinsic::kStrides: PushOp(StackVM::TVM_ARRAY_GET_STRIDES); break;
      case intrinsic::kNDim: PushOp(StackVM::TVM_ARRAY_GET_NDIM); break;
      case intrinsic::kTypeCode: PushOp(StackVM::TVM_ARRAY_GET_TYPE_CODE); break;
      case intrinsic::kTypeBits: PushOp(StackVM::TVM_ARRAY_GET_TYPE_BITS); break;
      case intrinsic::kTypeLanes: PushOp(StackVM::TVM_ARRAY_GET_TYPE_LANES); break;
      default: LOG(FATAL) << "unknown field code";
    }
  } else if (op->is_intrinsic(intrinsic::tvm_call_global)) {
    CHECK_GE(op->args.size(), 1U);
    const StringImm* s = op->args[0].as<StringImm>();
    CHECK(s != nullptr) << "tvm_call_global expect first argument as function name";
    for (size_t i = 1; i < op->args.size(); ++i) {
      this->Push(op->args[i]);
    }
    // find the fuction id.
    const std::string& func_name = s->value;
    auto it = global_fun_idmap_.find(func_name);
    int fid;
    if (it != global_fun_idmap_.end()) {
      fid = it->second;
    } else {
      fid = static_cast<int>(vm_.packed_func.size());
      PackedFunc f = PackedFunc::GetGlobal(func_name);
      vm_.packed_func.push_back(f);
      global_fun_idmap_[func_name] = fid;
    }
    // get the argument type code.
    std::vector<int> arg_type_codes;
    for (size_t i = 1; i < op->args.size(); ++i) {
      Type t = op->args[i].type();
      int code = t.code();
      int lanes = t.lanes();
      CHECK_EQ(lanes, 1);
      arg_type_codes.push_back(code);
    }
    this->PushCallPacked(fid, arg_type_codes);
  } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
    CHECK_EQ(op->args.size(), 1U);
    this->Push(op->args[0]);
    this->PushOp(StackVM::PUSH_I64, 0);
    this->PushOp(StackVM::EQ_I64);
  } else if (op->is_intrinsic(intrinsic::tvm_call_device)) {
    std::string func_name = op->args[0].as<StringImm>()->value;
    auto it = device_fun_idmap_.find(func_name);
    CHECK(it != device_fun_idmap_.end())
        << "Cannot find device function " << func_name;
    const int fid = it->second;
    std::vector<int> arg_type_codes;
    for (size_t i = 1; i < op->args.size(); ++i) {
      this->Push(op->args[i]);
      Type t = op->args[i].type();
      int lanes = t.lanes();
      CHECK_EQ(lanes, 1);
      arg_type_codes.push_back(t.code());
    }
    this->PushCallPacked(fid, arg_type_codes);
  } else {
    this->HandleUnknownCall(op);
  }
}

void CodeGenStackVM::HandleUnknownCall(const ir::Call* op) {
  LOG(FATAL) << "donot know how to handle call " << op->name;
}

inline void PushBinary(StackVM::OpCode op_int64,
                       const Expr& a,
                       const Expr& b,
                       CodeGenStackVM* p) {
  p->Push(a);
  p->Push(b);
  Type t = a.type();
  if (t.is_int()) {
    p->PushOp(op_int64);
  } else if (t.is_uint()) {
    if (t.bits() <= 32) {
      p->PushOp(op_int64);
    } else {
      LOG(FATAL) << "Cannot handle uint64_t in StackVM";
    }
  } else {
    p->PushOp(StackVM::CodeI64ToF64(op_int64));
  }
}




inline void PushCast(Type dst,
                     Type src,
                     CodeGenStackVM* p) {
  if (dst.is_int()) {
    if (src.is_int()) return;
    if (src.is_uint() && src.bits() <= 32) return;
  } else if (dst.is_uint() && dst.bits() <= 32) {
    if (src.is_int()) return;
    if (src.is_uint() && src.bits() <= 32) return;
  } else if (dst.is_float()) {
    if (src.is_float()) return;
  }
  LOG(FATAL) << "Cannot handle cast " << src << " to " << dst;
}

TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
.set_dispatch<StringImm>([](const StringImm *op, CodeGenStackVM *p) {
    int sid = p->GetStrID(op->value);
    p->PushOp(StackVM::PUSH_I64, sid);
  })
.set_dispatch<IntImm>([](const IntImm *op, CodeGenStackVM *p) {
    CHECK(op->value >= std::numeric_limits<int>::min() &&
          op->value <= std::numeric_limits<int>::max())
        << "Int constant exceed bound";
    p->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
  })
.set_dispatch<UIntImm>([](const UIntImm *op, CodeGenStackVM *p) {
    CHECK(op->value <= std::numeric_limits<int>::max())
        << "Int constant exceed bound";
    p->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
  })
.set_dispatch<FloatImm>([](const FloatImm *op, CodeGenStackVM *p) {
    LOG(FATAL) << "Float Imm is not supported";
  });

TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
.set_dispatch<Variable>([](const Variable *op, CodeGenStackVM* p) {
    int vid = p->GetVarID(op);
    p->PushOp(StackVM::LOAD_HEAP, vid);
  })
.set_dispatch<Cast>([](const Cast *op, CodeGenStackVM* p) {
    p->Push(op->value);
    PushCast(op->type, op->value.type(), p);
  })
.set_dispatch<Add>([](const Add *op, CodeGenStackVM* p) {
    PushBinary(StackVM::ADD_I64, op->a, op->b, p);
  })
.set_dispatch<Sub>([](const Sub *op, CodeGenStackVM* p) {
    PushBinary(StackVM::SUB_I64, op->a, op->b, p);
  })
.set_dispatch<Mul>([](const Mul *op, CodeGenStackVM* p) {
    PushBinary(StackVM::MUL_I64, op->a, op->b, p);
  })
.set_dispatch<Div>([](const Div *op, CodeGenStackVM* p) {
    PushBinary(StackVM::DIV_I64, op->a, op->b, p);
  })
.set_dispatch<Mod>([](const Mod *op, CodeGenStackVM* p) {
    PushBinary(StackVM::MOD_I64, op->a, op->b, p);
  })
.set_dispatch<Min>([](const Min *op, CodeGenStackVM* p) {
    p->Push(op->a);
    p->Push(op->b);
    p->PushOp(StackVM::PUSH_VALUE, -1);
    p->PushOp(StackVM::PUSH_VALUE, -1);
    p->PushOp(StackVM::LT_I64);
    p->PushOp(StackVM::SELECT);
  })
.set_dispatch<Max>([](const Max *op, CodeGenStackVM* p) {
    p->Push(op->a);
    p->Push(op->b);
    p->PushOp(StackVM::PUSH_VALUE, 0);
    p->PushOp(StackVM::PUSH_VALUE, -2);
    p->PushOp(StackVM::LT_I64);
    p->PushOp(StackVM::SELECT);
  })
.set_dispatch<EQ>([](const EQ *op, CodeGenStackVM* p) {
    PushBinary(StackVM::EQ_I64, op->a, op->b, p);
  })
.set_dispatch<LE>([](const LE *op, CodeGenStackVM* p) {
    PushBinary(StackVM::LE_I64, op->a, op->b, p);
  })
.set_dispatch<NE>([](const NE *op, CodeGenStackVM* p) {
    PushBinary(StackVM::EQ_I64, op->a, op->b, p);
    p->PushOp(StackVM::NOT);
  })
.set_dispatch<LT>([](const LT *op, CodeGenStackVM* p) {
    PushBinary(StackVM::LT_I64, op->a, op->b, p);
  })
.set_dispatch<GE>([](const GE *op, CodeGenStackVM* p) {
    PushBinary(StackVM::LT_I64, op->a, op->b, p);
    p->PushOp(StackVM::NOT);
  })
.set_dispatch<GT>([](const GT *op, CodeGenStackVM* p) {
    PushBinary(StackVM::LE_I64, op->a, op->b, p);
    p->PushOp(StackVM::NOT);
  })
.set_dispatch<And>([](const And *op, CodeGenStackVM* p) {
    p->Push(op->a);
    int64_t pc_jump = p->GetPC();
    int64_t opr_index = p->PushOp(StackVM::RJUMP_IF_FALSE, 0);
    p->PushOp(StackVM::POP);
    p->Push(op->b);
    int64_t diff = p->GetPC() - pc_jump;
    p->SetOperand(opr_index, diff);
})
.set_dispatch<Or>([](const Or *op, CodeGenStackVM* p) {
    p->Push(op->a);
    int64_t pc_jump = p->GetPC();
    int64_t opr_index = p->PushOp(StackVM::RJUMP_IF_TRUE, 0);
    p->Push(op->b);
    int64_t diff = p->GetPC() - pc_jump;
    p->SetOperand(opr_index, diff);
})
.set_dispatch<Not>([](const Not* op, CodeGenStackVM* p) {
    p->PushOp(StackVM::NOT);
  });


TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, CodeGenStackVM* p) {
    p->Push(op->body);
  })
.set_dispatch<For>([](const For *op, CodeGenStackVM* p) {
    CHECK(is_zero(op->min));
    int vid = p->AllocVarID(op->loop_var.get());
    p->PushOp(StackVM::PUSH_I64, 0);
    int64_t loop_head = p->GetPC();
    p->PushOp(StackVM::STORE_HEAP, vid);
    p->PushOp(StackVM::LOAD_HEAP, vid);
    p->Push(op->extent);
    p->PushOp(StackVM::LT_I64);
    int64_t label_fjump = p->GetPC();
    int64_t foward_jump = p->PushOp(StackVM::RJUMP_IF_FALSE, 0);
    p->PushOp(StackVM::POP);
    p->Push(op->body);
    p->PushOp(StackVM::LOAD_HEAP, vid);
    p->PushOp(StackVM::PUSH_I64, 1);
    p->PushOp(StackVM::ADD_I64);
    int64_t label_bjump = p->GetPC();
    int64_t backward_jump = p->PushOp(StackVM::RJUMP, 0);
    int64_t loop_end = p->GetPC();
    p->PushOp(StackVM::POP);
    p->SetOperand(foward_jump, loop_end - label_fjump);
    p->SetOperand(backward_jump, loop_head - label_bjump);
  })
.set_dispatch<Block>([](const Block *op, CodeGenStackVM* p) {
    p->Push(op->first);
    if (op->rest.defined()) p->Push(op->rest);
  })
.set_dispatch<Evaluate>([](const Evaluate *op, CodeGenStackVM* p) {
    if (is_const(op->value)) return;
    p->Push(op->value);
    p->PushOp(StackVM::POP);
  })
.set_dispatch<IfThenElse>([](const IfThenElse *op, CodeGenStackVM* p) {
    p->Push(op->condition);
    int64_t label_ejump = p->GetPC();
    int64_t else_jump = p->PushOp(StackVM::RJUMP_IF_FALSE, 0);
    p->PushOp(StackVM::POP);
    p->Push(op->then_case);
    if (op->else_case.defined()) {
      int64_t label_then_jump = p->GetPC();
      int64_t then_jump = p->PushOp(StackVM::RJUMP, 0);
      int64_t else_begin = p->GetPC();
      p->SetOperand(else_jump, else_begin - label_ejump);
      p->PushOp(StackVM::POP);
      p->Push(op->else_case);
      int64_t if_end = p->GetPC();
      p->SetOperand(then_jump, if_end - label_then_jump);
    } else {
      int64_t if_end = p->GetPC();
      p->SetOperand(else_jump, if_end - label_ejump);
      p->PushOp(StackVM::POP);
    }
  })
.set_dispatch<LetStmt>([](const LetStmt *op, CodeGenStackVM* p) {
    p->Push(op->value);
    int64_t vid = p->AllocVarID(op->var.get());
    p->PushOp(StackVM::STORE_HEAP, vid);
    p->Push(op->body);
  })
.set_dispatch<Ramp>([](const Ramp *op, CodeGenStackVM* p) {
    LOG(FATAL) << "Ramp is not supported";
  })
.set_dispatch<Broadcast>([](const Broadcast *op, CodeGenStackVM* p) {
    LOG(FATAL) << "Broadcast is not supported";
  })
.set_dispatch<Select>([](const Select *op, CodeGenStackVM* p) {
    p->Push(op->true_value);
    p->Push(op->false_value);
    p->Push(op->condition);
    p->PushOp(StackVM::SELECT);
  })
.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenStackVM* p) {
    if (op->message.as<StringImm>()) {
      int sid = p->GetStrID(op->message.as<StringImm>()->value);
      p->Push(op->condition);
      p->PushOp(StackVM::ASSERT, sid);
    }
  })
.set_dispatch<AttrStmt>([](const AttrStmt *op, CodeGenStackVM* p) {
    p->Push(op->body);
  })
.set_dispatch<Let>([](const Let *op, CodeGenStackVM* p) {
    p->Push(op->value);
    int64_t vid = p->AllocVarID(op->var.get());
    p->PushOp(StackVM::STORE_HEAP, vid);
    p->Push(op->body);
  })
.set_dispatch<Load>([](const Load *op, CodeGenStackVM* p) {
    p->Push_(op);
  })
.set_dispatch<Store>([](const Store *op, CodeGenStackVM* p) {
    p->Push_(op);
  })
.set_dispatch<Allocate>([](const Allocate *op, CodeGenStackVM* p) {
    p->Push_(op);
  })
.set_dispatch<Call>([](const Call *op, CodeGenStackVM* p) {
    p->Push_(op);
  });


}  // namespace codegen
}  // namespace tvm