Commit 3121441d by Lianmin Zheng Committed by Tianqi Chen

add device name to context attribute (#1090)

* add device name to context attribute

* update for other backends
parent ee514eac
...@@ -21,6 +21,7 @@ enum DeviceAttrKind : int { ...@@ -21,6 +21,7 @@ enum DeviceAttrKind : int {
kWarpSize = 2, kWarpSize = 2,
kMaxSharedMemoryPerBlock = 3, kMaxSharedMemoryPerBlock = 3,
kComputeVersion = 4, kComputeVersion = 4,
kDeviceName = 5
}; };
/*! \brief Number of bytes each allocation must align to */ /*! \brief Number of bytes each allocation must align to */
......
...@@ -142,7 +142,7 @@ class TVMContext(ctypes.Structure): ...@@ -142,7 +142,7 @@ class TVMContext(ctypes.Structure):
@property @property
def max_shared_memory_per_block(self): def max_shared_memory_per_block(self):
"""Total amount of shared memory per block in bytes""" """Total amount of shared memory per block in bytes."""
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 3) self.device_type, self.device_id, 3)
...@@ -160,6 +160,12 @@ class TVMContext(ctypes.Structure): ...@@ -160,6 +160,12 @@ class TVMContext(ctypes.Structure):
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 4) self.device_type, self.device_id, 4)
@property
def device_name(self):
"""Return the string name of device."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 5)
def sync(self): def sync(self):
"""Synchronize until jobs finished at the context.""" """Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None)) check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
......
...@@ -30,7 +30,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -30,7 +30,7 @@ class CUDADeviceAPI final : public DeviceAPI {
&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)
== cudaSuccess); == cudaSuccess);
break; break;
case kMaxThreadsPerBlock: { case kMaxThreadsPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute( CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)); &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id));
break; break;
...@@ -56,6 +56,12 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -56,6 +56,12 @@ class CUDADeviceAPI final : public DeviceAPI {
*rv = os.str(); *rv = os.str();
return; return;
} }
case kDeviceName: {
cudaDeviceProp props;
CUDA_CALL(cudaGetDeviceProperties(&props, ctx.device_id));
*rv = std::string(props.name);
return;
}
} }
*rv = value; *rv = value;
} }
......
...@@ -41,6 +41,7 @@ void MetalWorkspace::GetAttr( ...@@ -41,6 +41,7 @@ void MetalWorkspace::GetAttr(
} }
case kMaxSharedMemoryPerBlock: return; case kMaxSharedMemoryPerBlock: return;
case kComputeVersion: return; case kComputeVersion: return;
case kDeviceName: return;
case kExist: break; case kExist: break;
} }
} }
......
...@@ -54,6 +54,14 @@ void OpenCLWorkspace::GetAttr( ...@@ -54,6 +54,14 @@ void OpenCLWorkspace::GetAttr(
break; break;
} }
case kComputeVersion: return; case kComputeVersion: return;
case kDeviceName: {
char value[128] = {0};
OPENCL_CALL(clGetDeviceInfo(
devices[index], CL_DEVICE_NAME,
sizeof(value) - 1, value, nullptr));
*rv = std::string(value);
break;
}
case kExist: break; case kExist: break;
} }
} }
......
...@@ -93,9 +93,11 @@ void OpenGLWorkspace::GetAttr( ...@@ -93,9 +93,11 @@ void OpenGLWorkspace::GetAttr(
*rv = 1; *rv = 1;
break; break;
} }
case kMaxSharedMemoryPerBlock: return;
case kComputeVersion: { case kComputeVersion: {
break; break;
} }
case kDeviceName: return;
} }
} }
......
...@@ -51,6 +51,7 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -51,6 +51,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
*rv = prop.gcnArch; *rv = prop.gcnArch;
return; return;
} }
case kDeviceName: return;
} }
*rv = value; *rv = value;
} }
......
...@@ -73,6 +73,7 @@ void VulkanWorkspace::GetAttr( ...@@ -73,6 +73,7 @@ void VulkanWorkspace::GetAttr(
*rv = os.str(); *rv = os.str();
break; break;
} }
case kDeviceName: return;
case kExist: break; case kExist: break;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment