/*!
 *  Copyright (c) 2017 by Contributors
 * \file Use external miopen utils function
 */
#include "miopen_utils.h"
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <vector>
#include <string>

namespace tvm {
namespace contrib {
namespace miopen {

std::string miopenGetErrorString(int error_code) {
  const std::vector<std::string> mio_err{
      "StatusSuccess        ", "StatusNotInitialized ", "StatusInvalidValue   ",
      "StatusBadParm        ", "StatusAllocFailed    ", "StatusInternalError  ",
      "StatusNotImplemented ", "StatusUnknownError   "};
  return mio_err[error_code];
}

// MiopenThreadEntry
MIOpenThreadEntry::MIOpenThreadEntry() {
  auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream;
  auto func = runtime::Registry::Get("device_api.rocm");
  void *ret = (*func)();
  rocm_api = static_cast<runtime::DeviceAPI*>(ret);
  MIOPEN_CALL(miopenCreate(&handle));
  MIOPEN_CALL(miopenSetStream(handle, stream));
  conv_entry.rocm_api = rocm_api;
}

MIOpenThreadEntry::~MIOpenThreadEntry() {
  MIOPEN_CALL(miopenDestroy(handle));
}

typedef dmlc::ThreadLocalStore<MIOpenThreadEntry> MIOpenThreadStore;

MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() {
  return MIOpenThreadStore::Get();
}

// ConvEntry

ConvEntry::ConvEntry() {
  MIOPEN_CALL(miopenCreateConvolutionDescriptor(&conv_desc));
  MIOPEN_CALL(miopenCreateTensorDescriptor(&filter_desc));
  MIOPEN_CALL(miopenCreateTensorDescriptor(&input_desc));
  MIOPEN_CALL(miopenCreateTensorDescriptor(&output_desc));
}

ConvEntry::~ConvEntry() {
  MIOPEN_CALL(miopenDestroyConvolutionDescriptor(conv_desc));
  MIOPEN_CALL(miopenDestroyTensorDescriptor(filter_desc));
  MIOPEN_CALL(miopenDestroyTensorDescriptor(input_desc));
  MIOPEN_CALL(miopenDestroyTensorDescriptor(output_desc));
  CleanWorkspace();
}

void ConvEntry::UpdateWorkspace(const size_t wsize) {
  if (workspace_size < wsize) {
    if (workspace != nullptr) {
      CleanWorkspace();
    }
    workspace_size = wsize;
    workspace = rocm_api->AllocWorkspace(ctx, workspace_size);
  }
}

void ConvEntry::CleanWorkspace() {
  if (workspace) rocm_api->FreeWorkspace(ctx, workspace);
  workspace_size = 0;
}

}  // namespace miopen
}  // namespace contrib
}  // namespace tvm