Commit 90809381 by Aditya Atluri Committed by Tianqi Chen

[RUNTIME] v2: runtime support for rocm (#386)

* v2: runtime support for rocm

* fixed coding space errors

* removed kROCM from c_runtime_api.h
parent 6f54656f
...@@ -87,12 +87,12 @@ else ...@@ -87,12 +87,12 @@ else
endif endif
ifdef ROCM_PATH ifdef ROCM_PATH
CFLAGS += -I$(ROCM_PATH)/hip/include CFLAGS += -I$(ROCM_PATH)/include
LDFLAGS += -L$(ROCM_PATH)/hip/lib LDFLAGS += -L$(ROCM_PATH)/lib
endif endif
ifeq ($(USE_ROCM), 1) ifeq ($(USE_ROCM), 1)
CFLAGS += -DTVM_ROCM_RUNTIME=1 CFLAGS += -DTVM_ROCM_RUNTIME=1 -D__HIP_PLATFORM_HCC__=1
LDFLAGS += -lhip_hcc LDFLAGS += -lhip_hcc
RUNTIME_DEP += $(ROCM_OBJ) RUNTIME_DEP += $(ROCM_OBJ)
else else
......
...@@ -58,7 +58,7 @@ typedef enum { ...@@ -58,7 +58,7 @@ typedef enum {
/*! \brief Metal buffer. */ /*! \brief Metal buffer. */
kMetal = 8, kMetal = 8,
/*! \brief Simulated on board RAM */ /*! \brief Simulated on board RAM */
kVPI = 9 kVPI = 9,
} TVMDeviceExtType; } TVMDeviceExtType;
/*! /*!
......
...@@ -35,6 +35,9 @@ USE_CUDA = 0 ...@@ -35,6 +35,9 @@ USE_CUDA = 0
# if you have already add them to environment variable. # if you have already add them to environment variable.
# CUDA_PATH = /usr/local/cuda # CUDA_PATH = /usr/local/cuda
# ROCM
USE_ROCM = 0
# whether enable OpenCL during compile # whether enable OpenCL during compile
USE_OPENCL = 0 USE_OPENCL = 0
......
...@@ -16,7 +16,7 @@ from . import node ...@@ -16,7 +16,7 @@ from . import node
from . import ir_builder from . import ir_builder
from . import ndarray as nd from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm
from ._ffi.runtime_ctypes import TypeCode from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import Function from ._ffi.function import Function
......
...@@ -97,7 +97,8 @@ class TVMContext(ctypes.Structure): ...@@ -97,7 +97,8 @@ class TVMContext(ctypes.Structure):
2 : 'gpu', 2 : 'gpu',
4 : 'opencl', 4 : 'opencl',
8 : 'metal', 8 : 'metal',
9 : 'vpi' 9 : 'vpi',
10: 'rocm'
} }
STR2MASK = { STR2MASK = {
'cpu': 1, 'cpu': 1,
...@@ -106,7 +107,8 @@ class TVMContext(ctypes.Structure): ...@@ -106,7 +107,8 @@ class TVMContext(ctypes.Structure):
'cl': 4, 'cl': 4,
'opencl': 4, 'opencl': 4,
'metal': 8, 'metal': 8,
'vpi': 9 'vpi': 9,
'rocm': 10
} }
def __init__(self, device_type, device_id): def __init__(self, device_type, device_id):
super(TVMContext, self).__init__() super(TVMContext, self).__init__()
......
...@@ -56,6 +56,21 @@ def gpu(dev_id=0): ...@@ -56,6 +56,21 @@ def gpu(dev_id=0):
""" """
return TVMContext(2, dev_id) return TVMContext(2, dev_id)
def rocm(dev_id=0):
"""Construct a ROCM device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(10, dev_id)
def opencl(dev_id=0): def opencl(dev_id=0):
"""Construct a OpenCL device """Construct a OpenCL device
......
...@@ -30,6 +30,7 @@ inline std::string DeviceName(int type) { ...@@ -30,6 +30,7 @@ inline std::string DeviceName(int type) {
case kOpenCL: return "opencl"; case kOpenCL: return "opencl";
case kMetal: return "metal"; case kMetal: return "metal";
case kVPI: return "vpi"; case kVPI: return "vpi";
case kROCM: return "rocm";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
} }
} }
......
/*!
* Copyright (c) 2017 by Contributors
* \file rocm_common.h
* \brief Common utilities for ROCM
*/
#ifndef TVM_RUNTIME_ROCM_ROCM_COMMON_H_
#define TVM_RUNTIME_ROCM_ROCM_COMMON_H_
#include <tvm/runtime/config.h>
#include <tvm/runtime/packed_func.h>
#include <string>
#if TVM_ROCM_RUNTIME
#include <hip/hip_runtime_api.h>
#include "../workspace_pool.h"
namespace tvm {
namespace runtime {
#define ROCM_DRIVER_CALL(x) \
{ \
hipError_t result = x; \
if (result != hipSuccess && result != hipErrorDeinitialized) { \
LOG(FATAL) \
<< "ROCM HIP Error: " #x " failed with error: " << hipGetErrorString(result); \
} \
}
#define ROCM_CALL(func) \
{ \
hipError_t e = (func); \
CHECK(e == hipSuccess) \
<< "ROCM HIP: " << hipGetErrorString(e); \
}
/*! \brief Thread local workspace */
class ROCMThreadEntry {
public:
/*! \brief The hip stream */
hipStream_t stream{nullptr};
/*! \brief thread local pool*/
WorkspacePool pool;
/*! \brief constructor */
ROCMThreadEntry();
// get the threadlocal workspace
static ROCMThreadEntry* ThreadLocal();
};
} // namespace runtime
} // namespace tvm
#endif // TVM_ROCM_RUNTIME
#endif // TVM_RUNTIME_ROCM_ROCM_COMMON_H_
/*!
* Copyright (c) 2017 by Contributors
* \file rocm_device_api.cc
* \brief GPU specific API
*/
#include <tvm/runtime/config.h>
#include <tvm/runtime/device_api.h>
#if TVM_ROCM_RUNTIME
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <hip/hip_runtime_api.h>
#include <hsa/hsa.h>
#include "./rocm_common.h"
namespace tvm {
namespace runtime {
class ROCMDeviceAPI final : public DeviceAPI {
public:
void SetDevice(TVMContext ctx) final {
ROCM_CALL(hipSetDevice(ctx.device_id));
}
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
int value = 0;
switch (kind) {
case kExist: {
if (hsa_init() == HSA_STATUS_SUCCESS) {
int dev;
ROCM_CALL(hipGetDeviceCount(&dev));
value = dev > ctx.device_id ? 1 : 0;
hsa_shut_down();
} else {
value = 0;
}
break;
}
case kMaxThreadsPerBlock: {
value = 1024;
break;
}
case kWarpSize: {
value = 64;
break;
}
}
*rv = value;
}
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final {
ROCM_CALL(hipSetDevice(ctx.device_id));
CHECK_EQ(256 % alignment, 0U)
<< "ROCM space is aligned at 256 bytes";
void *ret;
ROCM_CALL(hipMalloc(&ret, size));
return ret;
}
void FreeDataSpace(TVMContext ctx, void* ptr) final {
ROCM_CALL(hipSetDevice(ctx.device_id));
ROCM_CALL(hipFree(ptr));
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) final {
hipStream_t hip_stream = static_cast<hipStream_t>(stream);
from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset;
if (ctx_from.device_type == kROCM && ctx_to.device_type == kROCM) {
ROCM_CALL(hipSetDevice(ctx_from.device_id));
if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream);
} else {
hipMemcpyPeerAsync(to, ctx_to.device_id,
from, ctx_from.device_id,
size, hip_stream);
}
} else if (ctx_from.device_type == kROCM && ctx_to.device_type == kCPU) {
ROCM_CALL(hipSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream);
} else if (ctx_from.device_type == kCPU && ctx_to.device_type == kROCM) {
ROCM_CALL(hipSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream);
} else {
LOG(FATAL) << "expect copy from/to GPU or between GPU";
}
}
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
ROCM_CALL(hipSetDevice(ctx.device_id));
ROCM_CALL(hipStreamSynchronize(static_cast<hipStream_t>(stream)));
}
void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
ROCMThreadEntry::ThreadLocal()
->stream = static_cast<hipStream_t>(stream);
}
void* AllocWorkspace(TVMContext ctx, size_t size) final {
return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
void FreeWorkspace(TVMContext ctx, void* data) final {
ROCMThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}
static const std::shared_ptr<ROCMDeviceAPI>& Global() {
static std::shared_ptr<ROCMDeviceAPI> inst =
std::make_shared<ROCMDeviceAPI>();
return inst;
}
private:
static void GPUCopy(const void* from,
void* to,
size_t size,
hipMemcpyKind kind,
hipStream_t stream) {
if (stream != 0) {
ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream));
} else {
ROCM_CALL(hipMemcpy(to, from, size, kind));
}
}
};
typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;
ROCMThreadEntry::ROCMThreadEntry()
: pool(kGPU, ROCMDeviceAPI::Global()) {
}
ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
return ROCMThreadStore::Get();
}
TVM_REGISTER_GLOBAL("device_api.rocm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
} // namespace tvm
#endif // TVM_ROCM_RUNTIME
/*!
* Copyright (c) 2017 by Contributors
* \file rocm_module.cc
*/
#include "./rocm_module.h"
#if TVM_ROCM_RUNTIME
#include <tvm/runtime/registry.h>
#include <hip/hip_runtime_api.h>
#include <vector>
#include <array>
#include <string>
#include <mutex>
#include "./rocm_common.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
#include "../meta_data.h"
#include "../file_util.h"
namespace tvm {
namespace runtime {
// Module to support thread-safe multi-GPU execution.
// hipModule_t is a per-GPU module
// The runtime will contain a per-device module table
// The modules will be lazily loaded
class ROCMModuleNode : public runtime::ModuleNode {
public:
explicit ROCMModuleNode(std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string hip_source)
: data_(data), fmt_(fmt), fmap_(fmap), hip_source_(hip_source) {
std::fill(module_.begin(), module_.end(), nullptr);
}
// destructor
~ROCMModuleNode() {
for (size_t i = 0; i < module_.size(); ++i) {
if (module_[i] != nullptr) {
ROCM_CALL(hipSetDevice(static_cast<int>(i)));
ROCM_DRIVER_CALL(hipModuleUnload(module_[i]));
}
}
}
const char* type_key() const final {
return "hip";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
stream->Write(data_);
}
// get a CUfunction from primary context in device_id
hipFunction_t GetFunc(int device_id, const std::string& func_name) {
std::lock_guard<std::mutex> lock(mutex_);
// must recheck under the lock scope
if (module_[device_id] == nullptr) {
ROCM_DRIVER_CALL(hipModuleLoadData(&(module_[device_id]), data_.c_str()));
}
hipFunction_t func;
hipError_t result = hipModuleGetFunction(&func, module_[device_id], func_name.c_str());
if (result != hipSuccess) {
LOG(FATAL)
<< "ROCMError: hipModuleGetFunction " << func_name
<< " failed with error: " << hipGetErrorString(result);
}
return func;
}
// get a global var from primary context in device_id
hipDeviceptr_t GetGlobal(int device_id,
const std::string& global_name,
size_t expect_nbytes) {
std::lock_guard<std::mutex> lock(mutex_);
// must recheck under the lock scope
if (module_[device_id] == nullptr) {
ROCM_DRIVER_CALL(hipModuleLoadData(&(module_[device_id]), data_.c_str()));
}
hipDeviceptr_t global = nullptr;
size_t nbytes = 0;
hipError_t result = hipSuccess;
// ROCM doesn't support hipModuleGetGlobal yet.
// hipError_t result = hipModuleGetGlobal(&global, &nbytes,
// module_[device_id], global_name.c_str());
CHECK_EQ(nbytes, expect_nbytes);
if (result != hipSuccess) {
LOG(FATAL)
<< "ROCMError: hipModuleGetGlobal " << global_name
<< " failed with error: " << hipGetErrorString(result);
}
return global;
}
private:
// the binary data
std::string data_;
// The format
std::string fmt_;
// function information table.
std::unordered_map<std::string, FunctionInfo> fmap_;
// The hip source.
std::string hip_source_;
// the internal modules per GPU, to be lazily initialized.
std::array<hipModule_t, kMaxNumGPUs> module_;
// internal mutex when updating the module
std::mutex mutex_;
};
// a wrapped function class to get packed fucn.
class ROCMWrappedFunc {
public:
// initialize the ROCM function.
void Init(ROCMModuleNode* m,
std::shared_ptr<ModuleNode> sptr,
const std::string& func_name,
size_t num_void_args,
const std::vector<std::string>& thread_axis_tags) {
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
std::fill(fcache_.begin(), fcache_.end(), nullptr);
thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args,
TVMRetValue* rv,
void** void_args) const {
int device_id;
ROCM_CALL(hipGetDevice(&device_id));
if (fcache_[device_id] == nullptr) {
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
}
hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
// HIP supports only extra_args.
ROCM_DRIVER_CALL(hipModuleLaunchKernel(
fcache_[device_id],
wl.grid_dim(0),
wl.grid_dim(1),
wl.grid_dim(2),
wl.block_dim(0),
wl.block_dim(1),
wl.block_dim(2),
0, strm, void_args, 0));
}
private:
// internal module
ROCMModuleNode* m_;
// the resource holder
std::shared_ptr<ModuleNode> sptr_;
// The name of the function.
std::string func_name_;
// Device function cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<hipFunction_t, kMaxNumGPUs> fcache_;
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
};
PackedFunc ROCMModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
ROCMWrappedFunc f;
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
return PackFuncVoidAddr(f, info.arg_types);
}
Module ROCMModuleCreate(
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string hip_source) {
std::shared_ptr<ROCMModuleNode> n =
std::make_shared<ROCMModuleNode>(data, fmt, fmap, hip_source);
return Module(n);
}
Module ROCMModuleLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt;
stream->Read(&fmt);
stream->Read(&fmap);
stream->Read(&data);
return ROCMModuleCreate(data, fmt, fmap, std::string());
}
TVM_REGISTER_GLOBAL("module.loadbinary_hsaco")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = ROCMModuleLoadBinary(args[0]);
});
} // namespace runtime
} // namespace tvm
#endif // TVM_ROCM_RUNTIME
/*!
* Copyright (c) 2017 by Contributors
* \file rocm_module.h
* \brief Execution handling of ROCM kernels
*/
#ifndef TVM_RUNTIME_ROCM_ROCM_MODULE_H_
#define TVM_RUNTIME_ROCM_ROCM_MODULE_H_
#include <tvm/runtime/config.h>
#include <tvm/runtime/module.h>
#include <memory>
#include <vector>
#include <string>
#include "../meta_data.h"
namespace tvm {
namespace runtime {
/*! \brief Maximum number of GPU supported in ROCMModule */
static constexpr const int kMaxNumGPUs = 32;
/*!
* \brief create a rocm module from data.
*
* \param data The module data, can be hsaco
* \param fmt The format of the data, can be "hsaco"
* \param fmap The map function information map of each function.
* \param rocm_source Optional, rocm source file
*/
Module ROCMModuleCreate(
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string rocm_source);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_ROCM_ROCM_MODULE_H_
...@@ -6,6 +6,7 @@ def enabled_ctx_list(): ...@@ -6,6 +6,7 @@ def enabled_ctx_list():
('gpu', tvm.gpu(0)), ('gpu', tvm.gpu(0)),
('cl', tvm.opencl(0)), ('cl', tvm.opencl(0)),
('metal', tvm.metal(0)), ('metal', tvm.metal(0)),
('rocm', tvm.rocm(0)),
('vpi', tvm.vpi(0))] ('vpi', tvm.vpi(0))]
for k, v in ctx_list: for k, v in ctx_list:
assert tvm.context(k, 0) == v assert tvm.context(k, 0) == v
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment