/*!
 *  Copyright (c) 2017 by Contributors
 * \file graph_executor_ext.cc
 */
#include "./graph_executor.h"

namespace tvm {
namespace contrib {

bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
    uint64_t header = kTVMNDArrayMagic, reserved = 0;
    strm->Write(&header, sizeof(header));
    strm->Write(&reserved, sizeof(reserved));

    strm->Write(&tensor->ctx, sizeof(tensor->ctx));
    strm->Write(&tensor->ndim, sizeof(tensor->ndim));
    strm->Write(&tensor->dtype, sizeof(tensor->dtype));

    int ndim = tensor->ndim;
    strm->Write(tensor->shape, sizeof(int64_t) * ndim);

    int type_size = tensor->dtype.bits / 8;
    int64_t size = 1;
    for (int i = 0; i < ndim; ++i) {
      size *= tensor->shape[i];
    }
    int64_t data_byte_size = type_size * size;
    strm->Write(&data_byte_size, sizeof(data_byte_size));
    strm->Write(tensor->data, data_byte_size);
    return true;
}

TVM_REGISTER_GLOBAL("tvm_graph._save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
    std::string fname = args[0];
    int num_params = args[1];
    std::vector<std::string> names;
    names.reserve(num_params);
    std::vector<DLTensor*> arrays;
    arrays.reserve(num_params);
    for (int i = 2; i < (2 + 2*num_params); i += 2) {
      names.emplace_back(args[i].operator std::string());
      arrays.emplace_back(args[i+1].operator DLTensor*());
    }

    std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
    uint64_t header = kTVMNDArrayListMagic, reserved = 0;
    fo->Write(&header, sizeof(header));
    fo->Write(&reserved, sizeof(reserved));

    fo->Write(names);
    {
      uint64_t sz = static_cast<uint64_t>(arrays.size());
      fo->Write(&sz, sizeof(sz));
      for (size_t i = 0; i < sz; ++i) {
        SaveDLTensor(fo.get(), arrays[i]);
      }
    }
  });

// Create executor
tvm::runtime::Module CreateExecutor(nnvm::Graph g, TVMContext ctx) {
  std::shared_ptr<GraphExecutor> exec =
      std::make_shared<GraphExecutor>();
  exec->Init(g, ctx);
  return tvm::runtime::Module(exec);
}

TVM_REGISTER_GLOBAL("tvm_graph._create_executor")
.set_body([](TVMArgs args, TVMRetValue *rv) {
    void* graph_handle = args[0];
    int device_type = args[1];
    int device_id = args[2];
    TVMContext ctx{static_cast<DLDeviceType>(device_type), device_id};
    nnvm::Graph g = static_cast<nnvm::Graph*>(graph_handle)[0];
    *rv = CreateExecutor(g, ctx);
  });


TVM_REGISTER_GLOBAL("tvm_graph._get_module_from_graph")
.set_body([](TVMArgs args, TVMRetValue *rv) {
    void* graph_handle = args[0];
    nnvm::Graph* g = static_cast<nnvm::Graph*>(graph_handle);
    *rv = g->MoveCopyAttr<tvm::runtime::Module>("module");
  });
}  // namespace contrib
}  // namespace tvm