metal_device_api.mm 9.92 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 * \file metal_device_api.mm
 */
#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h>
25
#include "metal_common.h"
26 27 28 29 30

namespace tvm {
namespace runtime {
namespace metal {

31 32 33 34
const std::shared_ptr<MetalWorkspace>& MetalWorkspace::Global() {
  static std::shared_ptr<MetalWorkspace> inst =
      std::make_shared<MetalWorkspace>();
  return inst;
35 36 37
}

void MetalWorkspace::GetAttr(
38
    TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
39
  this->Init();
40
  size_t index = static_cast<size_t>(ctx.device_id);
41 42 43 44 45 46 47 48 49
  if (kind == kExist) {
    *rv = int(index< devices.size());
    return;
  }
  CHECK_LT(index, devices.size())
      << "Invalid device id " << index;
  switch (kind) {
    case kMaxThreadsPerBlock: {
      *rv = static_cast<int>(
50
          [devices[ctx.device_id] maxThreadsPerThreadgroup].width);
51 52 53 54 55 56 57
      break;
    }
    case kWarpSize: {
      // Set warp size to be 1 for safty reason.
      *rv = 1;
      break;
    }
58
    case kMaxSharedMemoryPerBlock: return;
59
    case kComputeVersion: return;
60
    case kDeviceName: return;
61 62
    case kMaxClockRate: return;
    case kMultiProcessorCount: return;
63
    case kMaxThreadDimensions: return;
64
    case kExist: break;
65
    case kGcnArch: return; 
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
  }
}

static const char* kDummyKernel = R"A0B0(
using namespace metal;
// Simple copy kernel
// Just to get threadExecutionWidth from current Metal API.
kernel void CopyKernel(
  device float* dst [[buffer(0)]],
  device float* src [[buffer(1)]],
  ushort2 gid[[thread_position_in_grid]]) {
  dst[gid.x] = src[gid.x];
}
)A0B0";

// Hack to get Warp size from device.
// Note that in Metal
// state.threadExecutionWidth can vary per kernel
// maybe due to resource constraint.
// so state.threadExecutionWidth can be smaller than warp size
// For safe issue, turn off warp-aware optimization for now
// But we keep this code.
int GetWarpSize(id<MTLDevice> dev) {
  NSError* error_msg = nil;
  id<MTLLibrary> lib =
      [dev
        newLibraryWithSource:
          [NSString stringWithUTF8String:kDummyKernel]
        options:nil
        error:&error_msg];
96
  CHECK(lib != nil) << [[error_msg localizedDescription] UTF8String];
97 98 99 100 101 102 103 104 105
  id<MTLFunction> f =
      [lib
        newFunctionWithName:
          [NSString stringWithUTF8String:"CopyKernel"]];
  CHECK(f!= nil);
  id<MTLComputePipelineState> state =
      [dev
        newComputePipelineStateWithFunction:f
        error:&error_msg];
106
  CHECK(state != nil) << [[error_msg localizedDescription] UTF8String];
Tianqi Chen committed
107
  return static_cast<int>(state.threadExecutionWidth);
108 109
}

110 111 112 113 114 115 116 117 118
MetalWorkspace::~MetalWorkspace() {
  for (auto x : devices) {
    [x release];
  }
  for (auto x : queues) {
    [x release];
  }
}

119 120
void MetalWorkspace::Init() {
  if (initialized_) return;
121
  std::lock_guard<std::mutex> lock(this->mutex);
122 123 124
  if (initialized_) return;
  initialized_ = true;
  if (devices.size() != 0) return;
Tianqi Chen committed
125 126 127
#if TARGET_OS_IPHONE
    // on iPhone
    id<MTLDevice> d = MTLCreateSystemDefaultDevice();
128 129
    devices.push_back([d retain]);
    queues.push_back([[d newCommandQueue] retain]);
Tianqi Chen committed
130 131 132 133 134 135 136
#else
    NSArray<id<MTLDevice>>* devs = MTLCopyAllDevices();
    for (size_t i = 0; i < devs.count; ++i) {
      id<MTLDevice> d = [devs objectAtIndex:i];
      devices.push_back([d retain]);
      queues.push_back([[d newCommandQueue] retain]);
      LOG(INFO) << "Intializing Metal device " << i
137
                <<  ", name=" << [d.name UTF8String];
Tianqi Chen committed
138 139 140
      warp_size.push_back(GetWarpSize(d));
    }
#endif
141 142
}

143 144
void MetalWorkspace::SetDevice(TVMContext ctx) {
  MetalThreadEntry::ThreadLocal()->context.device_id = ctx.device_id;
145 146 147
}

void* MetalWorkspace::AllocDataSpace(
148
    TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) {
149 150
  this->Init();
  id<MTLDevice> dev = GetDevice(ctx);
Leyuan Wang committed
151 152 153 154 155 156 157 158 159
  // GPU memory only
  MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
  /*
  #if TARGET_OS_IPHONE
  storage_mode = MTLResourceStorageModeShared;
  #else
  storage_mode = MTLResourceStorageModeManaged;
  #endif
  */
160
  id<MTLBuffer> buf = [
161
      dev newBufferWithLength:nbytes
Leyuan Wang committed
162
          options:storage_mode];
Tianqi Chen committed
163
  CHECK(buf != nil);
164
  return (__bridge void*)([buf retain]);
165 166 167 168
}

void MetalWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
  // release the ptr.
169
  CFRelease(ptr);
170 171 172 173 174 175 176 177 178
}

void MetalWorkspace::CopyDataFromTo(const void* from,
                                    size_t from_offset,
                                    void* to,
                                    size_t to_offset,
                                    size_t size,
                                    TVMContext ctx_from,
                                    TVMContext ctx_to,
179
                                    TVMType type_hint,
180 181 182 183
                                    TVMStreamHandle stream) {
  this->Init();
  CHECK(stream == nullptr);
  TVMContext ctx = ctx_from;
184
  if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
185 186 187 188 189
  id<MTLCommandQueue> queue = GetCommandQueue(ctx);
  id<MTLCommandBuffer> cb = [queue commandBuffer];
  int from_dev_type = static_cast<int>(ctx_from.device_type);
  int to_dev_type = static_cast<int>(ctx_to.device_type);

190
  if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
191 192
    CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
        << "Metal disallow cross device copy.";
Tianqi Chen committed
193
    id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
194 195 196 197 198 199 200
    [encoder copyFromBuffer:(__bridge id<MTLBuffer>)(from)
             sourceOffset:from_offset
             toBuffer:(__bridge id<MTLBuffer>)(to)
             destinationOffset:to_offset
             size:size];
    [encoder endEncoding];
    [cb commit];
201
  } else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) {
202 203 204 205 206
    // copy to a local buffer before get into global buffer.
    id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from);
    if (from_buf.storageMode != MTLStorageModeShared) {
      id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
          ->GetTempBuffer(ctx_from, size);
Tianqi Chen committed
207
      id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
      [encoder copyFromBuffer:from_buf
               sourceOffset:from_offset
               toBuffer:temp
               destinationOffset:0
               size:size];
      [encoder endEncoding];
      [cb commit];
      [cb waitUntilCompleted];
      memcpy(static_cast<char*>(to) + to_offset,
             static_cast<char*>([temp contents]),
             size);
    } else {
      memcpy(static_cast<char*>(to) + to_offset,
             static_cast<char*>([from_buf contents]) + from_offset,
             size);
    }
224
  } else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) {
225
    id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to);
226
    if (to_buf.storageMode != MTLStorageModeShared) {
227 228 229 230 231
      id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
          ->GetTempBuffer(ctx_to, size);
      memcpy([temp contents],
              static_cast<const char*>(from) + from_offset,
              size);
Tianqi Chen committed
232
      id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
233 234 235 236 237 238 239
      [encoder copyFromBuffer:temp
               sourceOffset:0
               toBuffer:to_buf
               destinationOffset:to_offset
               size:size];
      [encoder endEncoding];
      [cb commit];
240
      [cb waitUntilCompleted];
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
    } else {
      memcpy(static_cast<char*>([to_buf contents]) + to_offset,
             static_cast<const char*>(from) + from_offset,
             size);
    }
  } else {
    LOG(FATAL) << "Expect copy from/to Metal or between Metal"
               << ", from=" << from_dev_type
               << ", to=" << to_dev_type;
  }
}

void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
  CHECK(stream == nullptr);
  // commit an empty command buffer and wait until it completes.
  id<MTLCommandQueue> queue = GetCommandQueue(ctx);
  id<MTLCommandBuffer> cb = [queue commandBuffer];
  [cb commit];
  [cb waitUntilCompleted];
}

262 263 264
void* MetalWorkspace::AllocWorkspace(TVMContext ctx,
                                     size_t size,
                                     TVMType type_hint) {
265 266 267 268 269 270 271
  return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}

void MetalWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
  MetalThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}

272 273 274 275 276 277
MetalThreadEntry::~MetalThreadEntry() {
  for (auto x : temp_buffer_) {
    if (x != nil) [x release];
  }
}

278 279 280 281 282 283 284
id<MTLBuffer> MetalThreadEntry::GetTempBuffer(TVMContext ctx, size_t size) {
  if (temp_buffer_.size() <= static_cast<size_t>(ctx.device_id)) {
    temp_buffer_.resize(ctx.device_id + 1, nil);
  }
  if (temp_buffer_[ctx.device_id] == nil ||
      temp_buffer_[ctx.device_id].length < size) {
    id<MTLDevice> dev = MetalWorkspace::Global()->GetDevice(ctx);
285 286 287
    if (temp_buffer_[ctx.device_id] != nil) {
      [temp_buffer_[ctx.device_id] release];
    }
288
    temp_buffer_[ctx.device_id] = [
289 290
        [dev newBufferWithLength:size
            options:MTLStorageModeShared] retain];
291 292 293 294 295 296 297 298 299 300 301 302
  }
  return temp_buffer_[ctx.device_id];
}

typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;

MetalThreadEntry* MetalThreadEntry::ThreadLocal() {
  return MetalThreadStore::Get();
}

TVM_REGISTER_GLOBAL("device_api.metal")
.set_body([](TVMArgs args, TVMRetValue* rv) {
303
    DeviceAPI* ptr = MetalWorkspace::Global().get();
304 305 306 307 308 309
    *rv = static_cast<void*>(ptr);
  });

}  // namespace metal
}  // namespace runtime
}  // namespace tvm