Commit 12cf7754 by MORITA Kazutaka Committed by Tianqi Chen

[RUNTIME][OPENCL] show correct device type name (#1441)

parent fe0f99b2
......@@ -108,6 +108,8 @@ class OpenCLThreadEntry;
*/
class OpenCLWorkspace : public DeviceAPI {
public:
// type key
std::string type_key;
// global platform id
cl_platform_id platform_id;
// global platform name
......@@ -138,9 +140,10 @@ class OpenCLWorkspace : public DeviceAPI {
}
}
// Initialzie the device.
void Init(const std::string& device_type, const std::string& platform_name = "");
void Init(const std::string& type_key, const std::string& device_type,
const std::string& platform_name = "");
virtual void Init() {
Init("gpu");
Init("opencl", "gpu");
}
// Check whether the context is OpenCL or not.
virtual bool IsOpenCLDevice(TVMContext ctx) {
......@@ -240,7 +243,7 @@ class OpenCLModuleNode : public ModuleNode {
*/
virtual const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace();
virtual const char* type_key() const;
const char* type_key() const final { return workspace_->type_key.c_str(); }
PackedFunc GetFunction(
const std::string& name,
......
......@@ -227,7 +227,8 @@ bool MatchPlatformInfo(
return param_value.find(value) != std::string::npos;
}
void OpenCLWorkspace::Init(const std::string& device_type, const std::string& platform_name) {
void OpenCLWorkspace::Init(const std::string& type_key, const std::string& device_type,
const std::string& platform_name) {
if (initialized_) return;
std::lock_guard<std::mutex> lock(this->mu);
if (initialized_) return;
......@@ -246,6 +247,7 @@ void OpenCLWorkspace::Init(const std::string& device_type, const std::string& pl
}
std::vector<cl_device_id> devices_matched = cl::GetDeviceIDs(platform_id, device_type);
if (devices_matched.size() > 0) {
this->type_key = type_key;
this->platform_id = platform_id;
this->platform_name = cl::GetPlatformInfo(platform_id, CL_PLATFORM_NAME);
this->device_type = device_type;
......@@ -271,7 +273,7 @@ void OpenCLWorkspace::Init(const std::string& device_type, const std::string& pl
this->queues.push_back(
clCreateCommandQueue(this->context, did, 0, &err_code));
OPENCL_CHECK_ERROR(err_code);
LOG(INFO) << "opencl(" << i
LOG(INFO) << type_key << "(" << i
<< ")=\'" << cl::GetDeviceInfo(did, CL_DEVICE_NAME)
<< "\' cl_device_id=" << did;
}
......
......@@ -100,10 +100,6 @@ const std::shared_ptr<cl::OpenCLWorkspace>& OpenCLModuleNode::GetGlobalWorkspace
return cl::OpenCLWorkspace::Global();
}
const char* OpenCLModuleNode::type_key() const {
return "opencl";
}
PackedFunc OpenCLModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
......
......@@ -20,7 +20,7 @@ const std::shared_ptr<OpenCLWorkspace>& SDAccelWorkspace::Global() {
}
void SDAccelWorkspace::Init() {
OpenCLWorkspace::Init("accelerator", "Xilinx");
OpenCLWorkspace::Init("sdaccel", "accelerator", "Xilinx");
}
bool SDAccelWorkspace::IsOpenCLDevice(TVMContext ctx) {
......
......@@ -21,17 +21,12 @@ class SDAccelModuleNode : public OpenCLModuleNode {
std::string source)
: OpenCLModuleNode(data, fmt, fmap, source) {}
const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace() final;
const char* type_key() const final;
};
const std::shared_ptr<cl::OpenCLWorkspace>& SDAccelModuleNode::GetGlobalWorkspace() {
return cl::SDAccelWorkspace::Global();
}
const char* SDAccelModuleNode::type_key() const {
return "sdaccel";
}
Module SDAccelModuleCreate(
std::string data,
std::string fmt,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment