#include "mps_utils.h"

namespace tvm {
namespace contrib {

using namespace runtime;

TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img")
.set_body([](TVMArgs args, TVMRetValue *ret) {
  DLTensor *buf = args[0];
  DLTensor *img = args[1];
  // copy to temp
  id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data);
  MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
  runtime::metal::MetalThreadEntry *rt =
      runtime::metal::MetalThreadEntry::ThreadLocal();
  id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(buf->ctx);
  id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
  entry_ptr->metal_api->CopyDataFromTo(
      (__bridge void *)mtlbuf, 0, (__bridge void *)temp, 0, [mtlbuf length],
      buf->ctx, buf->ctx, nullptr
  );

  MPSImageDescriptor *desc = [MPSImageDescriptor
      imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32
                                 width:buf->shape[2]
                                height:buf->shape[1]
                       featureChannels:buf->shape[3]];

  MPSImage *mpsimg = entry_ptr->AllocMPSImage(dev, desc);

  [mpsimg writeBytes:[temp contents]
          dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
          imageIndex:0];

  img->data = (__bridge void *)mpsimg;

  [mpsimg readBytes:[temp contents]
         dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
         imageIndex:0];
  
  });

TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
  DLTensor *img = args[0];
  DLTensor *buf = args[1];
  id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data);
  MPSImage *mpsimg = (__bridge MPSImage *)(img->data);
  MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
  runtime::metal::MetalThreadEntry *rt =
      runtime::metal::MetalThreadEntry::ThreadLocal();
  id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);

  [mpsimg readBytes:[temp contents]
         dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
         imageIndex:0];

  entry_ptr->metal_api->CopyDataFromTo(
      (__bridge void *)temp, 0, (__bridge void *)mtlbuf, 0, [mtlbuf length],
      buf->ctx, buf->ctx, nullptr);

    });

TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d")
.set_body([](TVMArgs args, TVMRetValue *ret) {
  // MPS-NHWC
  DLTensor *data = args[0];
  DLTensor *weight = args[1];
  DLTensor *output = args[2];
  int pad = args[3];
  int stride = args[4];

  CHECK_EQ(data->ndim, 4);
  CHECK_EQ(weight->ndim, 4);
  CHECK_EQ(output->ndim, 4);
  CHECK(output->strides == nullptr);
  CHECK(weight->strides == nullptr);
  CHECK(data->strides == nullptr);

  CHECK_EQ(data->shape[0], 1);
  CHECK_EQ(output->shape[0], 1);

  int oCh = weight->shape[0];
  int kH = weight->shape[1];
  int kW = weight->shape[2];
  int iCh = weight->shape[3];

  auto f_buf2img = runtime::Registry::Get("tvm.contrib.mps.buffer2img");
  auto f_img2buf = runtime::Registry::Get("tvm.contrib.mps.img2buffer");
  // Get Metal device API
  MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
  runtime::metal::MetalThreadEntry *rt =
      runtime::metal::MetalThreadEntry::ThreadLocal();
  id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(data->ctx);
  id<MTLCommandQueue> queue =
      entry_ptr->metal_api->GetCommandQueue(data->ctx);
  id<MTLCommandBuffer> cb = [queue commandBuffer];
  // data to MPSImage
  DLTensor tmp_in;
  (*f_buf2img)(data, &tmp_in);
  MPSImage *tempA = (__bridge MPSImage *)tmp_in.data;
  // weight to temp memory
  id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(weight->data);
  id<MTLBuffer> tempB = rt->GetTempBuffer(weight->ctx, [bufB length]);
  entry_ptr->metal_api->CopyDataFromTo(
      (__bridge void *)bufB, 0, (__bridge void *)tempB, 0, [bufB length],
      weight->ctx, weight->ctx, nullptr);
  float *ptr_w = (float *)[tempB contents];
  // output to MPSImage
  DLTensor tmp_out;
  (*f_buf2img)(output, &tmp_out);
  MPSImage *tempC = (__bridge MPSImage *)tmp_out.data;
  // conv desc

  MPSCNNConvolutionDescriptor *conv_desc = [MPSCNNConvolutionDescriptor
      cnnConvolutionDescriptorWithKernelWidth:kW
                                 kernelHeight:kH
                         inputFeatureChannels:iCh
                        outputFeatureChannels:oCh];
  [conv_desc setStrideInPixelsX:stride];
  [conv_desc setStrideInPixelsY:stride];

  MPSCNNConvolution *conv =
      [[MPSCNNConvolution alloc] initWithDevice:dev
                          convolutionDescriptor:conv_desc
                                  kernelWeights:ptr_w
                                      biasTerms:nil
                                          flags:MPSCNNConvolutionFlagsNone];
  if (pad == 0) {
    conv.padding = [MPSNNDefaultPadding
        paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
                          MPSNNPaddingMethodAlignCentered |
                          MPSNNPaddingMethodSizeSame];
  } else if (pad == 1) {
    conv.padding = [MPSNNDefaultPadding
        paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
                          MPSNNPaddingMethodAlignCentered |
                          MPSNNPaddingMethodSizeValidOnly];
  }
  [conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC];

  [cb commit];
  id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
  [encoder synchronizeResource:tempC.texture];
  [encoder endEncoding];
  [cb waitUntilCompleted];

  (*f_img2buf)(&tmp_out, output);
  
  });

} // namespace contrib
} // namespace tvm