Commit 6f2ef9dc by nhynes Committed by Tianqi Chen

Add SGXModule (#1019)

parent 3121441d
......@@ -58,6 +58,7 @@ ROCM_SRC = $(wildcard src/runtime/rocm/*.cc)
OPENCL_SRC = $(wildcard src/runtime/opencl/*.cc)
OPENGL_SRC = $(wildcard src/runtime/opengl/*.cc)
VULKAN_SRC = $(wildcard src/runtime/vulkan/*.cc)
SGX_SRC = $(wildcard src/runtime/sgx/untrusted/*.cc)
RPC_SRC = $(wildcard src/runtime/rpc/*.cc)
GRAPH_SRC = $(wildcard src/runtime/graph/*.cc)
RUNTIME_SRC = $(wildcard src/runtime/*.cc)
......@@ -72,6 +73,7 @@ ROCM_OBJ = $(patsubst src/%.cc, build/%.o, $(ROCM_SRC))
OPENCL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENCL_SRC))
OPENGL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENGL_SRC))
VULKAN_OBJ = $(patsubst src/%.cc, build/%.o, $(VULKAN_SRC))
SGX_OBJ = $(patsubst src/%.cc, build/%.o, $(SGX_SRC)) build/runtime/sgx/untrusted/tvm_u.o
RPC_OBJ = $(patsubst src/%.cc, build/%.o, $(RPC_SRC))
GRAPH_OBJ = $(patsubst src/%.cc, build/%.o, $(GRAPH_SRC))
CC_OBJ = $(patsubst src/%.cc, build/%.o, $(CC_SRC)) $(LLVM_OBJ)
......@@ -172,6 +174,20 @@ else
CFLAGS += -DTVM_METAL_RUNTIME=0
endif
ifeq ($(USE_SGX), 1)
EDGER8R = $(SGX_SDK)/bin/x64/sgx_edger8r
ifneq ($(SGX_MODE), HW)
sgx_sim := _sim
endif
urts_library_name := sgx_urts$(sgx_sim)
CFLAGS += -DTVM_SGX_RUNTIME=1
SGX_CFLAGS = -include "build/runtime/sgx/untrusted/tvm_u.h" -I$(SGX_SDK)/include
LDFLAGS += -L$(SGX_SDK)/lib64 -l$(urts_library_name)
RUNTIME_DEP += $(SGX_OBJ)
else
CFLAGS += -DTVM_SGX_RUNTIME=0
endif
ifeq ($(USE_RPC), 1)
RUNTIME_DEP += $(RPC_OBJ)
endif
......@@ -254,6 +270,21 @@ build/runtime/metal/%.o: src/runtime/metal/%.mm
$(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/runtime/metal/$*.o $< >build/runtime/metal/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
build/runtime/sgx/untrusted/tvm_u.h: src/runtime/sgx/tvm.edl
@mkdir -p $(@D)
$(EDGER8R) $< --untrusted --untrusted-dir $(@D) --search-path $(SGX_SDK)/include
mv $@ $@.in
awk 'NR==4{print "#include <tvm/runtime/c_runtime_api.h>"}1' $@.in > $@
build/runtime/sgx/untrusted/tvm_u.c: build/runtime/sgx/untrusted/tvm_u.h
build/runtime/sgx/untrusted/tvm_u.o: build/runtime/sgx/untrusted/tvm_u.c
$(CC) $(CFLAGS) $(SGX_CFLAGS) -c $< -o $@
build/runtime/sgx/untrusted/%.o: src/runtime/sgx/untrusted/%.cc build/runtime/sgx/untrusted/tvm_u.h
$(CXX) $(CFLAGS) $(SGX_CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) $(SGX_CFLAGS) -c $< -o $@
build/%.o: src/%.cc
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
......
......@@ -39,6 +39,7 @@ enclave_cflags := -static -nostdinc\
-fvisibility=hidden -fpie -fstack-protector-strong\
-ffunction-sections -fdata-sections\
-DDMLC_CXX11_THREAD_LOCAL=0\
-include "lib/tvm_t.h"\
$(enclave_include_paths)\
enclave_cxxflags := -nostdinc++ $(enclave_cflags) -DTVM_SGX_MAX_CONCURRENCY=4
......@@ -53,59 +54,31 @@ enclave_ldflags :=\
-Wl,-pie,-eenclave_entry -Wl,--export-dynamic\
-Wl,--defsym,__ImageBase=0 -Wl,--gc-sections
app_cflags := -I$(SGX_SDK)/include -Ilib
app_ldflags := -L$(SGX_SDK)/lib64\
-l$(urts_library_name) -l$(uservice_library_name) -lpthread\
.PHONY: clean all
all: lib/test_addone.signed.so bin/test_addone
# Build rule for all-in-one TVM package library
lib/tvm_runtime_pack.o: tvm_runtime_pack.cc lib/test_addone_t.o
@mkdir -p $(@D)
$(CXX) -c $< -o $@ $(pkg_cflags) $(pkg_ldflags) $(enclave_cxxflags) -g
all: lib/test_addone.signed.so
# The code library built by TVM
lib/test_addone_sys.o: prepare_test_libs.py
python prepare_test_libs.py
# EDL files
lib/test_addone_u.c: $(sgx_edger8r) test_addone.edl
$(sgx_edger8r) --untrusted test_addone.edl --untrusted-dir lib --search-path $(SGX_SDK)/include
lib/test_addone_u.o: lib/test_addone_u.c
$(CC) $(enclave_cflags) -c $< -o $@
lib/test_addone_t.c: test_addone.edl
lib/tvm_t.h: ../../src/runtime/sgx/tvm.edl
$(sgx_edger8r) --trusted $< --trusted-dir lib --search-path $(SGX_SDK)/include
mv $@ $@.in
awk 'NR==4{print "#include <tvm/runtime/c_runtime_api.h>"}1' $@.in > $@
lib/test_addone_t.o: lib/test_addone_t.c
$(CC) $(enclave_cflags) -c $< -o $@
lib/tvm_t.c: lib/tvm_t.h
lib/tvm_t.o: lib/tvm_t.c
$(CC) $(enclave_cflags) $(pkg_cflags) -c $< -o $@ -include $(TVM_ROOT)/include/tvm/runtime/c_runtime_api.h
# The enclave library
lib/test_addone.so: enclave.cc lib/tvm_runtime_pack.o lib/test_addone_t.o lib/test_addone_sys.o
lib/test_addone.so: $(TVM_ROOT)/src/runtime/sgx/trusted/runtime.cc lib/tvm_t.o lib/test_addone_sys.o
$(CXX) $^ -o $@ $(pkg_cflags) $(pkg_ldflags) $(enclave_cxxflags) $(enclave_ldflags) -g
# The signed enclave
lib/test_addone.signed.so: lib/test_addone.so enclave_config.xml
$(sgx_enclave_signer) sign -key enclave_private.pem -enclave $< -out $@ -config enclave_config.xml
# An app that runs the enclave
bin/test_addone: app.cc lib/test_addone_u.o
@mkdir -p $(@D)
$(CXX) $^ -o $@ $(app_cflags) $(app_ldflags) $(pkg_cflags) -g
# Debugging runtime pack built without SGX (c.f. howto_deploy/tvm_runtime_pack.cc)
lib/tvm_runtime_pack_nosgx.o: tvm_runtime_pack.cc
@mkdir -p $(@D)
$(CXX) -c $< -o $@ $(pkg_cflags) $(pkg_ldflags) -g
# Debugging binary that runs TVM without SGX
bin/addone_nosgx: enclave.cc lib/tvm_runtime_pack_nosgx.o lib/test_addone_sys.o
@mkdir -p $(@D)
$(CXX) $^ -o $@ $(pkg_cflags) $(pkg_ldflags) -g -lpthread
clean:
rm -rf lib bin
rm -rf lib
......@@ -5,7 +5,7 @@ This application demonstrates the use of a simple TVM model in the [Intel SGX](h
## Prerequisites
1. A GNU/Linux environment
2. TVM compiled with LLVM and the `tvm` Python module
2. TVM compiled with LLVM and SGX; and the `tvm` Python module
3. The [Linux SGX SDK](https://github.com/intel/linux-sgx) [link to pre-built libraries](https://01.org/intel-software-guard-extensions/downloads)
## Running the example
......
#include <cstdio>
#include <iostream>
#include "sgx_urts.h"
#include "sgx_eid.h"
#include "test_addone_u.h"
#include "../../sgx/runtime_u.cc"
#define TOKEN_FILENAME "bin/test_addone.token"
#define ENCLAVE_FILENAME "lib/test_addone.signed.so"
sgx_enclave_id_t tvm_sgx_eid;
typedef struct _sgx_errlist_t {
sgx_status_t err;
const char *msg;
} sgx_errlist_t;
/* Error code returned by sgx_create_enclave */
static sgx_errlist_t sgx_errlist[] = {
{ SGX_ERROR_DEVICE_BUSY, "SGX device was busy." },
{ SGX_ERROR_ENCLAVE_FILE_ACCESS, "Can't open enclave file." },
{ SGX_ERROR_ENCLAVE_LOST, "Power transition occurred." },
{ SGX_ERROR_INVALID_ATTRIBUTE, "Enclave was not authorized." },
{ SGX_ERROR_INVALID_ENCLAVE, "Invalid enclave image." },
{ SGX_ERROR_INVALID_ENCLAVE_ID, "Invalid enclave identification." },
{ SGX_ERROR_INVALID_METADATA, "Invalid enclave metadata." },
{ SGX_ERROR_INVALID_PARAMETER, "Invalid parameter." },
{ SGX_ERROR_INVALID_SIGNATURE, "Invalid enclave signature." },
{ SGX_ERROR_INVALID_VERSION, "Enclave version was invalid." },
{ SGX_ERROR_MEMORY_MAP_CONFLICT, "Memory map conflicted." },
{ SGX_ERROR_NO_DEVICE, "Invalid SGX device." },
{ SGX_ERROR_OUT_OF_EPC, "Out of EPC memory." },
{ SGX_ERROR_OUT_OF_MEMORY, "Out of memory." },
{ SGX_ERROR_UNEXPECTED, "Unexpected error occurred." },
};
/* Check error conditions for loading enclave */
void print_error_message(sgx_status_t status)
{
size_t idx = 0;
size_t ttl = sizeof sgx_errlist/sizeof sgx_errlist[0];
for (idx = 0; idx < ttl; idx++) {
if(status == sgx_errlist[idx].err) {
printf("Error: %s\n", sgx_errlist[idx].msg);
break;
}
}
if (idx == ttl)
printf("Error code is 0x%X. Please refer to the \"Intel SGX SDK Developer Reference\" for more details.\n", status);
}
/* Initialize the enclave:
* Step 1: try to retrieve the launch token saved by last transaction
* Step 2: call sgx_create_enclave to initialize an enclave instance
* Step 3: save the launch token if it is updated
*/
int initialize_enclave(void)
{
sgx_launch_token_t token = {0};
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
int updated = 0;
/* Step 1: try to retrieve the launch token saved by last transaction
* if there is no token, then create a new one.
*/
FILE *fp = fopen(TOKEN_FILENAME, "rb");
if (fp == NULL && (fp = fopen(TOKEN_FILENAME, "wb")) == NULL) {
printf("Warning: Failed to create/open the launch token file \"%s\".\n", TOKEN_FILENAME);
return -1;
}
/* read the token from saved file */
size_t read_num = fread(token, 1, sizeof(sgx_launch_token_t), fp);
if (read_num != 0 && read_num != sizeof(sgx_launch_token_t)) {
/* if token is invalid, clear the buffer */
memset(&token, 0x0, sizeof(sgx_launch_token_t));
printf("Warning: Invalid launch token read from \"%s\".\n", TOKEN_FILENAME);
}
/* Step 2: call sgx_create_enclave to initialize an enclave instance */
/* Debug Support: set 2nd parameter to 1 */
sgx_status = sgx_create_enclave(ENCLAVE_FILENAME, SGX_DEBUG_FLAG, &token, &updated, &tvm_sgx_eid, NULL);
if (sgx_status != SGX_SUCCESS) {
print_error_message(sgx_status);
if (fp != NULL) fclose(fp);
return -1;
}
/* Step 3: save the launch token if it is updated */
if (updated == 0 || fp == NULL) {
/* if the token is not updated, or file handler is invalid, do not perform saving */
if (fp != NULL) fclose(fp);
return 0;
}
/* reopen the file with write capablity */
fp = freopen(TOKEN_FILENAME, "wb", fp);
if (fp == NULL) return 0;
size_t write_num = fwrite(token, 1, sizeof(sgx_launch_token_t), fp);
if (write_num != sizeof(sgx_launch_token_t))
printf("Warning: Failed to save launch token to \"%s\".\n", TOKEN_FILENAME);
fclose(fp);
return 0;
}
int SGX_CDECL main(int argc, char *argv[]) {
if(initialize_enclave() < 0) {
printf("Failed to initialize enclave.\n");
return -1;
}
/* Run TVM within the enclave */
int addone_status;
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
sgx_status = tvm_ecall_run_module(tvm_sgx_eid, nullptr, &addone_status);
if (sgx_status != SGX_SUCCESS) {
print_error_message(sgx_status);
}
sgx_destroy_enclave(tvm_sgx_eid);
if (addone_status == 1) {
printf("It works!");
return 0;
}
printf("It doesn't work.");
return -1;
}
extern "C" {
void ocall_println(const char* str) {
std::cout << "Enclave says: " << str << std::endl;
}
}
#include <dlpack/dlpack.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#ifndef _LIBCPP_SGX_CONFIG
#include <iostream>
#endif
/* This function mirrors the one in howto_deploy except without the iostream */
int Verify(tvm::runtime::Module mod, std::string fname) {
// Get the function from the module.
tvm::runtime::PackedFunc f = mod.GetFunction(fname);
// Allocate the DLPack data structures.
DLTensor* x;
DLTensor* y;
int ndim = 1;
int dtype_code = kDLFloat;
int dtype_bits = 32;
int dtype_lanes = 1;
int device_type = kDLCPU;
int device_id = 0;
int64_t shape[1] = {10};
TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes,
device_type, device_id, &x);
TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes,
device_type, device_id, &y);
for (int i = 0; i < shape[0]; ++i) {
static_cast<float*>(x->data)[i] = i;
}
// Invoke the function
f(x, y);
// check the output
bool all_eq = true;
for (int i = 0; i < shape[0]; ++i) {
all_eq = all_eq && static_cast<float*>(y->data)[i] == i + 1.0f;
}
return all_eq;
}
extern "C" {
void tvm_ecall_run_module(const void* tvm_args, void* tvm_return_value) {
tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("module._GetSystemLib"))();
*(int*)tvm_return_value = Verify(mod_syslib, "addonesys");
}
}
#ifndef _LIBCPP_SGX_CONFIG
int main(void) {
tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("module._GetSystemLib"))();
if (Verify(mod_syslib, "addonesys")) {
std::cout << "It works!" << std::endl;
return 0;
}
std::cerr << "It doesn't work." << std::endl;
return -1;
}
#endif
<EnclaveConfiguration>
<ProdID>0</ProdID>
<ISVSVN>0</ISVSVN>
<StackMaxSize>0x100000</StackMaxSize>
<HeapMaxSize>0x100000</HeapMaxSize>
<StackMaxSize>0x2000</StackMaxSize>
<HeapMaxSize>0x2000</HeapMaxSize>
<TCSNum>5</TCSNum>
<TCSPolicy>1</TCSPolicy>
<DisableDebug>0</DisableDebug>
......
......@@ -15,7 +15,7 @@ def prepare_test_libs(base_path):
print(tvm.lower(s, [A, B], simple_mode=True))
# Compile library in system library mode
fadd_syslib = tvm.build(s, [A, B], 'llvm --system-lib', name='addonesys')
fadd_syslib = tvm.build(s, [A, B], 'llvm --system-lib')
syslib_path = osp.join(base_path, 'test_addone_sys.o')
fadd_syslib.save(syslib_path)
......
#!/bin/bash
sgx_sdk=${SGX_SDK:=/opt/sgxsdk}
mkdir -p bin lib
make
echo "========================="
LD_LIBRARY_PATH="$sgx_sdk/lib64":${LD_LIBRARY_PATH} bin/test_addone
LD_LIBRARY_PATH="$sgx_sdk/lib64":${LD_LIBRARY_PATH} TVM_CACHE_DIR=/tmp python test_addone.py
enclave {
from "../../sgx/tvm.edl" import *;
untrusted {
void ocall_println([in, string] const char *str);
};
};
import tvm
import numpy as np
ctx = tvm.context('cpu', 0)
fadd1 = tvm.module.load('lib/test_addone.signed.so')
n = 10
x = tvm.nd.array(np.random.uniform(size=n).astype('float32'), ctx)
y = tvm.nd.array(np.zeros(n, dtype='float32'), ctx)
fadd1(x, y)
np.testing.assert_allclose(y.asnumpy(), x.asnumpy() + 1)
print("It works!")
/*!
* \brief This is an all in one TVM runtime file for use in an SGX enclave.
*
* The files included here will be statically linked into the enclave.
* Please refer to the Makefile (rule lib/tvm_runtime_pack.o) for how to build.
*
*/
#ifdef _LIBCPP_SGX_CONFIG
#include "lib/test_addone_t.h"
#endif
#include "../../sgx/runtime_t.cc"
#ifndef _LIBCPP_SGX_CONFIG
#include "../../src/runtime/file_util.cc"
#endif
......@@ -76,7 +76,7 @@ class Registry {
// Internal class.
struct Manager;
private:
protected:
/*! \brief name of the function */
std::string name_;
/*! \brief internal packed function */
......
......@@ -44,6 +44,10 @@ USE_OPENCL = 0
# whether enable Metal during compile
USE_METAL = 0
# whether enable SGX during compile
USE_SGX = 0
SGX_SDK = /opt/sgxsdk
# Whether enable RPC during compile
USE_RPC = 1
......
/*!
* Copyright (c) 2018 by Contributors
* \file sgx_runtime.cc
*/
#include "../../src/runtime/c_runtime_api.cc"
#include "../../src/runtime/cpu_device_api.cc"
#include "../../src/runtime/workspace_pool.cc"
#include "../../src/runtime/module_util.cc"
#include "../../src/runtime/module.cc"
#include "../../src/runtime/registry.cc"
#include "../../src/runtime/system_lib_module.cc"
#ifndef _LIBCPP_SGX_CONFIG
#include "../../src/runtime/threading_backend.cc"
#else
#include "threading_backend.cc"
#endif
#include "../../src/runtime/thread_pool.cc"
#include <tvm/runtime/threading_backend.h>
#include "../../src/runtime/threading_backend.cc"
#include <iostream>
extern sgx_enclave_id_t tvm_sgx_eid;
extern "C" {
sgx_status_t tvm_ecall_run_worker(sgx_enclave_id_t eid, const void* cb);
}
namespace tvm {
namespace runtime {
namespace sgx {
static std::unique_ptr<tvm::runtime::threading::ThreadGroup> sgx_thread_group;
extern "C" {
void tvm_ocall_thread_group_launch(int num_tasks, void* cb) {
std::function<void(int)> runner = [cb](int _worker_id) {
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
sgx_status = tvm_ecall_run_worker(tvm_sgx_eid, cb);
CHECK(sgx_status == SGX_SUCCESS) << "SGX Error: " << sgx_status;
};
sgx_thread_group.reset(new tvm::runtime::threading::ThreadGroup(
num_tasks, runner, false /* include_main_thread */));
}
void tvm_ocall_thread_group_join() {
sgx_thread_group->Join();
}
}
} // namespace sgx
} // namespace runtime
} // namespace tvm
enclave {
from "sgx_tstdc.edl" import *;
trusted {
public void tvm_ecall_run_module([user_check] const void* tvm_args,
[user_check] void* tvm_ret_value);
public void tvm_ecall_run_worker([user_check] const void* cb);
};
untrusted {
void tvm_ocall_thread_group_launch(int num_workers, [user_check] void* cb);
void tvm_ocall_thread_group_join();
};
};
......@@ -51,13 +51,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
BackendPackedCFunc faddr =
reinterpret_cast<BackendPackedCFunc>(GetFunctionAddr(fname));
if (faddr == nullptr) return PackedFunc();
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
(void*)args.values, // NOLINT(*)
(int*)args.type_codes, // NOLINT(*)
args.num_args);
CHECK_EQ(ret, 0) << TVMGetLastError();
});
return WrapPackedFunc(faddr, sptr_to_self);
}
void SaveToFile(const std::string& file_name,
......
......@@ -10,6 +10,9 @@
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#ifdef _LIBCPP_SGX_CONFIG
#include "sgx/trusted/runtime.h"
#endif
#include <array>
#include <algorithm>
#include <string>
......@@ -186,7 +189,11 @@ const char *TVMGetLastError() {
}
void TVMAPISetLastError(const char* msg) {
#ifndef _LIBCPP_SGX_CONFIG
TVMAPIRuntimeStore::Get()->last_error = msg;
#else
sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}
int TVMModLoadFromFile(const char* file_name,
......
......@@ -53,6 +53,7 @@ std::string GetFileFormat(const std::string& file_name,
const std::string& format) {
std::string fmt = format;
if (fmt.length() == 0) {
if (file_name.find(".signed.so") != std::string::npos) return "sgx";
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(pos + 1, file_name.length() - pos - 1);
......@@ -64,6 +65,24 @@ std::string GetFileFormat(const std::string& file_name,
}
}
std::string GetCacheDir() {
char* env_cache_dir;
if ((env_cache_dir = getenv("TVM_CACHE_DIR"))) return env_cache_dir;
if ((env_cache_dir = getenv("XDG_CACHE_HOME"))) {
return std::string(env_cache_dir) + "/tvm";
}
if ((env_cache_dir = getenv("HOME"))) {
return std::string(env_cache_dir) + "/.cache/tvm";
}
return ".";
}
std::string GetFileBasename(const std::string& file_name) {
size_t last_slash = file_name.find_last_of("/");
if (last_slash == std::string::npos) return file_name;
return file_name.substr(last_slash + 1);
}
std::string GetMetaFilePath(const std::string& file_name) {
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
......
......@@ -20,12 +20,25 @@ std::string GetFileFormat(const std::string& file_name,
const std::string& format);
/*!
* \return the directory in which TVM stores cached files.
* May be set using TVM_CACHE_DIR; defaults to system locations.
*/
std::string GetCacheDir();
/*!
* \brief Get meta file path given file name and format.
* \param file_name The name of the file.
*/
std::string GetMetaFilePath(const std::string& file_name);
/*!
* \brief Get file basename (i.e. without leading directories)
* \param file_name The name of the file.
* \return the base name
*/
std::string GetFileBasename(const std::string& file_name);
/*!
* \brief Load binary file into a in-memory buffer.
* \param file_name The name of the file.
* \param data The data to be loaded.
......
/*!
* Copyright (c) 2018 by Contributors
* \file common.h
* \brief TVM SGX common API.
*/
#ifndef TVM_RUNTIME_SGX_COMMON_H_
#define TVM_RUNTIME_SGX_COMMON_H_
namespace tvm {
namespace runtime {
namespace sgx {
#define TVM_SGX_CHECKED_CALL(Function) \
sgx_status_t TVM_STR_CONCAT(__sgx_status_, __LINE__) = SGX_ERROR_UNEXPECTED; \
TVM_STR_CONCAT(__sgx_status_, __LINE__) = Function; \
CHECK_EQ(TVM_STR_CONCAT(__sgx_status_, __LINE__), SGX_SUCCESS) \
<< "SGX Error: " << TVM_STR_CONCAT(__sgx_status_, __LINE__);
} // namespace sgx
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_SGX_COMMON_H_
/*!
* Copyright (c) 2018 by Contributors
* \file ecall_registry.h
* \brief The global registry of packed functions available via ecall_packed_func.
*/
#ifndef TVM_RUNTIME_SGX_TRUSTED_ECALL_REGISTRY_H_
#define TVM_RUNTIME_SGX_TRUSTED_ECALL_REGISTRY_H_
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <string>
#include <algorithm>
#include <vector>
namespace tvm {
namespace runtime {
namespace sgx {
class ECallRegistry: public Registry {
public:
explicit ECallRegistry(std::string name) {
name_ = name;
}
Registry& set_body(PackedFunc f) {
func_ = f;
return *this;
}
Registry& set_body(PackedFunc::FType f) { // NOLINT(*)
return set_body(PackedFunc(f));
}
static Registry& Register(const std::string& name, bool override = false) {
for (auto& r : exports_) {
if (r.name_ == name) {
CHECK(override) << "ecall " << name << " is already registered";
return r;
}
}
TVM_SGX_CHECKED_CALL(
tvm_ocall_register_export(name.c_str(), exports_.size()));
exports_.emplace_back(name);
return exports_.back();
}
static bool Remove(const std::string& name) {
LOG(FATAL) << "Removing enclave exports is not supported.";
}
static const PackedFunc* Get(const std::string& name) {
for (const auto& r : exports_) {
if (r.name_ == name) return &r.func_;
}
return nullptr;
}
static const PackedFunc* Get(unsigned func_id) {
return func_id >= exports_.size() ? nullptr : &exports_[func_id].func_;
}
static std::vector<std::string> ListNames() {
std::vector<std::string> names;
names.resize(exports_.size());
std::transform(exports_.begin(), exports_.end(), names.begin(),
[](ECallRegistry r) { return r.name_; });
return names;
}
static std::vector<ECallRegistry> exports_;
};
std::vector<ECallRegistry> ECallRegistry::exports_;
/*!
* \brief Register a function callable via ecall_packed_func
* \code
* TVM_REGISTER_ENCLAVE_FUNC("DoThing")
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* });
* \endcode
*/
#define TVM_REGISTER_ENCLAVE_FUNC(OpName) \
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::sgx::ECallRegistry::Register(OpName)
} // namespace sgx
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_SGX_TRUSTED_ECALL_REGISTRY_H_
/*!
* Copyright (c) 2018 by Contributors
* \file runtime_t.cc
*/
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include "../../c_runtime_api.cc"
#include "../../cpu_device_api.cc"
#include "../../module.cc"
#include "../../module_util.cc"
#include "../../registry.cc"
#include "../../system_lib_module.cc"
#include "../../thread_pool.cc"
#include "../../workspace_pool.cc"
#include "./ecall_registry.h"
#include "./runtime.h"
#include "./threading_backend.cc"
namespace tvm {
namespace runtime {
namespace sgx {
extern "C" {
void tvm_ecall_init(TVMRetValueHandle ret) {}
void tvm_ecall_packed_func(int func_id,
const TVMValue* arg_values,
const int* type_codes,
int num_args,
TVMRetValueHandle ret) {
const PackedFunc* f = ECallRegistry::Get(func_id);
CHECK(f != nullptr) << "ecall function not found.";
TVMRetValue rv;
f->CallPacked(TVMArgs(arg_values, type_codes, num_args), &rv);
int ret_type_code = rv.type_code();
if (ret_type_code == kNull) return;
TVMValue ret_value;
if (ret_type_code == kBytes || ret_type_code == kStr) {
// allocate a buffer in untrusted, copy the values in
std::string bytes = rv;
void* ret_buf;
TVM_SGX_CHECKED_CALL(tvm_ocall_reserve_space(
&ret_buf, bytes.size() + sizeof(TVMByteArray)));
char* data_buf = static_cast<char*>(ret_buf) + sizeof(TVMByteArray);
memcpy(data_buf, bytes.data(), bytes.size());
TVMByteArray* arr = static_cast<TVMByteArray*>(ret_buf);
arr->data = data_buf;
arr->size = bytes.size();
ret_value = TVMValue{.v_handle = arr};
ret_type_code = kBytes;
} else {
rv.MoveToCHost(&ret_value, &ret_type_code);
}
TVM_SGX_CHECKED_CALL(tvm_ocall_set_return(ret, &ret_value, &ret_type_code, 1));
}
} // extern "C"
TVM_REGISTER_ENCLAVE_FUNC("__tvm_main__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module mod = (*Registry::Get("module._GetSystemLib"))();
mod.GetFunction("default_function").CallPacked(args, rv);
});
} // namespace sgx
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file trusted/runtime.h
* \brief TVM SGX trusted API.
*/
#ifndef TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_
#define TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_
#include <sgx_edger8r.h>
#include <tvm/runtime/packed_func.h>
#include <string>
#include "../common.h"
namespace tvm {
namespace runtime {
namespace sgx {
template<typename... Args>
inline TVMRetValue OCallPackedFunc(std::string name, Args&& ...args) {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize];
int type_codes[kArraySize];
detail::for_each(TVMArgsSetter(values, type_codes),
std::forward<Args>(args)...);
TVMValue ret_val;
int ret_type_code;
TVM_SGX_CHECKED_CALL(tvm_ocall_packed_func(name.c_str(),
values,
type_codes,
kNumArgs,
&ret_val,
&ret_type_code));
TVMRetValue* rv = new TVMRetValue();
*rv = TVMArgValue(ret_val, ret_type_code);
return *rv;
}
} // namespace sgx
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_
......@@ -8,11 +8,7 @@
#include <sgx_edger8r.h>
#include <sgx_trts.h>
#include <atomic>
extern "C" {
sgx_status_t SGX_CDECL tvm_ocall_thread_group_launch(int num_workers, void* cb);
sgx_status_t SGX_CDECL tvm_ocall_thread_group_join();
}
#include "runtime.h"
#ifndef TVM_SGX_MAX_CONCURRENCY
#define TVM_SGX_MAX_CONCURRENCY 1
......@@ -31,13 +27,12 @@ class ThreadGroup::Impl {
next_task_id_(exclude_worker0) {
CHECK(num_workers <= TVM_SGX_MAX_CONCURRENCY)
<< "Tried spawning more threads than allowed by TVM_SGX_MAX_CONCURRENCY.";
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED;
sgx_status = tvm_ocall_thread_group_launch(num_workers, this);
CHECK(sgx_status == SGX_SUCCESS) << "SGX Error: " << sgx_status;
sgx::OCallPackedFunc("__sgx_thread_group_launch__",
num_workers_, reinterpret_cast<void*>(this));
}
~Impl() {
tvm_ocall_thread_group_join();
sgx::OCallPackedFunc("__sgx_thread_group_join__");
}
void RunTask() {
......@@ -64,12 +59,12 @@ void Yield() {}
int MaxConcurrency() { return TVM_SGX_MAX_CONCURRENCY; }
extern "C" {
void tvm_ecall_run_worker(const void* impl) {
if (!sgx_is_within_enclave(impl, sizeof(ThreadGroup::Impl))) return;
((ThreadGroup::Impl*)impl)->RunTask();
}
}
TVM_REGISTER_ENCLAVE_FUNC("__tvm_run_worker__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
void* tg = args[0];
if (!sgx_is_within_enclave(tg, sizeof(ThreadGroup::Impl))) return;
reinterpret_cast<ThreadGroup::Impl*>(tg)->RunTask();
});
} // namespace threading
} // namespace runtime
......
enclave {
from "sgx_tstdc.edl" import *;
trusted {
public void tvm_ecall_init([isptr, user_check] TVMRetValueHandle ret);
public void tvm_ecall_packed_func(int func_id,
[in, count=num_args] const TVMValue* arg_values,
[in, count=num_args] const int* type_codes,
int num_args,
[isptr, user_check] TVMRetValueHandle ret);
};
untrusted {
void tvm_ocall_packed_func([in, string] const char* name,
[in, count=num_args] const TVMValue* arg_values,
[in, count=num_args] const int* type_codes,
int num_args,
[out] TVMValue* ret_val,
[out] int* ret_type_code);
void tvm_ocall_set_return([isptr, user_check] TVMRetValueHandle ret,
[in, count=num_ret] const TVMValue* value,
[in, count=num_ret] const int* type_code,
int num_ret);
void tvm_ocall_register_export([in, string] const char* name, int func_id);
void* tvm_ocall_reserve_space(size_t num_bytes);
};
};
/*!
* Copyright (c) 2018 by Contributors
* \file sgx_module.cc
* \brief SGX enclave module.
*/
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/threading_backend.h>
#include <sgx_urts.h>
#include <algorithm>
#include <fstream>
#include <iostream>
#include <iterator>
#include <sstream>
#include <string>
#include <unordered_map>
#include "../common.h"
#include "../../file_util.h"
namespace tvm {
namespace runtime {
class SGXModuleNode;
namespace sgx {
class EnclaveContext {
public:
explicit EnclaveContext(SGXModuleNode* mod) {
CHECK(Context()->mod_ == nullptr)
<< "Tried overriding existing enclave context.";
CHECK(mod != nullptr) << "Tried setting null enclave context.";
Context()->mod_ = mod;
}
~EnclaveContext() {
Context()->mod_ = nullptr;
}
static SGXModuleNode* GetModule() {
SGXModuleNode* ctx = Context()->mod_;
CHECK(ctx != nullptr) << "No current enclave context";
return ctx;
}
private:
EnclaveContext() {}
SGXModuleNode* mod_;
static EnclaveContext* Context() {
static thread_local EnclaveContext inst;
return &inst;
}
};
} // namespace sgx
class SGXModuleNode : public ModuleNode {
public:
~SGXModuleNode() {
if (eid_) {
sgx::EnclaveContext ctx(this);
sgx_destroy_enclave(eid_);
}
}
void Init(const std::string& enclave_file) {
std::string token_file = GetCacheDir() + "/" +
GetFileBasename(enclave_file) + ".token";
sgx_launch_token_t token = {0};
int token_updated = 0;
try {
std::ifstream ifs(token_file, std::fstream::in | std::fstream::binary);
ifs.exceptions(std::ifstream::failbit | std::ifstream::badbit);
ifs >> token;
} catch (std::ifstream::failure e) {
memset(&token, 0x0, sizeof(sgx_launch_token_t));
}
TVM_SGX_CHECKED_CALL(sgx_create_enclave(
enclave_file.c_str(), SGX_DEBUG_FLAG, &token, &token_updated, &eid_, NULL));
sgx::EnclaveContext ctx(this);
TVMRetValue rv;
TVM_SGX_CHECKED_CALL(tvm_ecall_init(eid_, &rv));
if (!token_updated) return;
try {
std::ofstream ofs(token_file, std::fstream::trunc | std::fstream::binary);
ofs.exceptions(std::ifstream::failbit | std::ifstream::badbit);
ofs << token;
} catch (std::ifstream::failure e) {
LOG(INFO) << "Could not save SGX launch token to " << token_file;
}
}
const char* type_key() const final {
return "sgx";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
auto exported = exports_.find(name);
if (exported == exports_.end()) return PackedFunc();
int func_id = exported->second;
return PackedFunc([this, func_id](TVMArgs args, TVMRetValue* rv) {
sgx::EnclaveContext ctx(this);
TVM_SGX_CHECKED_CALL(tvm_ecall_packed_func(eid_, func_id,
args.values, args.type_codes, args.num_args, rv));
});
}
void RunWorkers(int num_tasks, void* tg) {
std::function<void(int)> runner = [this, tg](int _worker_id) {
this->GetFunction("__tvm_run_worker__",
std::shared_ptr<SGXModuleNode>(nullptr))(tg);
};
thread_group_.reset(new tvm::runtime::threading::ThreadGroup(
num_tasks, runner, false /* include_main_thread */));
}
void JoinThreads() {
thread_group_->Join();
}
void RegisterExport(std::string name, int func_id) {
exports_[name] = func_id;
}
private:
// ID of the loaded enclave
sgx_enclave_id_t eid_;
// Names and IDs of functions exported by the enclave module
std::unordered_map<std::string, int> exports_;
std::unique_ptr<tvm::runtime::threading::ThreadGroup> thread_group_;
};
namespace sgx {
TVM_REGISTER_GLOBAL("__sgx_thread_group_launch__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
EnclaveContext::GetModule()->RunWorkers(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("__sgx_thread_group_join__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
EnclaveContext::GetModule()->JoinThreads();
});
TVM_REGISTER_GLOBAL("__sgx_set_last_error__")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::string err = args[0];
TVMAPISetLastError(err.c_str());
});
extern "C" {
void tvm_ocall_register_export(const char* name, int func_id) {
EnclaveContext::GetModule()->RegisterExport(name, func_id);
}
void tvm_ocall_packed_func(const char* name,
const TVMValue* arg_values,
const int* type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code) {
const PackedFunc* f = Registry::Get(name);
CHECK(f != nullptr) << "ocall to nonexistent function \"" << name << "\"";
TVMRetValue rv;
f->CallPacked(TVMArgs(arg_values, type_codes, num_args), &rv);
rv.MoveToCHost(ret_val, ret_type_code);
}
// Allocates space for return values. The returned pointer is only valid between
// successive calls to `tvm_ocall_reserve_space`.
void* tvm_ocall_reserve_space(size_t num_bytes) {
static thread_local std::vector<uint64_t> buf;
buf.reserve(num_bytes);
return buf.data();
}
void tvm_ocall_set_return(TVMRetValueHandle ret,
const TVMValue* value,
const int* type_code,
int num_ret) {
CHECK_EQ(num_ret, 1) << "Only one return value is currently supported.";
CHECK(type_code[0] != kStr) << "Return kBytes, not kStr.";
TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
*rv = TVMArgValue(value[0], type_code[0]);
}
} // extern "C"
} // namespace sgx
TVM_REGISTER_GLOBAL("module.loadfile_sgx")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<SGXModuleNode> node = std::make_shared<SGXModuleNode>();
node->Init(args[0]);
*rv = runtime::Module(node);
});
} // namespace runtime
} // namespace tvm
......@@ -55,16 +55,12 @@ class SystemLibModuleNode : public ModuleNode {
module_blob_ = ptr;
} else {
auto it = tbl_.find(name);
if (it != tbl_.end()) {
if (ptr != it->second) {
LOG(WARNING) << "SystemLib symbol " << name
<< " get overriden to a different address "
<< ptr << "->" << it->second;
tbl_[name] = ptr;
}
} else {
tbl_[name] = ptr;
if (it != tbl_.end() && ptr != it->second) {
LOG(WARNING) << "SystemLib symbol " << name
<< " get overriden to a different address "
<< ptr << "->" << it->second;
}
tbl_[name] = ptr;
}
}
......
......@@ -315,7 +315,11 @@ class ThreadPool {
}
int num_workers_;
// if excluding worker 0 and using master to run task 0
#ifndef _LIBCPP_SGX_CONFIG
bool exclude_worker0_{true};
#else
bool exclude_worker0_{false};
#endif
std::vector<std::unique_ptr<SpscTaskQueue> > queues_;
std::unique_ptr<tvm::runtime::threading::ThreadGroup> threads_;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment