/*!
 *  Copyright (c) 2018 by Contributors
 * \file ir_builder.h
 * \brief Utility for building SPIRV code block
 */
#ifndef TVM_CODEGEN_SPIRV_IR_BUILDER_H_
#define TVM_CODEGEN_SPIRV_IR_BUILDER_H_

#include <tvm/runtime/packed_func.h>
#include <tvm/ir.h>

#include <algorithm>
#include <utility>
#include <vector>
#include <string>
#include <map>

#include <spirv.hpp>

namespace tvm {
namespace codegen {
namespace spirv {

/*! \brief Represent the SPIRV Type */
struct SType {
  /*! \brief The Id to represent type */
  uint32_t id{0};
  /*! \brief corresponding TVM type */
  tvm::Type type;
  /*! \brief content type id if it is a pointer/struct-array class */
  uint32_t element_type_id{0};
  /*! \brief The storage class, if it is a pointer */
  spv::StorageClass storage_class{spv::StorageClassMax};
};

enum ValueKind {
  kNormal,
  kConstant,
  kVectorPtr,
  kStructArrayPtr,
  kPushConstantPtr,
  kFunction,
  kExtInst
};

/*! \brief Represent the SPIRV Value */
struct Value {
  /*! \brief The Id to represent value */
  uint32_t id{0};
  /*! \brief The data type */
  SType stype;
  /*! \brief additional flags about the value */
  ValueKind flag{kNormal};
};

/*! \brief Represent the SPIRV Label */
struct Label {
  /*! \brief The Id to represent label */
  uint32_t id{0};
};

/*!
 * \brief A SPIRV instruction,
 *  can be used as handle to modify its content later
 */
class Instr {
 public:
  /*! \return the word count */
  uint32_t WordCount() const {
    return word_count_;
  }
  /*!
   * \brief Access idx-th word of instruction
   * \param idx The index
   * \return reference to idx-th word.
   */
  uint32_t& operator[](uint32_t idx) {
    CHECK_LT(idx, word_count_);
    return (*data_)[begin_ + idx];
  }

 private:
  friend class InstrBuilder;
  /*!
   * \brief the data that backs this instruction
   *  Have to use vector reference because
   *  vector can change.
   */
  std::vector<uint32_t>* data_{nullptr};
  /*! \brief begin location of instruction */
  uint32_t begin_{0};
  /*! \brief work count */
  uint32_t word_count_{0};
};

/*! \brief Representation of phi value */
struct PhiValue : public Value {
  /*! \brief The corresponding instr */
  Instr instr;
  /*!
   * \brief Add incoming information of a PhiValue
   * \param index The location of Phi
   * \param value The value to come
   * \param parent The parent label.
   */
  void SetIncoming(uint32_t index,
                   const Value& value,
                   const Label& parent) {
    CHECK_EQ(this->stype.id, value.stype.id);
    instr[3 + index * 2] = value.id;
    instr[3 + index * 2 + 1] = parent.id;
  }
};

/*!
 * \brief Helper class to build SPIRV instruction.
 *
 * \code
 *
 *  std::vector<uint32_t> func_seg_vec_;
 *  InstrBuilder ib;
 *
 *  // construct and append to the end of func_seg_vec_;
 *  ib.Begin(spv::OpIAdd)
 *    .Add(result).Add(v1).Add(v2)
 *    .Commit(&func_seg_vec_);
 *
 * \endcode
 */
class InstrBuilder {
 public:
  /*!
   * \brief Begin construction of instruction.
   * \param op The op code
   * \return reference to self.
   */
  InstrBuilder& Begin(spv::Op op) {  // NOLINT(*);
    // finish previous build
    CHECK_EQ(data_.size(), 0U);
    op_ = op;
    data_.push_back(0);
    return *this;
  }
  /*!
   * \brief Add v to end of instruction.
   * \param v The value to be appended to the instruction.
   * \return reference to self.
   */
  InstrBuilder& Add(const Value& v) {
    data_.push_back(v.id);
    return *this;
  }
  /*!
   * \brief Add v to end of instruction.
   * \param v The type to be appended to the instruction.
   * \return reference to self.
   */
  InstrBuilder& Add(const SType& v) {
    data_.push_back(v.id);
    return *this;
  }
  /*!
   * \brief Add v to end of instruction.
   * \param v The label to be appended to the instruction.
   * \return reference to self.
   */
  InstrBuilder& Add(const Label& v) {
    data_.push_back(v.id);
    return *this;
  }
  /*!
   * \brief Add a word to end of instruction.
   * \param v The value to be added.
   * \return reference to self.
   */
  InstrBuilder& Add(const uint32_t& v) {
    data_.push_back(v);
    return *this;
  }
  /*!
   * \brief Add string literal of end of instruction.
   * \param v The string literal to be appended.
   * \return reference to self.
   */
  InstrBuilder& Add(const std::string& v) {
    const uint32_t kWordSize = sizeof(uint32_t);
    uint32_t nwords =
        (static_cast<uint32_t>(v.length()) + kWordSize) / kWordSize;
    size_t begin = data_.size();
    data_.resize(begin + nwords, 0U);
    std::copy(v.begin(), v.end(),
              reinterpret_cast<char*>(&data_[begin]));
    return *this;
  }
  /*!
   * \brief add sequence of values to instruction
   * \param args The instruction sequence
   * \return reference to self.
   * \tparams Args The positional arguments
   */
  template<typename... Args>
  InstrBuilder& AddSeq(Args&& ...args) {
    AddSeqHelper helper;
    helper.builder = this;
    runtime::detail::for_each(helper, std::forward<Args>(args)...);
    return *this;
  }
  /*!
   * \brief Finish build, commit the current
   *   instruction to the end of seg.
   *
   * \param seg The code segment to commit to
   * \return The result instruction.
   */
  Instr Commit(std::vector<uint32_t>* seg) {
    Instr ret;
    ret.data_ = seg;
    ret.begin_ = seg->size();
    ret.word_count_ = static_cast<uint32_t>(data_.size());
    data_[0] = op_ | (ret.word_count_ << spv::WordCountShift);
    seg->insert(seg->end(), data_.begin(), data_.end());
    data_.clear();
    return ret;
  }

 private:
  // current op code.
  spv::Op op_;
  // The internal data to store code
  std::vector<uint32_t> data_;
  // helper class to support variadic arguments
  struct AddSeqHelper {
    // The reference to builder
    InstrBuilder* builder;
    // invoke function
    template<typename T>
    void operator()(size_t, const T& v) const {
      builder->Add(v);
    }
  };
};

/*!
 * \brief Builder to build up a single SPIR-V module
 *
 *  This is a thin wrapper to build SPIRV binary.
 *  SPIRV adopts structure control-flow.
 *  We can build the code by always appending to the end of the
 *  binary code block and revisit some
 *
 *  This IRBuilder did not introduce concept of BasicBlock.
 *  instead instructions are append to end of each segment.
 */
class IRBuilder {
 public:
  /*! \brief Initialize header */
  void InitHeader();
  /*! \brief Initialize the predefined contents */
  void InitPreDefs();
  /*!
   * \brief Import additional extension libraries.
   * \param name The name of the library.
   * \return The finalized binary instruction.
   */
  Value ExtInstImport(const std::string& name) {
    Value val = NewValue(SType(), kExtInst);
    ib_.Begin(spv::OpExtInstImport).AddSeq(val, name).Commit(&header_);
    return val;
  }
  /*!
   * \brief Get the final binary built from the builder
   * \return The finalized binary instruction.
   */
  std::vector<uint32_t> Finalize() {
    std::vector<uint32_t> data;
    // set bound
    const int kBoundLoc = 3;
    header_[kBoundLoc] = id_counter_;
    data.insert(data.end(), header_.begin(), header_.end());
    data.insert(data.end(), entry_.begin(), entry_.end());
    data.insert(data.end(), exec_mode_.begin(), exec_mode_.end());
    data.insert(data.end(), debug_.begin(), debug_.end());
    data.insert(data.end(), decorate_.begin(), decorate_.end());
    data.insert(data.end(), global_.begin(), global_.end());
    data.insert(data.end(), function_.begin(), function_.end());
    return data;
  }
  /*!
   * \brief Create new label
   * \return The created new label
   */
  Label NewLabel() {
    Label label;
    label.id = id_counter_++;
    return label;
  }
  /*!
   * \brief Start a new block with given label
   * \param label The label we use.
   */
  void StartLabel(Label label) {
    MakeInst(spv::OpLabel, label);
    curr_label_ = label;
  }
  /*! \return The current label */
  Label CurrentLabel() const {
    return curr_label_;
  }
  /*!
   * \brief Add code to debug segment.
   * \param op The operator
   * \param args The instruction sequence
   * \tparams Args The positional arguments
   */
  template<typename... Args>
  void Debug(spv::Op op, Args&& ...args) {
    ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&debug_);
  }
  /*!
   * \brief Add Execution mode to a function.
   * \param func The function value
   * \param args The instruction sequence
   * \tparams Args The positional arguments
   */
  template<typename... Args>
  void ExecutionMode(Value func, Args&& ...args) {
    ib_.Begin(spv::OpExecutionMode).AddSeq(
        func, std::forward<Args>(args)...).Commit(&exec_mode_);
  }
  /*!
   * \brief Add code to decorate segment.
   * \param op The operator
   * \param args The instruction sequence
   * \tparams Args The positional arguments
   */
  template<typename... Args>
  void Decorate(spv::Op op, Args&& ...args) {
    ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&decorate_);
  }
  /*!
   * \brief Add code to global segment.
   * \param op The operator
   * \param args The instruction sequence
   * \tparams Args The positional arguments
   */
  template<typename... Args>
  void DeclareGlobal(spv::Op op, Args&& ...args) {
    ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&decorate_);
  }
  /*!
   * \brief Make a new instruction and append it to end of function segment.
   *
   * \param op The operator
   * \param args The instruction sequence
   * \return The result SSA value.
   * \tparams Args The positional arguments
   */
  template<typename... Args>
  Instr MakeInst(spv::Op op, Args&& ...args) {
    return ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&function_);
  }
  /*!
   * \brief Make a new SSA value,
   *
   * \param op The operator.
   * \param out_type The result type.
   * \param args The instruction sequence
   * \return The result SSA value.
   * \tparams Args The positional arguments
   */
  template<typename... Args>
  Value MakeValue(spv::Op op, const SType& out_type, Args&& ...args) {
    Value val = NewValue(out_type, kNormal);
    MakeInst(op, out_type, val, std::forward<Args>(args)...);
    return val;
  }
  /*!
   * \brief Make a phi value.
   *
   * \param out_type The output data type.
   * \param num_incoming number of incoming blocks.
   * \return The result Phi value.
   */
  PhiValue MakePhi(const SType& out_type, uint32_t num_incoming);
  /*!
   * \brief Create a GLSL450 call
   *
   * \param ret_type The result type.
   * \param inst_id The instance id of the function.
   * \param args The arguments
   * \return The result value.
   */
  Value CallGLSL450(const SType& ret_type,
                    uint32_t inst_id,
                    const std::vector<Value>& args);
  /*!
   * \brief Build vector by concatenating components
   *
   * \param vec The vector component
   * \tparams Args The positional arguments
   */
  Value Concat(const std::vector<Value>& vec);
  /*!
   * \brief Get the spirv type for a given tvm data type.
   * \param dtype The data type.
   * \return The corresponding spirv type.
   */
  SType GetSType(const tvm::Type& dtype);
  /*!
   * \brief Get the pointer type that points to value_type
   * \param value_type.
   * \param storage_class The storage class
   * \return The corresponding spirv type.
   */
  SType GetPointerType(const SType& value_type,
                       spv::StorageClass storage_class);
  /*!
   * \brief Get a struct{ value_type[num_elems] } type.
   * \param value_type the content value type.
   * \param num_elems number of elements in array
   *   num_elems = 0 means runtime array with BufferBlock Decoration
   *
   * \return The corresponding spirv type.
   */
  SType GetStructArrayType(const SType& value_type,
                           uint32_t num_elems);
  /*!
   * \brief Get a struct array access with a given index.
   * \param ptr_type The pointer type.
   * \param buffer The buffer ptr to struct array
   * \param index The array index.
   */
  Value StructArrayAccess(const SType& ptr_type,
                          Value buffer,
                          Value index);
  /*!
   * \brief Create a cast that cast value to dst_type
   * \param dst_type The target type.
   * \param value the source value.
   * \return The result value
   */
  Value Cast(const SType& dst_type, Value value);
  /*
   * \brief Create a const integer.
   * \param dtype The content data type.
   * \param value The data value.
   */
  Value IntImm(const SType& dtype, int64_t value);
  /*
   * \brief Create a const unsigned integer.
   * \param dtype The content data type.
   * \param value The data value.
   */
  Value UIntImm(const SType& dtype, uint64_t value);
  /*
   * \brief Create a const float.
   * \param dtype The content data type.
   * \param value The data value.
   */
  Value FloatImm(const SType& dtype, double value);
  /*
   * \brief Declare buffer argument of function
   *
   * \param arg_type The type of argument.
   * \param descriptor_set The descriptor set we want to use.
   * \param binding The binding locaiton in descriptor set.
   * \param The argument type.
   */
  Value BufferArgument(const SType& value_type,
                       uint32_t descriptor_set,
                       uint32_t binding);
  /*!
   * \brief Declare POD arguments through push constants.
   *
   * \note Only call this function once!
   * \param value_types The values in the push constant
   * \return reference to self.
   */
  Value DeclarePushConstant(const std::vector<SType>& value_types);
  /*!
   * \brief Get i-th push constant
   * \param v_type The value type
   * \param index The push constant index
   * \return the value of push constant
   */
  Value GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index);
  /*!
   * \brief Declare a new function
   * \return The created function ID.
   */
  Value NewFunction();
  /*!
   * \brief Declare the entry point for a kernel function. This should be
   * invoked after building the function so the builder is aware of which
   * variables to declare as part of the function's interface.
   * \param func The previously declared function.
   * \param name Name of the entry point.
   */
  void CommitKernelFunction(const Value& func, const std::string& name);
  /*!
   * \brief Start function scope.
   * \param func function to be started.
   */
  void StartFunction(const Value& func);
  /*!
   * \brief Set the local size of the function
   * \param func function of interest
   * \param local_size The local workgroup_size
   */
  void SetLocalSize(const Value& func, uint32_t local_size[3]);
  /*
   * \brief Allocate space
   * \param value_type The content value type
   * \param num_elems Number of elements to allocate.
   * \param storage_class The storage class we want to store to.
   */
  Value Allocate(const SType& value_type,
                 uint32_t num_elems,
                 spv::StorageClass storage_class);
  /*
   * \brief Get the i-th workgroup id.
   * \return The value representing the workgroup id.
   */
  Value GetWorkgroupID(uint32_t dim_index);
  /*
   * \brief Get the i-th local id.
   * \return The value representing the local id.
   */
  Value GetLocalID(uint32_t dim_index);
  // Expressions
  Value Add(Value a, Value b);
  Value Sub(Value a, Value b);
  Value Mul(Value a, Value b);
  Value Div(Value a, Value b);
  Value Mod(Value a, Value b);
  Value EQ(Value a, Value b);
  Value NE(Value a, Value b);
  Value LT(Value a, Value b);
  Value LE(Value a, Value b);
  Value GT(Value a, Value b);
  Value GE(Value a, Value b);
  Value Select(Value cond, Value a, Value b);

 private:
  /*!
   * \brief Create new value
   * \return The created new label
   */
  Value NewValue(const SType& stype, ValueKind flag) {
    Value val;
    val.id = id_counter_++;
    val.stype = stype;
    val.flag = flag;
    return val;
  }
  // get constant given value encoded in uint64_t
  Value GetConst_(const SType& dtype, const uint64_t* pvalue);
  // declare type
  SType DeclareType(const Type& dtype);
  /*! \brief internal instruction builder  */
  InstrBuilder ib_;
  /*! \brief Current label */
  Label curr_label_;
  /*! \brief The current maximum id */
  uint32_t id_counter_{1};
  /*! \brief glsl 450 extension */
  Value ext_glsl450_;
  /*! \brief Special cache int32, fp32, void*/
  SType t_bool_, t_int32_, t_uint32_, t_fp32_, t_void_, t_void_func_;
  /*! \brief quick cache for const one i32 */
  Value const_i32_zero_;
  /*! \brief cache value for workgroup_id, local_id */
  Value workgroup_id_, local_id_;
  /*! \brief whether push constant is defined */
  Value push_const_;
  /*! \brief map from type code to the type */
  std::unordered_map<uint32_t, SType> pod_type_tbl_;
  /*! \brief map from value to array type */
  std::map<std::pair<uint32_t, uint32_t>, SType> struct_array_type_tbl_;
  /*! \brief map from value to its pointer type */
  std::map<std::pair<uint32_t, spv::StorageClass>, SType> pointer_type_tbl_;
  /*! \brief map from constant int to its value */
  std::map<std::pair<uint32_t, uint64_t>, Value> const_tbl_;
  /*! \brief Header segment, include import */
  std::vector<uint32_t> header_;
  /*! \brief engtry point segment */
  std::vector<uint32_t> entry_;
  /*! \brief Header segment */
  std::vector<uint32_t> exec_mode_;
  /*! \brief Debug segment */
  std::vector<uint32_t> debug_;
  /*! \brief Annotation segment */
  std::vector<uint32_t> decorate_;
    /*! \brief Global segment: types, variables, types */
  std::vector<uint32_t> global_;
  /*! \brief Function segment */
  std::vector<uint32_t> function_;
};

}  // namespace spirv
}  // namespace codegen
}  // namespace tvm

#endif  // TVM_CODEGEN_SPIRV_IR_BUILDER_H_