diff --git a/make/contrib/random.mk b/make/contrib/random.mk index aea6770101d47f96e827bb426bb62bc173c73c6d..3f8c03f5e4e8aaf20bd2d06d100e0a8cb739a5a7 100644 --- a/make/contrib/random.mk +++ b/make/contrib/random.mk @@ -1,4 +1,4 @@ -RANDOM_CONTRIB_SRC = $(wildcard src/contrib/random/*.cc) +RANDOM_CONTRIB_SRC = $(wildcard src/contrib/random/random.cc) RANDOM_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(RANDOM_CONTRIB_SRC)) ifeq ($(USE_RANDOM), 1) diff --git a/src/contrib/random/mt_random_engine.cc b/src/contrib/random/mt_random_engine.cc new file mode 100644 index 0000000000000000000000000000000000000000..b7f8db4b85962d53443101e12cf16dc2b11a1b6a --- /dev/null +++ b/src/contrib/random/mt_random_engine.cc @@ -0,0 +1,86 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file random/mt_random_engine.cc + * \brief mt19937 random engine + */ +#include <dmlc/logging.h> +#include <algorithm> +#include <ctime> +#include <random> + +namespace tvm { +namespace contrib { + +/*! + * \brief An interface for generating [tensors of] random numbers. + */ +class RandomEngine { + public: + /*! + * \brief Creates a RandomEngine using a default seed. + */ + RandomEngine() { + this->Seed(time(0)); + } + + /*! + * \brief Creates a RandomEngine, suggesting the use of a provided seed. + */ + explicit RandomEngine(unsigned seed) { + this->Seed(seed); + } + + /*! + * \brief Seeds the underlying RNG, if possible. + */ + inline void Seed(unsigned seed) { + rnd_engine_.seed(seed); + this->rseed_ = static_cast<unsigned>(seed); + } + + /*! + * \return the seed associated with the underlying RNG. + */ + inline unsigned GetSeed() const { + return rseed_; + } + + /*! + * \return a random integer sampled from the RNG. + */ + inline unsigned GetRandInt() { + return rnd_engine_(); + } + + /*! + * \brief Fills a tensor with values drawn from Unif(low, high) + */ + void SampleUniform(DLTensor* data, float low, float high) { + CHECK_GT(high, low) << "high must be bigger than low"; + CHECK(data->strides == nullptr); + + DLDataType dtype = data->dtype; + int64_t size = 1; + for (int i = 0; i < data->ndim; ++i) { + size *= data->shape[i]; + } + + CHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1); + + if (data->ctx.device_type == kDLCPU) { + std::uniform_real_distribution<float> uniform_dist(low, high); + std::generate_n(static_cast<float*>(data->data), size, [&] () { + return uniform_dist(rnd_engine_); + }); + } else { + LOG(FATAL) << "Do not support random.randint on this device yet"; + } + } + + private: + std::mt19937 rnd_engine_; + unsigned rseed_; +}; + +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/random/random.cc b/src/contrib/random/random.cc index d0bcb18cb76bcd238c94c62fec380499898bd4ef..de60da906adc0006b8ceb98f9d04431af6a09a2b 100644 --- a/src/contrib/random/random.cc +++ b/src/contrib/random/random.cc @@ -7,8 +7,11 @@ #include <dmlc/logging.h> #include <dmlc/thread_local.h> #include <algorithm> -#include <random> -#include <ctime> +#ifndef _LIBCPP_SGX_CONFIG +#include "./mt_random_engine.cc" +#else +#include "./sgx_random_engine.cc" +#endif #define DLPACK_INTEGER_TYPE_SWITCH(type, DType, ...) \ if (type.code == kDLInt && type.bits == 32) { \ @@ -38,57 +41,6 @@ namespace contrib { using namespace runtime; -class RandomEngine { - public: - RandomEngine() { - this->Seed(time(0)); - } - explicit RandomEngine(int seed) { - this->Seed(seed); - } - - ~RandomEngine() {} - - inline void Seed(int seed) { - rnd_engine_.seed(seed); - this->rseed_ = static_cast<unsigned>(seed); - } - - inline unsigned GetSeed() const { - return rseed_; - } - - inline unsigned GetRandInt() { - return rnd_engine_(); - } - - void SampleUniform(DLTensor* data, float low, float high) { - CHECK_GT(high, low) << "high must be bigger than low"; - CHECK(data->strides == nullptr); - - DLDataType dtype = data->dtype; - int64_t size = 1; - for (int i = 0; i < data->ndim; ++i) { - size *= data->shape[i]; - } - - CHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1); - - if (data->ctx.device_type == kDLCPU) { - std::uniform_real_distribution<float> uniform_dist(low, high); - std::generate_n(static_cast<float*>(data->data), size, [&] () { - return uniform_dist(rnd_engine_); - }); - } else { - LOG(FATAL) << "Do not support random.randint on this device yet"; - } - } - - private: - std::mt19937 rnd_engine_; - unsigned rseed_; -}; - struct RandomThreadLocalEntry { RandomEngine random_engine; static RandomThreadLocalEntry* ThreadLocal(); diff --git a/src/contrib/random/sgx_random_engine.cc b/src/contrib/random/sgx_random_engine.cc new file mode 100644 index 0000000000000000000000000000000000000000..2bdcf529ba1edcadf78766098ec18252c23fe401 --- /dev/null +++ b/src/contrib/random/sgx_random_engine.cc @@ -0,0 +1,75 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file random/sgx_random_engine.h + * \brief SGX trusted random engine + */ +#include <dmlc/logging.h> +#include <sgx_trts.h> +#include <algorithm> +#include "../../runtime/sgx/common.h" + +namespace tvm { +namespace contrib { + +/*! + * \brief An interface for generating [tensors of] random numbers. + */ +class RandomEngine { + public: + /*! + * \brief Creates a RandomEngine, suggesting the use of a provided seed. + */ + explicit RandomEngine(unsigned seed) { + LOG(WARNING) << "SGX RandomEngine does not support seeding."; + } + + /*! + * \brief Seeds the underlying RNG, if possible. + */ + inline void Seed(unsigned seed) { + LOG(WARNING) << "SGX RandomEngine does not support seeding."; + } + + /*! + * \return the seed associated with the underlying RNG. + */ + inline unsigned GetSeed() const { + LOG(WARNING) << "SGX RandomEngine does not support seeding."; + return 0; + } + + /*! + * \return a random integer sampled from the RNG. + */ + inline unsigned GetRandInt() { + int rand_int; + TVM_SGX_CHECKED_CALL( + sgx_read_rand(reinterpret_cast<unsigned char*>(&rand_int), sizeof(int))); + return rand_int; + } + + /*! + * \brief Fills a tensor with values drawn from Unif(low, high) + */ + void SampleUniform(DLTensor* data, float low, float high) { + CHECK_GT(high, low) << "high must be bigger than low"; + CHECK(data->strides == nullptr); + + DLDataType dtype = data->dtype; + int64_t size = 1; + for (int i = 0; i < data->ndim; ++i) { + size *= data->shape[i]; + } + + CHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1); + + std::generate_n(static_cast<float*>(data->data), size, [&] () { + float max_int = static_cast<float>(std::numeric_limits<unsigned>::max()); + float unif01 = GetRandInt() / max_int; + return low + unif01 * (high - low); + }); + } +}; + +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/sgx/common.h b/src/runtime/sgx/common.h index a36b33d48b1aeaf80e8b43bafe9e60ac9979bdbc..a375bcd21dd22d3fb859ae33f892ddce1c8d7d2f 100644 --- a/src/runtime/sgx/common.h +++ b/src/runtime/sgx/common.h @@ -6,6 +6,8 @@ #ifndef TVM_RUNTIME_SGX_COMMON_H_ #define TVM_RUNTIME_SGX_COMMON_H_ +#include <sgx_error.h> + namespace tvm { namespace runtime { namespace sgx { diff --git a/src/runtime/sgx/trusted/runtime.h b/src/runtime/sgx/trusted/runtime.h index ded76fd52d5455431233d8a948c5a08dfe747a2c..9bd834e0513f098572b0fcad8fd5c72f51bd70d3 100644 --- a/src/runtime/sgx/trusted/runtime.h +++ b/src/runtime/sgx/trusted/runtime.h @@ -6,7 +6,6 @@ #ifndef TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_ #define TVM_RUNTIME_SGX_TRUSTED_RUNTIME_H_ -#include <sgx_edger8r.h> #include <tvm/runtime/packed_func.h> #include <string> #include "../common.h"