/*!
 *  Copyright (c) 2017 by Contributors
 * \file metal_module.cc
 */
#include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <array>
#include <string>
#include <mutex>
#include "metal_module.h"
#include "metal_common.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
#include "../meta_data.h"
#include "../file_util.h"

namespace tvm {
namespace runtime {

// Module to support thread-safe multi-GPU execution.
// The runtime will contain a per-device module table
// The modules will be lazily loaded
class MetalModuleNode final :public runtime::ModuleNode {
 public:
  explicit MetalModuleNode(std::string data,
                           std::string fmt,
                           std::unordered_map<std::string, FunctionInfo> fmap,
                           std::string source)
      : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {
  }
  const char* type_key() const final {
    return "metal";
  }

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final;

  void SaveToFile(const std::string& file_name,
                  const std::string& format) final {
    std::string fmt = GetFileFormat(file_name, format);
    CHECK_EQ(fmt, fmt_)
        << "Can only save to format=" << fmt_;
    std::string meta_file = GetMetaFilePath(file_name);
    SaveMetaDataToFile(meta_file, fmap_);
    SaveBinaryToFile(file_name, data_);
  }

  void SaveToBinary(dmlc::Stream* stream) final {
    stream->Write(fmt_);
    stream->Write(fmap_);
    stream->Write(data_);
  }
  std::string GetSource(const std::string& format) final {
    if (format == fmt_) return data_;
    if (source_.length() != 0) {
      return source_;
    } else if (fmt_ == "metal") {
      return data_;
    } else {
      return "";
    }
  }
  // get a from primary context in device_id
  id<MTLComputePipelineState> GetPipelineState(
      size_t device_id, const std::string& func_name) {
    metal::MetalWorkspace* w = metal::MetalWorkspace::Global().get();
    CHECK_LT(device_id, w->devices.size());
    // start lock scope.
    std::lock_guard<std::mutex> lock(mutex_);
    if (finfo_.size() <= device_id) {
      finfo_.resize(device_id + 1, DeviceEntry());
    }
    DeviceEntry& e = finfo_[device_id];
    auto it = e.smap.find(func_name);
    if (it != e.smap.end()) return it->second;
    // compile
    NSError* err_msg = nil;
    if (e.lib == nil) {
      if (fmt_ == "metal") {
        MTLCompileOptions *opts = [MTLCompileOptions alloc];
        // Use the Metal 1.2 for now.
        opts.languageVersion = MTLLanguageVersion1_2;
        opts.fastMathEnabled = YES;
        // opts = nil;
        e.lib = [
            w->devices[device_id]
             newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()]
             options:opts
             error:&err_msg];
        [opts dealloc];
        if (e.lib == nil) {
          LOG(FATAL) << "Fail to compile metal lib:"
                     << [[err_msg localizedDescription] UTF8String];
        }
        if (err_msg != nil) {
          LOG(INFO) << "Warning: "
                    << [[err_msg localizedDescription] UTF8String];
        }
      } else {
        // Build from library.
        auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL);
        auto data = dispatch_data_create(
            data_.c_str(), data_.length(), q, ^{});
        e.lib = [
            w->devices[device_id]
             newLibraryWithData:data
             error:&err_msg];
        if (err_msg != nil || e.lib == nil) {
          LOG(FATAL) << "Fail to compile metal lib:"
                     << [[err_msg localizedDescription] UTF8String];
        }
      }
      [e.lib retain];
    }
    id<MTLFunction> f = [
        e.lib
         newFunctionWithName:
           [NSString stringWithUTF8String:func_name.c_str()]];
    CHECK(f != nil) << "cannot find function " << func_name;
    id<MTLComputePipelineState> state =
        [w->devices[device_id]
          newComputePipelineStateWithFunction:f
          error:&err_msg];
    CHECK(state != nil)
        << "cannot get state:" << " for function " << func_name
        << [[err_msg localizedDescription] UTF8String];
    // The state.threadExecutionWidth can change dynamically according
    // to the resource constraint in kernel, so it is not strictly hold
    // Turn of warp aware optimziation for now.
    // CHECK_EQ(state.threadExecutionWidth, w->warp_size[device_id]);
    e.smap[func_name] = [state retain];
    return state;
  }

 private:
  // device specific entry
  struct DeviceEntry {
    // library
    id<MTLLibrary> lib = nil;
    // state cache;
    std::unordered_map<std::string, id<MTLComputePipelineState> > smap;

    ~DeviceEntry() {
      if (lib != nil) [lib release];
      for (auto &&kv : smap) {
        [kv.second release];
      }
    }
  };
  // the binary data
  std::string data_;
  // The format
  std::string fmt_;
  // function information table.
  std::unordered_map<std::string, FunctionInfo> fmap_;
  // The source
  std::string source_;
  // function information.
  std::vector<DeviceEntry> finfo_;
  // internal mutex when updating the module
  std::mutex mutex_;
};

// a wrapped function class to get packed fucn.
class MetalWrappedFunc {
 public:
  // initialize the METAL function.
  void Init(MetalModuleNode* m,
            std::shared_ptr<ModuleNode> sptr,
            const std::string& func_name,
            size_t num_buffer_args,
            size_t num_pack_args,
            const std::vector<std::string>& thread_axis_tags) {
    w_ = metal::MetalWorkspace::Global().get();
    m_ = m;
    sptr_ = sptr;
    func_name_ = func_name;
    num_buffer_args_ = num_buffer_args;
    num_pack_args_ = num_pack_args;
    std::fill(scache_.begin(), scache_.end(), (id<MTLComputePipelineState>)nil);
    thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
    metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
    int dev_id = t->context.device_id;
    scache_[dev_id] = m->GetPipelineState(dev_id, func_name);
  }
  // invoke the function with void arguments
  void operator()(TVMArgs args,
                  TVMRetValue* rv,
                  const ArgUnion* pack_args) const {
    metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
    int device_id = t->context.device_id;
    if (scache_[device_id] == nil) {
      scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
    }
    ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
    id<MTLCommandQueue> queue = w_->GetCommandQueue(t->context);
    id<MTLCommandBuffer> cb = [queue commandBuffer];
    id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
    [encoder setComputePipelineState:scache_[device_id]];
    for (size_t i = 0; i < num_buffer_args_; ++i) {
      void* buf = args[static_cast<int>(i)];
      [encoder setBuffer:(__bridge id<MTLBuffer>)(buf) offset:0 atIndex:i];
    }
    if (num_pack_args_ != 0) {
      [encoder setBytes:pack_args
               length:num_pack_args_ * sizeof(ArgUnion)
               atIndex:num_buffer_args_];
    }
    // launch
    MTLSize dimGrid = MTLSizeMake(
        wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
    MTLSize dimBlock = MTLSizeMake(
        wl.block_dim(0), wl.block_dim(1), wl.block_dim(2));
    [encoder dispatchThreadgroups: dimGrid
             threadsPerThreadgroup: dimBlock];
    [encoder endEncoding];
    [cb commit];
  }

 private:
  // Reference to global workspace.
  metal::MetalWorkspace* w_;
  // internal module
  MetalModuleNode* m_;
  // the resource holder
  std::shared_ptr<ModuleNode> sptr_;
  // The name of the function.
  std::string func_name_;
  // Number of buffer arguments
  size_t num_buffer_args_;
  // number of packed arguments.
  size_t num_pack_args_;
  // Device state cache per device.
  // mark as mutable, to enable lazy initialization
  mutable std::array<id<MTLComputePipelineState>, kMetalMaxNumDevice> scache_;
  // thread axis configuration
  ThreadAxisConfig thread_axis_cfg_;
};

PackedFunc MetalModuleNode::GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) {
  CHECK_EQ(sptr_to_self.get(), this);
  CHECK_NE(name, symbol::tvm_module_main)
      << "Device function do not have main";
  auto it = fmap_.find(name);
  if (it == fmap_.end()) return PackedFunc();
  const FunctionInfo& info = it->second;
  MetalWrappedFunc f;
  size_t num_buffer_args = NumBufferArgs(info.arg_types);
  f.Init(this, sptr_to_self, name,
         num_buffer_args, info.arg_types.size() - num_buffer_args,
         info.thread_axis_tags);
  return PackFuncNonBufferArg(f, info.arg_types);
}

Module MetalModuleCreate(
    std::string data,
    std::string fmt,
    std::unordered_map<std::string, FunctionInfo> fmap,
    std::string source) {
  metal::MetalWorkspace::Global()->Init();
  std::shared_ptr<MetalModuleNode> n =
      std::make_shared<MetalModuleNode>(data, fmt, fmap, source);
  return Module(n);
}

// Load module from module.
Module MetalModuleLoadFile(const std::string& file_name,
                           const std::string& format) {
  std::string data;
  std::unordered_map<std::string, FunctionInfo> fmap;
  std::string fmt = GetFileFormat(file_name, format);
  std::string meta_file = GetMetaFilePath(file_name);
  LoadBinaryFromFile(file_name, &data);
  LoadMetaDataFromFile(meta_file, &fmap);
  return MetalModuleCreate(data, fmt, fmap, "");
}

Module MetalModuleLoadBinary(void* strm) {
  dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
  std::string data;
  std::unordered_map<std::string, FunctionInfo> fmap;
  std::string fmt;
  stream->Read(&fmt);
  stream->Read(&fmap);
  stream->Read(&data);
  return MetalModuleCreate(data, fmt, fmap, "");
}

TVM_REGISTER_GLOBAL("module.loadfile_metal")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = MetalModuleLoadFile(args[0], args[1]);
  });

TVM_REGISTER_GLOBAL("module.loadbinary_metal")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = MetalModuleLoadBinary(args[0]);
  });
}  // namespace runtime
}  // namespace tvm