Commit e988e435 by Tianqi Chen Committed by GitHub

[METAL] Switch to manual ref counting (#114)

parent 305614a9
......@@ -31,7 +31,7 @@ ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
-Iinclude -Idlpack/include -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
export OBJCFLAGS= -fobjc-arc
export OBJCFLAGS= -fno-objc-arc
ifdef CUDA_PATH
NVCC=$(CUDA_PATH)/bin/nvcc
......
......@@ -81,7 +81,7 @@ void CodeGenStackVM::VisitExpr_(const Load* op) {
this->Push(op->buffer_var);
StackVM::OpCode code = StackVM::GetLoad(Type2TVMType(op->type));
if (const IntImm* index = op->index.as<IntImm>()) {
this->PushOp(code, op->index.as<IntImm>()->value);
this->PushOp(code, index->value);
} else {
this->Push(op->index);
this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes());
......
......@@ -40,6 +40,8 @@ class MetalWorkspace final : public DeviceAPI {
bool initialized_{false};
// the mutex for initialization
std::mutex mutex;
// Destructor
~MetalWorkspace();
// Get command queue for given context.
id<MTLCommandQueue> GetCommandQueue(TVMContext ctx) {
CHECK_EQ(ctx.device_type, kMetal);
......@@ -87,6 +89,7 @@ class MetalThreadEntry {
context.device_id = 0;
context.device_type = static_cast<DLDeviceType>(kMetal);
}
~MetalThreadEntry();
// Get temp buffer with at least size under ctx.
id<MTLBuffer> GetTempBuffer(TVMContext ctx, size_t size);
// get the global workspace
......
......@@ -83,6 +83,15 @@ int GetWarpSize(id<MTLDevice> dev) {
return state.threadExecutionWidth;
}
MetalWorkspace::~MetalWorkspace() {
for (auto x : devices) {
[x release];
}
for (auto x : queues) {
[x release];
}
}
void MetalWorkspace::Init() {
if (initialized_) return;
std::lock_guard<std::mutex>(this->mutex);
......@@ -92,8 +101,8 @@ void MetalWorkspace::Init() {
NSArray<id<MTLDevice>>* devs = MTLCopyAllDevices();
for (size_t i = 0; i < devs.count; ++i) {
id<MTLDevice> d = [devs objectAtIndex:i];
devices.push_back(d);
queues.push_back([d newCommandQueue]);
devices.push_back([d retain]);
queues.push_back([[d newCommandQueue] retain]);
LOG(INFO) << "Intializing Metal device " << i
<< ", name=" << d.name;
warp_size.push_back(GetWarpSize(d));
......@@ -112,13 +121,12 @@ void* MetalWorkspace::AllocDataSpace(
id<MTLBuffer> buf = [
dev newBufferWithLength:size
options:MTLResourceStorageModePrivate];
// retain ARC to keep it alive before release.
return (__bridge_retained void*)(buf);
return (__bridge void*)([buf retain]);
}
void MetalWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
// release the ptr.
CFBridgingRelease(ptr);
CFRelease(ptr);
}
void MetalWorkspace::CopyDataFromTo(const void* from,
......@@ -207,6 +215,12 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
[cb waitUntilCompleted];
}
MetalThreadEntry::~MetalThreadEntry() {
for (auto x : temp_buffer_) {
if (x != nil) [x release];
}
}
id<MTLBuffer> MetalThreadEntry::GetTempBuffer(TVMContext ctx, size_t size) {
if (temp_buffer_.size() <= static_cast<size_t>(ctx.device_id)) {
temp_buffer_.resize(ctx.device_id + 1, nil);
......@@ -214,9 +228,12 @@ id<MTLBuffer> MetalThreadEntry::GetTempBuffer(TVMContext ctx, size_t size) {
if (temp_buffer_[ctx.device_id] == nil ||
temp_buffer_[ctx.device_id].length < size) {
id<MTLDevice> dev = MetalWorkspace::Global()->GetDevice(ctx);
if (temp_buffer_[ctx.device_id] != nil) {
[temp_buffer_[ctx.device_id] release];
}
temp_buffer_[ctx.device_id] = [
dev newBufferWithLength:size
options:MTLStorageModeShared];
[dev newBufferWithLength:size
options:MTLStorageModeShared] retain];
}
return temp_buffer_[ctx.device_id];
}
......
......@@ -104,6 +104,7 @@ class MetalModuleNode final :public runtime::ModuleNode {
<< [[err_msg localizedDescription] UTF8String];
}
}
[e.lib retain];
}
id<MTLFunction> f = [
e.lib
......@@ -121,7 +122,7 @@ class MetalModuleNode final :public runtime::ModuleNode {
// to the resource constraint in kernel, so it is not strictly hold
// Turn of warp aware optimziation for now.
// CHECK_EQ(state.threadExecutionWidth, w->warp_size[device_id]);
e.smap[func_name] = state;
e.smap[func_name] = [state retain];
return state;
}
......@@ -132,6 +133,13 @@ class MetalModuleNode final :public runtime::ModuleNode {
id<MTLLibrary> lib = nil;
// state cache;
std::unordered_map<std::string, id<MTLComputePipelineState> > smap;
~DeviceEntry() {
if (lib != nil) [lib release];
for (auto &&kv : smap) {
[kv.second release];
}
}
};
// the binary data
std::string data_;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment