/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2017 by Contributors
 * \file metal_device_api.mm
 */
#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h>
#include "metal_common.h"

namespace tvm {
namespace runtime {
namespace metal {

const std::shared_ptr<MetalWorkspace>& MetalWorkspace::Global() {
  static std::shared_ptr<MetalWorkspace> inst =
      std::make_shared<MetalWorkspace>();
  return inst;
}

void MetalWorkspace::GetAttr(
    TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
  this->Init();
  size_t index = static_cast<size_t>(ctx.device_id);
  if (kind == kExist) {
    *rv = int(index< devices.size());
    return;
  }
  CHECK_LT(index, devices.size())
      << "Invalid device id " << index;
  switch (kind) {
    case kMaxThreadsPerBlock: {
      *rv = static_cast<int>(
          [devices[ctx.device_id] maxThreadsPerThreadgroup].width);
      break;
    }
    case kWarpSize: {
      // Set warp size to be 1 for safty reason.
      *rv = 1;
      break;
    }
    case kMaxSharedMemoryPerBlock: return;
    case kComputeVersion: return;
    case kDeviceName: return;
    case kMaxClockRate: return;
    case kMultiProcessorCount: return;
    case kMaxThreadDimensions: return;
    case kExist: break;
  }
}

static const char* kDummyKernel = R"A0B0(
using namespace metal;
// Simple copy kernel
// Just to get threadExecutionWidth from current Metal API.
kernel void CopyKernel(
  device float* dst [[buffer(0)]],
  device float* src [[buffer(1)]],
  ushort2 gid[[thread_position_in_grid]]) {
  dst[gid.x] = src[gid.x];
}
)A0B0";

// Hack to get Warp size from device.
// Note that in Metal
// state.threadExecutionWidth can vary per kernel
// maybe due to resource constraint.
// so state.threadExecutionWidth can be smaller than warp size
// For safe issue, turn off warp-aware optimization for now
// But we keep this code.
int GetWarpSize(id<MTLDevice> dev) {
  NSError* error_msg = nil;
  id<MTLLibrary> lib =
      [dev
        newLibraryWithSource:
          [NSString stringWithUTF8String:kDummyKernel]
        options:nil
        error:&error_msg];
  CHECK(lib != nil) << [[error_msg localizedDescription] UTF8String];
  id<MTLFunction> f =
      [lib
        newFunctionWithName:
          [NSString stringWithUTF8String:"CopyKernel"]];
  CHECK(f!= nil);
  id<MTLComputePipelineState> state =
      [dev
        newComputePipelineStateWithFunction:f
        error:&error_msg];
  CHECK(state != nil) << [[error_msg localizedDescription] UTF8String];
  return static_cast<int>(state.threadExecutionWidth);
}

MetalWorkspace::~MetalWorkspace() {
  for (auto x : devices) {
    [x release];
  }
  for (auto x : queues) {
    [x release];
  }
}

void MetalWorkspace::Init() {
  if (initialized_) return;
  std::lock_guard<std::mutex> lock(this->mutex);
  if (initialized_) return;
  initialized_ = true;
  if (devices.size() != 0) return;
#if TARGET_OS_IPHONE
    // on iPhone
    id<MTLDevice> d = MTLCreateSystemDefaultDevice();
    devices.push_back([d retain]);
    queues.push_back([[d newCommandQueue] retain]);
#else
    NSArray<id<MTLDevice>>* devs = MTLCopyAllDevices();
    for (size_t i = 0; i < devs.count; ++i) {
      id<MTLDevice> d = [devs objectAtIndex:i];
      devices.push_back([d retain]);
      queues.push_back([[d newCommandQueue] retain]);
      LOG(INFO) << "Intializing Metal device " << i
                <<  ", name=" << [d.name UTF8String];
      warp_size.push_back(GetWarpSize(d));
    }
#endif
}

void MetalWorkspace::SetDevice(TVMContext ctx) {
  MetalThreadEntry::ThreadLocal()->context.device_id = ctx.device_id;
}

void* MetalWorkspace::AllocDataSpace(
    TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) {
  this->Init();
  id<MTLDevice> dev = GetDevice(ctx);
  // GPU memory only
  MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
  /*
  #if TARGET_OS_IPHONE
  storage_mode = MTLResourceStorageModeShared;
  #else
  storage_mode = MTLResourceStorageModeManaged;
  #endif
  */
  id<MTLBuffer> buf = [
      dev newBufferWithLength:nbytes
          options:storage_mode];
  CHECK(buf != nil);
  return (__bridge void*)([buf retain]);
}

void MetalWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
  // release the ptr.
  CFRelease(ptr);
}

void MetalWorkspace::CopyDataFromTo(const void* from,
                                    size_t from_offset,
                                    void* to,
                                    size_t to_offset,
                                    size_t size,
                                    TVMContext ctx_from,
                                    TVMContext ctx_to,
                                    TVMType type_hint,
                                    TVMStreamHandle stream) {
  this->Init();
  CHECK(stream == nullptr);
  TVMContext ctx = ctx_from;
  if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
  id<MTLCommandQueue> queue = GetCommandQueue(ctx);
  id<MTLCommandBuffer> cb = [queue commandBuffer];
  int from_dev_type = static_cast<int>(ctx_from.device_type);
  int to_dev_type = static_cast<int>(ctx_to.device_type);

  if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
    CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
        << "Metal disallow cross device copy.";
    id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
    [encoder copyFromBuffer:(__bridge id<MTLBuffer>)(from)
             sourceOffset:from_offset
             toBuffer:(__bridge id<MTLBuffer>)(to)
             destinationOffset:to_offset
             size:size];
    [encoder endEncoding];
    [cb commit];
  } else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) {
    // copy to a local buffer before get into global buffer.
    id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from);
    if (from_buf.storageMode != MTLStorageModeShared) {
      id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
          ->GetTempBuffer(ctx_from, size);
      id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
      [encoder copyFromBuffer:from_buf
               sourceOffset:from_offset
               toBuffer:temp
               destinationOffset:0
               size:size];
      [encoder endEncoding];
      [cb commit];
      [cb waitUntilCompleted];
      memcpy(static_cast<char*>(to) + to_offset,
             static_cast<char*>([temp contents]),
             size);
    } else {
      memcpy(static_cast<char*>(to) + to_offset,
             static_cast<char*>([from_buf contents]) + from_offset,
             size);
    }
  } else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) {
    id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to);
    if (to_buf.storageMode != MTLStorageModeShared) {
      id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
          ->GetTempBuffer(ctx_to, size);
      memcpy([temp contents],
              static_cast<const char*>(from) + from_offset,
              size);
      id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
      [encoder copyFromBuffer:temp
               sourceOffset:0
               toBuffer:to_buf
               destinationOffset:to_offset
               size:size];
      [encoder endEncoding];
      [cb commit];
      [cb waitUntilCompleted];
    } else {
      memcpy(static_cast<char*>([to_buf contents]) + to_offset,
             static_cast<const char*>(from) + from_offset,
             size);
    }
  } else {
    LOG(FATAL) << "Expect copy from/to Metal or between Metal"
               << ", from=" << from_dev_type
               << ", to=" << to_dev_type;
  }
}

void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
  CHECK(stream == nullptr);
  // commit an empty command buffer and wait until it completes.
  id<MTLCommandQueue> queue = GetCommandQueue(ctx);
  id<MTLCommandBuffer> cb = [queue commandBuffer];
  [cb commit];
  [cb waitUntilCompleted];
}

void* MetalWorkspace::AllocWorkspace(TVMContext ctx,
                                     size_t size,
                                     TVMType type_hint) {
  return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}

void MetalWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
  MetalThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}

MetalThreadEntry::~MetalThreadEntry() {
  for (auto x : temp_buffer_) {
    if (x != nil) [x release];
  }
}

id<MTLBuffer> MetalThreadEntry::GetTempBuffer(TVMContext ctx, size_t size) {
  if (temp_buffer_.size() <= static_cast<size_t>(ctx.device_id)) {
    temp_buffer_.resize(ctx.device_id + 1, nil);
  }
  if (temp_buffer_[ctx.device_id] == nil ||
      temp_buffer_[ctx.device_id].length < size) {
    id<MTLDevice> dev = MetalWorkspace::Global()->GetDevice(ctx);
    if (temp_buffer_[ctx.device_id] != nil) {
      [temp_buffer_[ctx.device_id] release];
    }
    temp_buffer_[ctx.device_id] = [
        [dev newBufferWithLength:size
            options:MTLStorageModeShared] retain];
  }
  return temp_buffer_[ctx.device_id];
}

typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;

MetalThreadEntry* MetalThreadEntry::ThreadLocal() {
  return MetalThreadStore::Get();
}

TVM_REGISTER_GLOBAL("device_api.metal")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    DeviceAPI* ptr = MetalWorkspace::Global().get();
    *rv = static_cast<void*>(ptr);
  });

}  // namespace metal
}  // namespace runtime
}  // namespace tvm