Commit dc62760e by nhynes Committed by Tianqi Chen

Improve SGXModule (#1104)

* Improve SGXModule

* Address code review comments
parent 877254f4
...@@ -82,7 +82,7 @@ std::vector<ECallRegistry> ECallRegistry::exports_; ...@@ -82,7 +82,7 @@ std::vector<ECallRegistry> ECallRegistry::exports_;
*/ */
#define TVM_REGISTER_ENCLAVE_FUNC(OpName) \ #define TVM_REGISTER_ENCLAVE_FUNC(OpName) \
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \ TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::sgx::ECallRegistry::Register(OpName) ::tvm::runtime::sgx::ECallRegistry::Register(OpName, true)
} // namespace sgx } // namespace sgx
} // namespace runtime } // namespace runtime
......
...@@ -45,7 +45,7 @@ void tvm_ecall_packed_func(int func_id, ...@@ -45,7 +45,7 @@ void tvm_ecall_packed_func(int func_id,
void* ret_buf; void* ret_buf;
TVM_SGX_CHECKED_CALL(tvm_ocall_reserve_space( TVM_SGX_CHECKED_CALL(tvm_ocall_reserve_space(
&ret_buf, bytes.size() + sizeof(TVMByteArray))); &ret_buf, bytes.size() + sizeof(TVMByteArray), sizeof(uint64_t)));
char* data_buf = static_cast<char*>(ret_buf) + sizeof(TVMByteArray); char* data_buf = static_cast<char*>(ret_buf) + sizeof(TVMByteArray);
memcpy(data_buf, bytes.data(), bytes.size()); memcpy(data_buf, bytes.data(), bytes.size());
......
...@@ -22,7 +22,7 @@ enclave { ...@@ -22,7 +22,7 @@ enclave {
[in, count=num_ret] const int* type_code, [in, count=num_ret] const int* type_code,
int num_ret); int num_ret);
void tvm_ocall_register_export([in, string] const char* name, int func_id); void tvm_ocall_register_export([in, string] const char* name, int func_id);
void* tvm_ocall_reserve_space(size_t num_bytes); void* tvm_ocall_reserve_space(size_t num_bytes, size_t alignment);
}; };
}; };
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/threading_backend.h> #include <tvm/runtime/threading_backend.h>
#include <sgx_urts.h> #include <sgx_urts.h>
...@@ -156,6 +157,26 @@ TVM_REGISTER_GLOBAL("__sgx_set_last_error__") ...@@ -156,6 +157,26 @@ TVM_REGISTER_GLOBAL("__sgx_set_last_error__")
TVMAPISetLastError(err.c_str()); TVMAPISetLastError(err.c_str());
}); });
TVM_REGISTER_GLOBAL("__sgx_println__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::ostringstream msg;
for (int i = 0; i < args.num_args; ++i) {
switch (args.type_codes[i]) {
case kDLInt: msg << static_cast<int64_t>(args[i]); break;
case kDLUInt: msg << static_cast<uint64_t>(args[i]); break;
case kDLFloat: msg << static_cast<double>(args[i]); break;
case kStr:
case kBytes: {
std::string val = args[i];
msg << val;
}
break;
}
msg << " ";
}
LOG(INFO) << msg.str();
});
extern "C" { extern "C" {
void tvm_ocall_register_export(const char* name, int func_id) { void tvm_ocall_register_export(const char* name, int func_id) {
...@@ -177,10 +198,20 @@ void tvm_ocall_packed_func(const char* name, ...@@ -177,10 +198,20 @@ void tvm_ocall_packed_func(const char* name,
// Allocates space for return values. The returned pointer is only valid between // Allocates space for return values. The returned pointer is only valid between
// successive calls to `tvm_ocall_reserve_space`. // successive calls to `tvm_ocall_reserve_space`.
void* tvm_ocall_reserve_space(size_t num_bytes) { void* tvm_ocall_reserve_space(size_t num_bytes, size_t alignment) {
static thread_local std::vector<uint64_t> buf; static TVMContext ctx = { kDLCPU, 0 };
buf.reserve(num_bytes); static thread_local void* buf = nullptr;
return buf.data(); static thread_local size_t buf_size = 0;
static thread_local size_t buf_align = 0;
if (buf_size >= num_bytes && buf_align >= alignment) return buf;
DeviceAPI::Get(ctx)->FreeDataSpace(ctx, buf);
buf = DeviceAPI::Get(ctx)->AllocDataSpace(ctx, num_bytes, alignment, {});
buf_size = num_bytes;
buf_align = alignment;
return buf;
} }
void tvm_ocall_set_return(TVMRetValueHandle ret, void tvm_ocall_set_return(TVMRetValueHandle ret,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment