/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2018 by Contributors
 * \file sgx_module.cc
 * \brief SGX enclave module.
 */
#include <dmlc/logging.h>
#include <sgx_urts.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/threading_backend.h>
#include <algorithm>
#include <fstream>
#include <iostream>
#include <iterator>
#include <sstream>
#include <string>
#include <unordered_map>
#include "../common.h"
#include "../../file_util.h"
#include "./tvm_u.h"

namespace tvm {
namespace runtime {

class SGXModuleNode;

namespace sgx {

class EnclaveContext {
 public:
  explicit EnclaveContext(SGXModuleNode* mod) {
    CHECK(Context()->mod_ == nullptr)
      << "Tried overriding existing enclave context.";
    CHECK(mod != nullptr) << "Tried setting null enclave context.";
    Context()->mod_ = mod;
  }
  ~EnclaveContext() {
    Context()->mod_ = nullptr;
  }

  static SGXModuleNode* GetModule() {
    SGXModuleNode* ctx = Context()->mod_;
    CHECK(ctx != nullptr) << "No current enclave context";
    return ctx;
  }

 private:
  EnclaveContext() {}
  SGXModuleNode* mod_;

  static EnclaveContext* Context() {
    static thread_local EnclaveContext inst;
    return &inst;
  }
};

}  // namespace sgx

class SGXModuleNode : public ModuleNode {
 public:
  ~SGXModuleNode() {
    if (eid_) {
      sgx::EnclaveContext ctx(this);
      sgx_destroy_enclave(eid_);
    }
  }

  void Init(const std::string& enclave_file) {
    std::string token_file = GetCacheDir() + "/" +
                             GetFileBasename(enclave_file) + ".token";
    sgx_launch_token_t token = {0};
    int token_updated = 0;

    try {
      std::ifstream ifs(token_file, std::fstream::in | std::fstream::binary);
      ifs.exceptions(std::ifstream::failbit | std::ifstream::badbit);
      ifs >> token;
    } catch (std::ifstream::failure e) {
      memset(&token, 0x0, sizeof(sgx_launch_token_t));
    }

    TVM_SGX_CHECKED_CALL(sgx_create_enclave(
        enclave_file.c_str(), SGX_DEBUG_FLAG, &token, &token_updated, &eid_, NULL));

    sgx::EnclaveContext ctx(this);
    TVMRetValue rv;
    TVM_SGX_CHECKED_CALL(tvm_ecall_init(eid_, &rv));

    if (!token_updated) return;

    try {
      std::ofstream ofs(token_file, std::fstream::trunc | std::fstream::binary);
      ofs.exceptions(std::ifstream::failbit | std::ifstream::badbit);
      ofs << token;
    } catch (std::ifstream::failure e) {
      LOG(INFO) << "Could not save SGX launch token to " << token_file;
    }
  }

  const char* type_key() const final {
    return "sgx";
  }

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
    auto exported = exports_.find(name);
    if (exported == exports_.end()) return PackedFunc();
    int func_id = exported->second;
    return PackedFunc([this, func_id](TVMArgs args, TVMRetValue* rv) {
        sgx::EnclaveContext ctx(this);
        TVMValue ret_value;
        int ret_type_code;
        TVM_SGX_CHECKED_CALL(tvm_ecall_packed_func(eid_, func_id,
              args.values, args.type_codes, args.num_args, &ret_value, &ret_type_code));
        *rv = TVMArgValue(ret_value, ret_type_code);
      });
  }

  void RunWorkers(int num_tasks) {
    std::function<void(int)> runner = [this](int _worker_id) {
      this->GetFunction("__tvm_run_worker__",
                        std::shared_ptr<SGXModuleNode>(nullptr))();
    };
    thread_group_.reset(new tvm::runtime::threading::ThreadGroup(
          num_tasks, runner, false /* include_main_thread */));
  }

  void JoinThreads() {
    thread_group_->Join();
  }

  void RegisterExport(std::string name, int func_id) {
    exports_[name] = func_id;
  }

 private:
  // ID of the loaded enclave
  sgx_enclave_id_t eid_;
  // Names and IDs of functions exported by the enclave module
  std::unordered_map<std::string, int> exports_;
  std::unique_ptr<tvm::runtime::threading::ThreadGroup> thread_group_;
};

namespace sgx {

TVM_REGISTER_GLOBAL("__sgx_thread_group_launch__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
  EnclaveContext::GetModule()->RunWorkers(args[0]);
});

TVM_REGISTER_GLOBAL("__sgx_thread_group_join__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
  EnclaveContext::GetModule()->JoinThreads();
});

TVM_REGISTER_GLOBAL("__sgx_set_last_error__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
  std::string err = args[0];
  TVMAPISetLastError(err.c_str());
});

TVM_REGISTER_GLOBAL("__sgx_println__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
  std::ostringstream msg;
  for (int i = 0; i < args.num_args; ++i) {
    switch (args.type_codes[i]) {
    case kDLInt: msg << static_cast<int64_t>(args[i]); break;
    case kDLUInt: msg << static_cast<uint64_t>(args[i]); break;
    case kDLFloat: msg << static_cast<double>(args[i]); break;
    case kStr:
    case kBytes: {
      std::string val = args[i];
      msg << val;
    }
    break;
    }
    msg << " ";
  }
  LOG(INFO) << msg.str();
});

extern "C" {

void tvm_ocall_register_export(const char* name, int func_id) {
  EnclaveContext::GetModule()->RegisterExport(name, func_id);
}

void tvm_ocall_packed_func(const char* name,
                           const TVMValue* arg_values,
                           const int* type_codes,
                           int num_args,
                           TVMValue* ret_val,
                           int* ret_type_code) {
  const PackedFunc* f = Registry::Get(name);
  CHECK(f != nullptr) << "ocall to nonexistent function \"" << name << "\"";
  TVMRetValue rv;
  f->CallPacked(TVMArgs(arg_values, type_codes, num_args), &rv);
  rv.MoveToCHost(ret_val, ret_type_code);
}

// Allocates space for return values. The returned pointer is only valid between
// successive calls to `tvm_ocall_reserve_space`.
TVM_REGISTER_GLOBAL("__sgx_reserve_space__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
  size_t num_bytes = args[0];
  size_t alignment = args[1];

  static TVMContext ctx = { kDLCPU, 0 };
  static thread_local void* buf = nullptr;
  static thread_local size_t buf_size = 0;
  static thread_local size_t buf_align = 0;

  if (buf_size >= num_bytes && buf_align >= alignment) *rv = nullptr;

  DeviceAPI::Get(ctx)->FreeDataSpace(ctx, buf);
  buf = DeviceAPI::Get(ctx)->AllocDataSpace(ctx, num_bytes, alignment, {});
  buf_size = num_bytes;
  buf_align = alignment;

  *rv = buf;
});

}  // extern "C"
}  // namespace sgx

TVM_REGISTER_GLOBAL("module.loadfile_sgx")
.set_body([](TVMArgs args, TVMRetValue* rv) {
  std::shared_ptr<SGXModuleNode> node = std::make_shared<SGXModuleNode>();
  node->Init(args[0]);
  *rv = runtime::Module(node);
});

}  // namespace runtime
}  // namespace tvm