Commit 12cf7754 by MORITA Kazutaka Committed by Tianqi Chen

[RUNTIME][OPENCL] show correct device type name (#1441)

parent fe0f99b2
...@@ -108,6 +108,8 @@ class OpenCLThreadEntry; ...@@ -108,6 +108,8 @@ class OpenCLThreadEntry;
*/ */
class OpenCLWorkspace : public DeviceAPI { class OpenCLWorkspace : public DeviceAPI {
public: public:
// type key
std::string type_key;
// global platform id // global platform id
cl_platform_id platform_id; cl_platform_id platform_id;
// global platform name // global platform name
...@@ -138,9 +140,10 @@ class OpenCLWorkspace : public DeviceAPI { ...@@ -138,9 +140,10 @@ class OpenCLWorkspace : public DeviceAPI {
} }
} }
// Initialzie the device. // Initialzie the device.
void Init(const std::string& device_type, const std::string& platform_name = ""); void Init(const std::string& type_key, const std::string& device_type,
const std::string& platform_name = "");
virtual void Init() { virtual void Init() {
Init("gpu"); Init("opencl", "gpu");
} }
// Check whether the context is OpenCL or not. // Check whether the context is OpenCL or not.
virtual bool IsOpenCLDevice(TVMContext ctx) { virtual bool IsOpenCLDevice(TVMContext ctx) {
...@@ -240,7 +243,7 @@ class OpenCLModuleNode : public ModuleNode { ...@@ -240,7 +243,7 @@ class OpenCLModuleNode : public ModuleNode {
*/ */
virtual const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace(); virtual const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace();
virtual const char* type_key() const; const char* type_key() const final { return workspace_->type_key.c_str(); }
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
......
...@@ -227,7 +227,8 @@ bool MatchPlatformInfo( ...@@ -227,7 +227,8 @@ bool MatchPlatformInfo(
return param_value.find(value) != std::string::npos; return param_value.find(value) != std::string::npos;
} }
void OpenCLWorkspace::Init(const std::string& device_type, const std::string& platform_name) { void OpenCLWorkspace::Init(const std::string& type_key, const std::string& device_type,
const std::string& platform_name) {
if (initialized_) return; if (initialized_) return;
std::lock_guard<std::mutex> lock(this->mu); std::lock_guard<std::mutex> lock(this->mu);
if (initialized_) return; if (initialized_) return;
...@@ -246,6 +247,7 @@ void OpenCLWorkspace::Init(const std::string& device_type, const std::string& pl ...@@ -246,6 +247,7 @@ void OpenCLWorkspace::Init(const std::string& device_type, const std::string& pl
} }
std::vector<cl_device_id> devices_matched = cl::GetDeviceIDs(platform_id, device_type); std::vector<cl_device_id> devices_matched = cl::GetDeviceIDs(platform_id, device_type);
if (devices_matched.size() > 0) { if (devices_matched.size() > 0) {
this->type_key = type_key;
this->platform_id = platform_id; this->platform_id = platform_id;
this->platform_name = cl::GetPlatformInfo(platform_id, CL_PLATFORM_NAME); this->platform_name = cl::GetPlatformInfo(platform_id, CL_PLATFORM_NAME);
this->device_type = device_type; this->device_type = device_type;
...@@ -271,7 +273,7 @@ void OpenCLWorkspace::Init(const std::string& device_type, const std::string& pl ...@@ -271,7 +273,7 @@ void OpenCLWorkspace::Init(const std::string& device_type, const std::string& pl
this->queues.push_back( this->queues.push_back(
clCreateCommandQueue(this->context, did, 0, &err_code)); clCreateCommandQueue(this->context, did, 0, &err_code));
OPENCL_CHECK_ERROR(err_code); OPENCL_CHECK_ERROR(err_code);
LOG(INFO) << "opencl(" << i LOG(INFO) << type_key << "(" << i
<< ")=\'" << cl::GetDeviceInfo(did, CL_DEVICE_NAME) << ")=\'" << cl::GetDeviceInfo(did, CL_DEVICE_NAME)
<< "\' cl_device_id=" << did; << "\' cl_device_id=" << did;
} }
......
...@@ -100,10 +100,6 @@ const std::shared_ptr<cl::OpenCLWorkspace>& OpenCLModuleNode::GetGlobalWorkspace ...@@ -100,10 +100,6 @@ const std::shared_ptr<cl::OpenCLWorkspace>& OpenCLModuleNode::GetGlobalWorkspace
return cl::OpenCLWorkspace::Global(); return cl::OpenCLWorkspace::Global();
} }
const char* OpenCLModuleNode::type_key() const {
return "opencl";
}
PackedFunc OpenCLModuleNode::GetFunction( PackedFunc OpenCLModuleNode::GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const std::shared_ptr<ModuleNode>& sptr_to_self) {
......
...@@ -20,7 +20,7 @@ const std::shared_ptr<OpenCLWorkspace>& SDAccelWorkspace::Global() { ...@@ -20,7 +20,7 @@ const std::shared_ptr<OpenCLWorkspace>& SDAccelWorkspace::Global() {
} }
void SDAccelWorkspace::Init() { void SDAccelWorkspace::Init() {
OpenCLWorkspace::Init("accelerator", "Xilinx"); OpenCLWorkspace::Init("sdaccel", "accelerator", "Xilinx");
} }
bool SDAccelWorkspace::IsOpenCLDevice(TVMContext ctx) { bool SDAccelWorkspace::IsOpenCLDevice(TVMContext ctx) {
......
...@@ -21,17 +21,12 @@ class SDAccelModuleNode : public OpenCLModuleNode { ...@@ -21,17 +21,12 @@ class SDAccelModuleNode : public OpenCLModuleNode {
std::string source) std::string source)
: OpenCLModuleNode(data, fmt, fmap, source) {} : OpenCLModuleNode(data, fmt, fmap, source) {}
const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace() final; const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace() final;
const char* type_key() const final;
}; };
const std::shared_ptr<cl::OpenCLWorkspace>& SDAccelModuleNode::GetGlobalWorkspace() { const std::shared_ptr<cl::OpenCLWorkspace>& SDAccelModuleNode::GetGlobalWorkspace() {
return cl::SDAccelWorkspace::Global(); return cl::SDAccelWorkspace::Global();
} }
const char* SDAccelModuleNode::type_key() const {
return "sdaccel";
}
Module SDAccelModuleCreate( Module SDAccelModuleCreate(
std::string data, std::string data,
std::string fmt, std::string fmt,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment