/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#![feature(try_from)]

#[macro_use]
extern crate lazy_static;
#[macro_use]
extern crate tvm;

use std::{
  convert::{TryFrom, TryInto},
  sync::Mutex,
};

use tvm::{
  ffi::runtime::DLTensor,
  runtime::{
    load_param_dict, sgx, Graph, GraphExecutor, SystemLibModule, TVMArgValue, TVMRetValue, Tensor,
  },
};

lazy_static! {
  static ref SYSLIB: SystemLibModule = { SystemLibModule::default() };
  static ref MODEL: Mutex<GraphExecutor<'static, 'static>> = {
    let graph_json = include_str!(concat!("../", env!("BUILD_DIR"), "/graph.json"));
    let params_bytes = include_bytes!(concat!("../", env!("BUILD_DIR"), "/params.bin"));
    let params = load_param_dict(params_bytes).unwrap();

    let graph = Graph::try_from(graph_json).unwrap();
    let mut exec = GraphExecutor::new(graph, &*SYSLIB).unwrap();
    exec.load_params(params);
    Mutex::new(exec)
  };
}

fn ecall_init(_args: &[TVMArgValue]) -> TVMRetValue {
  lazy_static::initialize(&MODEL);
  TVMRetValue::from(0)
}

fn ecall_main(args: &[TVMArgValue<'static>]) -> TVMRetValue {
  let mut model = MODEL.lock().unwrap();
  let inp = args[0].try_into().unwrap();
  let mut out: Tensor = args[1].try_into().unwrap();
  model.set_input("data", inp);
  model.run();
  sgx::shutdown();
  out.copy(model.get_output(0).unwrap());
  TVMRetValue::from(1)
}

pub mod ecalls {
  //! todo: generate this using proc_macros

  use super::*;

  use std::{
    ffi::CString,
    mem,
    os::raw::{c_char, c_int, c_void},
    slice,
  };

  use tvm::{
    ffi::runtime::{TVMRetValueHandle, TVMValue},
    runtime::{
      sgx::{ocall_packed_func, run_worker, SgxStatus},
      DataType, PackedFunc,
    },
  };

  macro_rules! tvm_ocall {
    ($func: expr) => {
      match $func {
        0 => Ok(()),
        err => Err(err),
      }
    };
  }

  const ECALLS: &'static [&'static str] = &["__tvm_run_worker__", "__tvm_main__", "init"];

  pub type EcallPackedFunc = Box<Fn(&[TVMArgValue<'static>]) -> TVMRetValue + Send + Sync>;

  lazy_static! {
    static ref ECALL_FUNCS: Vec<EcallPackedFunc> = {
      vec![
        Box::new(run_worker),
        Box::new(ecall_main),
        Box::new(ecall_init),
      ]
    };
  }

  extern "C" {
    fn __tvm_module_startup() -> ();
    fn tvm_ocall_register_export(name: *const c_char, func_id: c_int) -> SgxStatus;
  }

  #[no_mangle]
  pub extern "C" fn tvm_ecall_init(_ret: TVMRetValueHandle) {
    unsafe {
      __tvm_module_startup();

      ECALLS.into_iter().enumerate().for_each(|(i, ecall)| {
        tvm_ocall!(tvm_ocall_register_export(
          CString::new(*ecall).unwrap().as_ptr(),
          i as i32
        ))
        .expect(&format!("Error registering `{}`", ecall));
      });
    }
  }

  #[no_mangle]
  pub extern "C" fn tvm_ecall_packed_func(
    func_id: c_int,
    arg_values: *const TVMValue,
    type_codes: *const c_int,
    num_args: c_int,
    ret_val: *mut TVMValue,
    ret_type_code: *mut i64,
  ) {
    let args = unsafe {
      let values = slice::from_raw_parts(arg_values, num_args as usize);
      let type_codes = slice::from_raw_parts(type_codes, num_args as usize);
      values
        .into_iter()
        .zip(type_codes.into_iter())
        .map(|(v, t)| TVMArgValue::new(*v, *t as i64))
        .collect::<Vec<TVMArgValue<'static>>>()
    };
    let (rv, tc) = ECALL_FUNCS[func_id as usize](&args).into_tvm_value();
    unsafe {
      *ret_val = rv;
      *ret_type_code = tc;
    }
  }
}