metal_device_api.mm 8.66 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/*!
 *  Copyright (c) 2017 by Contributors
 * \file metal_device_api.mm
 */
#include "./metal_common.h"

#if TVM_METAL_RUNTIME
#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h>

namespace tvm {
namespace runtime {
namespace metal {

15 16 17 18
const std::shared_ptr<MetalWorkspace>& MetalWorkspace::Global() {
  static std::shared_ptr<MetalWorkspace> inst =
      std::make_shared<MetalWorkspace>();
  return inst;
19 20 21
}

void MetalWorkspace::GetAttr(
22
    TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
23
  this->Init();
24
  size_t index = static_cast<size_t>(ctx.device_id);
25 26 27 28 29 30 31 32 33
  if (kind == kExist) {
    *rv = int(index< devices.size());
    return;
  }
  CHECK_LT(index, devices.size())
      << "Invalid device id " << index;
  switch (kind) {
    case kMaxThreadsPerBlock: {
      *rv = static_cast<int>(
34
          [devices[ctx.device_id] maxThreadsPerThreadgroup].width);
35 36 37 38 39 40 41
      break;
    }
    case kWarpSize: {
      // Set warp size to be 1 for safty reason.
      *rv = 1;
      break;
    }
42
    case kComputeVersion: return;
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    case kExist: break;
  }
}

static const char* kDummyKernel = R"A0B0(
using namespace metal;
// Simple copy kernel
// Just to get threadExecutionWidth from current Metal API.
kernel void CopyKernel(
  device float* dst [[buffer(0)]],
  device float* src [[buffer(1)]],
  ushort2 gid[[thread_position_in_grid]]) {
  dst[gid.x] = src[gid.x];
}
)A0B0";

// Hack to get Warp size from device.
// Note that in Metal
// state.threadExecutionWidth can vary per kernel
// maybe due to resource constraint.
// so state.threadExecutionWidth can be smaller than warp size
// For safe issue, turn off warp-aware optimization for now
// But we keep this code.
int GetWarpSize(id<MTLDevice> dev) {
  NSError* error_msg = nil;
  id<MTLLibrary> lib =
      [dev
        newLibraryWithSource:
          [NSString stringWithUTF8String:kDummyKernel]
        options:nil
        error:&error_msg];
74
  CHECK(lib != nil) << [[error_msg localizedDescription] UTF8String];
75 76 77 78 79 80 81 82 83
  id<MTLFunction> f =
      [lib
        newFunctionWithName:
          [NSString stringWithUTF8String:"CopyKernel"]];
  CHECK(f!= nil);
  id<MTLComputePipelineState> state =
      [dev
        newComputePipelineStateWithFunction:f
        error:&error_msg];
84
  CHECK(state != nil) << [[error_msg localizedDescription] UTF8String];
Tianqi Chen committed
85
  return static_cast<int>(state.threadExecutionWidth);
86 87
}

88 89 90 91 92 93 94 95 96
MetalWorkspace::~MetalWorkspace() {
  for (auto x : devices) {
    [x release];
  }
  for (auto x : queues) {
    [x release];
  }
}

97 98 99 100 101 102
void MetalWorkspace::Init() {
  if (initialized_) return;
  std::lock_guard<std::mutex>(this->mutex);
  if (initialized_) return;
  initialized_ = true;
  if (devices.size() != 0) return;
Tianqi Chen committed
103 104 105
#if TARGET_OS_IPHONE
    // on iPhone
    id<MTLDevice> d = MTLCreateSystemDefaultDevice();
106 107
    devices.push_back([d retain]);
    queues.push_back([[d newCommandQueue] retain]);
Tianqi Chen committed
108 109 110 111 112 113 114 115 116 117 118
#else
    NSArray<id<MTLDevice>>* devs = MTLCopyAllDevices();
    for (size_t i = 0; i < devs.count; ++i) {
      id<MTLDevice> d = [devs objectAtIndex:i];
      devices.push_back([d retain]);
      queues.push_back([[d newCommandQueue] retain]);
      LOG(INFO) << "Intializing Metal device " << i
                <<  ", name=" << d.name;
      warp_size.push_back(GetWarpSize(d));
    }
#endif
119 120
}

121 122
void MetalWorkspace::SetDevice(TVMContext ctx) {
  MetalThreadEntry::ThreadLocal()->context.device_id = ctx.device_id;
123 124 125 126 127 128 129 130 131 132
}

void* MetalWorkspace::AllocDataSpace(
    TVMContext ctx, size_t size, size_t alignment) {
  this->Init();
  id<MTLDevice> dev = GetDevice(ctx);
  // allocate buffer in GPU only mode.
  id<MTLBuffer> buf = [
      dev newBufferWithLength:size
          options:MTLResourceStorageModePrivate];
Tianqi Chen committed
133
  CHECK(buf != nil);
134
  return (__bridge void*)([buf retain]);
135 136 137 138
}

void MetalWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
  // release the ptr.
139
  CFRelease(ptr);
140 141 142 143 144 145 146 147 148 149 150 151 152
}

void MetalWorkspace::CopyDataFromTo(const void* from,
                                    size_t from_offset,
                                    void* to,
                                    size_t to_offset,
                                    size_t size,
                                    TVMContext ctx_from,
                                    TVMContext ctx_to,
                                    TVMStreamHandle stream) {
  this->Init();
  CHECK(stream == nullptr);
  TVMContext ctx = ctx_from;
153
  if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
154 155 156 157 158
  id<MTLCommandQueue> queue = GetCommandQueue(ctx);
  id<MTLCommandBuffer> cb = [queue commandBuffer];
  int from_dev_type = static_cast<int>(ctx_from.device_type);
  int to_dev_type = static_cast<int>(ctx_to.device_type);

159
  if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
160 161
    CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
        << "Metal disallow cross device copy.";
Tianqi Chen committed
162
    id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
163 164 165 166 167 168 169
    [encoder copyFromBuffer:(__bridge id<MTLBuffer>)(from)
             sourceOffset:from_offset
             toBuffer:(__bridge id<MTLBuffer>)(to)
             destinationOffset:to_offset
             size:size];
    [encoder endEncoding];
    [cb commit];
170
  } else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) {
171 172 173 174 175
    // copy to a local buffer before get into global buffer.
    id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from);
    if (from_buf.storageMode != MTLStorageModeShared) {
      id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
          ->GetTempBuffer(ctx_from, size);
Tianqi Chen committed
176
      id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
      [encoder copyFromBuffer:from_buf
               sourceOffset:from_offset
               toBuffer:temp
               destinationOffset:0
               size:size];
      [encoder endEncoding];
      [cb commit];
      [cb waitUntilCompleted];
      memcpy(static_cast<char*>(to) + to_offset,
             static_cast<char*>([temp contents]),
             size);
    } else {
      memcpy(static_cast<char*>(to) + to_offset,
             static_cast<char*>([from_buf contents]) + from_offset,
             size);
    }
193
  } else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) {
194
    id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to);
195
    if (to_buf.storageMode != MTLStorageModeShared) {
196 197 198 199 200
      id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
          ->GetTempBuffer(ctx_to, size);
      memcpy([temp contents],
              static_cast<const char*>(from) + from_offset,
              size);
Tianqi Chen committed
201
      id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
202 203 204 205 206 207 208
      [encoder copyFromBuffer:temp
               sourceOffset:0
               toBuffer:to_buf
               destinationOffset:to_offset
               size:size];
      [encoder endEncoding];
      [cb commit];
209
      [cb waitUntilCompleted];
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
    } else {
      memcpy(static_cast<char*>([to_buf contents]) + to_offset,
             static_cast<const char*>(from) + from_offset,
             size);
    }
  } else {
    LOG(FATAL) << "Expect copy from/to Metal or between Metal"
               << ", from=" << from_dev_type
               << ", to=" << to_dev_type;
  }
}

void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
  CHECK(stream == nullptr);
  // commit an empty command buffer and wait until it completes.
  id<MTLCommandQueue> queue = GetCommandQueue(ctx);
  id<MTLCommandBuffer> cb = [queue commandBuffer];
  [cb commit];
  [cb waitUntilCompleted];
}

231 232 233 234 235 236 237 238
void* MetalWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
  return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}

void MetalWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
  MetalThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}

239 240 241 242 243 244
MetalThreadEntry::~MetalThreadEntry() {
  for (auto x : temp_buffer_) {
    if (x != nil) [x release];
  }
}

245 246 247 248 249 250 251
id<MTLBuffer> MetalThreadEntry::GetTempBuffer(TVMContext ctx, size_t size) {
  if (temp_buffer_.size() <= static_cast<size_t>(ctx.device_id)) {
    temp_buffer_.resize(ctx.device_id + 1, nil);
  }
  if (temp_buffer_[ctx.device_id] == nil ||
      temp_buffer_[ctx.device_id].length < size) {
    id<MTLDevice> dev = MetalWorkspace::Global()->GetDevice(ctx);
252 253 254
    if (temp_buffer_[ctx.device_id] != nil) {
      [temp_buffer_[ctx.device_id] release];
    }
255
    temp_buffer_[ctx.device_id] = [
256 257
        [dev newBufferWithLength:size
            options:MTLStorageModeShared] retain];
258 259 260 261 262 263 264 265 266 267 268 269
  }
  return temp_buffer_[ctx.device_id];
}

typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;

MetalThreadEntry* MetalThreadEntry::ThreadLocal() {
  return MetalThreadStore::Get();
}

TVM_REGISTER_GLOBAL("device_api.metal")
.set_body([](TVMArgs args, TVMRetValue* rv) {
270
    DeviceAPI* ptr = MetalWorkspace::Global().get();
271 272 273 274 275 276 277 278
    *rv = static_cast<void*>(ptr);
  });

}  // namespace metal
}  // namespace runtime
}  // namespace tvm

#endif  // TVM_METAL_RUNTIME