Commit 819728db by Tianqi Chen Committed by GitHub

Update halideIR, add more device query for shared memory (#1087)

parent ca7b8322
Subproject commit e20e5e9abb3aa43147a90a4ffb3e190f62862970 Subproject commit a3698398faff7fec1c0fa4e4479357651382db75
...@@ -70,7 +70,11 @@ Target CreateTarget(const std::string& target_name, ...@@ -70,7 +70,11 @@ Target CreateTarget(const std::string& target_name,
t->thread_warp_size = 32; t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") { } else if (target_name == "rocm" || target_name == "opencl") {
// For now assume rocm schedule for opencl // For now assume rocm schedule for opencl
t->device_type = static_cast<int>(target_name == "rocm" ? kDLROCM : kDLOpenCL); if (target_name == "opencl") {
t->device_type = kDLOpenCL;
} else {
t->device_type = kDLROCM;
}
t->keys_array.push_back(ir::StringImm::make("rocm")); t->keys_array.push_back(ir::StringImm::make("rocm"));
t->keys_array.push_back(ir::StringImm::make("gpu")); t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 256; t->max_num_threads = 256;
...@@ -78,14 +82,21 @@ Target CreateTarget(const std::string& target_name, ...@@ -78,14 +82,21 @@ Target CreateTarget(const std::string& target_name,
t->thread_warp_size = 16; t->thread_warp_size = 16;
} }
} else if (target_name == "metal" || target_name == "vulkan") { } else if (target_name == "metal" || target_name == "vulkan") {
t->device_type = static_cast<int>(target_name == "metal" ? kDLMetal : kDLVulkan); if (target_name == "metal") {
t->device_type = kDLMetal;
} else {
t->device_type = kDLVulkan;
}
t->keys_array.push_back(ir::StringImm::make(target_name)); t->keys_array.push_back(ir::StringImm::make(target_name));
t->keys_array.push_back(ir::StringImm::make("gpu")); t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 256; t->max_num_threads = 256;
} else if (target_name == "opengl") { } else if (target_name == "opengl") {
t->device_type = kDLGPU; t->device_type = kOpenGL;
t->keys_array.push_back(ir::StringImm::make("opengl")); t->keys_array.push_back(ir::StringImm::make("opengl"));
} else if (target_name == "stackvm" || target_name == "ext_dev") { } else if (target_name == "stackvm") {
t->device_type = kDLCPU;
} else if (target_name == "ext_dev") {
t->device_type = kExtDev;
} else { } else {
LOG(ERROR) << "Unknown target name " << target_name; LOG(ERROR) << "Unknown target name " << target_name;
return target::stackvm(); return target::stackvm();
......
...@@ -39,6 +39,7 @@ void MetalWorkspace::GetAttr( ...@@ -39,6 +39,7 @@ void MetalWorkspace::GetAttr(
*rv = 1; *rv = 1;
break; break;
} }
case kMaxSharedMemoryPerBlock: return;
case kComputeVersion: return; case kComputeVersion: return;
case kExist: break; case kExist: break;
} }
......
...@@ -32,9 +32,9 @@ void OpenCLWorkspace::GetAttr( ...@@ -32,9 +32,9 @@ void OpenCLWorkspace::GetAttr(
} }
CHECK_LT(index, devices.size()) CHECK_LT(index, devices.size())
<< "Invalid device id " << index; << "Invalid device id " << index;
size_t value;
switch (kind) { switch (kind) {
case kMaxThreadsPerBlock: { case kMaxThreadsPerBlock: {
size_t value;
OPENCL_CALL(clGetDeviceInfo( OPENCL_CALL(clGetDeviceInfo(
devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE, devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE,
sizeof(size_t), &value, nullptr)); sizeof(size_t), &value, nullptr));
...@@ -45,8 +45,16 @@ void OpenCLWorkspace::GetAttr( ...@@ -45,8 +45,16 @@ void OpenCLWorkspace::GetAttr(
*rv = 1; *rv = 1;
break; break;
} }
case kComputeVersion: return; case kMaxSharedMemoryPerBlock: {
case kExist: break; cl_ulong value;
OPENCL_CALL(clGetDeviceInfo(
devices[index], CL_DEVICE_LOCAL_MEM_SIZE,
sizeof(cl_ulong), &value, nullptr));
*rv = static_cast<int64_t>(value);
break;
}
case kComputeVersion: return;
case kExist: break;
} }
} }
......
...@@ -44,6 +44,7 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -44,6 +44,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
value = 64; value = 64;
break; break;
} }
case kMaxSharedMemoryPerBlock: return;
case kComputeVersion: { case kComputeVersion: {
hipDeviceProp_t prop; hipDeviceProp_t prop;
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id)); ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
......
...@@ -51,6 +51,13 @@ void VulkanWorkspace::GetAttr( ...@@ -51,6 +51,13 @@ void VulkanWorkspace::GetAttr(
*rv = value; *rv = value;
break; break;
} }
case kMaxSharedMemoryPerBlock: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
int64_t value = phy_prop.limits.maxComputeSharedMemorySize;
*rv = value;
break;
}
case kWarpSize: { case kWarpSize: {
*rv = 1; *rv = 1;
break; break;
......
...@@ -15,6 +15,15 @@ TEST(Expr, Basic) { ...@@ -15,6 +15,15 @@ TEST(Expr, Basic) {
} }
TEST(ExprNodeRef, Basic) {
using namespace tvm;
Var x("x");
Expr z = max(x + 1 + 2, 100);
const ir::Max* op = z.as<ir::Max>();
CHECK(op->GetNodeRef().same_as(z));
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment