Commit e988e435 by Tianqi Chen Committed by GitHub

[METAL] Switch to manual ref counting (#114)

parent 305614a9
...@@ -31,7 +31,7 @@ ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR) ...@@ -31,7 +31,7 @@ ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
export LDFLAGS = -pthread -lm export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\ export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
-Iinclude -Idlpack/include -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0 -Iinclude -Idlpack/include -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
export OBJCFLAGS= -fobjc-arc export OBJCFLAGS= -fno-objc-arc
ifdef CUDA_PATH ifdef CUDA_PATH
NVCC=$(CUDA_PATH)/bin/nvcc NVCC=$(CUDA_PATH)/bin/nvcc
......
...@@ -81,7 +81,7 @@ void CodeGenStackVM::VisitExpr_(const Load* op) { ...@@ -81,7 +81,7 @@ void CodeGenStackVM::VisitExpr_(const Load* op) {
this->Push(op->buffer_var); this->Push(op->buffer_var);
StackVM::OpCode code = StackVM::GetLoad(Type2TVMType(op->type)); StackVM::OpCode code = StackVM::GetLoad(Type2TVMType(op->type));
if (const IntImm* index = op->index.as<IntImm>()) { if (const IntImm* index = op->index.as<IntImm>()) {
this->PushOp(code, op->index.as<IntImm>()->value); this->PushOp(code, index->value);
} else { } else {
this->Push(op->index); this->Push(op->index);
this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes()); this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes());
......
...@@ -40,6 +40,8 @@ class MetalWorkspace final : public DeviceAPI { ...@@ -40,6 +40,8 @@ class MetalWorkspace final : public DeviceAPI {
bool initialized_{false}; bool initialized_{false};
// the mutex for initialization // the mutex for initialization
std::mutex mutex; std::mutex mutex;
// Destructor
~MetalWorkspace();
// Get command queue for given context. // Get command queue for given context.
id<MTLCommandQueue> GetCommandQueue(TVMContext ctx) { id<MTLCommandQueue> GetCommandQueue(TVMContext ctx) {
CHECK_EQ(ctx.device_type, kMetal); CHECK_EQ(ctx.device_type, kMetal);
...@@ -87,6 +89,7 @@ class MetalThreadEntry { ...@@ -87,6 +89,7 @@ class MetalThreadEntry {
context.device_id = 0; context.device_id = 0;
context.device_type = static_cast<DLDeviceType>(kMetal); context.device_type = static_cast<DLDeviceType>(kMetal);
} }
~MetalThreadEntry();
// Get temp buffer with at least size under ctx. // Get temp buffer with at least size under ctx.
id<MTLBuffer> GetTempBuffer(TVMContext ctx, size_t size); id<MTLBuffer> GetTempBuffer(TVMContext ctx, size_t size);
// get the global workspace // get the global workspace
......
...@@ -83,6 +83,15 @@ int GetWarpSize(id<MTLDevice> dev) { ...@@ -83,6 +83,15 @@ int GetWarpSize(id<MTLDevice> dev) {
return state.threadExecutionWidth; return state.threadExecutionWidth;
} }
MetalWorkspace::~MetalWorkspace() {
for (auto x : devices) {
[x release];
}
for (auto x : queues) {
[x release];
}
}
void MetalWorkspace::Init() { void MetalWorkspace::Init() {
if (initialized_) return; if (initialized_) return;
std::lock_guard<std::mutex>(this->mutex); std::lock_guard<std::mutex>(this->mutex);
...@@ -92,8 +101,8 @@ void MetalWorkspace::Init() { ...@@ -92,8 +101,8 @@ void MetalWorkspace::Init() {
NSArray<id<MTLDevice>>* devs = MTLCopyAllDevices(); NSArray<id<MTLDevice>>* devs = MTLCopyAllDevices();
for (size_t i = 0; i < devs.count; ++i) { for (size_t i = 0; i < devs.count; ++i) {
id<MTLDevice> d = [devs objectAtIndex:i]; id<MTLDevice> d = [devs objectAtIndex:i];
devices.push_back(d); devices.push_back([d retain]);
queues.push_back([d newCommandQueue]); queues.push_back([[d newCommandQueue] retain]);
LOG(INFO) << "Intializing Metal device " << i LOG(INFO) << "Intializing Metal device " << i
<< ", name=" << d.name; << ", name=" << d.name;
warp_size.push_back(GetWarpSize(d)); warp_size.push_back(GetWarpSize(d));
...@@ -112,13 +121,12 @@ void* MetalWorkspace::AllocDataSpace( ...@@ -112,13 +121,12 @@ void* MetalWorkspace::AllocDataSpace(
id<MTLBuffer> buf = [ id<MTLBuffer> buf = [
dev newBufferWithLength:size dev newBufferWithLength:size
options:MTLResourceStorageModePrivate]; options:MTLResourceStorageModePrivate];
// retain ARC to keep it alive before release. return (__bridge void*)([buf retain]);
return (__bridge_retained void*)(buf);
} }
void MetalWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) { void MetalWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
// release the ptr. // release the ptr.
CFBridgingRelease(ptr); CFRelease(ptr);
} }
void MetalWorkspace::CopyDataFromTo(const void* from, void MetalWorkspace::CopyDataFromTo(const void* from,
...@@ -207,6 +215,12 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) { ...@@ -207,6 +215,12 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
[cb waitUntilCompleted]; [cb waitUntilCompleted];
} }
MetalThreadEntry::~MetalThreadEntry() {
for (auto x : temp_buffer_) {
if (x != nil) [x release];
}
}
id<MTLBuffer> MetalThreadEntry::GetTempBuffer(TVMContext ctx, size_t size) { id<MTLBuffer> MetalThreadEntry::GetTempBuffer(TVMContext ctx, size_t size) {
if (temp_buffer_.size() <= static_cast<size_t>(ctx.device_id)) { if (temp_buffer_.size() <= static_cast<size_t>(ctx.device_id)) {
temp_buffer_.resize(ctx.device_id + 1, nil); temp_buffer_.resize(ctx.device_id + 1, nil);
...@@ -214,9 +228,12 @@ id<MTLBuffer> MetalThreadEntry::GetTempBuffer(TVMContext ctx, size_t size) { ...@@ -214,9 +228,12 @@ id<MTLBuffer> MetalThreadEntry::GetTempBuffer(TVMContext ctx, size_t size) {
if (temp_buffer_[ctx.device_id] == nil || if (temp_buffer_[ctx.device_id] == nil ||
temp_buffer_[ctx.device_id].length < size) { temp_buffer_[ctx.device_id].length < size) {
id<MTLDevice> dev = MetalWorkspace::Global()->GetDevice(ctx); id<MTLDevice> dev = MetalWorkspace::Global()->GetDevice(ctx);
if (temp_buffer_[ctx.device_id] != nil) {
[temp_buffer_[ctx.device_id] release];
}
temp_buffer_[ctx.device_id] = [ temp_buffer_[ctx.device_id] = [
dev newBufferWithLength:size [dev newBufferWithLength:size
options:MTLStorageModeShared]; options:MTLStorageModeShared] retain];
} }
return temp_buffer_[ctx.device_id]; return temp_buffer_[ctx.device_id];
} }
......
...@@ -104,6 +104,7 @@ class MetalModuleNode final :public runtime::ModuleNode { ...@@ -104,6 +104,7 @@ class MetalModuleNode final :public runtime::ModuleNode {
<< [[err_msg localizedDescription] UTF8String]; << [[err_msg localizedDescription] UTF8String];
} }
} }
[e.lib retain];
} }
id<MTLFunction> f = [ id<MTLFunction> f = [
e.lib e.lib
...@@ -121,7 +122,7 @@ class MetalModuleNode final :public runtime::ModuleNode { ...@@ -121,7 +122,7 @@ class MetalModuleNode final :public runtime::ModuleNode {
// to the resource constraint in kernel, so it is not strictly hold // to the resource constraint in kernel, so it is not strictly hold
// Turn of warp aware optimziation for now. // Turn of warp aware optimziation for now.
// CHECK_EQ(state.threadExecutionWidth, w->warp_size[device_id]); // CHECK_EQ(state.threadExecutionWidth, w->warp_size[device_id]);
e.smap[func_name] = state; e.smap[func_name] = [state retain];
return state; return state;
} }
...@@ -132,6 +133,13 @@ class MetalModuleNode final :public runtime::ModuleNode { ...@@ -132,6 +133,13 @@ class MetalModuleNode final :public runtime::ModuleNode {
id<MTLLibrary> lib = nil; id<MTLLibrary> lib = nil;
// state cache; // state cache;
std::unordered_map<std::string, id<MTLComputePipelineState> > smap; std::unordered_map<std::string, id<MTLComputePipelineState> > smap;
~DeviceEntry() {
if (lib != nil) [lib release];
for (auto &&kv : smap) {
[kv.second release];
}
}
}; };
// the binary data // the binary data
std::string data_; std::string data_;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment