/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2017 by Contributors
 * \file Use external miopen utils function
 */
#include "miopen_utils.h"
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <vector>
#include <string>

namespace tvm {
namespace contrib {
namespace miopen {

std::string miopenGetErrorString(int error_code) {
  const std::vector<std::string> mio_err{
      "StatusSuccess        ", "StatusNotInitialized ", "StatusInvalidValue   ",
      "StatusBadParm        ", "StatusAllocFailed    ", "StatusInternalError  ",
      "StatusNotImplemented ", "StatusUnknownError   "};
  return mio_err[error_code];
}

// MiopenThreadEntry
MIOpenThreadEntry::MIOpenThreadEntry() {
  auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream;
  auto func = runtime::Registry::Get("device_api.rocm");
  void *ret = (*func)();
  rocm_api = static_cast<runtime::DeviceAPI*>(ret);
  MIOPEN_CALL(miopenCreate(&handle));
  MIOPEN_CALL(miopenSetStream(handle, stream));
  conv_entry.rocm_api = rocm_api;
}

MIOpenThreadEntry::~MIOpenThreadEntry() {
  MIOPEN_CALL(miopenDestroy(handle));
}

typedef dmlc::ThreadLocalStore<MIOpenThreadEntry> MIOpenThreadStore;

MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() {
  return MIOpenThreadStore::Get();
}

// ConvEntry

ConvEntry::ConvEntry() {
  MIOPEN_CALL(miopenCreateConvolutionDescriptor(&conv_desc));
  MIOPEN_CALL(miopenCreateTensorDescriptor(&filter_desc));
  MIOPEN_CALL(miopenCreateTensorDescriptor(&input_desc));
  MIOPEN_CALL(miopenCreateTensorDescriptor(&output_desc));
}

ConvEntry::~ConvEntry() {
  MIOPEN_CALL(miopenDestroyConvolutionDescriptor(conv_desc));
  MIOPEN_CALL(miopenDestroyTensorDescriptor(filter_desc));
  MIOPEN_CALL(miopenDestroyTensorDescriptor(input_desc));
  MIOPEN_CALL(miopenDestroyTensorDescriptor(output_desc));
  CleanWorkspace();
}

void ConvEntry::UpdateWorkspace(const size_t wsize) {
  if (workspace_size < wsize) {
    if (workspace != nullptr) {
      CleanWorkspace();
    }
    workspace_size = wsize;
    workspace = rocm_api->AllocWorkspace(ctx, workspace_size);
  }
}

void ConvEntry::CleanWorkspace() {
  if (workspace) rocm_api->FreeWorkspace(ctx, workspace);
  workspace_size = 0;
}

}  // namespace miopen
}  // namespace contrib
}  // namespace tvm