Commit 0235d283 by Peter Yeh Committed by masahi

[RUNTIME] Add device query for AMD GcnArch (#4341)

* add gcnArch query

* kGcnArch query for cuda is a no-op
parent 1e2c525b
...@@ -42,7 +42,8 @@ enum DeviceAttrKind : int { ...@@ -42,7 +42,8 @@ enum DeviceAttrKind : int {
kDeviceName = 5, kDeviceName = 5,
kMaxClockRate = 6, kMaxClockRate = 6,
kMultiProcessorCount = 7, kMultiProcessorCount = 7,
kMaxThreadDimensions = 8 kMaxThreadDimensions = 8,
kGcnArch = 9
}; };
/*! \brief Number of bytes each allocation must align to */ /*! \brief Number of bytes each allocation must align to */
......
...@@ -174,7 +174,7 @@ inline int DetectROCMComputeVersion(const std::string& target) { ...@@ -174,7 +174,7 @@ inline int DetectROCMComputeVersion(const std::string& target) {
TVMRetValue val; TVMRetValue val;
api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val); api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val);
if (val.operator int() == 1) { if (val.operator int() == 1) {
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val); tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kGcnArch, &val);
return val.operator int(); return val.operator int();
} }
} }
......
...@@ -105,6 +105,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -105,6 +105,7 @@ class CUDADeviceAPI final : public DeviceAPI {
*rv = ss.str(); *rv = ss.str();
return; return;
} }
case kGcnArch: return;
} }
*rv = value; *rv = value;
} }
......
...@@ -63,6 +63,7 @@ void MetalWorkspace::GetAttr( ...@@ -63,6 +63,7 @@ void MetalWorkspace::GetAttr(
case kMultiProcessorCount: return; case kMultiProcessorCount: return;
case kMaxThreadDimensions: return; case kMaxThreadDimensions: return;
case kExist: break; case kExist: break;
case kGcnArch: return;
} }
} }
......
...@@ -114,6 +114,7 @@ void OpenCLWorkspace::GetAttr( ...@@ -114,6 +114,7 @@ void OpenCLWorkspace::GetAttr(
*rv = ss.str(); *rv = ss.str();
break; break;
} }
case kGcnArch: return;
} }
} }
......
...@@ -117,6 +117,7 @@ void OpenGLWorkspace::GetAttr( ...@@ -117,6 +117,7 @@ void OpenGLWorkspace::GetAttr(
case kMaxClockRate: return; case kMaxClockRate: return;
case kMultiProcessorCount: return; case kMultiProcessorCount: return;
case kMaxThreadDimensions: return; case kMaxThreadDimensions: return;
case kGcnArch: return;
} }
} }
......
...@@ -26,9 +26,10 @@ ...@@ -26,9 +26,10 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <hsa/hsa.h> #include <hsa/hsa.h>
#include <tvm/runtime/registry.h>
#include "../../../include/tvm/runtime/device_api.h"
#include "rocm_common.h" #include "rocm_common.h"
namespace tvm { namespace tvm {
...@@ -62,16 +63,17 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -62,16 +63,17 @@ class ROCMDeviceAPI final : public DeviceAPI {
break; break;
} }
case kMaxSharedMemoryPerBlock: return; case kMaxSharedMemoryPerBlock: return;
case kComputeVersion: { case kComputeVersion:
case kDeviceName: return;
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
case kGcnArch: {
hipDeviceProp_t prop; hipDeviceProp_t prop;
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id)); ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
*rv = prop.gcnArch; *rv = prop.gcnArch;
return; return;
} }
case kDeviceName: return;
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
} }
*rv = value; *rv = value;
} }
......
...@@ -398,6 +398,8 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* ...@@ -398,6 +398,8 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
break; break;
case kMaxThreadDimensions: case kMaxThreadDimensions:
break; break;
case kGcnArch:
return;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment