/*! * Copyright (c) 2017 by Contributors * \file Use external miopen utils function */ #include "miopen_utils.h" #include <dmlc/thread_local.h> #include <tvm/runtime/registry.h> #include <vector> #include <string> namespace tvm { namespace contrib { namespace miopen { std::string miopenGetErrorString(int error_code) { const std::vector<std::string> mio_err{ "StatusSuccess ", "StatusNotInitialized ", "StatusInvalidValue ", "StatusBadParm ", "StatusAllocFailed ", "StatusInternalError ", "StatusNotImplemented ", "StatusUnknownError "}; return mio_err[error_code]; } // MiopenThreadEntry MIOpenThreadEntry::MIOpenThreadEntry() { auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream; auto func = runtime::Registry::Get("device_api.rocm"); void *ret = (*func)(); rocm_api = static_cast<runtime::DeviceAPI*>(ret); MIOPEN_CALL(miopenCreate(&handle)); MIOPEN_CALL(miopenSetStream(handle, stream)); conv_entry.rocm_api = rocm_api; } MIOpenThreadEntry::~MIOpenThreadEntry() { MIOPEN_CALL(miopenDestroy(handle)); } typedef dmlc::ThreadLocalStore<MIOpenThreadEntry> MIOpenThreadStore; MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() { return MIOpenThreadStore::Get(); } // ConvEntry ConvEntry::ConvEntry() { MIOPEN_CALL(miopenCreateConvolutionDescriptor(&conv_desc)); MIOPEN_CALL(miopenCreateTensorDescriptor(&filter_desc)); MIOPEN_CALL(miopenCreateTensorDescriptor(&input_desc)); MIOPEN_CALL(miopenCreateTensorDescriptor(&output_desc)); } ConvEntry::~ConvEntry() { MIOPEN_CALL(miopenDestroyConvolutionDescriptor(conv_desc)); MIOPEN_CALL(miopenDestroyTensorDescriptor(filter_desc)); MIOPEN_CALL(miopenDestroyTensorDescriptor(input_desc)); MIOPEN_CALL(miopenDestroyTensorDescriptor(output_desc)); CleanWorkspace(); } void ConvEntry::UpdateWorkspace(const size_t wsize) { if (workspace_size < wsize) { if (workspace != nullptr) { CleanWorkspace(); } workspace_size = wsize; workspace = rocm_api->AllocWorkspace(ctx, workspace_size); } } void ConvEntry::CleanWorkspace() { if (workspace) rocm_api->FreeWorkspace(ctx, workspace); workspace_size = 0; } } // namespace miopen } // namespace contrib } // namespace tvm