Commit 2e17e850 by Lianmin Zheng Committed by Tianqi Chen

add query for shared memory size (#1083)

parent 3d67ea17
......@@ -19,7 +19,8 @@ enum DeviceAttrKind : int {
kExist = 0,
kMaxThreadsPerBlock = 1,
kWarpSize = 2,
kComputeVersion = 3,
kMaxSharedMemoryPerBlock = 3,
kComputeVersion = 4,
};
/*! \brief Number of bytes each allocation must align to */
......
......@@ -141,6 +141,12 @@ class TVMContext(ctypes.Structure):
self.device_type, self.device_id, 2)
@property
def max_shared_memory_per_block(self):
"""Total amount of shared memory per block in bytes"""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 3)
@property
def compute_version(self):
"""Get compute verison number in string.
......@@ -152,7 +158,7 @@ class TVMContext(ctypes.Structure):
The version string in `major.minor` format.
"""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 3)
self.device_type, self.device_id, 4)
def sync(self):
"""Synchronize until jobs finished at the context."""
......
......@@ -40,6 +40,11 @@ class CUDADeviceAPI final : public DeviceAPI {
&value, cudaDevAttrWarpSize, ctx.device_id));
break;
}
case kMaxSharedMemoryPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id));
break;
}
case kComputeVersion: {
std::ostringstream os;
CUDA_CALL(cudaDeviceGetAttribute(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment