/*!
 *  Copyright (c) 2017 by Contributors
 *  Lower TVM related buildin intrinsics such as packed call.
 * \file lower_tvm_buildin.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"

namespace tvm {
namespace ir {

inline Expr ConstInt32(size_t index) {
  CHECK_LE(index, std::numeric_limits<int>::max());
  return make_const(Int(32), static_cast<int>(index));
}

inline Expr StackAlloca(std::string type, size_t num) {
  Array<Expr> args = {StringImm::make(type), ConstInt32(num)};
  return Call::make(Handle(), intrinsic::tvm_stack_alloca, args, Call::Intrinsic);
}

// Calculate the statistics of packed function.
// These information are needed during codegen.
class BuiltinLower : public IRMutator {
 public:
  Stmt Build(Stmt stmt) {
    stack_shape_ = Var("stack_shape", Handle());
    stack_array_ = Var("stack_array", Handle());
    stack_value_ = Var("stack_value", Handle());
    stack_tcode_ = Var("stack_tcode", Handle());
    stmt = this->Mutate(stmt);
    if (max_shape_stack_ != 0) {
      stmt = LetStmt::make(
          stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
    }
    if (max_array_stack_ != 0) {
      stmt = LetStmt::make(
          stack_array_, StackAlloca("array", max_array_stack_), stmt);
    }
    if (max_arg_stack_ != 0) {
      stmt = LetStmt::make(
          stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
      stmt = LetStmt::make(
          stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
    }
    return stmt;
  }

  Stmt Mutate(Stmt stmt) final {
    stmt = IRMutator::Mutate(stmt);
    CHECK_EQ(run_shape_stack_, 0);
    CHECK_EQ(run_array_stack_, 0);
    CHECK_EQ(run_arg_stack_, 0);
    while (prep_seq_.size() != 0) {
      stmt = Block::make(prep_seq_.back(), stmt);
      prep_seq_.pop_back();
    }
    return stmt;
  }

  Stmt Mutate_(const Allocate* op, const Stmt& s) {
    // Lower allocate to device allocate when needed.
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Allocate>();
    if (op->new_expr.defined()) return stmt;
    // Get constant allocation bound.
    int64_t dev_type;
    int64_t nbytes = GetVectorBytes(op->type);
    if (device_type_.defined()) {
      if (arith::GetConst(device_type_, &dev_type)) {
        if (dev_type == kDLCPU) {
          int32_t constant_size = op->constant_allocation_size();
          if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) {
            return stmt;
          }
        }
      }
    }
    Expr total_bytes = make_const(op->extents[0].type(), nbytes);
    for (size_t i = 0; i < op->extents.size(); ++i) {
      total_bytes = total_bytes * op->extents[i];
    }
    CHECK(device_type_.defined()) << "Unknown device type in current IR";
    CHECK(device_id_.defined()) << "Unknown device id in current IR";
    Stmt throw_last_error = Evaluate::make(Call::make(Int(32),
                                           intrinsic::tvm_throw_last_error, {},
                                           Call::Intrinsic));

    Stmt body = Block::make(
        IfThenElse::make(Call::make(Bool(1),
                                    intrinsic::tvm_handle_is_null,
                                    {op->buffer_var}, Call::PureIntrinsic),
                         throw_last_error),
        op->body);

    Stmt alloca = LetStmt::make(
        op->buffer_var,
        Call::make(op->buffer_var.type(),
                   "TVMBackendAllocWorkspace",
                   {cast(Int(32), device_type_),
                    cast(Int(32), device_id_),
                    cast(UInt(64), total_bytes),
                    IntImm::make(Int(32), op->type.code()),
                    IntImm::make(Int(32), op->type.bits())},
                   Call::Extern),
        body);

    Expr free_op = Call::make(Int(32),
                              "TVMBackendFreeWorkspace",
                              {cast(Int(32), device_type_),
                                    cast(Int(32), device_id_),
                                    op->buffer_var},
                              Call::Extern);
    Stmt free_stmt = IfThenElse::make(free_op != make_zero(Int(32)), throw_last_error);
    body = Block::make(alloca, free_stmt);
    body = AttrStmt::make(
        op->buffer_var, attr::storage_alignment,
        make_const(Int(32), runtime::kTempAllocaAlignment),
        body);
    return body;
  }

  Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
    if (op->attr_key == attr::device_context_id) {
      CHECK(!device_id_.defined());
      device_id_ = op->value;
      return Mutate(op->body);
    } else if (op->attr_key == attr::device_context_type) {
      CHECK(!device_type_.defined());
      device_type_ = op->value;
      return Mutate(op->body);
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }
  Expr Mutate_(const Call* op, const Expr &e) final {
    if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
      return MakeCallPacked(op, e);
    } else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
      return MakeShape(op, e);
    } else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) {
      return MakeArray(op, e);
    } else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
      return make_zero(op->type);
    } else {
      return IRMutator::Mutate_(op, e);
    }
  }
  // call shape
  Expr MakeShape(const Call* op, const Expr& e) {
    size_t stack_begin = run_shape_stack_;
    run_shape_stack_ += op->args.size();
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Call>();
    for (size_t i = 0; i < op->args.size(); ++i) {
      prep_seq_.emplace_back(
          Store::make(stack_shape_, cast(Int(64), op->args[i]),
                      ConstInt32(stack_begin +i), const_true(1)));
    }
    return AddressOffset(stack_shape_, Int(64), stack_begin);
  }
  // make array
  Expr MakeArray(const Call* op, const Expr& e) {
    size_t idx = run_array_stack_;
    run_array_stack_ += 1;
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Call>();
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1]));
    Expr strides = op->args[2];
    if (!strides.defined() || is_zero(strides)) {
      strides = make_zero(Handle());
    }
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides));
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3]));
    Type dtype = op->args[4].type();
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode,
                     make_const(UInt(8), static_cast<int>(dtype.code()))));
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits,
                     make_const(UInt(8), dtype.bits())));
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes,
                     make_const(UInt(16), dtype.lanes())));
    // set byte offset
    int data_bytes = GetVectorBytes(dtype);
    Expr byte_offset = op->args[5];
    if (!is_zero(byte_offset)) {
      byte_offset = byte_offset * make_const(byte_offset.type(), data_bytes);
    }
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
                     cast(UInt(64), byte_offset)));
    CHECK(device_type_.defined()) << "Unknown device type in current IR";
    CHECK(device_id_.defined()) << "Unknown device id in current IR";
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId,
                     cast(Int(32), device_id_)));
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType,
                     cast(Int(32), device_type_)));
    return TVMStructGet(Handle(), stack_array_, idx, intrinsic::kArrAddr);
  }
  // call packled.
  Expr MakeCallPacked(const Call* op, const Expr& e) {
    size_t restore_shape_stack = run_shape_stack_;
    size_t restore_array_stack = run_array_stack_;
    size_t arg_stack_begin = run_arg_stack_;
    run_arg_stack_ += op->args.size();
    // Specially handle the buffer packed intrinsic
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Call>();
    for (size_t i = 1; i < op->args.size(); ++i) {
      Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
      Expr arg = op->args[i];
      Type t = arg.type();
      Type api_type = APIType(t);
      if (t != api_type) {
        arg = Cast::make(api_type, arg);
      }
      prep_seq_.emplace_back(TVMStructSet(
          stack_value_, static_cast<int>(arg_stack_begin + i - 1),
          intrinsic::kTVMValueContent, arg));
      int arg_tcode = api_type.code();
      if (IsArrayHandle(arg)) arg_tcode = kArrayHandle;
      prep_seq_.emplace_back(
          Store::make(stack_tcode_,
                      ConstInt32(arg_tcode),
                      stack_index, const_true(1)));
    }
    // UPDATE stack value
    max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
    max_shape_stack_ = std::max(run_shape_stack_, max_shape_stack_);
    max_array_stack_ = std::max(run_array_stack_, max_array_stack_);
    run_shape_stack_ = restore_shape_stack;
    run_array_stack_ = restore_array_stack;
    run_arg_stack_ = arg_stack_begin;
    Array<Expr> packed_args = {
      op->args[0],
      stack_value_,
      stack_tcode_,
      ConstInt32(arg_stack_begin),
      ConstInt32(arg_stack_begin + op->args.size() - 1)
    };
    return Call::make(
        Int(32), intrinsic::tvm_call_packed_lowered,
        packed_args, Call::Intrinsic);
  }

 private:
  bool IsArrayHandle(const Expr& arg) {
    // specially set array handle.
    if (const Call* buf = arg.as<Call>()) {
      if (buf->is_intrinsic(intrinsic::tvm_struct_get) &&
          buf->args[2].as<IntImm>()->value == intrinsic::kArrAddr) {
        return true;
      }
    }
    return false;
  }

  // The prepration sequence to be emitted.
  std::vector<Stmt> prep_seq_;
  Expr device_type_;
  Expr device_id_;
  // Var handle for each stack.
  Var stack_shape_;
  Var stack_array_;
  Var stack_tcode_;
  Var stack_value_;
  // The running statistics
  uint64_t run_shape_stack_{0};
  uint64_t run_array_stack_{0};
  uint64_t run_arg_stack_{0};
  // statistics of stacks
  uint64_t max_shape_stack_{0};
  uint64_t max_array_stack_{0};
  uint64_t max_arg_stack_{0};
};

LoweredFunc LowerTVMBuiltin(LoweredFunc f) {
  auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
  n->body = BuiltinLower().Build(n->body);
  return LoweredFunc(n);
}

}  // namespace ir
}  // namespace tvm