/*!
 *  Copyright (c) 2017 by Contributors
 * \file pack_args.h
 * \brief Utility to pack TVMArgs to other type-erased fution calling convention.
 *
 *  Two type erased function signatures are supported.
 *   - cuda_style(void** args, int num_args);
 *      - Pack everything by address
 *   - metal_style(void** buffers, int num_buffers,
 *                 union_32bit args[N], int num_args);
 *      - Pack buffer by address, pack rest parameter into 32bit union buffer.
 */
#ifndef TVM_RUNTIME_PACK_ARGS_H_
#define TVM_RUNTIME_PACK_ARGS_H_

#include <tvm/runtime/c_runtime_api.h>
#include <vector>
#include <cstring>

namespace tvm {
namespace runtime {
/*!
 * \brief argument union type of 32bit.
 * Choose 32 bit because most GPU API do not work well with 64 bit.
 */
union ArgUnion {
  int32_t v_int32;
  uint32_t v_uint32;
  float v_float32;
};
/*!
 * \brief Create a packed function from void addr types.
 *
 * \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
 * \param arg_types The arguments type information.
 * \tparam F the function type
 *
 * \return The wrapped packed function.
 */
template<typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
/*!
 * \brief Create a packed function that from function only packs buffer arguments.
 *
 * \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
 * \param arg_types The arguments type information.
 * \tparam F the function type
 *
 * \return The wrapped packed function.
 */
template<typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types);
/*!
 * \brief Create a packed function that from function that takes a packed arguments.
 *
 * \param f with signature (TVMArgs args, TVMRetValue* rv, void* pack_args, size_t nbytes)
 * \param arg_types The arguments that wish to get from
 * \tparam F the function type
 *
 * \return The wrapped packed function.
 */
template<typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types);
/*!
 * \brief Extract number of buffer argument from the argument types.
 * \param arg_types The argument types.
 * \return number of buffer arguments
 */
inline size_t NumBufferArgs(const std::vector<TVMType>& arg_types);

// implementations details
namespace detail {
template<typename T, int kSize>
class TempArray {
 public:
  explicit TempArray(int size) {}
  T* data() {
    return data_;
  }
 private:
  T data_[kSize];
};
template<typename T>
class TempArray<T, 0> {
 public:
  explicit TempArray(int size) : data_(size) {}
  T* data() {
    return data_.data();
  }
 private:
  std::vector<T> data_;
};

/*! \brief conversion code used in void arg. */
enum ArgConvertCode {
  INT64_TO_INT64,
  INT64_TO_INT32,
  INT64_TO_UINT32,
  FLOAT64_TO_FLOAT32,
  FLOAT64_TO_FLOAT64,
  HANDLE_TO_HANDLE
};

inline ArgConvertCode GetArgConvertCode(TVMType t) {
  CHECK_EQ(t.lanes, 1U)
      << "Cannot pass vector type argument to devic function for now";
  if (t.code == kDLInt) {
    if (t.bits == 64U) return INT64_TO_INT64;
    if (t.bits == 32U) return INT64_TO_INT32;
  } else if (t.code == kDLUInt) {
    if (t.bits == 32U) return INT64_TO_UINT32;
  } else if (t.code == kDLFloat) {
    if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
    if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
  } else if (t.code == kHandle) {
    return HANDLE_TO_HANDLE;
  }
  LOG(FATAL) << "Cannot handle " << t << " as device function argument";
  return HANDLE_TO_HANDLE;
}

template<int N, typename F>
inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& codes) {
  int num_args = static_cast<int>(codes.size());
  auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
    TempArray<void*, N> addr_(num_args);
    TempArray<ArgUnion, N> holder_(num_args);
    void** addr = addr_.data();
    ArgUnion* holder = holder_.data();
    for (int i = 0; i < num_args; ++i) {
      switch (codes[i]) {
        case INT64_TO_INT64:
        case FLOAT64_TO_FLOAT64:
        case HANDLE_TO_HANDLE: {
          addr[i] = (void*)&(args.values[i]);  // NOLINT(*)
          break;
        }
        case INT64_TO_INT32: {
          holder[i].v_int32 = static_cast<int32_t>(args.values[i].v_int64);
          addr[i] = &(holder[i]);
          break;
        }
        case INT64_TO_UINT32 : {
          holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64);
          addr[i] = &(holder[i]);
          break;
        }
        case FLOAT64_TO_FLOAT32: {
          holder[i].v_float32 = static_cast<float>(args.values[i].v_float64);
          addr[i] = &(holder[i]);
          break;
        }
      }
    }
    f(args, ret, addr);
  };
  return PackedFunc(ret);
}

template<int N, typename F>
inline PackedFunc PackFuncNonBufferArg_(
    F f, int base, const std::vector<ArgConvertCode>& codes) {
  int num_args = static_cast<int>(codes.size());
  auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) {
    TempArray<ArgUnion, N> holder_(num_args);
    ArgUnion* holder = holder_.data();
    for (int i = 0; i < num_args; ++i) {
      switch (codes[i]) {
        case INT64_TO_INT64:
        case FLOAT64_TO_FLOAT64: {
          LOG(FATAL) << "Donot support 64bit argument to device function"; break;
        }
        case INT64_TO_INT32: {
          holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64);
          break;
        }
        case INT64_TO_UINT32 : {
          holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64);
          break;
        }
        case FLOAT64_TO_FLOAT32: {
          holder[i].v_float32 = static_cast<float>(args.values[base + i].v_float64);
          break;
        }
        case HANDLE_TO_HANDLE: {
          LOG(FATAL) << "not reached"; break;
        }
      }
    }
    f(args, ret, holder);
  };
  return PackedFunc(ret);
}

template<int N, typename F>
inline PackedFunc PackFuncPackedArg_(
    F f, const std::vector<ArgConvertCode>& codes) {
  int num_args = static_cast<int>(codes.size());
  auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
    TempArray<uint64_t, N> pack_(num_args);
    int32_t* pack = reinterpret_cast<int32_t*>(pack_.data());
    int32_t* ptr = pack;
    static_assert(sizeof(TVMValue) == 8, "invariant");
    static_assert(sizeof(void*) % sizeof(int32_t) == 0, "invariant");
    for (int i = 0; i < num_args; ++i) {
      switch (codes[i]) {
        case HANDLE_TO_HANDLE: {
          std::memcpy(ptr, &(args.values[i].v_handle), sizeof(void*));
          ptr += sizeof(void*) / sizeof(int32_t);
          break;
        }
        case INT64_TO_INT64:
        case FLOAT64_TO_FLOAT64: {
          std::memcpy(ptr, &args.values[i], sizeof(TVMValue));
          ptr += 2;
          break;
        }
        case INT64_TO_INT32: {
          *ptr = static_cast<int32_t>(args.values[i].v_int64);
          ++ptr;
          break;
        }
        case INT64_TO_UINT32 : {
          *reinterpret_cast<uint32_t*>(ptr) =
              static_cast<uint32_t>(args.values[i].v_int64);
          ++ptr;
          break;
        }
        case FLOAT64_TO_FLOAT32: {
          *reinterpret_cast<float*>(ptr) =
              static_cast<float>(args.values[i].v_float64);
          ++ptr;
          break;
        }
        default: {
          LOG(FATAL) << "not reached"; break;
        }
      }
    }
    f(args, ret, pack, (ptr - pack) * sizeof(int32_t));
  };
  return PackedFunc(ret);
}
}  // namespace detail

template<typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types) {
  std::vector<detail::ArgConvertCode> codes(arg_types.size());
  for (size_t i = 0; i < arg_types.size(); ++i) {
    codes[i] = detail::GetArgConvertCode(arg_types[i]);
  }
  size_t num_void_args = arg_types.size();
  // specialization
  if (num_void_args <= 4) {
    return detail::PackFuncVoidAddr_<4>(f, codes);
  } else if (num_void_args <= 8) {
    return detail::PackFuncVoidAddr_<8>(f, codes);
  } else {
    return detail::PackFuncVoidAddr_<0>(f, codes);
  }
}

inline size_t NumBufferArgs(const std::vector<TVMType>& arg_types) {
  size_t base = arg_types.size();
  for (size_t i = 0; i < arg_types.size(); ++i) {
    if (arg_types[i].code != kHandle) {
      base = i; break;
    }
  }
  for (size_t i = base; i < arg_types.size(); ++i) {
    CHECK(arg_types[i].code != kHandle)
        << "Device function need to be organized";
  }
  return base;
}

template<typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types) {
  size_t num_buffer = NumBufferArgs(arg_types);
  std::vector<detail::ArgConvertCode> codes;
  for (size_t i = num_buffer; i < arg_types.size(); ++i) {
    codes.push_back(detail::GetArgConvertCode(arg_types[i]));
  }
  int base = static_cast<int>(num_buffer);
  size_t nargs = codes.size();
  // specialization
  if (nargs <= 4) {
    return detail::PackFuncNonBufferArg_<4>(f, base, codes);
  } else {
    return detail::PackFuncNonBufferArg_<0>(f, base, codes);
  }
}

template<typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types) {
  std::vector<detail::ArgConvertCode> codes;
  for (size_t i = 0; i < arg_types.size(); ++i) {
    codes.push_back(detail::GetArgConvertCode(arg_types[i]));
  }
  size_t nargs = codes.size();
  // specialization
  if (nargs <= 4) {
    return detail::PackFuncPackedArg_<4>(f, codes);
  } else {
    return detail::PackFuncPackedArg_<0>(f, codes);
  }
}
}  // namespace runtime
}  // namespace tvm
#endif  // TVM_RUNTIME_PACK_ARGS_H_