Commit d0f40112 by nhynes Committed by Tianqi Chen

Add SGX random engine (#1113)

parent f2244b09
RANDOM_CONTRIB_SRC = $(wildcard src/contrib/random/*.cc) RANDOM_CONTRIB_SRC = $(wildcard src/contrib/random/random.cc)
RANDOM_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(RANDOM_CONTRIB_SRC)) RANDOM_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(RANDOM_CONTRIB_SRC))
ifeq ($(USE_RANDOM), 1) ifeq ($(USE_RANDOM), 1)
......
/*!
* Copyright (c) 2018 by Contributors
* \file random/mt_random_engine.cc
* \brief mt19937 random engine
*/
#include <dmlc/logging.h>
#include <algorithm>
#include <ctime>
#include <random>
namespace tvm {
namespace contrib {
/*!
* \brief An interface for generating [tensors of] random numbers.
*/
class RandomEngine {
public:
/*!
* \brief Creates a RandomEngine using a default seed.
*/
RandomEngine() {
this->Seed(time(0));
}
/*!
* \brief Creates a RandomEngine, suggesting the use of a provided seed.
*/
explicit RandomEngine(unsigned seed) {
this->Seed(seed);
}
/*!
* \brief Seeds the underlying RNG, if possible.
*/
inline void Seed(unsigned seed) {
rnd_engine_.seed(seed);
this->rseed_ = static_cast<unsigned>(seed);
}
/*!
* \return the seed associated with the underlying RNG.
*/
inline unsigned GetSeed() const {
return rseed_;
}
/*!
* \return a random integer sampled from the RNG.
*/
inline unsigned GetRandInt() {
return rnd_engine_();
}
/*!
* \brief Fills a tensor with values drawn from Unif(low, high)
*/
void SampleUniform(DLTensor* data, float low, float high) {
CHECK_GT(high, low) << "high must be bigger than low";
CHECK(data->strides == nullptr);
DLDataType dtype = data->dtype;
int64_t size = 1;
for (int i = 0; i < data->ndim; ++i) {
size *= data->shape[i];
}
CHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1);
if (data->ctx.device_type == kDLCPU) {
std::uniform_real_distribution<float> uniform_dist(low, high);
std::generate_n(static_cast<float*>(data->data), size, [&] () {
return uniform_dist(rnd_engine_);
});
} else {
LOG(FATAL) << "Do not support random.randint on this device yet";
}
}
private:
std::mt19937 rnd_engine_;
unsigned rseed_;
};
} // namespace contrib
} // namespace tvm
...@@ -7,8 +7,11 @@ ...@@ -7,8 +7,11 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <algorithm> #include <algorithm>
#include <random> #ifndef _LIBCPP_SGX_CONFIG
#include <ctime> #include "./mt_random_engine.cc"
#else
#include "./sgx_random_engine.cc"
#endif
#define DLPACK_INTEGER_TYPE_SWITCH(type, DType, ...) \ #define DLPACK_INTEGER_TYPE_SWITCH(type, DType, ...) \
if (type.code == kDLInt && type.bits == 32) { \ if (type.code == kDLInt && type.bits == 32) { \
...@@ -38,57 +41,6 @@ namespace contrib { ...@@ -38,57 +41,6 @@ namespace contrib {
using namespace runtime; using namespace runtime;
class RandomEngine {
public:
RandomEngine() {
this->Seed(time(0));
}
explicit RandomEngine(int seed) {
this->Seed(seed);
}
~RandomEngine() {}
inline void Seed(int seed) {
rnd_engine_.seed(seed);
this->rseed_ = static_cast<unsigned>(seed);
}
inline unsigned GetSeed() const {
return rseed_;
}
inline unsigned GetRandInt() {
return rnd_engine_();
}
void SampleUniform(DLTensor* data, float low, float high) {
CHECK_GT(high, low) << "high must be bigger than low";
CHECK(data->strides == nullptr);
DLDataType dtype = data->dtype;
int64_t size = 1;
for (int i = 0; i < data->ndim; ++i) {
size *= data->shape[i];
}
CHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1);
if (data->ctx.device_type == kDLCPU) {
std::uniform_real_distribution<float> uniform_dist(low, high);
std::generate_n(static_cast<float*>(data->data), size, [&] () {
return uniform_dist(rnd_engine_);
});
} else {
LOG(FATAL) << "Do not support random.randint on this device yet";
}
}
private:
std::mt19937 rnd_engine_;
unsigned rseed_;
};
struct RandomThreadLocalEntry { struct RandomThreadLocalEntry {
RandomEngine random_engine; RandomEngine random_engine;
static RandomThreadLocalEntry* ThreadLocal(); static RandomThreadLocalEntry* ThreadLocal();
......
/*!
* Copyright (c) 2018 by Contributors
* \file random/sgx_random_engine.h
* \brief SGX trusted random engine
*/
#include <dmlc/logging.h>
#include <sgx_trts.h>
#include <algorithm>
#include "../../runtime/sgx/common.h"
namespace tvm {
namespace contrib {
/*!
* \brief An interface for generating [tensors of] random numbers.
*/
class RandomEngine {
public:
/*!
* \brief Creates a RandomEngine, suggesting the use of a provided seed.
*/
explicit RandomEngine(unsigned seed) {
LOG(WARNING) << "SGX RandomEngine does not support seeding.";
}
/*!
* \brief Seeds the underlying RNG, if possible.
*/
inline void Seed(unsigned seed) {
LOG(WARNING) << "SGX RandomEngine does not support seeding.";
}
/*!
* \return the seed associated with the underlying RNG.
*/
inline unsigned GetSeed() const {
LOG(WARNING) << "SGX RandomEngine does not support seeding.";
return 0;
}
/*!
* \return a random integer sampled from the RNG.
*/
inline unsigned GetRandInt() {
int rand_int;
TVM_SGX_CHECKED_CALL(
sgx_read_rand(reinterpret_cast<unsigned char*>(&rand_int), sizeof(int)));
return rand_int;
}
/*!
* \brief Fills a tensor with values drawn from Unif(low, high)
*/
void SampleUniform(DLTensor* data, float low, float high) {
CHECK_GT(high, low) << "high must be bigger than low";
CHECK(data->strides == nullptr);
DLDataType dtype = data->dtype;
int64_t size = 1;
for (int i = 0; i < data->ndim; ++i) {
size *= data->shape[i];
}
CHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1);
std::generate_n(static_cast<float*>(data->data), size, [&] () {
float max_int = static_cast<float>(std::numeric_limits<unsigned>::max());
float unif01 = GetRandInt() / max_int;
return low + unif01 * (high - low);
});
}
};
} // namespace contrib
} // namespace tvm
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#ifndef TVM_RUNTIME_SGX_COMMON_H_ #ifndef TVM_RUNTIME_SGX_COMMON_H_
#define TVM_RUNTIME_SGX_COMMON_H_ #define TVM_RUNTIME_SGX_COMMON_H_
#include <sgx_error.h>
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
namespace sgx { namespace sgx {
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
#ifndef TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_ #ifndef TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_
#define TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_ #define TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_
#include <sgx_edger8r.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <string> #include <string>
#include "../common.h" #include "../common.h"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment