Commit dc62760e by nhynes Committed by Tianqi Chen

Improve SGXModule (#1104)

* Improve SGXModule

* Address code review comments
parent 877254f4
......@@ -82,7 +82,7 @@ std::vector<ECallRegistry> ECallRegistry::exports_;
*/
#define TVM_REGISTER_ENCLAVE_FUNC(OpName) \
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::sgx::ECallRegistry::Register(OpName)
::tvm::runtime::sgx::ECallRegistry::Register(OpName, true)
} // namespace sgx
} // namespace runtime
......
......@@ -45,7 +45,7 @@ void tvm_ecall_packed_func(int func_id,
void* ret_buf;
TVM_SGX_CHECKED_CALL(tvm_ocall_reserve_space(
&ret_buf, bytes.size() + sizeof(TVMByteArray)));
&ret_buf, bytes.size() + sizeof(TVMByteArray), sizeof(uint64_t)));
char* data_buf = static_cast<char*>(ret_buf) + sizeof(TVMByteArray);
memcpy(data_buf, bytes.data(), bytes.size());
......
......@@ -22,7 +22,7 @@ enclave {
[in, count=num_ret] const int* type_code,
int num_ret);
void tvm_ocall_register_export([in, string] const char* name, int func_id);
void* tvm_ocall_reserve_space(size_t num_bytes);
void* tvm_ocall_reserve_space(size_t num_bytes, size_t alignment);
};
};
......@@ -5,6 +5,7 @@
*/
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/threading_backend.h>
#include <sgx_urts.h>
......@@ -156,6 +157,26 @@ TVM_REGISTER_GLOBAL("__sgx_set_last_error__")
TVMAPISetLastError(err.c_str());
});
TVM_REGISTER_GLOBAL("__sgx_println__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::ostringstream msg;
for (int i = 0; i < args.num_args; ++i) {
switch (args.type_codes[i]) {
case kDLInt: msg << static_cast<int64_t>(args[i]); break;
case kDLUInt: msg << static_cast<uint64_t>(args[i]); break;
case kDLFloat: msg << static_cast<double>(args[i]); break;
case kStr:
case kBytes: {
std::string val = args[i];
msg << val;
}
break;
}
msg << " ";
}
LOG(INFO) << msg.str();
});
extern "C" {
void tvm_ocall_register_export(const char* name, int func_id) {
......@@ -177,10 +198,20 @@ void tvm_ocall_packed_func(const char* name,
// Allocates space for return values. The returned pointer is only valid between
// successive calls to `tvm_ocall_reserve_space`.
void* tvm_ocall_reserve_space(size_t num_bytes) {
static thread_local std::vector<uint64_t> buf;
buf.reserve(num_bytes);
return buf.data();
void* tvm_ocall_reserve_space(size_t num_bytes, size_t alignment) {
static TVMContext ctx = { kDLCPU, 0 };
static thread_local void* buf = nullptr;
static thread_local size_t buf_size = 0;
static thread_local size_t buf_align = 0;
if (buf_size >= num_bytes && buf_align >= alignment) return buf;
DeviceAPI::Get(ctx)->FreeDataSpace(ctx, buf);
buf = DeviceAPI::Get(ctx)->AllocDataSpace(ctx, num_bytes, alignment, {});
buf_size = num_bytes;
buf_align = alignment;
return buf;
}
void tvm_ocall_set_return(TVMRetValueHandle ret,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment