Commit 5cf08d6c by Tianqi Chen

[REFACTOR] copy DMLC headers, move operator to example (#20)

parent 0538a9fc
export LDFLAGS = -pthread -lm export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -msse2 -Wno-unknown-pragmas -funroll-loops\ export CFLAGS = -std=c++11 -Wall -O2 -msse2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -fPIC -Iinclude -fPIC
# specify tensor path # specify tensor path
.PHONY: clean all test lint doc cython cython3 cyclean .PHONY: clean all test lint doc cython cython3 cyclean
all: lib/libnnvm.so lib/libnnvm.a cli_test all: lib/libnnvm.a lib/libnnvm_example.so
SRC = $(wildcard src/*.cc src/*/*.cc example/*.cc) SRC = $(wildcard src/*.cc src/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(filter-out build/test_main.o, $(ALL_OBJ)) ALL_DEP = $(ALL_OBJ)
include tests/cpp/unittest.mk include tests/cpp/unittest.mk
...@@ -20,16 +20,14 @@ build/%.o: src/%.cc ...@@ -20,16 +20,14 @@ build/%.o: src/%.cc
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d $(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@ $(CXX) -c $(CFLAGS) -c $< -o $@
lib/libnnvm.so: $(ALL_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lib/libnnvm.a: $(ALL_DEP) lib/libnnvm.a: $(ALL_DEP)
@mkdir -p $(@D) @mkdir -p $(@D)
ar crv $@ $(filter %.o, $?) ar crv $@ $(filter %.o, $?)
cli_test: $(ALL_DEP) build/test_main.o lib/libnnvm_example.so: example/src/operator.cc lib/libnnvm.a
$(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cc, $^) $(LDFLAGS) -Wl,--whole-archive lib/libnnvm.a -Wl,--no-whole-archive
cython: cython:
cd python; python setup.py build_ext --inplace cd python; python setup.py build_ext --inplace
......
...@@ -96,7 +96,7 @@ NNVM_REGISTER_OP(__add_symbol__) ...@@ -96,7 +96,7 @@ NNVM_REGISTER_OP(__add_symbol__)
.set_num_inputs(2); .set_num_inputs(2);
NNVM_REGISTER_OP(exp) NNVM_REGISTER_OP(exp)
.describe("take exponmential") .describe("take exponential")
.set_num_inputs(1) .set_num_inputs(1)
.attr("inplace_pair", std::make_pair(0, 0)) .attr("inplace_pair", std::make_pair(0, 0))
.attr<FInferShape>("FInferShape", SameShape); .attr<FInferShape>("FInferShape", SameShape);
......
This folder is synced from dmlc-core/include/dmlc
Contains useful utility headers for the project.
/*!
* Copyright (c) 2016 by Contributors
* \file any.h
* \brief Container to hold any data type.
*/
#ifndef DMLC_ANY_H_
#define DMLC_ANY_H_
// This code need c++11 to compile
#include <typeinfo>
#include <type_traits>
#include <utility>
#include <algorithm>
#include "./base.h"
#include "./logging.h"
namespace dmlc {
// forward declare any;
class any;
/*!
* Get a reference to content stored in the any as type T.
* This will cause an error if
* T does not match the type stored.
* This function is not part of std::any standard.
*
* \param src The source source any container.
* \return The reference of content
* \tparam T The type of the value to be fetched.
*/
template<typename T>
inline T& get(any& src); // NOLINT(*)
/*!
* Get the const reference content stored in the any as type T.
* This will cause an error if
* T does not match the type stored.
* This function is not part of std::any standard.
*
* \param src The source source any container.
* \return The reference of content
* \tparam T The type of the value to be fetched.
*/
template<typename T>
inline const T& get(const any& src);
/*!
* \brief An any class that is compatible to std::any in c++17.
*
* \code
* dmlc::any a = std::string("mydear"), b = 1;
* // get reference out and add it
* dmlc::get<int>(b) += 1;
* // a is now string
* LOG(INFO) << dmlc::get<std::string>(a);
* // a is now 2, the string stored will be properly destructed
* a = std::move(b);
* LOG(INFO) << dmlc::get<int>(a);
* \endcode
* \sa get
*/
class any {
public:
/*! \brief default constructor */
inline any() = default;
/*!
* \brief move constructor from another any
* \param other The other any to be moved
*/
inline any(any&& other); // NOLINT(*)
/*!
* \brief copy constructor
* \param other The other any to be copied
*/
inline any(const any& other); // NOLINT(*)
/*!
* \brief constructor from any types
* \param other The other types to be constructed into any.
* \tparam T The value type of other.
*/
template<typename T>
inline any(T&& other); // NOLINT(*)
/*! \brief destructor */
inline ~any();
/*!
* \brief assign operator from other
* \param other The other any to be copy or moved.
* \return self
*/
inline any& operator=(any&& other);
/*!
* \brief assign operator from other
* \param other The other any to be copy or moved.
* \return self
*/
inline any& operator=(const any& other);
/*!
* \brief assign operator from any type.
* \param other The other any to be copy or moved.
* \tparam T The value type of other.
* \return self
*/
template<typename T>
inline any& operator=(T&& other);
/*!
* \return whether the container is empty.
*/
inline bool empty() const;
/*!
* \return clear the content of container
*/
inline void clear();
/*!
* swap current content with other
* \param other The other data to be swapped.
*/
inline void swap(any& other); // NOLINT(*)
/*!
* \return The type_info about the stored type.
*/
inline const std::type_info& type() const;
private:
//! \cond Doxygen_Suppress
// declare of helper class
template<typename T>
class TypeOnHeap;
template<typename T>
class TypeOnStack;
template<typename T>
class TypeInfo;
// size of stack space, it takes 32 bytes for one any type.
static const size_t kStack = sizeof(void*) * 3;
static const size_t kAlign = sizeof(void*);
// container use dynamic storage only when space runs lager
union Data {
// stack space
std::aligned_storage<kStack, kAlign>::type stack;
// pointer to heap space
void* pheap;
};
// type specific information
struct Type {
// destructor function
void (*destroy)(Data* data);
// copy constructor
void (*create_from_data)(Data* dst, const Data& src);
// the type info function
const std::type_info* ptype_info;
};
// constant to check if data can be stored on heap.
template<typename T>
struct data_on_stack {
static const bool value = alignof(T) <= kAlign && sizeof(T) <= kStack;
};
// declare friend with
template<typename T>
friend T& get(any& src); // NOLINT(*)
template<typename T>
friend const T& get(const any& src);
// internal construct function
inline void construct(any&& other);
// internal construct function
inline void construct(const any& other);
// internal function to check if type is correct.
template<typename T>
inline void check_type() const;
// internal type specific information
const Type* type_{nullptr};
// internal data
Data data_;
};
template<typename T>
inline any::any(T&& other) {
typedef typename std::decay<T>::type DT;
if (std::is_same<DT, any>::value) {
this->construct(std::forward<T>(other));
} else {
static_assert(std::is_copy_constructible<DT>::value,
"Any can only hold value that is copy constructable");
type_ = TypeInfo<DT>::get_type();
if (data_on_stack<DT>::value) {
new (&(data_.stack)) DT(std::forward<T>(other));
} else {
data_.pheap = new DT(std::forward<T>(other));
}
}
}
inline any::any(any&& other) {
this->construct(std::move(other));
}
inline any::any(const any& other) {
this->construct(other);
}
inline void any::construct(any&& other) {
type_ = other.type_;
data_ = other.data_;
other.type_ = nullptr;
}
inline void any::construct(const any& other) {
type_ = other.type_;
if (type_ != nullptr) {
type_->create_from_data(&data_, other.data_);
}
}
inline any::~any() {
this->clear();
}
inline any& any::operator=(any&& other) {
any(std::move(other)).swap(*this);
return *this;
}
inline any& any::operator=(const any& other) {
any(other).swap(*this);
return *this;
}
template<typename T>
inline any& any::operator=(T&& other) {
any(std::forward<T>(other)).swap(*this);
return *this;
}
inline void any::swap(any& other) { // NOLINT(*)
std::swap(type_, other.type_);
std::swap(data_, other.data_);
}
inline void any::clear() {
if (type_ != nullptr) {
if (type_->destroy != nullptr) {
type_->destroy(&data_);
}
type_ = nullptr;
}
}
inline bool any::empty() const {
return type_ == nullptr;
}
inline const std::type_info& any::type() const {
if (type_ != nullptr) {
return *(type_->ptype_info);
} else {
return typeid(void);
}
}
template<typename T>
inline void any::check_type() const {
CHECK(type_ != nullptr)
<< "The any container is empty";
CHECK(type_->ptype_info == &typeid(T))
<< "The stored type mismatch"
<< " stored=" << type_->ptype_info->name()
<< " requested=" << typeid(T).name();
}
template<typename T>
inline const T& get(const any& src) {
src.check_type<T>();
return *any::TypeInfo<T>::get_ptr(&(src.data_));
}
template<typename T>
inline T& get(any& src) { // NOLINT(*)
src.check_type<T>();
return *any::TypeInfo<T>::get_ptr(&(src.data_));
}
template<typename T>
class any::TypeOnHeap {
public:
inline static T* get_ptr(any::Data* data) {
return static_cast<T*>(data->pheap);
}
inline static const T* get_ptr(const any::Data* data) {
return static_cast<const T*>(data->pheap);
}
inline static void create_from_data(any::Data* dst, const any::Data& data) {
dst->pheap = new T(*get_ptr(&data));
}
inline static void destroy(Data* data) {
delete static_cast<T*>(data->pheap);
}
};
template<typename T>
class any::TypeOnStack {
public:
inline static T* get_ptr(any::Data* data) {
return reinterpret_cast<T*>(&(data->stack));
}
inline static const T* get_ptr(const any::Data* data) {
return reinterpret_cast<const T*>(&(data->stack));
}
inline static void create_from_data(any::Data* dst, const any::Data& data) {
new (&(dst->stack)) T(*get_ptr(&data));
}
inline static void destroy(Data* data) {
T* dptr = reinterpret_cast<T*>(&(data->stack));
dptr->~T();
}
};
template<typename T>
class any::TypeInfo
: public std::conditional<any::data_on_stack<T>::value,
any::TypeOnStack<T>,
any::TypeOnHeap<T> >::type {
public:
inline static const Type* get_type() {
static TypeInfo<T> tp;
return &(tp.type_);
}
private:
// local type
Type type_;
// constructor
TypeInfo() {
if (std::is_pod<T>::value) {
type_.destroy = nullptr;
} else {
type_.destroy = TypeInfo<T>::destroy;
}
type_.create_from_data = TypeInfo<T>::create_from_data;
type_.ptype_info = &typeid(T);
}
};
//! \endcond
} // namespace dmlc
#endif // DMLC_ANY_H_
/*!
* Copyright (c) 2016 by Contributors
* \file array_view.h
* \brief Read only data structure to reference array
*/
#ifndef DMLC_ARRAY_VIEW_H_
#define DMLC_ARRAY_VIEW_H_
#include <vector>
#include <array>
namespace dmlc {
/*!
* \brief Read only data structure to reference continuous memory region of array.
* Provide unified view for vector, array and C style array.
* This data structure do not guarantee aliveness of referenced array.
*
* Make sure do not use array_view to record data in async function closures.
* Also do not use array_view to create reference to temporary data structure.
*
* \tparam ValueType The value
*
* \code
* std::vector<int> myvec{1,2,3};
* dmlc::array_view<int> view(myvec);
* // indexed visit to the view.
* LOG(INFO) << view[0];
*
* for (int v : view) {
* // visit each element in the view
* }
* \endcode
*/
template<typename ValueType>
class array_view {
public:
/*! \brief default constructor */
array_view() = default;
/*!
* \brief default copy constructor
* \param other another array view.
*/
array_view(const array_view<ValueType> &other) = default; // NOLINT(*)
/*!
* \brief default move constructor
* \param other another array view.
*/
array_view(array_view<ValueType>&& other) = default; // NOLINT(*)
/*!
* \brief default assign constructor
* \param other another array view.
* \return self.
*/
array_view<ValueType>& operator=(const array_view<ValueType>& other) = default; // NOLINT(*)
/*!
* \brief construct array view std::vector
* \param other vector container
*/
array_view(const std::vector<ValueType>& other) { // NOLINT(*)
if (other.size() != 0) {
begin_ = &other[0]; size_ = other.size();
}
}
/*!
* \brief construct array std::array
* \param other another array view.
*/
template<std::size_t size>
array_view(const std::array<ValueType, size>& other) { // NOLINT(*)
if (size != 0) {
begin_ = &other[0]; size_ = size;
}
}
/*!
* \brief construct array view from continuous segment
* \param begin beginning pointre
* \param end end pointer
*/
array_view(const ValueType* begin, const ValueType* end) {
if (begin < end) {
begin_ = begin;
size_ = end - begin;
}
}
/*! \return size of the array */
inline size_t size() const {
return size_;
}
/*! \return begin of the array */
inline const ValueType* begin() const {
return begin_;
}
/*! \return end point of the array */
inline const ValueType* end() const {
return begin_ + size_;
}
/*!
* \brief get i-th element from the view
* \param i The index.
* \return const reference to i-th element.
*/
inline const ValueType& operator[](size_t i) const {
return begin_[i];
}
private:
/*! \brief the begin of the view */
const ValueType* begin_{nullptr};
/*! \brief The size of the view */
size_t size_{0};
};
} // namespace dmlc
#endif // DMLC_ARRAY_VIEW_H_
/*!
* Copyright (c) 2015 by Contributors
* \file base.h
* \brief defines configuration macros
*/
#ifndef DMLC_BASE_H_
#define DMLC_BASE_H_
/*! \brief whether use glog for logging */
#ifndef DMLC_USE_GLOG
#define DMLC_USE_GLOG 0
#endif
/*!
* \brief whether throw dmlc::Error instead of
* directly calling abort when FATAL error occured
* NOTE: this may still not be perfect.
* do not use FATAL and CHECK in destructors
*/
#ifndef DMLC_LOG_FATAL_THROW
#define DMLC_LOG_FATAL_THROW 1
#endif
/*!
* \brief whether always log a message before throw
* This can help identify the error that cannot be catched.
*/
#ifndef DMLC_LOG_BEFORE_THROW
#define DMLC_LOG_BEFORE_THROW 1
#endif
/*!
* \brief Whether to use customized logger,
* whose output can be decided by other libraries.
*/
#ifndef DMLC_LOG_CUSTOMIZE
#define DMLC_LOG_CUSTOMIZE 0
#endif
/*! \brief whether compile with hdfs support */
#ifndef DMLC_USE_HDFS
#define DMLC_USE_HDFS 0
#endif
/*! \brief whether compile with s3 support */
#ifndef DMLC_USE_S3
#define DMLC_USE_S3 0
#endif
/*! \brief whether or not use parameter server */
#ifndef DMLC_USE_PS
#define DMLC_USE_PS 0
#endif
/*! \brief whether or not use c++11 support */
#ifndef DMLC_USE_CXX11
#define DMLC_USE_CXX11 (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\
__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/// check if g++ is before 4.6
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6
#pragma message("Will need g++-4.6 or higher to compile all" \
"the features in dmlc-core, " \
"compile without c++0x, some features may be disabled")
#undef DMLC_USE_CXX11
#define DMLC_USE_CXX11 0
#endif
#endif
/*!
* \brief Enable std::thread related modules,
* Used to disable some module in mingw compile.
*/
#ifndef DMLC_ENABLE_STD_THREAD
#define DMLC_ENABLE_STD_THREAD DMLC_USE_CXX11
#endif
/*! \brief whether enable regex support, actually need g++-4.9 or higher*/
#ifndef DMLC_USE_REGEX
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/*! \brief helper macro to generate string concat */
#define DMLC_STR_CONCAT_(__x, __y) __x##__y
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)
/*!
* \brief Disable copy constructor and assignment operator.
*
* If C++11 is supported, both copy and move constructors and
* assignment operators are deleted explicitly. Otherwise, they are
* only declared but not implemented. Place this macro in private
* section if C++11 is not available.
*/
#ifndef DISALLOW_COPY_AND_ASSIGN
# if DMLC_USE_CXX11
# define DISALLOW_COPY_AND_ASSIGN(T) \
T(T const&) = delete; \
T(T&&) = delete; \
T& operator=(T const&) = delete; \
T& operator=(T&&) = delete
# else
# define DISALLOW_COPY_AND_ASSIGN(T) \
T(T const&); \
T& operator=(T const&)
# endif
#endif
///
/// code block to handle optionally loading
///
#if !defined(__GNUC__)
#define fopen64 std::fopen
#endif
#if (defined __MINGW32__) && !(defined __MINGW64__)
#define fopen64 std::fopen
#endif
#ifdef _MSC_VER
#if _MSC_VER < 1900
// NOTE: sprintf_s is not equivalent to snprintf,
// they are equivalent when success, which is sufficient for our case
#define snprintf sprintf_s
#define vsnprintf vsprintf_s
#endif
#else
#ifdef _FILE_OFFSET_BITS
#if _FILE_OFFSET_BITS == 32
#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit")
#endif
#endif
#ifdef __APPLE__
#define off64_t off_t
#define fopen64 std::fopen
#endif
extern "C" {
#include <sys/types.h>
}
#endif
#ifdef _MSC_VER
//! \cond Doxygen_Suppress
typedef signed char int8_t;
typedef __int16 int16_t;
typedef __int32 int32_t;
typedef __int64 int64_t;
typedef unsigned char uint8_t;
typedef unsigned __int16 uint16_t;
typedef unsigned __int32 uint32_t;
typedef unsigned __int64 uint64_t;
//! \endcond
#else
#include <inttypes.h>
#endif
#include <string>
#include <vector>
#if defined(_MSC_VER) && _MSC_VER < 1900
#define noexcept_true throw ()
#define noexcept_false
#define noexcept(a) noexcept_##a
#endif
#if DMLC_USE_CXX11
#define DMLC_THROW_EXCEPTION noexcept(false)
#define DMLC_NO_EXCEPTION noexcept(true)
#else
#define DMLC_THROW_EXCEPTION
#define DMLC_NO_EXCEPTION
#endif
/*! \brief namespace for dmlc */
namespace dmlc {
/*!
* \brief safely get the beginning address of a vector
* \param vec input vector
* \return beginning address of a vector
*/
template<typename T>
inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
if (vec.size() == 0) {
return NULL;
} else {
return &vec[0];
}
}
/*!
* \brief get the beginning address of a vector
* \param vec input vector
* \return beginning address of a vector
*/
template<typename T>
inline const T *BeginPtr(const std::vector<T> &vec) {
if (vec.size() == 0) {
return NULL;
} else {
return &vec[0];
}
}
/*!
* \brief get the beginning address of a vector
* \param str input string
* \return beginning address of a string
*/
inline char* BeginPtr(std::string &str) { // NOLINT(*)
if (str.length() == 0) return NULL;
return &str[0];
}
/*!
* \brief get the beginning address of a vector
* \param str input string
* \return beginning address of a string
*/
inline const char* BeginPtr(const std::string &str) {
if (str.length() == 0) return NULL;
return &str[0];
}
} // namespace dmlc
#if defined(_MSC_VER) && _MSC_VER < 1900
#define constexpr const
#define alignof __alignof
#endif
#endif // DMLC_BASE_H_
/*!
* Copyright (c) 2015 by Contributors
* \file logging.h
* \brief defines logging macros of dmlc
* allows use of GLOG, fall back to internal
* implementation when disabled
*/
#ifndef DMLC_LOGGING_H_
#define DMLC_LOGGING_H_
#include <cstdio>
#include <cstdlib>
#include <string>
#include <vector>
#include <stdexcept>
#include "./base.h"
namespace dmlc {
/*!
* \brief exception class that will be thrown by
* default logger if DMLC_LOG_FATAL_THROW == 1
*/
struct Error : public std::runtime_error {
/*!
* \brief constructor
* \param s the error message
*/
explicit Error(const std::string &s) : std::runtime_error(s) {}
};
} // namespace dmlc
#if DMLC_USE_GLOG
#include <glog/logging.h>
namespace dmlc {
/*!
* \brief optionally redirect to google's init log
* \param argv0 The arguments.
*/
inline void InitLogging(const char* argv0) {
google::InitGoogleLogging(argv0);
}
} // namespace dmlc
#else
// use a light version of glog
#include <assert.h>
#include <iostream>
#include <sstream>
#include <ctime>
#if defined(_MSC_VER)
#pragma warning(disable : 4722)
#endif
namespace dmlc {
inline void InitLogging(const char* argv0) {
// DO NOTHING
}
// Always-on checking
#define CHECK(x) \
if (!(x)) \
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check " \
"failed: " #x << ' '
#define CHECK_LT(x, y) CHECK((x) < (y))
#define CHECK_GT(x, y) CHECK((x) > (y))
#define CHECK_LE(x, y) CHECK((x) <= (y))
#define CHECK_GE(x, y) CHECK((x) >= (y))
#define CHECK_EQ(x, y) CHECK((x) == (y))
#define CHECK_NE(x, y) CHECK((x) != (y))
#define CHECK_NOTNULL(x) \
((x) == NULL ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*)
// Debug-only checking.
#ifdef NDEBUG
#define DCHECK(x) \
while (false) CHECK(x)
#define DCHECK_LT(x, y) \
while (false) CHECK((x) < (y))
#define DCHECK_GT(x, y) \
while (false) CHECK((x) > (y))
#define DCHECK_LE(x, y) \
while (false) CHECK((x) <= (y))
#define DCHECK_GE(x, y) \
while (false) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) \
while (false) CHECK((x) == (y))
#define DCHECK_NE(x, y) \
while (false) CHECK((x) != (y))
#else
#define DCHECK(x) CHECK(x)
#define DCHECK_LT(x, y) CHECK((x) < (y))
#define DCHECK_GT(x, y) CHECK((x) > (y))
#define DCHECK_LE(x, y) CHECK((x) <= (y))
#define DCHECK_GE(x, y) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) CHECK((x) == (y))
#define DCHECK_NE(x, y) CHECK((x) != (y))
#endif // NDEBUG
#if DMLC_LOG_CUSTOMIZE
#define LOG_INFO dmlc::CustomLogMessage(__FILE__, __LINE__)
#else
#define LOG_INFO dmlc::LogMessage(__FILE__, __LINE__)
#endif
#define LOG_ERROR LOG_INFO
#define LOG_WARNING LOG_INFO
#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__)
#define LOG_QFATAL LOG_FATAL
// Poor man version of VLOG
#define VLOG(x) LOG_INFO.stream()
#define LOG(severity) LOG_##severity.stream()
#define LG LOG_INFO.stream()
#define LOG_IF(severity, condition) \
!(condition) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity)
#ifdef NDEBUG
#define LOG_DFATAL LOG_ERROR
#define DFATAL ERROR
#define DLOG(severity) true ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity)
#define DLOG_IF(severity, condition) \
(true || !(condition)) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity)
#else
#define LOG_DFATAL LOG_FATAL
#define DFATAL FATAL
#define DLOG(severity) LOG(severity)
#define DLOG_IF(severity, condition) LOG_IF(severity, condition)
#endif
// Poor man version of LOG_EVERY_N
#define LOG_EVERY_N(severity, n) LOG(severity)
class DateLogger {
public:
DateLogger() {
#if defined(_MSC_VER)
_tzset();
#endif
}
const char* HumanDate() {
#if defined(_MSC_VER)
_strtime_s(buffer_, sizeof(buffer_));
#else
time_t time_value = time(NULL);
struct tm *pnow;
#if !defined(_WIN32)
struct tm now;
pnow = localtime_r(&time_value, &now);
#else
pnow = localtime(&time_value); // NOLINT(*)
#endif
snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d",
pnow->tm_hour, pnow->tm_min, pnow->tm_sec);
#endif
return buffer_;
}
private:
char buffer_[9];
};
class LogMessage {
public:
LogMessage(const char* file, int line)
:
#ifdef __ANDROID__
log_stream_(std::cout)
#else
log_stream_(std::cerr)
#endif
{
log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":"
<< line << ": ";
}
~LogMessage() { log_stream_ << '\n'; }
std::ostream& stream() { return log_stream_; }
protected:
std::ostream& log_stream_;
private:
DateLogger pretty_date_;
LogMessage(const LogMessage&);
void operator=(const LogMessage&);
};
// customized logger that can allow user to define where to log the message.
class CustomLogMessage {
public:
CustomLogMessage(const char* file, int line) {
log_stream_ << "[" << DateLogger().HumanDate() << "] " << file << ":"
<< line << ": ";
}
~CustomLogMessage() {
Log(log_stream_.str());
}
std::ostream& stream() { return log_stream_; }
/*!
* \brief customized logging of the message.
* This function won't be implemented by libdmlc
* \param msg The message to be logged.
*/
static void Log(const std::string& msg);
private:
std::ostringstream log_stream_;
};
#if DMLC_LOG_FATAL_THROW == 0
class LogMessageFatal : public LogMessage {
public:
LogMessageFatal(const char* file, int line) : LogMessage(file, line) {}
~LogMessageFatal() {
log_stream_ << "\n";
abort();
}
private:
LogMessageFatal(const LogMessageFatal&);
void operator=(const LogMessageFatal&);
};
#else
class LogMessageFatal {
public:
LogMessageFatal(const char* file, int line) {
log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":"
<< line << ": ";
}
std::ostringstream &stream() { return log_stream_; }
~LogMessageFatal() DMLC_THROW_EXCEPTION {
// throwing out of destructor is evil
// hopefully we can do it here
// also log the message before throw
#if DMLC_LOG_BEFORE_THROW
LOG(ERROR) << log_stream_.str();
#endif
throw Error(log_stream_.str());
}
private:
std::ostringstream log_stream_;
DateLogger pretty_date_;
LogMessageFatal(const LogMessageFatal&);
void operator=(const LogMessageFatal&);
};
#endif
// This class is used to explicitly ignore values in the conditional
// logging macros. This avoids compiler warnings like "value computed
// is not used" and "statement has no effect".
class LogMessageVoidify {
public:
LogMessageVoidify() {}
// This has to be an operator with a precedence lower than << but
// higher than "?:". See its usage.
void operator&(std::ostream&) {}
};
} // namespace dmlc
#endif
#endif // DMLC_LOGGING_H_
/*!
* Copyright (c) 2015 by Contributors
* \file memory.h
* \brief Additional memory hanlding utilities.
*/
#ifndef DMLC_MEMORY_H_
#define DMLC_MEMORY_H_
#include <vector>
#include "./base.h"
#include "./logging.h"
#include "./thread_local.h"
namespace dmlc {
/*!
* \brief A memory pool that allocate memory of fixed size and alignment.
* \tparam size The size of each piece.
* \tparam align The alignment requirement of the memory.
*/
template<size_t size, size_t align>
class MemoryPool {
public:
/*! \brief constructor */
MemoryPool() {
static_assert(align % alignof(LinkedList) == 0,
"alignment requirement failed.");
curr_page_.reset(new Page());
}
/*! \brief allocate a new memory of size */
inline void* allocate() {
if (head_ != nullptr) {
LinkedList* ret = head_;
head_ = head_->next;
return ret;
} else {
if (page_ptr_ < kPageSize) {
return &(curr_page_->data[page_ptr_++]);
} else {
allocated_.push_back(std::move(curr_page_));
curr_page_.reset(new Page());
page_ptr_ = 1;
return &(curr_page_->data[0]);
}
}
}
/*!
* \brief deallocate a piece of memory
* \param p The pointer to the memory to be de-allocated.
*/
inline void deallocate(void* p) {
LinkedList* ptr = static_cast<LinkedList*>(p);
ptr->next = head_;
head_ = ptr;
}
private:
// page size of each member
static const int kPageSize = ((1 << 22) / size);
// page to be requested.
struct Page {
typename std::aligned_storage<size, align>::type data[kPageSize];
};
// internal linked list structure.
struct LinkedList {
LinkedList* next{nullptr};
};
// head of free list
LinkedList* head_{nullptr};
// current free page
std::unique_ptr<Page> curr_page_;
// pointer to the current free page position.
size_t page_ptr_{0};
// allocated pages.
std::vector<std::unique_ptr<Page> > allocated_;
};
/*!
* \brief A thread local allocator that get memory from a threadlocal memory pool.
* This is suitable to allocate objects that do not cross thread.
* \tparam T the type of the data to be allocated.
*/
template<typename T>
class ThreadlocalAllocator {
public:
/*! \brief pointer type */
typedef T* pointer;
/*! \brief const pointer type */
typedef const T* const_ptr;
/*! \brief value type */
typedef T value_type;
/*! \brief default constructor */
ThreadlocalAllocator() {}
/*!
* \brief constructor from another allocator
* \param other another allocator
* \tparam U another type
*/
template<typename U>
ThreadlocalAllocator(const ThreadlocalAllocator<U>& other) {}
/*!
* \brief allocate memory
* \param n number of blocks
* \return an uninitialized memory of type T.
*/
inline T* allocate(size_t n) {
CHECK_EQ(n, 1);
typedef ThreadLocalStore<MemoryPool<sizeof(T), alignof(T)> > Store;
return static_cast<T*>(Store::Get()->allocate());
}
/*!
* \brief deallocate memory
* \param p a memory to be returned.
* \param n number of blocks
*/
inline void deallocate(T* p, size_t n) {
CHECK_EQ(n, 1);
typedef ThreadLocalStore<MemoryPool<sizeof(T), alignof(T)> > Store;
Store::Get()->deallocate(p);
}
};
/*!
* \brief a shared pointer like type that allocate object
* from a threadlocal object pool. This object is not thread-safe
* but can be faster than shared_ptr in certain usecases.
* \tparam T the data type.
*/
template<typename T>
struct ThreadlocalSharedPtr {
public:
/*! \brief default constructor */
ThreadlocalSharedPtr() : block_(nullptr) {}
/*!
* \brief constructor from nullptr
* \param other the nullptr type
*/
ThreadlocalSharedPtr(std::nullptr_t other) : block_(nullptr) {} // NOLINT(*)
/*!
* \brief copy constructor
* \param other another pointer.
*/
ThreadlocalSharedPtr(const ThreadlocalSharedPtr<T>& other)
: block_(other.block_) {
IncRef(block_);
}
/*!
* \brief move constructor
* \param other another pointer.
*/
ThreadlocalSharedPtr(ThreadlocalSharedPtr<T>&& other)
: block_(other.block_) {
other.block_ = nullptr;
}
/*!
* \brief destructor
*/
~ThreadlocalSharedPtr() {
DecRef(block_);
}
/*!
* \brief move assignment
* \param other another object to be assigned.
* \return self.
*/
inline ThreadlocalSharedPtr<T>& operator=(ThreadlocalSharedPtr<T>&& other) {
DecRef(block_);
block_ = other.block_;
other.block_ = nullptr;
return *this;
}
/*!
* \brief copy assignment
* \param other another object to be assigned.
* \return self.
*/
inline ThreadlocalSharedPtr<T> &operator=(const ThreadlocalSharedPtr<T>& other) {
DecRef(block_);
block_ = other.block_;
IncRef(block_);
return *this;
}
/*! \brief check if nullptr */
inline bool operator==(std::nullptr_t other) const {
return block_ == nullptr;
}
/*!
* \return get the pointer content.
*/
inline T* get() const {
if (block_ == nullptr) return nullptr;
return reinterpret_cast<T*>(&(block_->data));
}
/*!
* \brief reset the pointer to nullptr.
*/
inline void reset() {
DecRef(block_);
block_ = nullptr;
}
/*! \return if use_count == 1*/
inline bool unique() const {
if (block_ == nullptr) return false;
return block_->use_count_ == 1;
}
/*! \return dereference pointer */
inline T* operator*() const {
return reinterpret_cast<T*>(&(block_->data));
}
/*! \return dereference pointer */
inline T* operator->() const {
return reinterpret_cast<T*>(&(block_->data));
}
/*!
* \brief create a new space from threadlocal storage and return it.
* \tparam Args the arguments.
* \param args The input argument
* \return the allocated pointer.
*/
template <typename... Args>
inline static ThreadlocalSharedPtr<T> Create(Args&&... args) {
ThreadlocalAllocator<RefBlock> arena;
ThreadlocalSharedPtr<T> p;
p.block_ = arena.allocate(1);
p.block_->use_count_ = 1;
new (&(p.block_->data)) T(std::forward<Args>(args)...);
return p;
}
private:
// internal reference block
struct RefBlock {
typename std::aligned_storage<sizeof(T), alignof(T)>::type data;
unsigned use_count_;
};
// decrease ref counter
inline static void DecRef(RefBlock* block) {
if (block != nullptr) {
if (--block->use_count_ == 0) {
ThreadlocalAllocator<RefBlock> arena;
T* dptr = reinterpret_cast<T*>(&(block->data));
dptr->~T();
arena.deallocate(block, 1);
}
}
}
// increase ref counter
inline static void IncRef(RefBlock* block) {
if (block != nullptr) {
++block->use_count_;
}
}
// internal block
RefBlock *block_;
};
} // namespace dmlc
#endif // DMLC_MEMORY_H_
/*!
* Copyright (c) 2015 by Contributors
* \file registry.h
* \brief Registry utility that helps to build registry singletons.
*/
#ifndef DMLC_REGISTRY_H_
#define DMLC_REGISTRY_H_
#include <map>
#include <string>
#include <vector>
#include "./base.h"
#include "./logging.h"
#include "./parameter.h"
#include "./type_traits.h"
namespace dmlc {
/*!
* \brief Registry class.
* Registry can be used to register global singletons.
* The most commonly use case are factory functions.
*
* \tparam EntryType Type of Registry entries,
* EntryType need to name a name field.
*/
template<typename EntryType>
class Registry {
public:
/*! \return list of functions in the registry */
inline static const std::vector<const EntryType*> &List() {
return Get()->entry_list_;
}
/*!
* \brief Find the entry with corresponding name.
* \param name name of the function
* \return the corresponding function, can be NULL
*/
inline static const EntryType *Find(const std::string &name) {
const std::map<std::string, EntryType*> &fmap = Get()->fmap_;
typename std::map<std::string, EntryType*>::const_iterator p = fmap.find(name);
if (p != fmap.end()) {
return p->second;
} else {
return NULL;
}
}
/*!
* \brief Internal function to register a name function under name.
* \param name name of the function
* \return ref to the registered entry, used to set properties
*/
inline EntryType &__REGISTER__(const std::string& name) {
CHECK_EQ(fmap_.count(name), 0)
<< name << " already registered";
EntryType *e = new EntryType();
e->name = name;
fmap_[name] = e;
entry_list_.push_back(e);
return *e;
}
/*!
* \brief Internal function to either register or get registered entry
* \param name name of the function
* \return ref to the registered entry, used to set properties
*/
inline EntryType &__REGISTER_OR_GET__(const std::string& name) {
if (fmap_.count(name) == 0) {
return __REGISTER__(name);
} else {
return *fmap_.at(name);
}
}
/*!
* \brief get a singleton of the Registry.
* This function can be defined by DMLC_ENABLE_REGISTRY.
* \return get a singleton
*/
static Registry *Get();
private:
/*! \brief list of entry types */
std::vector<const EntryType*> entry_list_;
/*! \brief map of name->function */
std::map<std::string, EntryType*> fmap_;
/*! \brief constructor */
Registry() {}
/*! \brief destructor */
~Registry() {
for (typename std::map<std::string, EntryType*>::iterator p = fmap_.begin();
p != fmap_.end(); ++p) {
delete p->second;
}
}
};
/*!
* \brief Common base class for function registry.
*
* \code
* // This example demonstrates how to use Registry to create a factory of trees.
* struct TreeFactory :
* public FunctionRegEntryBase<TreeFactory, std::function<Tree*()> > {
* };
*
* // in a independent cc file
* namespace dmlc {
* DMLC_REGISTRY_ENABLE(TreeFactory);
* }
* // register binary tree constructor into the registry.
* DMLC_REGISTRY_REGISTER(TreeFactory, TreeFactory, BinaryTree)
* .describe("Constructor of BinaryTree")
* .set_body([]() { return new BinaryTree(); });
* \endcode
*
* \tparam EntryType The type of subclass that inheritate the base.
* \tparam FunctionType The function type this registry is registerd.
*/
template<typename EntryType, typename FunctionType>
class FunctionRegEntryBase {
public:
/*! \brief name of the entry */
std::string name;
/*! \brief description of the entry */
std::string description;
/*! \brief additional arguments to the factory function */
std::vector<ParamFieldInfo> arguments;
/*! \brief Function body to create ProductType */
FunctionType body;
/*! \brief Return type of the function */
std::string return_type;
/*!
* \brief Set the function body.
* \param body Function body to set.
* \return reference to self.
*/
inline EntryType &set_body(FunctionType body) {
this->body = body;
return this->self();
}
/*!
* \brief Describe the function.
* \param description The description of the factory function.
* \return reference to self.
*/
inline EntryType &describe(const std::string &description) {
this->description = description;
return this->self();
}
/*!
* \brief Add argument information to the function.
* \param name Name of the argument.
* \param type Type of the argument.
* \param description Description of the argument.
* \return reference to self.
*/
inline EntryType &add_argument(const std::string &name,
const std::string &type,
const std::string &description) {
ParamFieldInfo info;
info.name = name;
info.type = type;
info.type_info_str = info.type;
info.description = description;
arguments.push_back(info);
return this->self();
}
/*!
* \brief Append list if arguments to the end.
* \param args Additional list of arguments.
* \return reference to self.
*/
inline EntryType &add_arguments(const std::vector<ParamFieldInfo> &args) {
arguments.insert(arguments.end(), args.begin(), args.end());
return this->self();
}
/*!
* \brief Set the return type.
* \param type Return type of the function, could be Symbol or Symbol[]
* \return reference to self.
*/
inline EntryType &set_return_type(const std::string &type) {
return_type = type;
return this->self();
}
protected:
/*!
* \return reference of self as derived type
*/
inline EntryType &self() {
return *(static_cast<EntryType*>(this));
}
};
/*!
* \def DMLC_REGISTRY_ENABLE
* \brief Macro to enable the registry of EntryType.
* This macro must be used under namespace dmlc, and only used once in cc file.
* \param EntryType Type of registry entry
*/
#define DMLC_REGISTRY_ENABLE(EntryType) \
template<> \
Registry<EntryType > *Registry<EntryType >::Get() { \
static Registry<EntryType > inst; \
return &inst; \
} \
/*!
* \brief Generic macro to register an EntryType
* There is a complete example in FactoryRegistryEntryBase.
*
* \param EntryType The type of registry entry.
* \param EntryTypeName The typename of EntryType, must do not contain namespace :: .
* \param Name The name to be registered.
* \sa FactoryRegistryEntryBase
*/
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
static EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \
/*!
* \brief (Optional) Declare a file tag to current file that contains object registrations.
*
* This will declare a dummy function that will be called by register file to
* incur a link dependency.
*
* \param UniqueTag The unique tag used to represent.
* \sa DMLC_REGISTRY_LINK_TAG
*/
#define DMLC_REGISTRY_FILE_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __() { return 0; }
/*!
* \brief (Optional) Force link to all the objects registered in file tag.
*
* This macro must be used in the same file as DMLC_REGISTRY_ENABLE and
* in the same namespace as DMLC_REGISTRY_FILE_TAG
*
* DMLC_REGISTRY_FILE_TAG and DMLC_REGISTRY_LINK_TAG are optional macros for registration.
* They are used to encforce link of certain file into during static linking.
*
* This is mainly used to solve problem during statically link a library which contains backward registration.
* Specifically, this avoids the objects in these file tags to be ignored by compiler.
*
* For dynamic linking, this problem won't occur as everything is loaded by default.
*
* Use of this is optional as it will create an error when a file tag do not exist.
* An alternative solution is always ask user to enable --whole-archieve during static link.
*
* \begincode
* // in file objective_registry.cc
* DMLC_REGISTRY_ENABLE(MyObjective);
* DMLC_REGISTRY_LINK_TAG(regression_op);
* DMLC_REGISTRY_LINK_TAG(rank_op);
*
* // in file regression_op.cc
* // declare tag of this file.
* DMLC_REGISTRY_FILE_TAG(regression_op);
* DMLC_REGISTRY_REGISTER(MyObjective, logistic_reg, logistic_reg);
* // ...
*
* // in file rank_op.cc
* // declare tag of this file.
* DMLC_REGISTRY_FILE_TAG(rank_op);
* DMLC_REGISTRY_REGISTER(MyObjective, pairwiserank, pairwiserank);
*
* \endcode
*
* \param UniqueTag The unique tag used to represent.
* \sa DMLC_REGISTRY_ENABLE, DMLC_REGISTRY_FILE_TAG
*/
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __();
} // namespace dmlc
#endif // DMLC_REGISTRY_H_
/*!
* Copyright (c) 2015 by Contributors
* \file thread_local.h
* \brief Portable thread local storage.
*/
#ifndef DMLC_THREAD_LOCAL_H_
#define DMLC_THREAD_LOCAL_H_
#include <mutex>
#include <memory>
#include <vector>
namespace dmlc {
// macro hanlding for threadlocal variables
#ifdef __GNUC__
#define MX_TREAD_LOCAL __thread
#elif __STDC_VERSION__ >= 201112L
#define MX_TREAD_LOCAL _Thread_local
#elif defined(_MSC_VER)
#define MX_TREAD_LOCAL __declspec(thread)
#endif
#ifndef MX_TREAD_LOCAL
#message("Warning: Threadlocal is not enabled");
#endif
/*!
* \brief A threadlocal store to store threadlocal variables.
* Will return a thread local singleton of type T
* \tparam T the type we like to store
*/
template<typename T>
class ThreadLocalStore {
public:
/*! \return get a thread local singleton */
static T* Get() {
static MX_TREAD_LOCAL T* ptr = nullptr;
if (ptr == nullptr) {
ptr = new T();
Singleton()->RegisterDelete(ptr);
}
return ptr;
}
private:
/*! \brief constructor */
ThreadLocalStore() {}
/*! \brief destructor */
~ThreadLocalStore() {
for (size_t i = 0; i < data_.size(); ++i) {
delete data_[i];
}
}
/*! \return singleton of the store */
static ThreadLocalStore<T> *Singleton() {
static ThreadLocalStore<T> inst;
return &inst;
}
/*!
* \brief register str for internal deletion
* \param str the string pointer
*/
void RegisterDelete(T *str) {
std::unique_lock<std::mutex> lock(mutex_);
data_.push_back(str);
lock.unlock();
}
/*! \brief internal mutex */
std::mutex mutex_;
/*!\brief internal data */
std::vector<T*> data_;
};
} // namespace dmlc
#endif // DMLC_THREAD_LOCAL_H_
/*!
* Copyright (c) 2015 by Contributors
* \file timer.h
* \brief cross platform timer for timing
* \author Tianqi Chen
*/
#ifndef DMLC_TIMER_H_
#define DMLC_TIMER_H_
#include "base.h"
#if DMLC_USE_CXX11
#include <chrono>
#endif
#include <time.h>
#ifdef __MACH__
#include <mach/clock.h>
#include <mach/mach.h>
#endif
#include "./logging.h"
namespace dmlc {
/*!
* \brief return time in seconds
*/
inline double GetTime(void) {
#if DMLC_USE_CXX11
return std::chrono::duration<double>(
std::chrono::high_resolution_clock::now().time_since_epoch()).count();
#elif defined __MACH__
clock_serv_t cclock;
mach_timespec_t mts;
host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock);
CHECK(clock_get_time(cclock, &mts) == 0) << "failed to get time";
mach_port_deallocate(mach_task_self(), cclock);
return static_cast<double>(mts.tv_sec) + static_cast<double>(mts.tv_nsec) * 1e-9;
#else
#if defined(__unix__) || defined(__linux__)
timespec ts;
CHECK(clock_gettime(CLOCK_REALTIME, &ts) == 0) << "failed to get time";
return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9;
#else
return static_cast<double>(time(NULL));
#endif
#endif
}
} // namespace dmlc
#endif // DMLC_TIMER_H_
/*!
* Copyright (c) 2015 by Contributors
* \file type_traits.h
* \brief type traits information header
*/
#ifndef DMLC_TYPE_TRAITS_H_
#define DMLC_TYPE_TRAITS_H_
#include "./base.h"
#if DMLC_USE_CXX11
#include <type_traits>
#endif
#include <string>
namespace dmlc {
/*!
* \brief whether a type is pod type
* \tparam T the type to query
*/
template<typename T>
struct is_pod {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_pod<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = false;
#endif
};
/*!
* \brief whether a type is integer type
* \tparam T the type to query
*/
template<typename T>
struct is_integral {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_integral<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = false;
#endif
};
/*!
* \brief whether a type is floating point type
* \tparam T the type to query
*/
template<typename T>
struct is_floating_point {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_floating_point<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = false;
#endif
};
/*!
* \brief whether a type is arithemetic type
* \tparam T the type to query
*/
template<typename T>
struct is_arithmetic {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_arithmetic<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = (dmlc::is_integral<T>::value ||
dmlc::is_floating_point<T>::value);
#endif
};
/*!
* \brief the string representation of type name
* \tparam T the type to query
* \return a const string of typename.
*/
template<typename T>
inline const char* type_name() {
return "";
}
/*!
* \brief whether a type have save/load function
* \tparam T the type to query
*/
template<typename T>
struct has_saveload {
/*! \brief the value of the traits */
static const bool value = false;
};
/*!
* \brief template to select type based on condition
* For example, IfThenElseType<true, int, float>::Type will give int
* \tparam cond the condition
* \tparam Then the typename to be returned if cond is true
* \tparam The typename to be returned if cond is false
*/
template<bool cond, typename Then, typename Else>
struct IfThenElseType;
/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TRAITS(Trait, Type, Value) \
template<> \
struct Trait<Type> { \
static const bool value = Value; \
}
/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TYPE_NAME(Type, Name) \
template<> \
inline const char* type_name<Type>() { \
return Name; \
}
//! \cond Doxygen_Suppress
// declare special traits when C++11 is not available
#if DMLC_USE_CXX11 == 0
DMLC_DECLARE_TRAITS(is_pod, char, true);
DMLC_DECLARE_TRAITS(is_pod, int8_t, true);
DMLC_DECLARE_TRAITS(is_pod, int16_t, true);
DMLC_DECLARE_TRAITS(is_pod, int32_t, true);
DMLC_DECLARE_TRAITS(is_pod, int64_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint8_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint16_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint32_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint64_t, true);
DMLC_DECLARE_TRAITS(is_pod, float, true);
DMLC_DECLARE_TRAITS(is_pod, double, true);
DMLC_DECLARE_TRAITS(is_integral, char, true);
DMLC_DECLARE_TRAITS(is_integral, int8_t, true);
DMLC_DECLARE_TRAITS(is_integral, int16_t, true);
DMLC_DECLARE_TRAITS(is_integral, int32_t, true);
DMLC_DECLARE_TRAITS(is_integral, int64_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint8_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint16_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint32_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint64_t, true);
DMLC_DECLARE_TRAITS(is_floating_point, float, true);
DMLC_DECLARE_TRAITS(is_floating_point, double, true);
#endif
DMLC_DECLARE_TYPE_NAME(float, "float");
DMLC_DECLARE_TYPE_NAME(double, "double");
DMLC_DECLARE_TYPE_NAME(int, "int");
DMLC_DECLARE_TYPE_NAME(uint32_t, "int (non-negative)");
DMLC_DECLARE_TYPE_NAME(uint64_t, "long (non-negative)");
DMLC_DECLARE_TYPE_NAME(std::string, "string");
DMLC_DECLARE_TYPE_NAME(bool, "boolean");
template<typename Then, typename Else>
struct IfThenElseType<true, Then, Else> {
typedef Then Type;
};
template<typename Then, typename Else>
struct IfThenElseType<false, Then, Else> {
typedef Else Type;
};
//! \endcond
} // namespace dmlc
#endif // DMLC_TYPE_TRAITS_H_
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <vector> #include <vector>
#include <type_traits> #include <type_traits>
#include <algorithm> #include <algorithm>
#include <utility>
#include <iostream> #include <iostream>
#include "./base.h" #include "./base.h"
......
...@@ -98,12 +98,12 @@ class SymbolBase(object): ...@@ -98,12 +98,12 @@ class SymbolBase(object):
**kwargs **kwargs
The attributes to set The attributes to set
""" """
keys = _base.c_array(_ctypes.c_char_p, keys = c_array(ctypes.c_char_p,
[_base.c_str(key) for key in kwargs.keys()]) [c_str(key) for key in kwargs.keys()])
vals = _base.c_array(_ctypes.c_char_p, vals = c_array(ctypes.c_char_p,
[_base.c_str(str(val)) for val in kwargs.values()]) [c_str(str(val)) for val in kwargs.values()])
num_args = _base.nn_uint(len(kwargs)) num_args = nn_uint(len(kwargs))
_check_call(_LIB.NNSymbolSetAttrs( check_call(_LIB.NNSymbolSetAttrs(
self.handle, num_args, keys, vals)) self.handle, num_args, keys, vals))
......
...@@ -27,9 +27,9 @@ def find_lib_path(): ...@@ -27,9 +27,9 @@ def find_lib_path():
elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None): elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")]) dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
if os.name == 'nt': if os.name == 'nt':
dll_path = [os.path.join(p, 'libnnvm.dll') for p in dll_path] dll_path = [os.path.join(p, 'libnnvm_example.dll') for p in dll_path]
else: else:
dll_path = [os.path.join(p, 'libnnvm.so') for p in dll_path] dll_path = [os.path.join(p, 'libnnvm_example.so') for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(lib_path) == 0: if len(lib_path) == 0:
raise RuntimeError('Cannot find the files.\n' + raise RuntimeError('Cannot find the files.\n' +
......
import os import os
import sys import sys
from distutils.core import setup from distutils.core import setup
from Cython.Build import cythonize
from distutils.extension import Extension
def config_cython():
try:
from Cython.Build import cythonize
from distutils.extension import Extension
if sys.version_info >= (3, 0):
subdir = "_cy3"
else:
subdir = "_cy2"
ret = []
path = "nnvm/cython"
def config(): for fn in os.listdir(path):
if sys.version_info >= (3, 0): if not fn.endswith(".pyx"):
subdir = "_cy3" continue
else: ret.append(Extension(
subdir = "_cy2" "nnvm/%s/%s" % (subdir, fn[:-4]),
ret = [] ["nnvm/cython/%s" % fn],
path = "nnvm/cython" include_dirs=["../include/"],
language="c++"))
for fn in os.listdir(path): return cythonize(ret)
if not fn.endswith(".pyx"): except:
continue print("Cython is not installed, will compile without cython module")
ret.append(Extension( return []
"nnvm/%s/%s" % (subdir, fn[:-4]),
["nnvm/cython/%s" % fn],
include_dirs=["../include/"],
language="c++"))
return ret
setup( setup(
name='nnvm', name='nnvm',
ext_modules = cythonize(config()) ext_modules = config_cython()
) )
// Copyright (c) 2016 by Contributors
#include <nnvm/op.h>
#include <nnvm/graph.h>
#include <nnvm/tuple.h>
#include <nnvm/c_api.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass_functions.h>
#include <dmlc/timer.h>
#include <string>
void test_speed() {
auto add = nnvm::Op::Get("add");
double tstart = dmlc::GetTime();
size_t rep = 1000;
size_t n = 1000;
std::unordered_map<std::string, const nnvm::Symbol*> tmp;
std::unordered_map<std::string, std::string> kwargs;
std::vector<const nnvm::Symbol*> vec{2};
std::string name = "xx";
for (size_t t = 0; t < rep; ++t) {
nnvm::Symbol s = nnvm::Symbol::CreateVariable("x");
for (size_t i = 0; i < n; ++i) {
nnvm::Symbol nw = nnvm::Symbol::CreateFunctor(add, kwargs);
vec[0] = &s;
vec[1] =&s;
tmp.clear();
nw.Compose(vec, tmp, name);
s = nw;
}
}
double tend = dmlc::GetTime();
LOG(INFO) << "compose speed = " << n * rep / (tend - tstart) << " ops/sec";
}
void test_node_speed() {
using namespace nnvm;
auto add = nnvm::Op::Get("add");
double tstart = dmlc::GetTime();
size_t rep = 1000;
size_t n = 1000;
for (size_t t = 0; t < rep; ++t) {
nnvm::Symbol s = nnvm::Symbol::CreateVariable("x");
for (size_t i = 0; i < n; ++i) {
auto xx = NodeEntry{Node::Create(), 0, 0};
NodeEntry x = s.outputs[0];
xx.node->op = add;
xx.node->inputs.emplace_back(x);
xx.node->inputs.emplace_back(x);
Symbol ss;
ss.outputs.push_back(xx);
s = ss;
}
}
double tend = dmlc::GetTime();
LOG(INFO) << "test_node_speed speed = " << n * rep / (tend - tstart) << " ops/sec";
}
void test_api_speed() {
auto add = (void*)nnvm::Op::Get("add"); // NOLINT(*)
double tstart = dmlc::GetTime();
size_t rep = 1000;
size_t n = 1000;
std::unordered_map<std::string, const nnvm::Symbol*> tmp;
std::vector<const nnvm::Symbol*> vec{2};
std::string name = "xx";
for (size_t t = 0; t < rep; ++t) {
SymbolHandle s;
NNSymbolCreateVariable("xx", &s);
for (size_t i = 0; i < n; ++i) {
SymbolHandle arg[2];
SymbolHandle ss;
NNSymbolCreateAtomicSymbol(add, 0, nullptr, nullptr, &ss);
arg[0] = s;
arg[1] = s;
NNSymbolCompose(ss, "nn", 2, nullptr, arg);
s = ss;
}
}
double tend = dmlc::GetTime();
LOG(INFO) << "API compose speed = " << n * rep / (tend - tstart) << " ops/sec";
}
int main() {
test_speed();
test_node_speed();
test_api_speed();
return 0;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment