Commit 5cf08d6c by Tianqi Chen

[REFACTOR] copy DMLC headers, move operator to example (#20)

parent 0538a9fc
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -msse2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -fPIC
-Iinclude -fPIC
# specify tensor path
.PHONY: clean all test lint doc cython cython3 cyclean
all: lib/libnnvm.so lib/libnnvm.a cli_test
all: lib/libnnvm.a lib/libnnvm_example.so
SRC = $(wildcard src/*.cc src/*/*.cc example/*.cc)
SRC = $(wildcard src/*.cc src/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(filter-out build/test_main.o, $(ALL_OBJ))
ALL_DEP = $(ALL_OBJ)
include tests/cpp/unittest.mk
......@@ -20,16 +20,14 @@ build/%.o: src/%.cc
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@
lib/libnnvm.so: $(ALL_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lib/libnnvm.a: $(ALL_DEP)
@mkdir -p $(@D)
ar crv $@ $(filter %.o, $?)
cli_test: $(ALL_DEP) build/test_main.o
$(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lib/libnnvm_example.so: example/src/operator.cc lib/libnnvm.a
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cc, $^) $(LDFLAGS) -Wl,--whole-archive lib/libnnvm.a -Wl,--no-whole-archive
cython:
cd python; python setup.py build_ext --inplace
......
......@@ -96,7 +96,7 @@ NNVM_REGISTER_OP(__add_symbol__)
.set_num_inputs(2);
NNVM_REGISTER_OP(exp)
.describe("take exponmential")
.describe("take exponential")
.set_num_inputs(1)
.attr("inplace_pair", std::make_pair(0, 0))
.attr<FInferShape>("FInferShape", SameShape);
......
This folder is synced from dmlc-core/include/dmlc
Contains useful utility headers for the project.
/*!
* Copyright (c) 2016 by Contributors
* \file any.h
* \brief Container to hold any data type.
*/
#ifndef DMLC_ANY_H_
#define DMLC_ANY_H_
// This code need c++11 to compile
#include <typeinfo>
#include <type_traits>
#include <utility>
#include <algorithm>
#include "./base.h"
#include "./logging.h"
namespace dmlc {
// forward declare any;
class any;
/*!
* Get a reference to content stored in the any as type T.
* This will cause an error if
* T does not match the type stored.
* This function is not part of std::any standard.
*
* \param src The source source any container.
* \return The reference of content
* \tparam T The type of the value to be fetched.
*/
template<typename T>
inline T& get(any& src); // NOLINT(*)
/*!
* Get the const reference content stored in the any as type T.
* This will cause an error if
* T does not match the type stored.
* This function is not part of std::any standard.
*
* \param src The source source any container.
* \return The reference of content
* \tparam T The type of the value to be fetched.
*/
template<typename T>
inline const T& get(const any& src);
/*!
* \brief An any class that is compatible to std::any in c++17.
*
* \code
* dmlc::any a = std::string("mydear"), b = 1;
* // get reference out and add it
* dmlc::get<int>(b) += 1;
* // a is now string
* LOG(INFO) << dmlc::get<std::string>(a);
* // a is now 2, the string stored will be properly destructed
* a = std::move(b);
* LOG(INFO) << dmlc::get<int>(a);
* \endcode
* \sa get
*/
class any {
public:
/*! \brief default constructor */
inline any() = default;
/*!
* \brief move constructor from another any
* \param other The other any to be moved
*/
inline any(any&& other); // NOLINT(*)
/*!
* \brief copy constructor
* \param other The other any to be copied
*/
inline any(const any& other); // NOLINT(*)
/*!
* \brief constructor from any types
* \param other The other types to be constructed into any.
* \tparam T The value type of other.
*/
template<typename T>
inline any(T&& other); // NOLINT(*)
/*! \brief destructor */
inline ~any();
/*!
* \brief assign operator from other
* \param other The other any to be copy or moved.
* \return self
*/
inline any& operator=(any&& other);
/*!
* \brief assign operator from other
* \param other The other any to be copy or moved.
* \return self
*/
inline any& operator=(const any& other);
/*!
* \brief assign operator from any type.
* \param other The other any to be copy or moved.
* \tparam T The value type of other.
* \return self
*/
template<typename T>
inline any& operator=(T&& other);
/*!
* \return whether the container is empty.
*/
inline bool empty() const;
/*!
* \return clear the content of container
*/
inline void clear();
/*!
* swap current content with other
* \param other The other data to be swapped.
*/
inline void swap(any& other); // NOLINT(*)
/*!
* \return The type_info about the stored type.
*/
inline const std::type_info& type() const;
private:
//! \cond Doxygen_Suppress
// declare of helper class
template<typename T>
class TypeOnHeap;
template<typename T>
class TypeOnStack;
template<typename T>
class TypeInfo;
// size of stack space, it takes 32 bytes for one any type.
static const size_t kStack = sizeof(void*) * 3;
static const size_t kAlign = sizeof(void*);
// container use dynamic storage only when space runs lager
union Data {
// stack space
std::aligned_storage<kStack, kAlign>::type stack;
// pointer to heap space
void* pheap;
};
// type specific information
struct Type {
// destructor function
void (*destroy)(Data* data);
// copy constructor
void (*create_from_data)(Data* dst, const Data& src);
// the type info function
const std::type_info* ptype_info;
};
// constant to check if data can be stored on heap.
template<typename T>
struct data_on_stack {
static const bool value = alignof(T) <= kAlign && sizeof(T) <= kStack;
};
// declare friend with
template<typename T>
friend T& get(any& src); // NOLINT(*)
template<typename T>
friend const T& get(const any& src);
// internal construct function
inline void construct(any&& other);
// internal construct function
inline void construct(const any& other);
// internal function to check if type is correct.
template<typename T>
inline void check_type() const;
// internal type specific information
const Type* type_{nullptr};
// internal data
Data data_;
};
template<typename T>
inline any::any(T&& other) {
typedef typename std::decay<T>::type DT;
if (std::is_same<DT, any>::value) {
this->construct(std::forward<T>(other));
} else {
static_assert(std::is_copy_constructible<DT>::value,
"Any can only hold value that is copy constructable");
type_ = TypeInfo<DT>::get_type();
if (data_on_stack<DT>::value) {
new (&(data_.stack)) DT(std::forward<T>(other));
} else {
data_.pheap = new DT(std::forward<T>(other));
}
}
}
inline any::any(any&& other) {
this->construct(std::move(other));
}
inline any::any(const any& other) {
this->construct(other);
}
inline void any::construct(any&& other) {
type_ = other.type_;
data_ = other.data_;
other.type_ = nullptr;
}
inline void any::construct(const any& other) {
type_ = other.type_;
if (type_ != nullptr) {
type_->create_from_data(&data_, other.data_);
}
}
inline any::~any() {
this->clear();
}
inline any& any::operator=(any&& other) {
any(std::move(other)).swap(*this);
return *this;
}
inline any& any::operator=(const any& other) {
any(other).swap(*this);
return *this;
}
template<typename T>
inline any& any::operator=(T&& other) {
any(std::forward<T>(other)).swap(*this);
return *this;
}
inline void any::swap(any& other) { // NOLINT(*)
std::swap(type_, other.type_);
std::swap(data_, other.data_);
}
inline void any::clear() {
if (type_ != nullptr) {
if (type_->destroy != nullptr) {
type_->destroy(&data_);
}
type_ = nullptr;
}
}
inline bool any::empty() const {
return type_ == nullptr;
}
inline const std::type_info& any::type() const {
if (type_ != nullptr) {
return *(type_->ptype_info);
} else {
return typeid(void);
}
}
template<typename T>
inline void any::check_type() const {
CHECK(type_ != nullptr)
<< "The any container is empty";
CHECK(type_->ptype_info == &typeid(T))
<< "The stored type mismatch"
<< " stored=" << type_->ptype_info->name()
<< " requested=" << typeid(T).name();
}
template<typename T>
inline const T& get(const any& src) {
src.check_type<T>();
return *any::TypeInfo<T>::get_ptr(&(src.data_));
}
template<typename T>
inline T& get(any& src) { // NOLINT(*)
src.check_type<T>();
return *any::TypeInfo<T>::get_ptr(&(src.data_));
}
template<typename T>
class any::TypeOnHeap {
public:
inline static T* get_ptr(any::Data* data) {
return static_cast<T*>(data->pheap);
}
inline static const T* get_ptr(const any::Data* data) {
return static_cast<const T*>(data->pheap);
}
inline static void create_from_data(any::Data* dst, const any::Data& data) {
dst->pheap = new T(*get_ptr(&data));
}
inline static void destroy(Data* data) {
delete static_cast<T*>(data->pheap);
}
};
template<typename T>
class any::TypeOnStack {
public:
inline static T* get_ptr(any::Data* data) {
return reinterpret_cast<T*>(&(data->stack));
}
inline static const T* get_ptr(const any::Data* data) {
return reinterpret_cast<const T*>(&(data->stack));
}
inline static void create_from_data(any::Data* dst, const any::Data& data) {
new (&(dst->stack)) T(*get_ptr(&data));
}
inline static void destroy(Data* data) {
T* dptr = reinterpret_cast<T*>(&(data->stack));
dptr->~T();
}
};
template<typename T>
class any::TypeInfo
: public std::conditional<any::data_on_stack<T>::value,
any::TypeOnStack<T>,
any::TypeOnHeap<T> >::type {
public:
inline static const Type* get_type() {
static TypeInfo<T> tp;
return &(tp.type_);
}
private:
// local type
Type type_;
// constructor
TypeInfo() {
if (std::is_pod<T>::value) {
type_.destroy = nullptr;
} else {
type_.destroy = TypeInfo<T>::destroy;
}
type_.create_from_data = TypeInfo<T>::create_from_data;
type_.ptype_info = &typeid(T);
}
};
//! \endcond
} // namespace dmlc
#endif // DMLC_ANY_H_
/*!
* Copyright (c) 2016 by Contributors
* \file array_view.h
* \brief Read only data structure to reference array
*/
#ifndef DMLC_ARRAY_VIEW_H_
#define DMLC_ARRAY_VIEW_H_
#include <vector>
#include <array>
namespace dmlc {
/*!
* \brief Read only data structure to reference continuous memory region of array.
* Provide unified view for vector, array and C style array.
* This data structure do not guarantee aliveness of referenced array.
*
* Make sure do not use array_view to record data in async function closures.
* Also do not use array_view to create reference to temporary data structure.
*
* \tparam ValueType The value
*
* \code
* std::vector<int> myvec{1,2,3};
* dmlc::array_view<int> view(myvec);
* // indexed visit to the view.
* LOG(INFO) << view[0];
*
* for (int v : view) {
* // visit each element in the view
* }
* \endcode
*/
template<typename ValueType>
class array_view {
public:
/*! \brief default constructor */
array_view() = default;
/*!
* \brief default copy constructor
* \param other another array view.
*/
array_view(const array_view<ValueType> &other) = default; // NOLINT(*)
/*!
* \brief default move constructor
* \param other another array view.
*/
array_view(array_view<ValueType>&& other) = default; // NOLINT(*)
/*!
* \brief default assign constructor
* \param other another array view.
* \return self.
*/
array_view<ValueType>& operator=(const array_view<ValueType>& other) = default; // NOLINT(*)
/*!
* \brief construct array view std::vector
* \param other vector container
*/
array_view(const std::vector<ValueType>& other) { // NOLINT(*)
if (other.size() != 0) {
begin_ = &other[0]; size_ = other.size();
}
}
/*!
* \brief construct array std::array
* \param other another array view.
*/
template<std::size_t size>
array_view(const std::array<ValueType, size>& other) { // NOLINT(*)
if (size != 0) {
begin_ = &other[0]; size_ = size;
}
}
/*!
* \brief construct array view from continuous segment
* \param begin beginning pointre
* \param end end pointer
*/
array_view(const ValueType* begin, const ValueType* end) {
if (begin < end) {
begin_ = begin;
size_ = end - begin;
}
}
/*! \return size of the array */
inline size_t size() const {
return size_;
}
/*! \return begin of the array */
inline const ValueType* begin() const {
return begin_;
}
/*! \return end point of the array */
inline const ValueType* end() const {
return begin_ + size_;
}
/*!
* \brief get i-th element from the view
* \param i The index.
* \return const reference to i-th element.
*/
inline const ValueType& operator[](size_t i) const {
return begin_[i];
}
private:
/*! \brief the begin of the view */
const ValueType* begin_{nullptr};
/*! \brief The size of the view */
size_t size_{0};
};
} // namespace dmlc
#endif // DMLC_ARRAY_VIEW_H_
/*!
* Copyright (c) 2015 by Contributors
* \file base.h
* \brief defines configuration macros
*/
#ifndef DMLC_BASE_H_
#define DMLC_BASE_H_
/*! \brief whether use glog for logging */
#ifndef DMLC_USE_GLOG
#define DMLC_USE_GLOG 0
#endif
/*!
* \brief whether throw dmlc::Error instead of
* directly calling abort when FATAL error occured
* NOTE: this may still not be perfect.
* do not use FATAL and CHECK in destructors
*/
#ifndef DMLC_LOG_FATAL_THROW
#define DMLC_LOG_FATAL_THROW 1
#endif
/*!
* \brief whether always log a message before throw
* This can help identify the error that cannot be catched.
*/
#ifndef DMLC_LOG_BEFORE_THROW
#define DMLC_LOG_BEFORE_THROW 1
#endif
/*!
* \brief Whether to use customized logger,
* whose output can be decided by other libraries.
*/
#ifndef DMLC_LOG_CUSTOMIZE
#define DMLC_LOG_CUSTOMIZE 0
#endif
/*! \brief whether compile with hdfs support */
#ifndef DMLC_USE_HDFS
#define DMLC_USE_HDFS 0
#endif
/*! \brief whether compile with s3 support */
#ifndef DMLC_USE_S3
#define DMLC_USE_S3 0
#endif
/*! \brief whether or not use parameter server */
#ifndef DMLC_USE_PS
#define DMLC_USE_PS 0
#endif
/*! \brief whether or not use c++11 support */
#ifndef DMLC_USE_CXX11
#define DMLC_USE_CXX11 (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\
__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/// check if g++ is before 4.6
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6
#pragma message("Will need g++-4.6 or higher to compile all" \
"the features in dmlc-core, " \
"compile without c++0x, some features may be disabled")
#undef DMLC_USE_CXX11
#define DMLC_USE_CXX11 0
#endif
#endif
/*!
* \brief Enable std::thread related modules,
* Used to disable some module in mingw compile.
*/
#ifndef DMLC_ENABLE_STD_THREAD
#define DMLC_ENABLE_STD_THREAD DMLC_USE_CXX11
#endif
/*! \brief whether enable regex support, actually need g++-4.9 or higher*/
#ifndef DMLC_USE_REGEX
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/*! \brief helper macro to generate string concat */
#define DMLC_STR_CONCAT_(__x, __y) __x##__y
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)
/*!
* \brief Disable copy constructor and assignment operator.
*
* If C++11 is supported, both copy and move constructors and
* assignment operators are deleted explicitly. Otherwise, they are
* only declared but not implemented. Place this macro in private
* section if C++11 is not available.
*/
#ifndef DISALLOW_COPY_AND_ASSIGN
# if DMLC_USE_CXX11
# define DISALLOW_COPY_AND_ASSIGN(T) \
T(T const&) = delete; \
T(T&&) = delete; \
T& operator=(T const&) = delete; \
T& operator=(T&&) = delete
# else
# define DISALLOW_COPY_AND_ASSIGN(T) \
T(T const&); \
T& operator=(T const&)
# endif
#endif
///
/// code block to handle optionally loading
///
#if !defined(__GNUC__)
#define fopen64 std::fopen
#endif
#if (defined __MINGW32__) && !(defined __MINGW64__)
#define fopen64 std::fopen
#endif
#ifdef _MSC_VER
#if _MSC_VER < 1900
// NOTE: sprintf_s is not equivalent to snprintf,
// they are equivalent when success, which is sufficient for our case
#define snprintf sprintf_s
#define vsnprintf vsprintf_s
#endif
#else
#ifdef _FILE_OFFSET_BITS
#if _FILE_OFFSET_BITS == 32
#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit")
#endif
#endif
#ifdef __APPLE__
#define off64_t off_t
#define fopen64 std::fopen
#endif
extern "C" {
#include <sys/types.h>
}
#endif
#ifdef _MSC_VER
//! \cond Doxygen_Suppress
typedef signed char int8_t;
typedef __int16 int16_t;
typedef __int32 int32_t;
typedef __int64 int64_t;
typedef unsigned char uint8_t;
typedef unsigned __int16 uint16_t;
typedef unsigned __int32 uint32_t;
typedef unsigned __int64 uint64_t;
//! \endcond
#else
#include <inttypes.h>
#endif
#include <string>
#include <vector>
#if defined(_MSC_VER) && _MSC_VER < 1900
#define noexcept_true throw ()
#define noexcept_false
#define noexcept(a) noexcept_##a
#endif
#if DMLC_USE_CXX11
#define DMLC_THROW_EXCEPTION noexcept(false)
#define DMLC_NO_EXCEPTION noexcept(true)
#else
#define DMLC_THROW_EXCEPTION
#define DMLC_NO_EXCEPTION
#endif
/*! \brief namespace for dmlc */
namespace dmlc {
/*!
* \brief safely get the beginning address of a vector
* \param vec input vector
* \return beginning address of a vector
*/
template<typename T>
inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
if (vec.size() == 0) {
return NULL;
} else {
return &vec[0];
}
}
/*!
* \brief get the beginning address of a vector
* \param vec input vector
* \return beginning address of a vector
*/
template<typename T>
inline const T *BeginPtr(const std::vector<T> &vec) {
if (vec.size() == 0) {
return NULL;
} else {
return &vec[0];
}
}
/*!
* \brief get the beginning address of a vector
* \param str input string
* \return beginning address of a string
*/
inline char* BeginPtr(std::string &str) { // NOLINT(*)
if (str.length() == 0) return NULL;
return &str[0];
}
/*!
* \brief get the beginning address of a vector
* \param str input string
* \return beginning address of a string
*/
inline const char* BeginPtr(const std::string &str) {
if (str.length() == 0) return NULL;
return &str[0];
}
} // namespace dmlc
#if defined(_MSC_VER) && _MSC_VER < 1900
#define constexpr const
#define alignof __alignof
#endif
#endif // DMLC_BASE_H_
/*!
* Copyright (c) 2015 by Contributors
* \file json.h
* \brief Lightweight JSON Reader/Writer that read save into C++ data structs.
* This includes STL composites and structures.
*/
#ifndef DMLC_JSON_H_
#define DMLC_JSON_H_
// This code requires C++11 to compile
#include <vector>
#include <iostream>
#include <cctype>
#include <string>
#include <algorithm>
#include <map>
#include <list>
#include <utility>
#include "./base.h"
#include "./logging.h"
#include "./type_traits.h"
#if DMLC_USE_CXX11
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "./any.h"
#endif // DMLC_USE_CXX11
namespace dmlc {
/*!
* \brief Lightweight JSON Reader to read any STL compositions and structs.
* The user need to know the schema of the
*
*/
class JSONReader {
public:
/*!
* \brief Constructor.
* \param is the input stream.
*/
explicit JSONReader(std::istream *is)
: is_(is),
line_count_r_(0),
line_count_n_(0) {}
/*!
* \brief Parse next JSON string.
* \param out_str the output string.
* \throw dmlc::Error when next token is not string
*/
inline void ReadString(std::string *out_str);
/*!
* \brief Read Number.
* \param out_value output value;
* \throw dmlc::Error when next token is not number of ValueType.
* \tparam ValueType type of the number
*/
template<typename ValueType>
inline void ReadNumber(ValueType *out_value);
/*!
* \brief Begin parsing an object.
* \code
* std::string key;
* // value can be any type that is json serializable.
* std::string value;
* reader->BeginObject();
* while (reader->NextObjectItem(&key)) {
* // do somthing to key value
* reader->Read(&value);
* }
* \endcode
*/
inline void BeginObject();
/*!
* \brief Begin parsing an array.
* \code
* // value can be any type that is json serializable.
* std::string value;
* reader->BeginArray();
* while (reader->NextObjectArrayItem(&value)) {
* // do somthing to value
* }
* \endcode
*/
inline void BeginArray();
/*!
* \brief Try to move to next object item.
* If this call is successful, user can proceed to call
* reader->Read to read in the value.
* \param out_key the key to the next object.
* \return true if the read is successful, false if we are at end of the object.
*/
inline bool NextObjectItem(std::string *out_key);
/*!
* \brief Try to read the next element in the array.
* If this call is successful, user can proceed to call
* reader->Read to read in the value.
* \return true if the read is successful, false if we are at end of the array.
*/
inline bool NextArrayItem();
/*!
* \brief Read next ValueType.
* \param out_value any STL or json readable type to be read
* \throw dmlc::Error when the read of ValueType is not successful.
* \tparam ValueType the data type to be read.
*/
template<typename ValueType>
inline void Read(ValueType *out_value);
/*! \return current line count */
inline std::string line_info() const {
char temp[64];
std::ostringstream os;
os << " Line " << std::max(line_count_r_, line_count_n_);
is_->getline(temp, 64);
os << ", around ^`" << temp << "`";
return os.str();
}
private:
/*! \brief internal reader stream */
std::istream *is_;
/*! \brief "\\r" counter */
size_t line_count_r_;
/*! \brief "\\n" counter */
size_t line_count_n_;
/*!
* \brief record how many element processed in
* current array/object scope.
*/
std::vector<size_t> scope_counter_;
/*!
* \brief Read next nonspace character.
* \return the next nonspace character.
*/
inline int NextNonSpace();
/*!
* \brief Read just before next nonspace but not read that.
* \return the next nonspace character.
*/
inline int PeekNextNonSpace();
};
/*!
* \brief Lightweight json to write any STL compositions.
*/
class JSONWriter {
public:
/*!
* \brief Constructor.
* \param os the output stream.
*/
explicit JSONWriter(std::ostream *os)
: os_(os) {}
/*!
* \brief Write a string that do not contain escape characters.
* \param s the string to be written.
*/
inline void WriteNoEscape(const std::string &s);
/*!
* \brief Write a string that can contain escape characters.
* \param s the string to be written.
*/
inline void WriteString(const std::string &s);
/*!
* \brief Write a string that can contain escape characters.
* \param v the value to be written.
* \tparam ValueType The value type to be written.
*/
template<typename ValueType>
inline void WriteNumber(const ValueType &v);
/*!
* \brief Start beginning of array.
* \param multi_line whether to start an multi_line array.
* \code
* writer->BeginArray();
* for (auto& v : vdata) {
* writer->WriteArrayItem(v);
* }
* writer->EndArray();
* \endcode
*/
inline void BeginArray(bool multi_line = true);
/*! \brief Finish writing an array. */
inline void EndArray();
/*!
* \brief Start beginning of array.
* \param multi_line whether to start an multi_line array.
* \code
* writer->BeginObject();
* for (auto& kv : vmap) {
* writer->WriteObjectKeyValue(kv.first, kv.second);
* }
* writer->EndObject();
* \endcode
*/
inline void BeginObject(bool multi_line = true);
/*! \brief Finish writing object. */
inline void EndObject();
/*!
* \brief Write key value pair in the object.
* \param key the key of the object.
* \param value the value of to be written.
* \tparam ValueType The value type to be written.
*/
template<typename ValueType>
inline void WriteObjectKeyValue(const std::string &key,
const ValueType &value);
/*!
* \brief Write seperator of array, before writing next element.
* User can proceed to call writer->Write to write next item
*/
inline void WriteArraySeperator();
/*!
* \brief Write value into array.
* \param value The value of to be written.
* \tparam ValueType The value type to be written.
*/
template<typename ValueType>
inline void WriteArrayItem(const ValueType &value);
/*!
* \brief Write value to json.
* \param value any STL or json readable that can be written.
* \tparam ValueType the data type to be write.
*/
template<typename ValueType>
inline void Write(const ValueType &value);
private:
/*! \brief Output stream */
std::ostream *os_;
/*!
* \brief record how many element processed in
* current array/object scope.
*/
std::vector<size_t> scope_counter_;
/*! \brief Record whether current is a multiline scope */
std::vector<bool> scope_multi_line_;
/*!
* \brief Write seperating space and newlines
*/
inline void WriteSeperator();
};
/*!
* \brief Helper class to read JSON into a class or struct object.
* \code
* struct Param {
* std::string name;
* int value;
* // define load function from JSON
* inline void Load(dmlc::JSONReader *reader) {
* dmlc::JSONStructReadHelper helper;
* helper.DeclareField("name", &name);
* helper.DeclareField("value", &value);
* helper.ReadAllFields(reader);
* }
* };
* \endcode
*/
class JSONObjectReadHelper {
public:
/*!
* \brief Declare field of type T
* \param key the key of the of field.
* \param addr address of the data type.
* \tparam T the data type to be read, must be STL composition of JSON serializable.
*/
template<typename T>
inline void DeclareField(const std::string &key, T *addr) {
DeclareFieldInternal(key, addr, false);
}
/*!
* \brief Declare optional field of type T
* \param key the key of the of field.
* \param addr address of the data type.
* \tparam T the data type to be read, must be STL composition of JSON serializable.
*/
template<typename T>
inline void DeclareOptionalField(const std::string &key, T *addr) {
DeclareFieldInternal(key, addr, true);
}
/*!
* \brief Read in all the declared fields.
* \param reader the JSONReader to read the json.
*/
inline void ReadAllFields(JSONReader *reader);
private:
/*!
* \brief Internal function to declare field.
* \param key the key of the of field.
* \param addr address of the data type.
* \param optional if set to true, no error will be reported if the key is not presented.
* \tparam T the data type to be read, must be STL composition of JSON serializable.
*/
template<typename T>
inline void DeclareFieldInternal(const std::string &key, T *addr, bool optional);
/*!
* \brief The internal reader function.
* \param reader The reader to read.
* \param addr The memory address to read.
*/
template<typename T>
inline static void ReaderFunction(JSONReader *reader, void *addr);
/*! \brief callback type to reader function */
typedef void (*ReadFunction)(JSONReader *reader, void *addr);
/*! \brief internal data entry */
struct Entry {
/*! \brief the reader function */
ReadFunction func;
/*! \brief the address to read */
void *addr;
/*! \brief whether it is optional */
bool optional;
};
/*! \brief the internal map of reader callbacks */
std::map<std::string, Entry> map_;
};
#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \
static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __
/*!
* \def DMLC_JSON_ENABLE_ANY
* \brief Macro to enable save/load JSON of dmlc:: whose actual type is Type.
* Any type will be saved as json array [KeyName, content]
*
* \param Type The type to be registered.
* \param KeyName The Type key assigned to the type, must be same during load.
*/
#define DMLC_JSON_ENABLE_ANY(Type, KeyName) \
DMLC_STR_CONCAT(DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName), __COUNTER__) = \
::dmlc::json::AnyJSONManager::Global()->EnableType<Type>(#KeyName) \
//! \cond Doxygen_Suppress
namespace json {
/*!
* \brief generic serialization handler
* \tparam T the type to be serialized
*/
template<typename T>
struct Handler;
template<typename ValueType>
struct NumericHandler {
inline static void Write(JSONWriter *writer, const ValueType &value) {
writer->WriteNumber<ValueType>(value);
}
inline static void Read(JSONReader *reader, ValueType *value) {
reader->ReadNumber<ValueType>(value);
}
};
template<typename ContainerType>
struct ArrayHandler {
inline static void Write(JSONWriter *writer, const ContainerType &array) {
typedef typename ContainerType::value_type ElemType;
writer->BeginArray(array.size() > 10 || !dmlc::is_pod<ElemType>::value);
for (typename ContainerType::const_iterator it = array.begin();
it != array.end(); ++it) {
writer->WriteArrayItem(*it);
}
writer->EndArray();
}
inline static void Read(JSONReader *reader, ContainerType *array) {
typedef typename ContainerType::value_type ElemType;
array->clear();
reader->BeginArray();
while (reader->NextArrayItem()) {
ElemType value;
Handler<ElemType>::Read(reader, &value);
array->insert(array->end(), value);
}
}
};
template<typename ContainerType>
struct MapHandler{
inline static void Write(JSONWriter *writer, const ContainerType &map) {
writer->BeginObject(map.size() > 1);
for (typename ContainerType::const_iterator it = map.begin(); it != map.end(); ++it) {
writer->WriteObjectKeyValue(it->first, it->second);
}
writer->EndObject();
}
inline static void Read(JSONReader *reader, ContainerType *map) {
typedef typename ContainerType::mapped_type ElemType;
map->clear();
reader->BeginObject();
std::string key;
while (reader->NextObjectItem(&key)) {
ElemType value;
reader->Read(&value);
(*map)[key] = value;
}
}
};
template<typename T>
struct CommonJSONSerializer {
inline static void Write(JSONWriter *writer, const T &value) {
value.Save(writer);
}
inline static void Read(JSONReader *reader, T *value) {
value->Load(reader);
}
};
template<>
struct Handler<std::string> {
inline static void Write(JSONWriter *writer, const std::string &value) {
writer->WriteString(value);
}
inline static void Read(JSONReader *reader, std::string *str) {
reader->ReadString(str);
}
};
template<typename T>
struct Handler<std::vector<T> > : public ArrayHandler<std::vector<T> > {
};
template<typename K, typename V>
struct Handler<std::pair<K, V> > {
inline static void Write(JSONWriter *writer, const std::pair<K, V> &kv) {
writer->BeginArray();
writer->WriteArrayItem(kv.first);
writer->WriteArrayItem(kv.second);
writer->EndArray();
}
inline static void Read(JSONReader *reader, std::pair<K, V> *kv) {
reader->BeginArray();
CHECK(reader->NextArrayItem())
<< "Expect array of length 2";
Handler<K>::Read(reader, &(kv->first));
CHECK(reader->NextArrayItem())
<< "Expect array of length 2";
Handler<V>::Read(reader, &(kv->second));
CHECK(!reader->NextArrayItem())
<< "Expect array of length 2";
}
};
template<typename T>
struct Handler<std::list<T> > : public ArrayHandler<std::list<T> > {
};
template<typename V>
struct Handler<std::map<std::string, V> > : public MapHandler<std::map<std::string, V> > {
};
#if DMLC_USE_CXX11
template<typename V>
struct Handler<std::unordered_map<std::string, V> >
: public MapHandler<std::unordered_map<std::string, V> > {
};
#endif // DMLC_USE_CXX11
template<typename T>
struct Handler {
inline static void Write(JSONWriter *writer, const T &data) {
typedef typename dmlc::IfThenElseType<dmlc::is_arithmetic<T>::value,
NumericHandler<T>,
CommonJSONSerializer<T> >::Type THandler;
THandler::Write(writer, data);
}
inline static void Read(JSONReader *reader, T *data) {
typedef typename dmlc::IfThenElseType<dmlc::is_arithmetic<T>::value,
NumericHandler<T>,
CommonJSONSerializer<T> >::Type THandler;
THandler::Read(reader, data);
}
};
#if DMLC_USE_CXX11
// Manager to store json serialization strategy.
class AnyJSONManager {
public:
template<typename T>
inline AnyJSONManager& EnableType(const std::string& type_name) { // NOLINT(*)
std::type_index tp = std::type_index(typeid(T));
if (type_name_.count(tp) != 0) {
CHECK(type_name_.at(tp) == type_name)
<< "Type has already been registered as another typename " << type_name_.at(tp);
return *this;
}
CHECK(type_map_.count(type_name) == 0)
<< "Type name " << type_name << " already registered in registry";
Entry e;
e.read = ReadAny<T>;
e.write = WriteAny<T>;
type_name_[tp] = type_name;
type_map_[type_name] = e;
return *this;
}
// return global singleton
inline static AnyJSONManager* Global() {
static AnyJSONManager inst;
return &inst;
}
private:
AnyJSONManager() {}
template<typename T>
inline static void WriteAny(JSONWriter *writer, const any &data) {
writer->Write(dmlc::get<T>(data));
}
template<typename T>
inline static void ReadAny(JSONReader *reader, any* data) {
T temp;
reader->Read(&temp);
*data = std::move(temp);
}
// data entry to store vtable for any type
struct Entry {
void (*read)(JSONReader* reader, any *data);
void (*write)(JSONWriter* reader, const any& data);
};
template<typename T>
friend struct Handler;
std::unordered_map<std::type_index, std::string> type_name_;
std::unordered_map<std::string, Entry> type_map_;
};
template<>
struct Handler<any> {
inline static void Write(JSONWriter *writer, const any &data) {
std::unordered_map<std::type_index, std::string>&
nmap = AnyJSONManager::Global()->type_name_;
std::type_index id = std::type_index(data.type());
auto it = nmap.find(id);
CHECK(it != nmap.end() && it->first == id)
<< "Type " << id.name() << " has not been registered via DMLC_JSON_ENABLE_ANY";
std::string type_name = it->second;
AnyJSONManager::Entry e = AnyJSONManager::Global()->type_map_.at(type_name);
writer->BeginArray(false);
writer->WriteArrayItem(type_name);
writer->WriteArraySeperator();
e.write(writer, data);
writer->EndArray();
}
inline static void Read(JSONReader *reader, any *data) {
std::string type_name;
reader->BeginArray();
CHECK(reader->NextArrayItem()) << "invalid any json format";
Handler<std::string>::Read(reader, &type_name);
std::unordered_map<std::string, AnyJSONManager::Entry>&
tmap = AnyJSONManager::Global()->type_map_;
auto it = tmap.find(type_name);
CHECK(it != tmap.end() && it->first == type_name)
<< "Typename " << type_name << " has not been registered via DMLC_JSON_ENABLE_ANY";
AnyJSONManager::Entry e = it->second;
CHECK(reader->NextArrayItem()) << "invalid any json format";
e.read(reader, data);
CHECK(!reader->NextArrayItem()) << "invalid any json format";
}
};
#endif // DMLC_USE_CXX11
} // namespace json
// implementations of JSONReader/Writer
inline int JSONReader::NextNonSpace() {
int ch;
do {
ch = is_->get();
if (ch == '\n') ++line_count_n_;
if (ch == '\r') ++line_count_r_;
} while (isspace(ch));
return ch;
}
inline int JSONReader::PeekNextNonSpace() {
int ch;
while (true) {
ch = is_->peek();
if (ch == '\n') ++line_count_n_;
if (ch == '\r') ++line_count_r_;
if (!isspace(ch)) break;
is_->get();
}
return ch;
}
inline void JSONReader::ReadString(std::string *out_str) {
int ch = NextNonSpace();
CHECK_EQ(ch, '\"')
<< "Error at" << line_info()
<< ", Expect \'\"\' but get \'" << static_cast<char>(ch) << '\'';
std::ostringstream os;
while (true) {
ch = is_->get();
if (ch == '\\') {
char sch = static_cast<char>(is_->get());
switch (sch) {
case 'r': os << "\r"; break;
case 'n': os << "\n"; break;
case '\\': os << "\\"; break;
case '\t': os << "\t"; break;
case '\"': os << "\""; break;
default: LOG(FATAL) << "unknown string escape \\" << sch;
}
} else {
if (ch == '\"') break;
os << static_cast<char>(ch);
}
if (ch == EOF || ch == '\r' || ch == '\n') {
LOG(FATAL)
<< "Error at" << line_info()
<< ", Expect \'\"\' but reach end of line ";
}
}
*out_str = os.str();
}
template<typename ValueType>
inline void JSONReader::ReadNumber(ValueType *out_value) {
*is_ >> *out_value;
CHECK(!is_->fail())
<< "Error at" << line_info()
<< ", Expect number";
}
inline void JSONReader::BeginObject() {
int ch = NextNonSpace();
CHECK_EQ(ch, '{')
<< "Error at" << line_info()
<< ", Expect \'{\' but get \'" << static_cast<char>(ch) << '\'';
scope_counter_.push_back(0);
}
inline void JSONReader::BeginArray() {
int ch = NextNonSpace();
CHECK_EQ(ch, '[')
<< "Error at" << line_info()
<< ", Expect \'{\' but get \'" << static_cast<char>(ch) << '\'';
scope_counter_.push_back(0);
}
inline bool JSONReader::NextObjectItem(std::string *out_key) {
bool next = true;
if (scope_counter_.back() != 0) {
int ch = NextNonSpace();
if (ch == EOF) {
next = false;
} else if (ch == '}') {
next = false;
} else {
CHECK_EQ(ch, ',')
<< "Error at" << line_info()
<< ", JSON object expect \'}\' or \',\' \'" << static_cast<char>(ch) << '\'';
}
} else {
int ch = PeekNextNonSpace();
if (ch == '}') {
is_->get();
next = false;
}
}
if (!next) {
scope_counter_.pop_back();
return false;
} else {
scope_counter_.back() += 1;
ReadString(out_key);
int ch = NextNonSpace();
CHECK_EQ(ch, ':')
<< "Error at" << line_info()
<< ", Expect \':\' but get \'" << static_cast<char>(ch) << '\'';
return true;
}
}
inline bool JSONReader::NextArrayItem() {
bool next = true;
if (scope_counter_.back() != 0) {
int ch = NextNonSpace();
if (ch == EOF) {
next = false;
} else if (ch == ']') {
next = false;
} else {
CHECK_EQ(ch, ',')
<< "Error at" << line_info()
<< ", JSON array expect \']\' or \',\'. Get \'" << static_cast<char>(ch) << "\' instead";
}
} else {
int ch = PeekNextNonSpace();
if (ch == ']') {
is_->get();
next = false;
}
}
if (!next) {
scope_counter_.pop_back();
return false;
} else {
scope_counter_.back() += 1;
return true;
}
}
template<typename ValueType>
inline void JSONReader::Read(ValueType *out_value) {
json::Handler<ValueType>::Read(this, out_value);
}
inline void JSONWriter::WriteNoEscape(const std::string &s) {
*os_ << '\"' << s << '\"';
}
inline void JSONWriter::WriteString(const std::string &s) {
std::ostream &os = *os_;
os << '\"';
for (size_t i = 0; i < s.length(); ++i) {
char ch = s[i];
switch (ch) {
case '\r': os << "\\r"; break;
case '\n': os << "\\n"; break;
case '\\': os << "\\\\"; break;
case '\t': os << "\\t"; break;
case '\"': os << "\\\""; break;
default: os << ch;
}
}
os << '\"';
}
template<typename ValueType>
inline void JSONWriter::WriteNumber(const ValueType &v) {
*os_ << v;
}
inline void JSONWriter::BeginArray(bool multi_line) {
*os_ << '[';
scope_multi_line_.push_back(multi_line);
scope_counter_.push_back(0);
}
inline void JSONWriter::EndArray() {
CHECK_NE(scope_multi_line_.size(), 0);
CHECK_NE(scope_counter_.size(), 0);
bool newline = scope_multi_line_.back();
size_t nelem = scope_counter_.back();
scope_multi_line_.pop_back();
scope_counter_.pop_back();
if (newline && nelem != 0) WriteSeperator();
*os_ << ']';
}
inline void JSONWriter::BeginObject(bool multi_line) {
*os_ << "{";
scope_multi_line_.push_back(multi_line);
scope_counter_.push_back(0);
}
inline void JSONWriter::EndObject() {
CHECK_NE(scope_multi_line_.size(), 0);
CHECK_NE(scope_counter_.size(), 0);
bool newline = scope_multi_line_.back();
size_t nelem = scope_counter_.back();
scope_multi_line_.pop_back();
scope_counter_.pop_back();
if (newline && nelem != 0) WriteSeperator();
*os_ << '}';
}
template<typename ValueType>
inline void JSONWriter::WriteObjectKeyValue(const std::string &key,
const ValueType &value) {
std::ostream &os = *os_;
if (scope_counter_.back() == 0) {
WriteSeperator();
os << '\"' << key << "\": ";
} else {
os << ", ";
WriteSeperator();
os << '\"' << key << "\": ";
}
scope_counter_.back() += 1;
json::Handler<ValueType>::Write(this, value);
}
inline void JSONWriter::WriteArraySeperator() {
std::ostream &os = *os_;
if (scope_counter_.back() != 0) {
os << ", ";
}
scope_counter_.back() += 1;
WriteSeperator();
}
template<typename ValueType>
inline void JSONWriter::WriteArrayItem(const ValueType &value) {
this->WriteArraySeperator();
json::Handler<ValueType>::Write(this, value);
}
template<typename ValueType>
inline void JSONWriter::Write(const ValueType &value) {
size_t nscope = scope_multi_line_.size();
json::Handler<ValueType>::Write(this, value);
CHECK_EQ(nscope, scope_multi_line_.size())
<< "Uneven scope, did you call EndArray/EndObject after each BeginObject/Array?";
}
inline void JSONWriter::WriteSeperator() {
if (scope_multi_line_.size() == 0 || scope_multi_line_.back()) {
*os_ << '\n' << std::string(scope_multi_line_.size() * 2, ' ');
}
}
inline void JSONObjectReadHelper::ReadAllFields(JSONReader *reader) {
reader->BeginObject();
std::map<std::string, int> visited;
std::string key;
while (reader->NextObjectItem(&key)) {
if (map_.count(key) != 0) {
Entry e = map_[key];
(*e.func)(reader, e.addr);
visited[key] = 0;
} else {
std::ostringstream os;
os << "JSONReader: Unknown field " << key << ", candidates are: \n";
for (std::map<std::string, Entry>::iterator
it = map_.begin(); it != map_.end(); ++it) {
os << '\"' <<it->first << "\"\n";
}
LOG(FATAL) << os.str();
}
}
if (visited.size() != map_.size()) {
for (std::map<std::string, Entry>::iterator
it = map_.begin(); it != map_.end(); ++it) {
if (it->second.optional) continue;
CHECK_NE(visited.count(it->first), 0)
<< "JSONReader: Missing field \"" << it->first << "\"\n At "
<< reader->line_info();
}
}
}
template<typename T>
inline void JSONObjectReadHelper::ReaderFunction(JSONReader *reader, void *addr) {
json::Handler<T>::Read(reader, static_cast<T*>(addr));
}
template<typename T>
inline void JSONObjectReadHelper::
DeclareFieldInternal(const std::string &key, T *addr, bool optional) {
CHECK_EQ(map_.count(key), 0)
<< "Adding duplicate field " << key;
Entry e;
e.func = ReaderFunction<T>;
e.addr = static_cast<void*>(addr);
e.optional = optional;
map_[key] = e;
}
//! \endcond
} // namespace dmlc
#endif // DMLC_JSON_H_
/*!
* Copyright (c) 2015 by Contributors
* \file logging.h
* \brief defines logging macros of dmlc
* allows use of GLOG, fall back to internal
* implementation when disabled
*/
#ifndef DMLC_LOGGING_H_
#define DMLC_LOGGING_H_
#include <cstdio>
#include <cstdlib>
#include <string>
#include <vector>
#include <stdexcept>
#include "./base.h"
namespace dmlc {
/*!
* \brief exception class that will be thrown by
* default logger if DMLC_LOG_FATAL_THROW == 1
*/
struct Error : public std::runtime_error {
/*!
* \brief constructor
* \param s the error message
*/
explicit Error(const std::string &s) : std::runtime_error(s) {}
};
} // namespace dmlc
#if DMLC_USE_GLOG
#include <glog/logging.h>
namespace dmlc {
/*!
* \brief optionally redirect to google's init log
* \param argv0 The arguments.
*/
inline void InitLogging(const char* argv0) {
google::InitGoogleLogging(argv0);
}
} // namespace dmlc
#else
// use a light version of glog
#include <assert.h>
#include <iostream>
#include <sstream>
#include <ctime>
#if defined(_MSC_VER)
#pragma warning(disable : 4722)
#endif
namespace dmlc {
inline void InitLogging(const char* argv0) {
// DO NOTHING
}
// Always-on checking
#define CHECK(x) \
if (!(x)) \
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check " \
"failed: " #x << ' '
#define CHECK_LT(x, y) CHECK((x) < (y))
#define CHECK_GT(x, y) CHECK((x) > (y))
#define CHECK_LE(x, y) CHECK((x) <= (y))
#define CHECK_GE(x, y) CHECK((x) >= (y))
#define CHECK_EQ(x, y) CHECK((x) == (y))
#define CHECK_NE(x, y) CHECK((x) != (y))
#define CHECK_NOTNULL(x) \
((x) == NULL ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*)
// Debug-only checking.
#ifdef NDEBUG
#define DCHECK(x) \
while (false) CHECK(x)
#define DCHECK_LT(x, y) \
while (false) CHECK((x) < (y))
#define DCHECK_GT(x, y) \
while (false) CHECK((x) > (y))
#define DCHECK_LE(x, y) \
while (false) CHECK((x) <= (y))
#define DCHECK_GE(x, y) \
while (false) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) \
while (false) CHECK((x) == (y))
#define DCHECK_NE(x, y) \
while (false) CHECK((x) != (y))
#else
#define DCHECK(x) CHECK(x)
#define DCHECK_LT(x, y) CHECK((x) < (y))
#define DCHECK_GT(x, y) CHECK((x) > (y))
#define DCHECK_LE(x, y) CHECK((x) <= (y))
#define DCHECK_GE(x, y) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) CHECK((x) == (y))
#define DCHECK_NE(x, y) CHECK((x) != (y))
#endif // NDEBUG
#if DMLC_LOG_CUSTOMIZE
#define LOG_INFO dmlc::CustomLogMessage(__FILE__, __LINE__)
#else
#define LOG_INFO dmlc::LogMessage(__FILE__, __LINE__)
#endif
#define LOG_ERROR LOG_INFO
#define LOG_WARNING LOG_INFO
#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__)
#define LOG_QFATAL LOG_FATAL
// Poor man version of VLOG
#define VLOG(x) LOG_INFO.stream()
#define LOG(severity) LOG_##severity.stream()
#define LG LOG_INFO.stream()
#define LOG_IF(severity, condition) \
!(condition) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity)
#ifdef NDEBUG
#define LOG_DFATAL LOG_ERROR
#define DFATAL ERROR
#define DLOG(severity) true ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity)
#define DLOG_IF(severity, condition) \
(true || !(condition)) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity)
#else
#define LOG_DFATAL LOG_FATAL
#define DFATAL FATAL
#define DLOG(severity) LOG(severity)
#define DLOG_IF(severity, condition) LOG_IF(severity, condition)
#endif
// Poor man version of LOG_EVERY_N
#define LOG_EVERY_N(severity, n) LOG(severity)
class DateLogger {
public:
DateLogger() {
#if defined(_MSC_VER)
_tzset();
#endif
}
const char* HumanDate() {
#if defined(_MSC_VER)
_strtime_s(buffer_, sizeof(buffer_));
#else
time_t time_value = time(NULL);
struct tm *pnow;
#if !defined(_WIN32)
struct tm now;
pnow = localtime_r(&time_value, &now);
#else
pnow = localtime(&time_value); // NOLINT(*)
#endif
snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d",
pnow->tm_hour, pnow->tm_min, pnow->tm_sec);
#endif
return buffer_;
}
private:
char buffer_[9];
};
class LogMessage {
public:
LogMessage(const char* file, int line)
:
#ifdef __ANDROID__
log_stream_(std::cout)
#else
log_stream_(std::cerr)
#endif
{
log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":"
<< line << ": ";
}
~LogMessage() { log_stream_ << '\n'; }
std::ostream& stream() { return log_stream_; }
protected:
std::ostream& log_stream_;
private:
DateLogger pretty_date_;
LogMessage(const LogMessage&);
void operator=(const LogMessage&);
};
// customized logger that can allow user to define where to log the message.
class CustomLogMessage {
public:
CustomLogMessage(const char* file, int line) {
log_stream_ << "[" << DateLogger().HumanDate() << "] " << file << ":"
<< line << ": ";
}
~CustomLogMessage() {
Log(log_stream_.str());
}
std::ostream& stream() { return log_stream_; }
/*!
* \brief customized logging of the message.
* This function won't be implemented by libdmlc
* \param msg The message to be logged.
*/
static void Log(const std::string& msg);
private:
std::ostringstream log_stream_;
};
#if DMLC_LOG_FATAL_THROW == 0
class LogMessageFatal : public LogMessage {
public:
LogMessageFatal(const char* file, int line) : LogMessage(file, line) {}
~LogMessageFatal() {
log_stream_ << "\n";
abort();
}
private:
LogMessageFatal(const LogMessageFatal&);
void operator=(const LogMessageFatal&);
};
#else
class LogMessageFatal {
public:
LogMessageFatal(const char* file, int line) {
log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":"
<< line << ": ";
}
std::ostringstream &stream() { return log_stream_; }
~LogMessageFatal() DMLC_THROW_EXCEPTION {
// throwing out of destructor is evil
// hopefully we can do it here
// also log the message before throw
#if DMLC_LOG_BEFORE_THROW
LOG(ERROR) << log_stream_.str();
#endif
throw Error(log_stream_.str());
}
private:
std::ostringstream log_stream_;
DateLogger pretty_date_;
LogMessageFatal(const LogMessageFatal&);
void operator=(const LogMessageFatal&);
};
#endif
// This class is used to explicitly ignore values in the conditional
// logging macros. This avoids compiler warnings like "value computed
// is not used" and "statement has no effect".
class LogMessageVoidify {
public:
LogMessageVoidify() {}
// This has to be an operator with a precedence lower than << but
// higher than "?:". See its usage.
void operator&(std::ostream&) {}
};
} // namespace dmlc
#endif
#endif // DMLC_LOGGING_H_
/*!
* Copyright (c) 2015 by Contributors
* \file memory.h
* \brief Additional memory hanlding utilities.
*/
#ifndef DMLC_MEMORY_H_
#define DMLC_MEMORY_H_
#include <vector>
#include "./base.h"
#include "./logging.h"
#include "./thread_local.h"
namespace dmlc {
/*!
* \brief A memory pool that allocate memory of fixed size and alignment.
* \tparam size The size of each piece.
* \tparam align The alignment requirement of the memory.
*/
template<size_t size, size_t align>
class MemoryPool {
public:
/*! \brief constructor */
MemoryPool() {
static_assert(align % alignof(LinkedList) == 0,
"alignment requirement failed.");
curr_page_.reset(new Page());
}
/*! \brief allocate a new memory of size */
inline void* allocate() {
if (head_ != nullptr) {
LinkedList* ret = head_;
head_ = head_->next;
return ret;
} else {
if (page_ptr_ < kPageSize) {
return &(curr_page_->data[page_ptr_++]);
} else {
allocated_.push_back(std::move(curr_page_));
curr_page_.reset(new Page());
page_ptr_ = 1;
return &(curr_page_->data[0]);
}
}
}
/*!
* \brief deallocate a piece of memory
* \param p The pointer to the memory to be de-allocated.
*/
inline void deallocate(void* p) {
LinkedList* ptr = static_cast<LinkedList*>(p);
ptr->next = head_;
head_ = ptr;
}
private:
// page size of each member
static const int kPageSize = ((1 << 22) / size);
// page to be requested.
struct Page {
typename std::aligned_storage<size, align>::type data[kPageSize];
};
// internal linked list structure.
struct LinkedList {
LinkedList* next{nullptr};
};
// head of free list
LinkedList* head_{nullptr};
// current free page
std::unique_ptr<Page> curr_page_;
// pointer to the current free page position.
size_t page_ptr_{0};
// allocated pages.
std::vector<std::unique_ptr<Page> > allocated_;
};
/*!
* \brief A thread local allocator that get memory from a threadlocal memory pool.
* This is suitable to allocate objects that do not cross thread.
* \tparam T the type of the data to be allocated.
*/
template<typename T>
class ThreadlocalAllocator {
public:
/*! \brief pointer type */
typedef T* pointer;
/*! \brief const pointer type */
typedef const T* const_ptr;
/*! \brief value type */
typedef T value_type;
/*! \brief default constructor */
ThreadlocalAllocator() {}
/*!
* \brief constructor from another allocator
* \param other another allocator
* \tparam U another type
*/
template<typename U>
ThreadlocalAllocator(const ThreadlocalAllocator<U>& other) {}
/*!
* \brief allocate memory
* \param n number of blocks
* \return an uninitialized memory of type T.
*/
inline T* allocate(size_t n) {
CHECK_EQ(n, 1);
typedef ThreadLocalStore<MemoryPool<sizeof(T), alignof(T)> > Store;
return static_cast<T*>(Store::Get()->allocate());
}
/*!
* \brief deallocate memory
* \param p a memory to be returned.
* \param n number of blocks
*/
inline void deallocate(T* p, size_t n) {
CHECK_EQ(n, 1);
typedef ThreadLocalStore<MemoryPool<sizeof(T), alignof(T)> > Store;
Store::Get()->deallocate(p);
}
};
/*!
* \brief a shared pointer like type that allocate object
* from a threadlocal object pool. This object is not thread-safe
* but can be faster than shared_ptr in certain usecases.
* \tparam T the data type.
*/
template<typename T>
struct ThreadlocalSharedPtr {
public:
/*! \brief default constructor */
ThreadlocalSharedPtr() : block_(nullptr) {}
/*!
* \brief constructor from nullptr
* \param other the nullptr type
*/
ThreadlocalSharedPtr(std::nullptr_t other) : block_(nullptr) {} // NOLINT(*)
/*!
* \brief copy constructor
* \param other another pointer.
*/
ThreadlocalSharedPtr(const ThreadlocalSharedPtr<T>& other)
: block_(other.block_) {
IncRef(block_);
}
/*!
* \brief move constructor
* \param other another pointer.
*/
ThreadlocalSharedPtr(ThreadlocalSharedPtr<T>&& other)
: block_(other.block_) {
other.block_ = nullptr;
}
/*!
* \brief destructor
*/
~ThreadlocalSharedPtr() {
DecRef(block_);
}
/*!
* \brief move assignment
* \param other another object to be assigned.
* \return self.
*/
inline ThreadlocalSharedPtr<T>& operator=(ThreadlocalSharedPtr<T>&& other) {
DecRef(block_);
block_ = other.block_;
other.block_ = nullptr;
return *this;
}
/*!
* \brief copy assignment
* \param other another object to be assigned.
* \return self.
*/
inline ThreadlocalSharedPtr<T> &operator=(const ThreadlocalSharedPtr<T>& other) {
DecRef(block_);
block_ = other.block_;
IncRef(block_);
return *this;
}
/*! \brief check if nullptr */
inline bool operator==(std::nullptr_t other) const {
return block_ == nullptr;
}
/*!
* \return get the pointer content.
*/
inline T* get() const {
if (block_ == nullptr) return nullptr;
return reinterpret_cast<T*>(&(block_->data));
}
/*!
* \brief reset the pointer to nullptr.
*/
inline void reset() {
DecRef(block_);
block_ = nullptr;
}
/*! \return if use_count == 1*/
inline bool unique() const {
if (block_ == nullptr) return false;
return block_->use_count_ == 1;
}
/*! \return dereference pointer */
inline T* operator*() const {
return reinterpret_cast<T*>(&(block_->data));
}
/*! \return dereference pointer */
inline T* operator->() const {
return reinterpret_cast<T*>(&(block_->data));
}
/*!
* \brief create a new space from threadlocal storage and return it.
* \tparam Args the arguments.
* \param args The input argument
* \return the allocated pointer.
*/
template <typename... Args>
inline static ThreadlocalSharedPtr<T> Create(Args&&... args) {
ThreadlocalAllocator<RefBlock> arena;
ThreadlocalSharedPtr<T> p;
p.block_ = arena.allocate(1);
p.block_->use_count_ = 1;
new (&(p.block_->data)) T(std::forward<Args>(args)...);
return p;
}
private:
// internal reference block
struct RefBlock {
typename std::aligned_storage<sizeof(T), alignof(T)>::type data;
unsigned use_count_;
};
// decrease ref counter
inline static void DecRef(RefBlock* block) {
if (block != nullptr) {
if (--block->use_count_ == 0) {
ThreadlocalAllocator<RefBlock> arena;
T* dptr = reinterpret_cast<T*>(&(block->data));
dptr->~T();
arena.deallocate(block, 1);
}
}
}
// increase ref counter
inline static void IncRef(RefBlock* block) {
if (block != nullptr) {
++block->use_count_;
}
}
// internal block
RefBlock *block_;
};
} // namespace dmlc
#endif // DMLC_MEMORY_H_
/*!
* Copyright (c) 2015 by Contributors
* \file parameter.h
* \brief Provide lightweight util to do parameter setup and checking.
*/
#ifndef DMLC_PARAMETER_H_
#define DMLC_PARAMETER_H_
#include <cstddef>
#include <cstdlib>
#include <sstream>
#include <limits>
#include <map>
#include <set>
#include <typeinfo>
#include <string>
#include <vector>
#include <algorithm>
#include <utility>
#include <iostream>
#include "./base.h"
#include "./json.h"
#include "./logging.h"
#include "./type_traits.h"
namespace dmlc {
// this file is backward compatible with non-c++11
/*! \brief Error throwed by parameter checking */
struct ParamError : public dmlc::Error {
/*!
* \brief constructor
* \param msg error message
*/
explicit ParamError(const std::string &msg)
: dmlc::Error(msg) {}
};
/*!
* \brief Get environment variable with default.
* \param key the name of environment variable.
* \param default_value the default value of environment vriable.
* \return The value received
*/
template<typename ValueType>
inline ValueType GetEnv(const char *key,
ValueType default_value);
/*! \brief internal namespace for parameter manangement */
namespace parameter {
// forward declare ParamManager
class ParamManager;
// forward declare FieldAccessEntry
class FieldAccessEntry;
// forward declare FieldEntry
template<typename DType>
class FieldEntry;
// forward declare ParamManagerSingleton
template<typename PType>
struct ParamManagerSingleton;
} // namespace parameter
/*!
* \brief Information about a parameter field in string representations.
*/
struct ParamFieldInfo {
/*! \brief name of the field */
std::string name;
/*! \brief type of the field in string format */
std::string type;
/*!
* \brief detailed type information string
* This include the default value, enum constran and typename.
*/
std::string type_info_str;
/*! \brief detailed description of the type */
std::string description;
};
/*!
* \brief Parameter is the base type every parameter struct should inheritate from
* The following code is a complete example to setup parameters.
* \code
* struct Param : public dmlc::Parameter<Param> {
* float learning_rate;
* int num_hidden;
* std::string name;
* // declare parameters in header file
* DMLC_DECLARE_PARAMETER(Param) {
* DMLC_DECLARE_FIELD(num_hidden).set_range(0, 1000);
* DMLC_DECLARE_FIELD(learning_rate).set_default(0.01f);
* DMLC_DECLARE_FIELD(name).set_default("hello");
* }
* };
* // register it in cc file
* DMLC_REGISTER_PARAMETER(Param);
* \endcode
*
* After that, the Param struct will get all the functions defined in Parameter.
* \tparam PType the type of parameter struct
*
* \sa DMLC_DECLARE_FIELD, DMLC_REGISTER_PARAMETER, DMLC_DECLARE_PARAMETER
*/
template<typename PType>
struct Parameter {
public:
/*!
* \brief initialize the parameter by keyword arguments.
* This function will initialize the parameter struct, check consistency
* and throw error if something wrong happens.
*
* \param kwargs map of keyword arguments, or vector of pairs
* \tparam Container container type
* \throw ParamError when something go wrong.
*/
template<typename Container>
inline void Init(const Container &kwargs) {
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(), NULL);
}
/*!
* \brief initialize the parameter by keyword arguments.
* This is same as Init, but allow unknown arguments.
*
* \param kwargs map of keyword arguments, or vector of pairs
* \tparam Container container type
* \throw ParamError when something go wrong.
* \return vector of pairs of unknown arguments.
*/
template<typename Container>
inline std::vector<std::pair<std::string, std::string> >
InitAllowUnknown(const Container &kwargs) {
std::vector<std::pair<std::string, std::string> > unknown;
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(), &unknown);
return unknown;
}
/*!
* \brief Return a dictionary representation of the parameters
* \return A dictionary that maps key -> value
*/
inline std::map<std::string, std::string> __DICT__() const {
std::vector<std::pair<std::string, std::string> > vec
= PType::__MANAGER__()->GetDict(this->head());
return std::map<std::string, std::string>(vec.begin(), vec.end());
}
/*!
* \brief Write the parameters in JSON format.
* \param writer JSONWriter used for writing.
*/
inline void Save(dmlc::JSONWriter *writer) const {
writer->Write(this->__DICT__());
}
/*!
* \brief Load the parameters from JSON.
* \param reader JSONReader used for loading.
* \throw ParamError when something go wrong.
*/
inline void Load(dmlc::JSONReader *reader) {
std::map<std::string, std::string> kwargs;
reader->Read(&kwargs);
this->Init(kwargs);
}
/*!
* \brief Get the fields of the parameters.
* \return List of ParamFieldInfo of each field.
*/
inline static std::vector<ParamFieldInfo> __FIELDS__() {
return PType::__MANAGER__()->GetFieldInfo();
}
/*!
* \brief Print docstring of the parameter
* \return the printed docstring
*/
inline static std::string __DOC__() {
std::ostringstream os;
PType::__MANAGER__()->PrintDocString(os);
return os.str();
}
protected:
/*!
* \brief internal function to allow declare of a parameter memember
* \param manager the parameter manager
* \param key the key name of the parameter
* \param ref the reference to the parameter in the struct.
*/
template<typename DType>
inline parameter::FieldEntry<DType>& DECLARE(
parameter::ParamManagerSingleton<PType> *manager,
const std::string &key, DType &ref) { // NOLINT(*)
parameter::FieldEntry<DType> *e =
new parameter::FieldEntry<DType>();
e->Init(key, this->head(), ref);
manager->manager.AddEntry(key, e);
return *e;
}
private:
/*! \return Get head pointer of child structure */
inline PType *head() const {
return static_cast<PType*>(const_cast<Parameter<PType>*>(this));
}
};
//! \cond Doxygen_Suppress
/*!
* \brief macro used to declare parameter
*
* Example:
* \code
* struct Param : public dmlc::Parameter<Param> {
* // declare parameters in header file
* DMLC_DECLARE_PARAMETER(Param) {
* // details of declarations
* }
* };
* \endcode
*
* This macro need to be put in a source file so that registeration only happens once.
* Refer to example code in Parameter for details
*
* \param PType the name of parameter struct.
* \sa Parameter
*/
#define DMLC_DECLARE_PARAMETER(PType) \
static ::dmlc::parameter::ParamManager *__MANAGER__(); \
inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \
/*!
* \brief macro to declare fields
* \param FieldName the name of the field.
*/
#define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName)
/*!
* \brief macro to declare alias of a fields
* \param FieldName the name of the field.
* \param AliasName the name of the alias, must be declared after the field is declared.
*/
#define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName)
/*!
* \brief Macro used to register parameter.
*
* This macro need to be put in a source file so that registeration only happens once.
* Refer to example code in Parameter for details
* \param PType the type of parameter struct.
* \sa Parameter
*/
#define DMLC_REGISTER_PARAMETER(PType) \
::dmlc::parameter::ParamManager *PType::__MANAGER__() { \
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
return &inst.manager; \
} \
static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \
(*PType::__MANAGER__()) \
//! \endcond
/*!
* \brief internal namespace for parameter manangement
* There is no need to use it directly in normal case
*/
namespace parameter {
/*!
* \brief FieldAccessEntry interface to help manage the parameters
* Each entry can be used to access one parameter in the Parameter struct.
*
* This is an internal interface used that is used to manage parameters
*/
class FieldAccessEntry {
public:
FieldAccessEntry()
: has_default_(false) {}
/*! \brief destructor */
virtual ~FieldAccessEntry() {}
/*!
* \brief set the default value.
* \param head the pointer to the head of the struct
* \throw error if no default is presented
*/
virtual void SetDefault(void *head) const = 0;
/*!
* \brief set the parameter by string value
* \param head the pointer to the head of the struct
* \param value the value to be set
*/
virtual void Set(void *head, const std::string &value) const = 0;
// check if value is OK
virtual void Check(void *head) const {}
/*!
* \brief get the string representation of value.
* \param head the pointer to the head of the struct
*/
virtual std::string GetStringValue(void *head) const = 0;
/*!
* \brief Get field information
* \return the corresponding field information
*/
virtual ParamFieldInfo GetFieldInfo() const = 0;
protected:
/*! \brief whether this parameter have default value */
bool has_default_;
/*! \brief positional index of parameter in struct */
size_t index_;
/*! \brief parameter key name */
std::string key_;
/*! \brief parameter type */
std::string type_;
/*! \brief description of the parameter */
std::string description_;
/*!
* \brief print string representation of default value
* \parma os the stream to print the docstring to.
*/
virtual void PrintDefaultValueString(std::ostream &os) const = 0; // NOLINT(*)
// allow ParamManager to modify self
friend class ParamManager;
};
/*!
* \brief manager class to handle parameter setting for each type
* An manager will be created for each parameter types.
*/
class ParamManager {
public:
/*! \brief destructor */
~ParamManager() {
for (size_t i = 0; i < entry_.size(); ++i) {
delete entry_[i];
}
}
/*!
* \brief find the access entry by parameter key
* \param key the key of the parameter.
* \return pointer to FieldAccessEntry, NULL if nothing is found.
*/
inline FieldAccessEntry *Find(const std::string &key) const {
std::map<std::string, FieldAccessEntry*>::const_iterator it =
entry_map_.find(key);
if (it == entry_map_.end()) return NULL;
return it->second;
}
/*!
* \brief set parameter by keyword arguments.
* \param head head to the parameter field.
* \param begin begin iterator of original kwargs
* \param end end iterator of original kwargs
* \param unknown_args optional, used to hold unknown arguments
* When it is specified, unknown arguments will be stored into here, instead of raise an error
* \tparam RandomAccessIterator iterator type
* \throw ParamError when there is unknown argument and unknown_args == NULL, or required argument is missing.
*/
template<typename RandomAccessIterator>
inline void RunInit(void *head,
RandomAccessIterator begin,
RandomAccessIterator end,
std::vector<std::pair<std::string, std::string> > *unknown_args) const {
std::set<FieldAccessEntry*> selected_args;
for (RandomAccessIterator it = begin; it != end; ++it) {
FieldAccessEntry *e = Find(it->first);
if (e != NULL) {
e->Set(head, it->second);
e->Check(head);
selected_args.insert(e);
} else {
if (unknown_args != NULL) {
unknown_args->push_back(*it);
} else {
std::ostringstream os;
os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
os << "----------------\n";
PrintDocString(os);
throw dmlc::ParamError(os.str());
}
}
}
for (std::map<std::string, FieldAccessEntry*>::const_iterator it = entry_map_.begin();
it != entry_map_.end(); ++it) {
if (selected_args.count(it->second) == 0) {
it->second->SetDefault(head);
}
}
}
/*!
* \brief internal function to add entry to manager,
* The manager will take ownership of the entry.
* \param key the key to the parameters
* \param e the pointer to the new entry.
*/
inline void AddEntry(const std::string &key, FieldAccessEntry *e) {
e->index_ = entry_.size();
// TODO(bing) better error message
if (entry_map_.count(key) != 0) {
LOG(FATAL) << "key " << key << " has already been registered in " << name_;
}
entry_.push_back(e);
entry_map_[key] = e;
}
/*!
* \brief internal function to add entry to manager,
* The manager will take ownership of the entry.
* \param key the key to the parameters
* \param e the pointer to the new entry.
*/
inline void AddAlias(const std::string& field, const std::string& alias) {
if (entry_map_.count(field) == 0) {
LOG(FATAL) << "key " << field << " has not been registered in " << name_;
}
if (entry_map_.count(alias) != 0) {
LOG(FATAL) << "Alias " << alias << " has already been registered in " << name_;
}
entry_map_[alias] = entry_map_[field];
}
/*!
* \brief set the name of parameter manager
* \param name the name to set
*/
inline void set_name(const std::string &name) {
name_ = name;
}
/*!
* \brief get field information of each field.
* \return field information
*/
inline std::vector<ParamFieldInfo> GetFieldInfo() const {
std::vector<ParamFieldInfo> ret(entry_.size());
for (size_t i = 0; i < entry_.size(); ++i) {
ret[i] = entry_[i]->GetFieldInfo();
}
return ret;
}
/*!
* \brief Print readible docstring to ostream, add newline.
* \parma os the stream to print the docstring to.
*/
inline void PrintDocString(std::ostream &os) const { // NOLINT(*)
for (size_t i = 0; i < entry_.size(); ++i) {
ParamFieldInfo info = entry_[i]->GetFieldInfo();
os << info.name << " : " << info.type_info_str << '\n';
if (info.description.length() != 0) {
os << " " << info.description << '\n';
}
}
}
/*!
* \brief Get internal parameters in vector of pairs.
* \param head the head of the struct.
* \param skip_default skip the values that equals default value.
* \return the parameter dictionary.
*/
inline std::vector<std::pair<std::string, std::string> > GetDict(void * head) const {
std::vector<std::pair<std::string, std::string> > ret;
for (std::map<std::string, FieldAccessEntry*>::const_iterator
it = entry_map_.begin(); it != entry_map_.end(); ++it) {
ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head)));
}
return ret;
}
private:
/*! \brief parameter struct name */
std::string name_;
/*! \brief positional list of entries */
std::vector<FieldAccessEntry*> entry_;
/*! \brief map of key to entry */
std::map<std::string, FieldAccessEntry*> entry_map_;
};
//! \cond Doxygen_Suppress
// The following piece of code will be template heavy and less documented
// singleton parameter manager for certain type, used for initialization
template<typename PType>
struct ParamManagerSingleton {
ParamManager manager;
explicit ParamManagerSingleton(const std::string &param_name) {
PType param;
param.__DECLARE__(this);
manager.set_name(param_name);
}
};
// Base class of FieldEntry
// implement set_default
template<typename TEntry, typename DType>
class FieldEntryBase : public FieldAccessEntry {
public:
// entry type
typedef TEntry EntryType;
// implement set value
virtual void Set(void *head, const std::string &value) const {
std::istringstream is(value);
is >> this->Get(head);
if (!is.fail()) {
while (!is.eof()) {
int ch = is.get();
if (ch == EOF) {
is.clear(); break;
}
if (!isspace(ch)) {
is.setstate(std::ios::failbit); break;
}
}
}
if (is.fail()) {
std::ostringstream os;
os << "Invalid Parameter format for " << key_
<< " expect " << type_ << " but value=\'" << value<< '\'';
throw dmlc::ParamError(os.str());
}
}
virtual std::string GetStringValue(void *head) const {
std::ostringstream os;
PrintValue(os, this->Get(head));
return os.str();
}
virtual ParamFieldInfo GetFieldInfo() const {
ParamFieldInfo info;
std::ostringstream os;
info.name = key_;
info.type = type_;
os << type_;
if (has_default_) {
os << ',' << " optional, default=";
PrintDefaultValueString(os);
} else {
os << ", required";
}
info.type_info_str = os.str();
info.description = description_;
return info;
}
// implement set head to default value
virtual void SetDefault(void *head) const {
if (!has_default_) {
std::ostringstream os;
os << "Required parameter " << key_
<< " of " << type_ << " is not presented";
throw dmlc::ParamError(os.str());
} else {
this->Get(head) = default_value_;
}
}
// return reference of self as derived type
inline TEntry &self() {
return *(static_cast<TEntry*>(this));
}
// implement set_default
inline TEntry &set_default(const DType &default_value) {
default_value_ = default_value;
has_default_ = true;
// return self to allow chaining
return this->self();
}
// implement describe
inline TEntry &describe(const std::string &description) {
description_ = description;
// return self to allow chaining
return this->self();
}
// initialization function
inline void Init(const std::string &key,
void *head, DType &ref) { // NOLINT(*)
this->key_ = key;
if (this->type_.length() == 0) {
this->type_ = dmlc::type_name<DType>();
}
this->offset_ = ((char*)&ref) - ((char*)head); // NOLINT(*)
}
protected:
// print the value
virtual void PrintValue(std::ostream &os, DType value) const { // NOLINT(*)
os << value;
}
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
PrintValue(os, default_value_);
}
// get the internal representation of parameter
// for example if this entry corresponds field param.learning_rate
// then Get(&param) will return reference to param.learning_rate
inline DType &Get(void *head) const {
return *(DType*)((char*)(head) + offset_); // NOLINT(*)
}
// internal offset of the field
ptrdiff_t offset_;
// default value of field
DType default_value_;
};
// parameter base for numeric types that have range
template<typename TEntry, typename DType>
class FieldEntryNumeric
: public FieldEntryBase<TEntry, DType> {
public:
FieldEntryNumeric()
: has_begin_(false), has_end_(false) {}
// implement set_range
virtual TEntry &set_range(DType begin, DType end) {
begin_ = begin; end_ = end;
has_begin_ = true; has_end_ = true;
return this->self();
}
// implement set_range
virtual TEntry &set_lower_bound(DType begin) {
begin_ = begin; has_begin_ = true;
return this->self();
}
// consistency check for numeric ranges
virtual void Check(void *head) const {
FieldEntryBase<TEntry, DType>::Check(head);
DType v = this->Get(head);
if (has_begin_ && has_end_) {
if (v < begin_ || v > end_) {
std::ostringstream os;
os << "value " << v << "for Parameter " << this->key_
<< " exceed bound [" << begin_ << ',' << end_ <<']';
throw dmlc::ParamError(os.str());
}
} else if (has_begin_ && v < begin_) {
std::ostringstream os;
os << "value " << v << "for Parameter " << this->key_
<< " should be greater equal to " << begin_;
throw dmlc::ParamError(os.str());
} else if (has_end_ && v > end_) {
std::ostringstream os;
os << "value " << v << "for Parameter " << this->key_
<< " should be smaller equal to " << end_;
throw dmlc::ParamError(os.str());
}
}
protected:
// whether it have begin and end range
bool has_begin_, has_end_;
// data bound
DType begin_, end_;
};
/*!
* \brief FieldEntry defines parsing and checking behavior of DType.
* This class can be specialized to implement specific behavior of more settings.
* \tparam DType the data type of the entry.
*/
template<typename DType>
class FieldEntry :
public IfThenElseType<dmlc::is_arithmetic<DType>::value,
FieldEntryNumeric<FieldEntry<DType>, DType>,
FieldEntryBase<FieldEntry<DType>, DType> >::Type {
};
// specialize define for int(enum)
template<>
class FieldEntry<int>
: public FieldEntryNumeric<FieldEntry<int>, int> {
public:
// construct
FieldEntry<int>() : is_enum_(false) {}
// parent
typedef FieldEntryNumeric<FieldEntry<int>, int> Parent;
// override set
virtual void Set(void *head, const std::string &value) const {
if (is_enum_) {
std::map<std::string, int>::const_iterator it = enum_map_.find(value);
std::ostringstream os;
if (it == enum_map_.end()) {
os << "Invalid Input: \'" << value;
os << "\', valid values are: ";
PrintEnums(os);
throw dmlc::ParamError(os.str());
} else {
os << it->second;
Parent::Set(head, os.str());
}
} else {
Parent::Set(head, value);
}
}
virtual ParamFieldInfo GetFieldInfo() const {
if (is_enum_) {
ParamFieldInfo info;
std::ostringstream os;
info.name = key_;
info.type = type_;
PrintEnums(os);
if (has_default_) {
os << ',' << "optional, default=";
PrintDefaultValueString(os);
} else {
os << ", required";
}
info.type_info_str = os.str();
info.description = description_;
return info;
} else {
return Parent::GetFieldInfo();
}
}
// add enum
inline FieldEntry<int> &add_enum(const std::string &key, int value) {
if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
enum_back_map_.count(value) != 0) {
std::ostringstream os;
os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n";
os << "Enums: ";
for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
it != enum_map_.end(); ++it) {
os << "(" << it->first << ": " << it->second << "), ";
}
throw dmlc::ParamError(os.str());
}
enum_map_[key] = value;
enum_back_map_[value] = key;
is_enum_ = true;
return this->self();
}
protected:
// enum flag
bool is_enum_;
// enum map
std::map<std::string, int> enum_map_;
// enum map
std::map<int, std::string> enum_back_map_;
// override print behavior
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
os << '\'';
PrintValue(os, default_value_);
os << '\'';
}
// override print default
virtual void PrintValue(std::ostream &os, int value) const { // NOLINT(*)
if (is_enum_) {
CHECK_NE(enum_back_map_.count(value), 0)
<< "Value not found in enum declared";
os << enum_back_map_.at(value);
} else {
os << value;
}
}
private:
inline void PrintEnums(std::ostream &os) const { // NOLINT(*)
os << '{';
for (std::map<std::string, int>::const_iterator
it = enum_map_.begin(); it != enum_map_.end(); ++it) {
if (it != enum_map_.begin()) {
os << ", ";
}
os << "\'" << it->first << '\'';
}
os << '}';
}
};
// specialize define for string
template<>
class FieldEntry<std::string>
: public FieldEntryBase<FieldEntry<std::string>, std::string> {
public:
// parent class
typedef FieldEntryBase<FieldEntry<std::string>, std::string> Parent;
// override set
virtual void Set(void *head, const std::string &value) const {
this->Get(head) = value;
}
// override print default
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
os << '\'' << default_value_ << '\'';
}
};
// specialize define for bool
template<>
class FieldEntry<bool>
: public FieldEntryBase<FieldEntry<bool>, bool> {
public:
// parent class
typedef FieldEntryBase<FieldEntry<bool>, bool> Parent;
// override set
virtual void Set(void *head, const std::string &value) const {
std::string lower_case; lower_case.resize(value.length());
std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower);
bool &ref = this->Get(head);
if (lower_case == "true") {
ref = true;
} else if (lower_case == "false") {
ref = false;
} else if (lower_case == "1") {
ref = true;
} else if (lower_case == "0") {
ref = false;
} else {
std::ostringstream os;
os << "Invalid Parameter format for " << key_
<< " expect " << type_ << " but value=\'" << value<< '\'';
throw dmlc::ParamError(os.str());
}
}
protected:
// print default string
virtual void PrintValue(std::ostream &os, bool value) const { // NOLINT(*)
if (value) {
os << "True";
} else {
os << "False";
}
}
};
} // namespace parameter
//! \endcond
// implement GetEnv
template<typename ValueType>
inline ValueType GetEnv(const char *key,
ValueType default_value) {
const char *val = getenv(key);
if (val == NULL) return default_value;
ValueType ret;
parameter::FieldEntry<ValueType> e;
e.Init(key, &ret, ret);
e.Set(&ret, val);
return ret;
}
} // namespace dmlc
#endif // DMLC_PARAMETER_H_
/*!
* Copyright (c) 2015 by Contributors
* \file registry.h
* \brief Registry utility that helps to build registry singletons.
*/
#ifndef DMLC_REGISTRY_H_
#define DMLC_REGISTRY_H_
#include <map>
#include <string>
#include <vector>
#include "./base.h"
#include "./logging.h"
#include "./parameter.h"
#include "./type_traits.h"
namespace dmlc {
/*!
* \brief Registry class.
* Registry can be used to register global singletons.
* The most commonly use case are factory functions.
*
* \tparam EntryType Type of Registry entries,
* EntryType need to name a name field.
*/
template<typename EntryType>
class Registry {
public:
/*! \return list of functions in the registry */
inline static const std::vector<const EntryType*> &List() {
return Get()->entry_list_;
}
/*!
* \brief Find the entry with corresponding name.
* \param name name of the function
* \return the corresponding function, can be NULL
*/
inline static const EntryType *Find(const std::string &name) {
const std::map<std::string, EntryType*> &fmap = Get()->fmap_;
typename std::map<std::string, EntryType*>::const_iterator p = fmap.find(name);
if (p != fmap.end()) {
return p->second;
} else {
return NULL;
}
}
/*!
* \brief Internal function to register a name function under name.
* \param name name of the function
* \return ref to the registered entry, used to set properties
*/
inline EntryType &__REGISTER__(const std::string& name) {
CHECK_EQ(fmap_.count(name), 0)
<< name << " already registered";
EntryType *e = new EntryType();
e->name = name;
fmap_[name] = e;
entry_list_.push_back(e);
return *e;
}
/*!
* \brief Internal function to either register or get registered entry
* \param name name of the function
* \return ref to the registered entry, used to set properties
*/
inline EntryType &__REGISTER_OR_GET__(const std::string& name) {
if (fmap_.count(name) == 0) {
return __REGISTER__(name);
} else {
return *fmap_.at(name);
}
}
/*!
* \brief get a singleton of the Registry.
* This function can be defined by DMLC_ENABLE_REGISTRY.
* \return get a singleton
*/
static Registry *Get();
private:
/*! \brief list of entry types */
std::vector<const EntryType*> entry_list_;
/*! \brief map of name->function */
std::map<std::string, EntryType*> fmap_;
/*! \brief constructor */
Registry() {}
/*! \brief destructor */
~Registry() {
for (typename std::map<std::string, EntryType*>::iterator p = fmap_.begin();
p != fmap_.end(); ++p) {
delete p->second;
}
}
};
/*!
* \brief Common base class for function registry.
*
* \code
* // This example demonstrates how to use Registry to create a factory of trees.
* struct TreeFactory :
* public FunctionRegEntryBase<TreeFactory, std::function<Tree*()> > {
* };
*
* // in a independent cc file
* namespace dmlc {
* DMLC_REGISTRY_ENABLE(TreeFactory);
* }
* // register binary tree constructor into the registry.
* DMLC_REGISTRY_REGISTER(TreeFactory, TreeFactory, BinaryTree)
* .describe("Constructor of BinaryTree")
* .set_body([]() { return new BinaryTree(); });
* \endcode
*
* \tparam EntryType The type of subclass that inheritate the base.
* \tparam FunctionType The function type this registry is registerd.
*/
template<typename EntryType, typename FunctionType>
class FunctionRegEntryBase {
public:
/*! \brief name of the entry */
std::string name;
/*! \brief description of the entry */
std::string description;
/*! \brief additional arguments to the factory function */
std::vector<ParamFieldInfo> arguments;
/*! \brief Function body to create ProductType */
FunctionType body;
/*! \brief Return type of the function */
std::string return_type;
/*!
* \brief Set the function body.
* \param body Function body to set.
* \return reference to self.
*/
inline EntryType &set_body(FunctionType body) {
this->body = body;
return this->self();
}
/*!
* \brief Describe the function.
* \param description The description of the factory function.
* \return reference to self.
*/
inline EntryType &describe(const std::string &description) {
this->description = description;
return this->self();
}
/*!
* \brief Add argument information to the function.
* \param name Name of the argument.
* \param type Type of the argument.
* \param description Description of the argument.
* \return reference to self.
*/
inline EntryType &add_argument(const std::string &name,
const std::string &type,
const std::string &description) {
ParamFieldInfo info;
info.name = name;
info.type = type;
info.type_info_str = info.type;
info.description = description;
arguments.push_back(info);
return this->self();
}
/*!
* \brief Append list if arguments to the end.
* \param args Additional list of arguments.
* \return reference to self.
*/
inline EntryType &add_arguments(const std::vector<ParamFieldInfo> &args) {
arguments.insert(arguments.end(), args.begin(), args.end());
return this->self();
}
/*!
* \brief Set the return type.
* \param type Return type of the function, could be Symbol or Symbol[]
* \return reference to self.
*/
inline EntryType &set_return_type(const std::string &type) {
return_type = type;
return this->self();
}
protected:
/*!
* \return reference of self as derived type
*/
inline EntryType &self() {
return *(static_cast<EntryType*>(this));
}
};
/*!
* \def DMLC_REGISTRY_ENABLE
* \brief Macro to enable the registry of EntryType.
* This macro must be used under namespace dmlc, and only used once in cc file.
* \param EntryType Type of registry entry
*/
#define DMLC_REGISTRY_ENABLE(EntryType) \
template<> \
Registry<EntryType > *Registry<EntryType >::Get() { \
static Registry<EntryType > inst; \
return &inst; \
} \
/*!
* \brief Generic macro to register an EntryType
* There is a complete example in FactoryRegistryEntryBase.
*
* \param EntryType The type of registry entry.
* \param EntryTypeName The typename of EntryType, must do not contain namespace :: .
* \param Name The name to be registered.
* \sa FactoryRegistryEntryBase
*/
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
static EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \
/*!
* \brief (Optional) Declare a file tag to current file that contains object registrations.
*
* This will declare a dummy function that will be called by register file to
* incur a link dependency.
*
* \param UniqueTag The unique tag used to represent.
* \sa DMLC_REGISTRY_LINK_TAG
*/
#define DMLC_REGISTRY_FILE_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __() { return 0; }
/*!
* \brief (Optional) Force link to all the objects registered in file tag.
*
* This macro must be used in the same file as DMLC_REGISTRY_ENABLE and
* in the same namespace as DMLC_REGISTRY_FILE_TAG
*
* DMLC_REGISTRY_FILE_TAG and DMLC_REGISTRY_LINK_TAG are optional macros for registration.
* They are used to encforce link of certain file into during static linking.
*
* This is mainly used to solve problem during statically link a library which contains backward registration.
* Specifically, this avoids the objects in these file tags to be ignored by compiler.
*
* For dynamic linking, this problem won't occur as everything is loaded by default.
*
* Use of this is optional as it will create an error when a file tag do not exist.
* An alternative solution is always ask user to enable --whole-archieve during static link.
*
* \begincode
* // in file objective_registry.cc
* DMLC_REGISTRY_ENABLE(MyObjective);
* DMLC_REGISTRY_LINK_TAG(regression_op);
* DMLC_REGISTRY_LINK_TAG(rank_op);
*
* // in file regression_op.cc
* // declare tag of this file.
* DMLC_REGISTRY_FILE_TAG(regression_op);
* DMLC_REGISTRY_REGISTER(MyObjective, logistic_reg, logistic_reg);
* // ...
*
* // in file rank_op.cc
* // declare tag of this file.
* DMLC_REGISTRY_FILE_TAG(rank_op);
* DMLC_REGISTRY_REGISTER(MyObjective, pairwiserank, pairwiserank);
*
* \endcode
*
* \param UniqueTag The unique tag used to represent.
* \sa DMLC_REGISTRY_ENABLE, DMLC_REGISTRY_FILE_TAG
*/
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __();
} // namespace dmlc
#endif // DMLC_REGISTRY_H_
/*!
* Copyright (c) 2015 by Contributors
* \file thread_local.h
* \brief Portable thread local storage.
*/
#ifndef DMLC_THREAD_LOCAL_H_
#define DMLC_THREAD_LOCAL_H_
#include <mutex>
#include <memory>
#include <vector>
namespace dmlc {
// macro hanlding for threadlocal variables
#ifdef __GNUC__
#define MX_TREAD_LOCAL __thread
#elif __STDC_VERSION__ >= 201112L
#define MX_TREAD_LOCAL _Thread_local
#elif defined(_MSC_VER)
#define MX_TREAD_LOCAL __declspec(thread)
#endif
#ifndef MX_TREAD_LOCAL
#message("Warning: Threadlocal is not enabled");
#endif
/*!
* \brief A threadlocal store to store threadlocal variables.
* Will return a thread local singleton of type T
* \tparam T the type we like to store
*/
template<typename T>
class ThreadLocalStore {
public:
/*! \return get a thread local singleton */
static T* Get() {
static MX_TREAD_LOCAL T* ptr = nullptr;
if (ptr == nullptr) {
ptr = new T();
Singleton()->RegisterDelete(ptr);
}
return ptr;
}
private:
/*! \brief constructor */
ThreadLocalStore() {}
/*! \brief destructor */
~ThreadLocalStore() {
for (size_t i = 0; i < data_.size(); ++i) {
delete data_[i];
}
}
/*! \return singleton of the store */
static ThreadLocalStore<T> *Singleton() {
static ThreadLocalStore<T> inst;
return &inst;
}
/*!
* \brief register str for internal deletion
* \param str the string pointer
*/
void RegisterDelete(T *str) {
std::unique_lock<std::mutex> lock(mutex_);
data_.push_back(str);
lock.unlock();
}
/*! \brief internal mutex */
std::mutex mutex_;
/*!\brief internal data */
std::vector<T*> data_;
};
} // namespace dmlc
#endif // DMLC_THREAD_LOCAL_H_
/*!
* Copyright (c) 2015 by Contributors
* \file timer.h
* \brief cross platform timer for timing
* \author Tianqi Chen
*/
#ifndef DMLC_TIMER_H_
#define DMLC_TIMER_H_
#include "base.h"
#if DMLC_USE_CXX11
#include <chrono>
#endif
#include <time.h>
#ifdef __MACH__
#include <mach/clock.h>
#include <mach/mach.h>
#endif
#include "./logging.h"
namespace dmlc {
/*!
* \brief return time in seconds
*/
inline double GetTime(void) {
#if DMLC_USE_CXX11
return std::chrono::duration<double>(
std::chrono::high_resolution_clock::now().time_since_epoch()).count();
#elif defined __MACH__
clock_serv_t cclock;
mach_timespec_t mts;
host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock);
CHECK(clock_get_time(cclock, &mts) == 0) << "failed to get time";
mach_port_deallocate(mach_task_self(), cclock);
return static_cast<double>(mts.tv_sec) + static_cast<double>(mts.tv_nsec) * 1e-9;
#else
#if defined(__unix__) || defined(__linux__)
timespec ts;
CHECK(clock_gettime(CLOCK_REALTIME, &ts) == 0) << "failed to get time";
return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9;
#else
return static_cast<double>(time(NULL));
#endif
#endif
}
} // namespace dmlc
#endif // DMLC_TIMER_H_
/*!
* Copyright (c) 2015 by Contributors
* \file type_traits.h
* \brief type traits information header
*/
#ifndef DMLC_TYPE_TRAITS_H_
#define DMLC_TYPE_TRAITS_H_
#include "./base.h"
#if DMLC_USE_CXX11
#include <type_traits>
#endif
#include <string>
namespace dmlc {
/*!
* \brief whether a type is pod type
* \tparam T the type to query
*/
template<typename T>
struct is_pod {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_pod<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = false;
#endif
};
/*!
* \brief whether a type is integer type
* \tparam T the type to query
*/
template<typename T>
struct is_integral {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_integral<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = false;
#endif
};
/*!
* \brief whether a type is floating point type
* \tparam T the type to query
*/
template<typename T>
struct is_floating_point {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_floating_point<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = false;
#endif
};
/*!
* \brief whether a type is arithemetic type
* \tparam T the type to query
*/
template<typename T>
struct is_arithmetic {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_arithmetic<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = (dmlc::is_integral<T>::value ||
dmlc::is_floating_point<T>::value);
#endif
};
/*!
* \brief the string representation of type name
* \tparam T the type to query
* \return a const string of typename.
*/
template<typename T>
inline const char* type_name() {
return "";
}
/*!
* \brief whether a type have save/load function
* \tparam T the type to query
*/
template<typename T>
struct has_saveload {
/*! \brief the value of the traits */
static const bool value = false;
};
/*!
* \brief template to select type based on condition
* For example, IfThenElseType<true, int, float>::Type will give int
* \tparam cond the condition
* \tparam Then the typename to be returned if cond is true
* \tparam The typename to be returned if cond is false
*/
template<bool cond, typename Then, typename Else>
struct IfThenElseType;
/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TRAITS(Trait, Type, Value) \
template<> \
struct Trait<Type> { \
static const bool value = Value; \
}
/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TYPE_NAME(Type, Name) \
template<> \
inline const char* type_name<Type>() { \
return Name; \
}
//! \cond Doxygen_Suppress
// declare special traits when C++11 is not available
#if DMLC_USE_CXX11 == 0
DMLC_DECLARE_TRAITS(is_pod, char, true);
DMLC_DECLARE_TRAITS(is_pod, int8_t, true);
DMLC_DECLARE_TRAITS(is_pod, int16_t, true);
DMLC_DECLARE_TRAITS(is_pod, int32_t, true);
DMLC_DECLARE_TRAITS(is_pod, int64_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint8_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint16_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint32_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint64_t, true);
DMLC_DECLARE_TRAITS(is_pod, float, true);
DMLC_DECLARE_TRAITS(is_pod, double, true);
DMLC_DECLARE_TRAITS(is_integral, char, true);
DMLC_DECLARE_TRAITS(is_integral, int8_t, true);
DMLC_DECLARE_TRAITS(is_integral, int16_t, true);
DMLC_DECLARE_TRAITS(is_integral, int32_t, true);
DMLC_DECLARE_TRAITS(is_integral, int64_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint8_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint16_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint32_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint64_t, true);
DMLC_DECLARE_TRAITS(is_floating_point, float, true);
DMLC_DECLARE_TRAITS(is_floating_point, double, true);
#endif
DMLC_DECLARE_TYPE_NAME(float, "float");
DMLC_DECLARE_TYPE_NAME(double, "double");
DMLC_DECLARE_TYPE_NAME(int, "int");
DMLC_DECLARE_TYPE_NAME(uint32_t, "int (non-negative)");
DMLC_DECLARE_TYPE_NAME(uint64_t, "long (non-negative)");
DMLC_DECLARE_TYPE_NAME(std::string, "string");
DMLC_DECLARE_TYPE_NAME(bool, "boolean");
template<typename Then, typename Else>
struct IfThenElseType<true, Then, Else> {
typedef Then Type;
};
template<typename Then, typename Else>
struct IfThenElseType<false, Then, Else> {
typedef Else Type;
};
//! \endcond
} // namespace dmlc
#endif // DMLC_TYPE_TRAITS_H_
......@@ -9,6 +9,7 @@
#include <vector>
#include <type_traits>
#include <algorithm>
#include <utility>
#include <iostream>
#include "./base.h"
......
......@@ -98,12 +98,12 @@ class SymbolBase(object):
**kwargs
The attributes to set
"""
keys = _base.c_array(_ctypes.c_char_p,
[_base.c_str(key) for key in kwargs.keys()])
vals = _base.c_array(_ctypes.c_char_p,
[_base.c_str(str(val)) for val in kwargs.values()])
num_args = _base.nn_uint(len(kwargs))
_check_call(_LIB.NNSymbolSetAttrs(
keys = c_array(ctypes.c_char_p,
[c_str(key) for key in kwargs.keys()])
vals = c_array(ctypes.c_char_p,
[c_str(str(val)) for val in kwargs.values()])
num_args = nn_uint(len(kwargs))
check_call(_LIB.NNSymbolSetAttrs(
self.handle, num_args, keys, vals))
......
......@@ -27,9 +27,9 @@ def find_lib_path():
elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
if os.name == 'nt':
dll_path = [os.path.join(p, 'libnnvm.dll') for p in dll_path]
dll_path = [os.path.join(p, 'libnnvm_example.dll') for p in dll_path]
else:
dll_path = [os.path.join(p, 'libnnvm.so') for p in dll_path]
dll_path = [os.path.join(p, 'libnnvm_example.so') for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(lib_path) == 0:
raise RuntimeError('Cannot find the files.\n' +
......
import os
import sys
from distutils.core import setup
from Cython.Build import cythonize
from distutils.extension import Extension
def config_cython():
try:
from Cython.Build import cythonize
from distutils.extension import Extension
if sys.version_info >= (3, 0):
subdir = "_cy3"
else:
subdir = "_cy2"
ret = []
path = "nnvm/cython"
def config():
if sys.version_info >= (3, 0):
subdir = "_cy3"
else:
subdir = "_cy2"
ret = []
path = "nnvm/cython"
for fn in os.listdir(path):
if not fn.endswith(".pyx"):
continue
ret.append(Extension(
"nnvm/%s/%s" % (subdir, fn[:-4]),
["nnvm/cython/%s" % fn],
include_dirs=["../include/"],
language="c++"))
return ret
for fn in os.listdir(path):
if not fn.endswith(".pyx"):
continue
ret.append(Extension(
"nnvm/%s/%s" % (subdir, fn[:-4]),
["nnvm/cython/%s" % fn],
include_dirs=["../include/"],
language="c++"))
return cythonize(ret)
except:
print("Cython is not installed, will compile without cython module")
return []
setup(
name='nnvm',
ext_modules = cythonize(config())
ext_modules = config_cython()
)
// Copyright (c) 2016 by Contributors
#include <nnvm/op.h>
#include <nnvm/graph.h>
#include <nnvm/tuple.h>
#include <nnvm/c_api.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass_functions.h>
#include <dmlc/timer.h>
#include <string>
void test_speed() {
auto add = nnvm::Op::Get("add");
double tstart = dmlc::GetTime();
size_t rep = 1000;
size_t n = 1000;
std::unordered_map<std::string, const nnvm::Symbol*> tmp;
std::unordered_map<std::string, std::string> kwargs;
std::vector<const nnvm::Symbol*> vec{2};
std::string name = "xx";
for (size_t t = 0; t < rep; ++t) {
nnvm::Symbol s = nnvm::Symbol::CreateVariable("x");
for (size_t i = 0; i < n; ++i) {
nnvm::Symbol nw = nnvm::Symbol::CreateFunctor(add, kwargs);
vec[0] = &s;
vec[1] =&s;
tmp.clear();
nw.Compose(vec, tmp, name);
s = nw;
}
}
double tend = dmlc::GetTime();
LOG(INFO) << "compose speed = " << n * rep / (tend - tstart) << " ops/sec";
}
void test_node_speed() {
using namespace nnvm;
auto add = nnvm::Op::Get("add");
double tstart = dmlc::GetTime();
size_t rep = 1000;
size_t n = 1000;
for (size_t t = 0; t < rep; ++t) {
nnvm::Symbol s = nnvm::Symbol::CreateVariable("x");
for (size_t i = 0; i < n; ++i) {
auto xx = NodeEntry{Node::Create(), 0, 0};
NodeEntry x = s.outputs[0];
xx.node->op = add;
xx.node->inputs.emplace_back(x);
xx.node->inputs.emplace_back(x);
Symbol ss;
ss.outputs.push_back(xx);
s = ss;
}
}
double tend = dmlc::GetTime();
LOG(INFO) << "test_node_speed speed = " << n * rep / (tend - tstart) << " ops/sec";
}
void test_api_speed() {
auto add = (void*)nnvm::Op::Get("add"); // NOLINT(*)
double tstart = dmlc::GetTime();
size_t rep = 1000;
size_t n = 1000;
std::unordered_map<std::string, const nnvm::Symbol*> tmp;
std::vector<const nnvm::Symbol*> vec{2};
std::string name = "xx";
for (size_t t = 0; t < rep; ++t) {
SymbolHandle s;
NNSymbolCreateVariable("xx", &s);
for (size_t i = 0; i < n; ++i) {
SymbolHandle arg[2];
SymbolHandle ss;
NNSymbolCreateAtomicSymbol(add, 0, nullptr, nullptr, &ss);
arg[0] = s;
arg[1] = s;
NNSymbolCompose(ss, "nn", 2, nullptr, arg);
s = ss;
}
}
double tend = dmlc::GetTime();
LOG(INFO) << "API compose speed = " << n * rep / (tend - tstart) << " ops/sec";
}
int main() {
test_speed();
test_node_speed();
test_api_speed();
return 0;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment