Commit 022b285d by Peter Yeh Committed by masahi

proper device query through rocm api (#4305)

parent 9e04298b
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -22,14 +22,13 @@ ...@@ -22,14 +22,13 @@
* \file rocm_device_api.cc * \file rocm_device_api.cc
* \brief GPU specific API * \brief GPU specific API
*/ */
#include <tvm/runtime/device_api.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <hsa/hsa.h> #include <hsa/hsa.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include "../../../include/tvm/runtime/device_api.h"
#include "rocm_common.h" #include "rocm_common.h"
namespace tvm { namespace tvm {
...@@ -55,19 +54,57 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -55,19 +54,57 @@ class ROCMDeviceAPI final : public DeviceAPI {
break; break;
} }
case kMaxThreadsPerBlock: { case kMaxThreadsPerBlock: {
value = 1024; ROCM_CALL(hipDeviceGetAttribute(
&value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id));
break; break;
} }
case kWarpSize: { case kWarpSize: {
value = 64; ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize,
ctx.device_id));
break; break;
} }
case kMaxSharedMemoryPerBlock: return; case kMaxSharedMemoryPerBlock: {
case kComputeVersion: ROCM_CALL(hipDeviceGetAttribute(
case kDeviceName: return; &value, hipDeviceAttributeMaxSharedMemoryPerBlock, ctx.device_id));
case kMaxClockRate: return; break;
case kMultiProcessorCount: return; }
case kMaxThreadDimensions: return; case kComputeVersion: {
std::ostringstream os;
ROCM_CALL(hipDeviceGetAttribute(
&value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id));
os << value << ".";
ROCM_CALL(hipDeviceGetAttribute(
&value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id));
os << value;
*rv = os.str();
return;
}
case kDeviceName:
return;
case kMaxClockRate: {
ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate,
ctx.device_id));
break;
}
case kMultiProcessorCount: {
ROCM_CALL(hipDeviceGetAttribute(
&value, hipDeviceAttributeMultiprocessorCount, ctx.device_id));
break;
}
case kMaxThreadDimensions: {
int dims[3];
ROCM_CALL(hipDeviceGetAttribute(
&dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id));
ROCM_CALL(hipDeviceGetAttribute(
&dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id));
ROCM_CALL(hipDeviceGetAttribute(
&dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id));
std::stringstream ss;
ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
*rv = ss.str();
return;
}
case kGcnArch: { case kGcnArch: {
hipDeviceProp_t prop; hipDeviceProp_t prop;
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id)); ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
...@@ -77,14 +114,11 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -77,14 +114,11 @@ class ROCMDeviceAPI final : public DeviceAPI {
} }
*rv = value; *rv = value;
} }
void* AllocDataSpace(TVMContext ctx, void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
size_t nbytes,
size_t alignment,
TVMType type_hint) final { TVMType type_hint) final {
ROCM_CALL(hipSetDevice(ctx.device_id)); ROCM_CALL(hipSetDevice(ctx.device_id));
CHECK_EQ(256 % alignment, 0U) CHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes";
<< "ROCM space is aligned at 256 bytes"; void* ret;
void *ret;
ROCM_CALL(hipMalloc(&ret, nbytes)); ROCM_CALL(hipMalloc(&ret, nbytes));
return ret; return ret;
} }
...@@ -94,14 +128,9 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -94,14 +128,9 @@ class ROCMDeviceAPI final : public DeviceAPI {
ROCM_CALL(hipFree(ptr)); ROCM_CALL(hipFree(ptr));
} }
void CopyDataFromTo(const void* from, void CopyDataFromTo(const void* from, size_t from_offset, void* to,
size_t from_offset, size_t to_offset, size_t size, TVMContext ctx_from,
void* to, TVMContext ctx_to, TVMType type_hint,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) final { TVMStreamHandle stream) final {
hipStream_t hip_stream = static_cast<hipStream_t>(stream); hipStream_t hip_stream = static_cast<hipStream_t>(stream);
from = static_cast<const char*>(from) + from_offset; from = static_cast<const char*>(from) + from_offset;
...@@ -111,14 +140,15 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -111,14 +140,15 @@ class ROCMDeviceAPI final : public DeviceAPI {
if (ctx_from.device_id == ctx_to.device_id) { if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream); GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream);
} else { } else {
hipMemcpyPeerAsync(to, ctx_to.device_id, hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size,
from, ctx_from.device_id, hip_stream);
size, hip_stream);
} }
} else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) { } else if (ctx_from.device_type == kDLROCM &&
ctx_to.device_type == kDLCPU) {
ROCM_CALL(hipSetDevice(ctx_from.device_id)); ROCM_CALL(hipSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream); GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream);
} else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) { } else if (ctx_from.device_type == kDLCPU &&
ctx_to.device_type == kDLROCM) {
ROCM_CALL(hipSetDevice(ctx_to.device_id)); ROCM_CALL(hipSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream); GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream);
} else { } else {
...@@ -132,8 +162,7 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -132,8 +162,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
} }
void SetStream(TVMContext ctx, TVMStreamHandle stream) final { void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
ROCMThreadEntry::ThreadLocal() ROCMThreadEntry::ThreadLocal()->stream = static_cast<hipStream_t>(stream);
->stream = static_cast<hipStream_t>(stream);
} }
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final { void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
...@@ -151,11 +180,8 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -151,11 +180,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
} }
private: private:
static void GPUCopy(const void* from, static void GPUCopy(const void* from, void* to, size_t size,
void* to, hipMemcpyKind kind, hipStream_t stream) {
size_t size,
hipMemcpyKind kind,
hipStream_t stream) {
if (stream != 0) { if (stream != 0) {
ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream)); ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream));
} else { } else {
...@@ -166,19 +192,16 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -166,19 +192,16 @@ class ROCMDeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore; typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;
ROCMThreadEntry::ROCMThreadEntry() ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {}
: pool(kDLROCM, ROCMDeviceAPI::Global()) {
}
ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
return ROCMThreadStore::Get(); return ROCMThreadStore::Get();
} }
TVM_REGISTER_GLOBAL("device_api.rocm") TVM_REGISTER_GLOBAL("device_api.rocm")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = ROCMDeviceAPI::Global().get(); DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr); *rv = static_cast<void*>(ptr);
}); });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment