Commit 9c36d9f0 by Tianqi Chen Committed by GitHub

[RUNTIME] Fix Metal runtime compile (#241)

parent 9c954ada
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <vector> #include <vector>
#include "../workspace_pool.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
......
...@@ -12,9 +12,10 @@ namespace tvm { ...@@ -12,9 +12,10 @@ namespace tvm {
namespace runtime { namespace runtime {
namespace metal { namespace metal {
MetalWorkspace* MetalWorkspace::Global() { const std::shared_ptr<MetalWorkspace>& MetalWorkspace::Global() {
static MetalWorkspace inst; static std::shared_ptr<MetalWorkspace> inst =
return &inst; std::make_shared<MetalWorkspace>();
return inst;
} }
void MetalWorkspace::GetAttr( void MetalWorkspace::GetAttr(
...@@ -254,7 +255,7 @@ MetalThreadEntry* MetalThreadEntry::ThreadLocal() { ...@@ -254,7 +255,7 @@ MetalThreadEntry* MetalThreadEntry::ThreadLocal() {
TVM_REGISTER_GLOBAL("device_api.metal") TVM_REGISTER_GLOBAL("device_api.metal")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = MetalWorkspace::Global(); DeviceAPI* ptr = MetalWorkspace::Global().get();
*rv = static_cast<void*>(ptr); *rv = static_cast<void*>(ptr);
}); });
......
...@@ -69,7 +69,7 @@ class MetalModuleNode final :public runtime::ModuleNode { ...@@ -69,7 +69,7 @@ class MetalModuleNode final :public runtime::ModuleNode {
// get a CUfunction from primary context in device_id // get a CUfunction from primary context in device_id
id<MTLComputePipelineState> GetPipelineState( id<MTLComputePipelineState> GetPipelineState(
size_t device_id, const std::string& func_name) { size_t device_id, const std::string& func_name) {
metal::MetalWorkspace* w = metal::MetalWorkspace::Global(); metal::MetalWorkspace* w = metal::MetalWorkspace::Global().get();
CHECK_LT(device_id, w->devices.size()); CHECK_LT(device_id, w->devices.size());
// start lock scope. // start lock scope.
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
...@@ -167,7 +167,7 @@ class MetalWrappedFunc { ...@@ -167,7 +167,7 @@ class MetalWrappedFunc {
size_t num_buffer_args, size_t num_buffer_args,
size_t num_pack_args, size_t num_pack_args,
const std::vector<std::string>& thread_axis_tags) { const std::vector<std::string>& thread_axis_tags) {
w_ = metal::MetalWorkspace::Global(); w_ = metal::MetalWorkspace::Global().get();
m_ = m; m_ = m;
sptr_ = sptr; sptr_ = sptr;
func_name_ = func_name; func_name_ = func_name;
...@@ -255,8 +255,7 @@ Module MetalModuleCreate( ...@@ -255,8 +255,7 @@ Module MetalModuleCreate(
std::string fmt, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::unordered_map<std::string, FunctionInfo> fmap,
std::string source) { std::string source) {
metal::MetalWorkspace* w = metal::MetalWorkspace::Global(); metal::MetalWorkspace::Global()->Init();
w->Init();
std::shared_ptr<MetalModuleNode> n = std::shared_ptr<MetalModuleNode> n =
std::make_shared<MetalModuleNode>(data, fmt, fmap, source); std::make_shared<MetalModuleNode>(data, fmt, fmap, source);
return Module(n); return Module(n);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment