Commit 819728db by Tianqi Chen Committed by GitHub

Update halideIR, add more device query for shared memory (#1087)

parent ca7b8322
Subproject commit e20e5e9abb3aa43147a90a4ffb3e190f62862970
Subproject commit a3698398faff7fec1c0fa4e4479357651382db75
......@@ -70,7 +70,11 @@ Target CreateTarget(const std::string& target_name,
t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") {
// For now assume rocm schedule for opencl
t->device_type = static_cast<int>(target_name == "rocm" ? kDLROCM : kDLOpenCL);
if (target_name == "opencl") {
t->device_type = kDLOpenCL;
} else {
t->device_type = kDLROCM;
}
t->keys_array.push_back(ir::StringImm::make("rocm"));
t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 256;
......@@ -78,14 +82,21 @@ Target CreateTarget(const std::string& target_name,
t->thread_warp_size = 16;
}
} else if (target_name == "metal" || target_name == "vulkan") {
t->device_type = static_cast<int>(target_name == "metal" ? kDLMetal : kDLVulkan);
if (target_name == "metal") {
t->device_type = kDLMetal;
} else {
t->device_type = kDLVulkan;
}
t->keys_array.push_back(ir::StringImm::make(target_name));
t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 256;
} else if (target_name == "opengl") {
t->device_type = kDLGPU;
t->device_type = kOpenGL;
t->keys_array.push_back(ir::StringImm::make("opengl"));
} else if (target_name == "stackvm" || target_name == "ext_dev") {
} else if (target_name == "stackvm") {
t->device_type = kDLCPU;
} else if (target_name == "ext_dev") {
t->device_type = kExtDev;
} else {
LOG(ERROR) << "Unknown target name " << target_name;
return target::stackvm();
......
......@@ -39,6 +39,7 @@ void MetalWorkspace::GetAttr(
*rv = 1;
break;
}
case kMaxSharedMemoryPerBlock: return;
case kComputeVersion: return;
case kExist: break;
}
......
......@@ -32,9 +32,9 @@ void OpenCLWorkspace::GetAttr(
}
CHECK_LT(index, devices.size())
<< "Invalid device id " << index;
size_t value;
switch (kind) {
case kMaxThreadsPerBlock: {
size_t value;
OPENCL_CALL(clGetDeviceInfo(
devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE,
sizeof(size_t), &value, nullptr));
......@@ -45,6 +45,14 @@ void OpenCLWorkspace::GetAttr(
*rv = 1;
break;
}
case kMaxSharedMemoryPerBlock: {
cl_ulong value;
OPENCL_CALL(clGetDeviceInfo(
devices[index], CL_DEVICE_LOCAL_MEM_SIZE,
sizeof(cl_ulong), &value, nullptr));
*rv = static_cast<int64_t>(value);
break;
}
case kComputeVersion: return;
case kExist: break;
}
......
......@@ -44,6 +44,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
value = 64;
break;
}
case kMaxSharedMemoryPerBlock: return;
case kComputeVersion: {
hipDeviceProp_t prop;
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
......
......@@ -51,6 +51,13 @@ void VulkanWorkspace::GetAttr(
*rv = value;
break;
}
case kMaxSharedMemoryPerBlock: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
int64_t value = phy_prop.limits.maxComputeSharedMemorySize;
*rv = value;
break;
}
case kWarpSize: {
*rv = 1;
break;
......
......@@ -15,6 +15,15 @@ TEST(Expr, Basic) {
}
TEST(ExprNodeRef, Basic) {
using namespace tvm;
Var x("x");
Expr z = max(x + 1 + 2, 100);
const ir::Max* op = z.as<ir::Max>();
CHECK(op->GetNodeRef().same_as(z));
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment