From 5cf08d6c9bc3f55807447b8b330075898b44fd2a Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Tue, 2 Aug 2016 11:09:01 -0700 Subject: [PATCH] [REFACTOR] copy DMLC headers, move operator to example (#20) --- nnvm/Makefile | 16 +++++++--------- nnvm/example/src/operator.cc | 127 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nnvm/include/dmlc/README | 2 ++ nnvm/include/dmlc/any.h | 345 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nnvm/include/dmlc/array_view.h | 116 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nnvm/include/dmlc/base.h | 228 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nnvm/include/dmlc/json.h | 868 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nnvm/include/dmlc/logging.h | 262 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nnvm/include/dmlc/memory.h | 261 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nnvm/include/dmlc/parameter.h | 831 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nnvm/include/dmlc/registry.h | 277 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nnvm/include/dmlc/thread_local.h | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nnvm/include/dmlc/timer.h | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ nnvm/include/dmlc/type_traits.h | 171 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nnvm/include/nnvm/tuple.h | 1 + nnvm/python/nnvm/ctypes/symbol.py | 12 ++++++------ nnvm/python/nnvm/libinfo.py | 4 ++-- nnvm/python/setup.py | 43 +++++++++++++++++++++++-------------------- nnvm/src/example/operator.cc | 127 ------------------------------------------------------------------------------------------------------------------------------- nnvm/src/test_main.cc | 88 ---------------------------------------------------------------------------------------- 20 files changed, 3653 insertions(+), 252 deletions(-) create mode 100644 nnvm/example/src/operator.cc create mode 100644 nnvm/include/dmlc/README create mode 100644 nnvm/include/dmlc/any.h create mode 100644 nnvm/include/dmlc/array_view.h create mode 100644 nnvm/include/dmlc/base.h create mode 100644 nnvm/include/dmlc/json.h create mode 100644 nnvm/include/dmlc/logging.h create mode 100644 nnvm/include/dmlc/memory.h create mode 100644 nnvm/include/dmlc/parameter.h create mode 100644 nnvm/include/dmlc/registry.h create mode 100644 nnvm/include/dmlc/thread_local.h create mode 100644 nnvm/include/dmlc/timer.h create mode 100644 nnvm/include/dmlc/type_traits.h delete mode 100644 nnvm/src/example/operator.cc delete mode 100644 nnvm/src/test_main.cc diff --git a/nnvm/Makefile b/nnvm/Makefile index 86edfae..07f46ee 100644 --- a/nnvm/Makefile +++ b/nnvm/Makefile @@ -1,15 +1,15 @@ export LDFLAGS = -pthread -lm export CFLAGS = -std=c++11 -Wall -O2 -msse2 -Wno-unknown-pragmas -funroll-loops\ - -Iinclude -Idmlc-core/include -fPIC + -Iinclude -fPIC # specify tensor path .PHONY: clean all test lint doc cython cython3 cyclean -all: lib/libnnvm.so lib/libnnvm.a cli_test +all: lib/libnnvm.a lib/libnnvm_example.so -SRC = $(wildcard src/*.cc src/*/*.cc example/*.cc) +SRC = $(wildcard src/*.cc src/*/*.cc) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) -ALL_DEP = $(filter-out build/test_main.o, $(ALL_OBJ)) +ALL_DEP = $(ALL_OBJ) include tests/cpp/unittest.mk @@ -20,16 +20,14 @@ build/%.o: src/%.cc $(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d $(CXX) -c $(CFLAGS) -c $< -o $@ -lib/libnnvm.so: $(ALL_DEP) - @mkdir -p $(@D) - $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS) lib/libnnvm.a: $(ALL_DEP) @mkdir -p $(@D) ar crv $@ $(filter %.o, $?) -cli_test: $(ALL_DEP) build/test_main.o - $(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS) +lib/libnnvm_example.so: example/src/operator.cc lib/libnnvm.a + @mkdir -p $(@D) + $(CXX) $(CFLAGS) -shared -o $@ $(filter %.cc, $^) $(LDFLAGS) -Wl,--whole-archive lib/libnnvm.a -Wl,--no-whole-archive cython: cd python; python setup.py build_ext --inplace diff --git a/nnvm/example/src/operator.cc b/nnvm/example/src/operator.cc new file mode 100644 index 0000000..70063e4 --- /dev/null +++ b/nnvm/example/src/operator.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2016 by Contributors +// This is an example on how we can register operator information to NNVM + +#include <nnvm/base.h> +#include <nnvm/op.h> +#include <nnvm/op_attr_types.h> +#include <nnvm/node.h> +#include <nnvm/graph_attr_types.h> +#include <utility> + +namespace myproject { + +using nnvm::FListInputNames; +using nnvm::FMutateInput; +using nnvm::FInferShape; +using nnvm::FInferType; +using nnvm::FInplaceOption; +using nnvm::NodeAttrs; +using nnvm::TShape; +using nnvm::array_view; + +// simply return the shape as same +inline bool SameShape(const NodeAttrs& attrs, + array_view<TShape*> ishape, + array_view<TShape*> oshape) { + if (ishape.size() == 0 || ishape[0]->ndim() == 0) return false; + for (TShape* pshape : oshape) { + *pshape = *ishape[0]; + } + for (TShape* pshape : ishape) { + *pshape = *ishape[0]; + } + return true; +} + +inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs) { + return {{0, 0}}; +} + +// simple demonstration of reshape. +NNVM_REGISTER_OP(reshape) +.describe("reshape source to target shape") +.set_num_inputs(1) +.set_attr_parser( + [](NodeAttrs* attrs) { + // parse attr parser to get target attribute + TShape target; + std::istringstream is(attrs->dict.at("target")); + CHECK(is >> target); + attrs->parsed = std::move(target); + }) +.attr<FInferShape>( + "FInferShape", [] (const NodeAttrs& attrs, + array_view<TShape*> ishape, + array_view<TShape*> oshape) { + // get parsed attribute + const TShape& target = nnvm::get<TShape>(attrs.parsed); + *oshape[0] = target; + if (ishape[0]->ndim() == 0) return false; + CHECK_EQ(ishape[0]->Size(), target.Size()) + << "Reshape op: source target shape mismatch"; + return true; + }) +.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0); + + +NNVM_REGISTER_OP(cast) +.describe("cast source type to target") +.set_num_inputs(1) +.set_attr_parser( + [](NodeAttrs* attrs) { + // parse attr parser to get target attribute + int dtype; + std::istringstream is(attrs->dict.at("dtype")); + CHECK(is >> dtype); + attrs->parsed = std::move(dtype); + }) +.attr<FInferShape>("FInferShape", SameShape) +.attr<FInferType>( + "FInferType", [](const NodeAttrs& attrs, + array_view<int*> itype, + array_view<int*> otype) { + *otype[0] = nnvm::get<int>(attrs.parsed); + return true; + }); + + +NNVM_REGISTER_OP(add) +.describe("add two data together") +.set_num_inputs(2) +.attr<FInferShape>("FInferShape", SameShape) +.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0); + +NNVM_REGISTER_OP(__add_symbol__) +.describe("Alias of add") +.set_num_inputs(2); + +NNVM_REGISTER_OP(exp) +.describe("take exponential") +.set_num_inputs(1) +.attr("inplace_pair", std::make_pair(0, 0)) +.attr<FInferShape>("FInferShape", SameShape); + +NNVM_REGISTER_OP(cross_device_copy) +.describe("Copy data across device.") +.set_num_inputs(1) +.attr<FInferShape>("FInferShape", SameShape); + + +NNVM_REGISTER_OP(conv2d) +.describe("take conv of input") +.set_num_inputs(2) +.attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector<std::string>{"data", "weight"}; + }); + +NNVM_REGISTER_OP(add) +.attr<std::string>("nick_name", "plus"); + +NNVM_REGISTER_OP(assign) +.set_num_inputs(2) +.set_num_outputs(1) +.attr<FMutateInput>("FMutateInput", [](const NodeAttrs& attrs, uint32_t index) { + return index == 0; + }); + +} // namespace myproject diff --git a/nnvm/include/dmlc/README b/nnvm/include/dmlc/README new file mode 100644 index 0000000..bfda492 --- /dev/null +++ b/nnvm/include/dmlc/README @@ -0,0 +1,2 @@ +This folder is synced from dmlc-core/include/dmlc +Contains useful utility headers for the project. diff --git a/nnvm/include/dmlc/any.h b/nnvm/include/dmlc/any.h new file mode 100644 index 0000000..5707a36 --- /dev/null +++ b/nnvm/include/dmlc/any.h @@ -0,0 +1,345 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file any.h + * \brief Container to hold any data type. + */ +#ifndef DMLC_ANY_H_ +#define DMLC_ANY_H_ + +// This code need c++11 to compile +#include <typeinfo> +#include <type_traits> +#include <utility> +#include <algorithm> + +#include "./base.h" +#include "./logging.h" + +namespace dmlc { +// forward declare any; +class any; + +/*! + * Get a reference to content stored in the any as type T. + * This will cause an error if + * T does not match the type stored. + * This function is not part of std::any standard. + * + * \param src The source source any container. + * \return The reference of content + * \tparam T The type of the value to be fetched. + */ +template<typename T> +inline T& get(any& src); // NOLINT(*) + +/*! + * Get the const reference content stored in the any as type T. + * This will cause an error if + * T does not match the type stored. + * This function is not part of std::any standard. + * + * \param src The source source any container. + * \return The reference of content + * \tparam T The type of the value to be fetched. + */ +template<typename T> +inline const T& get(const any& src); + +/*! + * \brief An any class that is compatible to std::any in c++17. + * + * \code + * dmlc::any a = std::string("mydear"), b = 1; + * // get reference out and add it + * dmlc::get<int>(b) += 1; + * // a is now string + * LOG(INFO) << dmlc::get<std::string>(a); + * // a is now 2, the string stored will be properly destructed + * a = std::move(b); + * LOG(INFO) << dmlc::get<int>(a); + * \endcode + * \sa get + */ +class any { + public: + /*! \brief default constructor */ + inline any() = default; + /*! + * \brief move constructor from another any + * \param other The other any to be moved + */ + inline any(any&& other); // NOLINT(*) + /*! + * \brief copy constructor + * \param other The other any to be copied + */ + inline any(const any& other); // NOLINT(*) + /*! + * \brief constructor from any types + * \param other The other types to be constructed into any. + * \tparam T The value type of other. + */ + template<typename T> + inline any(T&& other); // NOLINT(*) + /*! \brief destructor */ + inline ~any(); + /*! + * \brief assign operator from other + * \param other The other any to be copy or moved. + * \return self + */ + inline any& operator=(any&& other); + /*! + * \brief assign operator from other + * \param other The other any to be copy or moved. + * \return self + */ + inline any& operator=(const any& other); + /*! + * \brief assign operator from any type. + * \param other The other any to be copy or moved. + * \tparam T The value type of other. + * \return self + */ + template<typename T> + inline any& operator=(T&& other); + /*! + * \return whether the container is empty. + */ + inline bool empty() const; + /*! + * \return clear the content of container + */ + inline void clear(); + /*! + * swap current content with other + * \param other The other data to be swapped. + */ + inline void swap(any& other); // NOLINT(*) + /*! + * \return The type_info about the stored type. + */ + inline const std::type_info& type() const; + + private: + //! \cond Doxygen_Suppress + // declare of helper class + template<typename T> + class TypeOnHeap; + template<typename T> + class TypeOnStack; + template<typename T> + class TypeInfo; + // size of stack space, it takes 32 bytes for one any type. + static const size_t kStack = sizeof(void*) * 3; + static const size_t kAlign = sizeof(void*); + // container use dynamic storage only when space runs lager + union Data { + // stack space + std::aligned_storage<kStack, kAlign>::type stack; + // pointer to heap space + void* pheap; + }; + // type specific information + struct Type { + // destructor function + void (*destroy)(Data* data); + // copy constructor + void (*create_from_data)(Data* dst, const Data& src); + // the type info function + const std::type_info* ptype_info; + }; + // constant to check if data can be stored on heap. + template<typename T> + struct data_on_stack { + static const bool value = alignof(T) <= kAlign && sizeof(T) <= kStack; + }; + // declare friend with + template<typename T> + friend T& get(any& src); // NOLINT(*) + template<typename T> + friend const T& get(const any& src); + // internal construct function + inline void construct(any&& other); + // internal construct function + inline void construct(const any& other); + // internal function to check if type is correct. + template<typename T> + inline void check_type() const; + // internal type specific information + const Type* type_{nullptr}; + // internal data + Data data_; +}; + +template<typename T> +inline any::any(T&& other) { + typedef typename std::decay<T>::type DT; + if (std::is_same<DT, any>::value) { + this->construct(std::forward<T>(other)); + } else { + static_assert(std::is_copy_constructible<DT>::value, + "Any can only hold value that is copy constructable"); + type_ = TypeInfo<DT>::get_type(); + if (data_on_stack<DT>::value) { + new (&(data_.stack)) DT(std::forward<T>(other)); + } else { + data_.pheap = new DT(std::forward<T>(other)); + } + } +} + +inline any::any(any&& other) { + this->construct(std::move(other)); +} + +inline any::any(const any& other) { + this->construct(other); +} + +inline void any::construct(any&& other) { + type_ = other.type_; + data_ = other.data_; + other.type_ = nullptr; +} + +inline void any::construct(const any& other) { + type_ = other.type_; + if (type_ != nullptr) { + type_->create_from_data(&data_, other.data_); + } +} + +inline any::~any() { + this->clear(); +} + +inline any& any::operator=(any&& other) { + any(std::move(other)).swap(*this); + return *this; +} + +inline any& any::operator=(const any& other) { + any(other).swap(*this); + return *this; +} + +template<typename T> +inline any& any::operator=(T&& other) { + any(std::forward<T>(other)).swap(*this); + return *this; +} + +inline void any::swap(any& other) { // NOLINT(*) + std::swap(type_, other.type_); + std::swap(data_, other.data_); +} + +inline void any::clear() { + if (type_ != nullptr) { + if (type_->destroy != nullptr) { + type_->destroy(&data_); + } + type_ = nullptr; + } +} + +inline bool any::empty() const { + return type_ == nullptr; +} + +inline const std::type_info& any::type() const { + if (type_ != nullptr) { + return *(type_->ptype_info); + } else { + return typeid(void); + } +} + +template<typename T> +inline void any::check_type() const { + CHECK(type_ != nullptr) + << "The any container is empty"; + CHECK(type_->ptype_info == &typeid(T)) + << "The stored type mismatch" + << " stored=" << type_->ptype_info->name() + << " requested=" << typeid(T).name(); +} + +template<typename T> +inline const T& get(const any& src) { + src.check_type<T>(); + return *any::TypeInfo<T>::get_ptr(&(src.data_)); +} + +template<typename T> +inline T& get(any& src) { // NOLINT(*) + src.check_type<T>(); + return *any::TypeInfo<T>::get_ptr(&(src.data_)); +} + +template<typename T> +class any::TypeOnHeap { + public: + inline static T* get_ptr(any::Data* data) { + return static_cast<T*>(data->pheap); + } + inline static const T* get_ptr(const any::Data* data) { + return static_cast<const T*>(data->pheap); + } + inline static void create_from_data(any::Data* dst, const any::Data& data) { + dst->pheap = new T(*get_ptr(&data)); + } + inline static void destroy(Data* data) { + delete static_cast<T*>(data->pheap); + } +}; + +template<typename T> +class any::TypeOnStack { + public: + inline static T* get_ptr(any::Data* data) { + return reinterpret_cast<T*>(&(data->stack)); + } + inline static const T* get_ptr(const any::Data* data) { + return reinterpret_cast<const T*>(&(data->stack)); + } + inline static void create_from_data(any::Data* dst, const any::Data& data) { + new (&(dst->stack)) T(*get_ptr(&data)); + } + inline static void destroy(Data* data) { + T* dptr = reinterpret_cast<T*>(&(data->stack)); + dptr->~T(); + } +}; + +template<typename T> +class any::TypeInfo + : public std::conditional<any::data_on_stack<T>::value, + any::TypeOnStack<T>, + any::TypeOnHeap<T> >::type { + public: + inline static const Type* get_type() { + static TypeInfo<T> tp; + return &(tp.type_); + } + + private: + // local type + Type type_; + // constructor + TypeInfo() { + if (std::is_pod<T>::value) { + type_.destroy = nullptr; + } else { + type_.destroy = TypeInfo<T>::destroy; + } + type_.create_from_data = TypeInfo<T>::create_from_data; + type_.ptype_info = &typeid(T); + } +}; +//! \endcond + +} // namespace dmlc + +#endif // DMLC_ANY_H_ diff --git a/nnvm/include/dmlc/array_view.h b/nnvm/include/dmlc/array_view.h new file mode 100644 index 0000000..c9f19ef --- /dev/null +++ b/nnvm/include/dmlc/array_view.h @@ -0,0 +1,116 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file array_view.h + * \brief Read only data structure to reference array + */ +#ifndef DMLC_ARRAY_VIEW_H_ +#define DMLC_ARRAY_VIEW_H_ + +#include <vector> +#include <array> + +namespace dmlc { + +/*! + * \brief Read only data structure to reference continuous memory region of array. + * Provide unified view for vector, array and C style array. + * This data structure do not guarantee aliveness of referenced array. + * + * Make sure do not use array_view to record data in async function closures. + * Also do not use array_view to create reference to temporary data structure. + * + * \tparam ValueType The value + * + * \code + * std::vector<int> myvec{1,2,3}; + * dmlc::array_view<int> view(myvec); + * // indexed visit to the view. + * LOG(INFO) << view[0]; + * + * for (int v : view) { + * // visit each element in the view + * } + * \endcode + */ +template<typename ValueType> +class array_view { + public: + /*! \brief default constructor */ + array_view() = default; + /*! + * \brief default copy constructor + * \param other another array view. + */ + array_view(const array_view<ValueType> &other) = default; // NOLINT(*) + /*! + * \brief default move constructor + * \param other another array view. + */ + array_view(array_view<ValueType>&& other) = default; // NOLINT(*) + /*! + * \brief default assign constructor + * \param other another array view. + * \return self. + */ + array_view<ValueType>& operator=(const array_view<ValueType>& other) = default; // NOLINT(*) + /*! + * \brief construct array view std::vector + * \param other vector container + */ + array_view(const std::vector<ValueType>& other) { // NOLINT(*) + if (other.size() != 0) { + begin_ = &other[0]; size_ = other.size(); + } + } + /*! + * \brief construct array std::array + * \param other another array view. + */ + template<std::size_t size> + array_view(const std::array<ValueType, size>& other) { // NOLINT(*) + if (size != 0) { + begin_ = &other[0]; size_ = size; + } + } + /*! + * \brief construct array view from continuous segment + * \param begin beginning pointre + * \param end end pointer + */ + array_view(const ValueType* begin, const ValueType* end) { + if (begin < end) { + begin_ = begin; + size_ = end - begin; + } + } + /*! \return size of the array */ + inline size_t size() const { + return size_; + } + /*! \return begin of the array */ + inline const ValueType* begin() const { + return begin_; + } + /*! \return end point of the array */ + inline const ValueType* end() const { + return begin_ + size_; + } + /*! + * \brief get i-th element from the view + * \param i The index. + * \return const reference to i-th element. + */ + inline const ValueType& operator[](size_t i) const { + return begin_[i]; + } + + private: + /*! \brief the begin of the view */ + const ValueType* begin_{nullptr}; + /*! \brief The size of the view */ + size_t size_{0}; +}; + +} // namespace dmlc + +#endif // DMLC_ARRAY_VIEW_H_ diff --git a/nnvm/include/dmlc/base.h b/nnvm/include/dmlc/base.h new file mode 100644 index 0000000..5b34fd6 --- /dev/null +++ b/nnvm/include/dmlc/base.h @@ -0,0 +1,228 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file base.h + * \brief defines configuration macros + */ +#ifndef DMLC_BASE_H_ +#define DMLC_BASE_H_ + +/*! \brief whether use glog for logging */ +#ifndef DMLC_USE_GLOG +#define DMLC_USE_GLOG 0 +#endif + +/*! + * \brief whether throw dmlc::Error instead of + * directly calling abort when FATAL error occured + * NOTE: this may still not be perfect. + * do not use FATAL and CHECK in destructors + */ +#ifndef DMLC_LOG_FATAL_THROW +#define DMLC_LOG_FATAL_THROW 1 +#endif + +/*! + * \brief whether always log a message before throw + * This can help identify the error that cannot be catched. + */ +#ifndef DMLC_LOG_BEFORE_THROW +#define DMLC_LOG_BEFORE_THROW 1 +#endif + +/*! + * \brief Whether to use customized logger, + * whose output can be decided by other libraries. + */ +#ifndef DMLC_LOG_CUSTOMIZE +#define DMLC_LOG_CUSTOMIZE 0 +#endif + +/*! \brief whether compile with hdfs support */ +#ifndef DMLC_USE_HDFS +#define DMLC_USE_HDFS 0 +#endif + +/*! \brief whether compile with s3 support */ +#ifndef DMLC_USE_S3 +#define DMLC_USE_S3 0 +#endif + +/*! \brief whether or not use parameter server */ +#ifndef DMLC_USE_PS +#define DMLC_USE_PS 0 +#endif + +/*! \brief whether or not use c++11 support */ +#ifndef DMLC_USE_CXX11 +#define DMLC_USE_CXX11 (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\ + __cplusplus >= 201103L || defined(_MSC_VER)) +#endif + +/// check if g++ is before 4.6 +#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__) +#if __GNUC__ == 4 && __GNUC_MINOR__ < 6 +#pragma message("Will need g++-4.6 or higher to compile all" \ + "the features in dmlc-core, " \ + "compile without c++0x, some features may be disabled") +#undef DMLC_USE_CXX11 +#define DMLC_USE_CXX11 0 +#endif +#endif + +/*! + * \brief Enable std::thread related modules, + * Used to disable some module in mingw compile. + */ +#ifndef DMLC_ENABLE_STD_THREAD +#define DMLC_ENABLE_STD_THREAD DMLC_USE_CXX11 +#endif + +/*! \brief whether enable regex support, actually need g++-4.9 or higher*/ +#ifndef DMLC_USE_REGEX +#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER)) +#endif + +/*! \brief helper macro to generate string concat */ +#define DMLC_STR_CONCAT_(__x, __y) __x##__y +#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y) + +/*! + * \brief Disable copy constructor and assignment operator. + * + * If C++11 is supported, both copy and move constructors and + * assignment operators are deleted explicitly. Otherwise, they are + * only declared but not implemented. Place this macro in private + * section if C++11 is not available. + */ +#ifndef DISALLOW_COPY_AND_ASSIGN +# if DMLC_USE_CXX11 +# define DISALLOW_COPY_AND_ASSIGN(T) \ + T(T const&) = delete; \ + T(T&&) = delete; \ + T& operator=(T const&) = delete; \ + T& operator=(T&&) = delete +# else +# define DISALLOW_COPY_AND_ASSIGN(T) \ + T(T const&); \ + T& operator=(T const&) +# endif +#endif + +/// +/// code block to handle optionally loading +/// +#if !defined(__GNUC__) +#define fopen64 std::fopen +#endif +#if (defined __MINGW32__) && !(defined __MINGW64__) +#define fopen64 std::fopen +#endif +#ifdef _MSC_VER +#if _MSC_VER < 1900 +// NOTE: sprintf_s is not equivalent to snprintf, +// they are equivalent when success, which is sufficient for our case +#define snprintf sprintf_s +#define vsnprintf vsprintf_s +#endif +#else +#ifdef _FILE_OFFSET_BITS +#if _FILE_OFFSET_BITS == 32 +#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit") +#endif +#endif + +#ifdef __APPLE__ +#define off64_t off_t +#define fopen64 std::fopen +#endif + +extern "C" { +#include <sys/types.h> +} +#endif + +#ifdef _MSC_VER +//! \cond Doxygen_Suppress +typedef signed char int8_t; +typedef __int16 int16_t; +typedef __int32 int32_t; +typedef __int64 int64_t; +typedef unsigned char uint8_t; +typedef unsigned __int16 uint16_t; +typedef unsigned __int32 uint32_t; +typedef unsigned __int64 uint64_t; +//! \endcond +#else +#include <inttypes.h> +#endif +#include <string> +#include <vector> + +#if defined(_MSC_VER) && _MSC_VER < 1900 +#define noexcept_true throw () +#define noexcept_false +#define noexcept(a) noexcept_##a +#endif + +#if DMLC_USE_CXX11 +#define DMLC_THROW_EXCEPTION noexcept(false) +#define DMLC_NO_EXCEPTION noexcept(true) +#else +#define DMLC_THROW_EXCEPTION +#define DMLC_NO_EXCEPTION +#endif + +/*! \brief namespace for dmlc */ +namespace dmlc { +/*! + * \brief safely get the beginning address of a vector + * \param vec input vector + * \return beginning address of a vector + */ +template<typename T> +inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*) + if (vec.size() == 0) { + return NULL; + } else { + return &vec[0]; + } +} +/*! + * \brief get the beginning address of a vector + * \param vec input vector + * \return beginning address of a vector + */ +template<typename T> +inline const T *BeginPtr(const std::vector<T> &vec) { + if (vec.size() == 0) { + return NULL; + } else { + return &vec[0]; + } +} +/*! + * \brief get the beginning address of a vector + * \param str input string + * \return beginning address of a string + */ +inline char* BeginPtr(std::string &str) { // NOLINT(*) + if (str.length() == 0) return NULL; + return &str[0]; +} +/*! + * \brief get the beginning address of a vector + * \param str input string + * \return beginning address of a string + */ +inline const char* BeginPtr(const std::string &str) { + if (str.length() == 0) return NULL; + return &str[0]; +} +} // namespace dmlc + +#if defined(_MSC_VER) && _MSC_VER < 1900 +#define constexpr const +#define alignof __alignof +#endif + +#endif // DMLC_BASE_H_ diff --git a/nnvm/include/dmlc/json.h b/nnvm/include/dmlc/json.h new file mode 100644 index 0000000..2daa0aa --- /dev/null +++ b/nnvm/include/dmlc/json.h @@ -0,0 +1,868 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file json.h + * \brief Lightweight JSON Reader/Writer that read save into C++ data structs. + * This includes STL composites and structures. + */ +#ifndef DMLC_JSON_H_ +#define DMLC_JSON_H_ + +// This code requires C++11 to compile +#include <vector> +#include <iostream> +#include <cctype> +#include <string> +#include <algorithm> +#include <map> +#include <list> +#include <utility> + +#include "./base.h" +#include "./logging.h" +#include "./type_traits.h" + +#if DMLC_USE_CXX11 +#include <typeindex> +#include <typeinfo> +#include <unordered_map> +#include "./any.h" +#endif // DMLC_USE_CXX11 + +namespace dmlc { +/*! + * \brief Lightweight JSON Reader to read any STL compositions and structs. + * The user need to know the schema of the + * + */ +class JSONReader { + public: + /*! + * \brief Constructor. + * \param is the input stream. + */ + explicit JSONReader(std::istream *is) + : is_(is), + line_count_r_(0), + line_count_n_(0) {} + /*! + * \brief Parse next JSON string. + * \param out_str the output string. + * \throw dmlc::Error when next token is not string + */ + inline void ReadString(std::string *out_str); + /*! + * \brief Read Number. + * \param out_value output value; + * \throw dmlc::Error when next token is not number of ValueType. + * \tparam ValueType type of the number + */ + template<typename ValueType> + inline void ReadNumber(ValueType *out_value); + /*! + * \brief Begin parsing an object. + * \code + * std::string key; + * // value can be any type that is json serializable. + * std::string value; + * reader->BeginObject(); + * while (reader->NextObjectItem(&key)) { + * // do somthing to key value + * reader->Read(&value); + * } + * \endcode + */ + inline void BeginObject(); + /*! + * \brief Begin parsing an array. + * \code + * // value can be any type that is json serializable. + * std::string value; + * reader->BeginArray(); + * while (reader->NextObjectArrayItem(&value)) { + * // do somthing to value + * } + * \endcode + */ + inline void BeginArray(); + /*! + * \brief Try to move to next object item. + * If this call is successful, user can proceed to call + * reader->Read to read in the value. + * \param out_key the key to the next object. + * \return true if the read is successful, false if we are at end of the object. + */ + inline bool NextObjectItem(std::string *out_key); + /*! + * \brief Try to read the next element in the array. + * If this call is successful, user can proceed to call + * reader->Read to read in the value. + * \return true if the read is successful, false if we are at end of the array. + */ + inline bool NextArrayItem(); + /*! + * \brief Read next ValueType. + * \param out_value any STL or json readable type to be read + * \throw dmlc::Error when the read of ValueType is not successful. + * \tparam ValueType the data type to be read. + */ + template<typename ValueType> + inline void Read(ValueType *out_value); + + /*! \return current line count */ + inline std::string line_info() const { + char temp[64]; + std::ostringstream os; + os << " Line " << std::max(line_count_r_, line_count_n_); + is_->getline(temp, 64); + os << ", around ^`" << temp << "`"; + return os.str(); + } + + private: + /*! \brief internal reader stream */ + std::istream *is_; + /*! \brief "\\r" counter */ + size_t line_count_r_; + /*! \brief "\\n" counter */ + size_t line_count_n_; + /*! + * \brief record how many element processed in + * current array/object scope. + */ + std::vector<size_t> scope_counter_; + /*! + * \brief Read next nonspace character. + * \return the next nonspace character. + */ + inline int NextNonSpace(); + /*! + * \brief Read just before next nonspace but not read that. + * \return the next nonspace character. + */ + inline int PeekNextNonSpace(); +}; + +/*! + * \brief Lightweight json to write any STL compositions. + */ +class JSONWriter { + public: + /*! + * \brief Constructor. + * \param os the output stream. + */ + explicit JSONWriter(std::ostream *os) + : os_(os) {} + /*! + * \brief Write a string that do not contain escape characters. + * \param s the string to be written. + */ + inline void WriteNoEscape(const std::string &s); + /*! + * \brief Write a string that can contain escape characters. + * \param s the string to be written. + */ + inline void WriteString(const std::string &s); + /*! + * \brief Write a string that can contain escape characters. + * \param v the value to be written. + * \tparam ValueType The value type to be written. + */ + template<typename ValueType> + inline void WriteNumber(const ValueType &v); + /*! + * \brief Start beginning of array. + * \param multi_line whether to start an multi_line array. + * \code + * writer->BeginArray(); + * for (auto& v : vdata) { + * writer->WriteArrayItem(v); + * } + * writer->EndArray(); + * \endcode + */ + inline void BeginArray(bool multi_line = true); + /*! \brief Finish writing an array. */ + inline void EndArray(); + /*! + * \brief Start beginning of array. + * \param multi_line whether to start an multi_line array. + * \code + * writer->BeginObject(); + * for (auto& kv : vmap) { + * writer->WriteObjectKeyValue(kv.first, kv.second); + * } + * writer->EndObject(); + * \endcode + */ + inline void BeginObject(bool multi_line = true); + /*! \brief Finish writing object. */ + inline void EndObject(); + /*! + * \brief Write key value pair in the object. + * \param key the key of the object. + * \param value the value of to be written. + * \tparam ValueType The value type to be written. + */ + template<typename ValueType> + inline void WriteObjectKeyValue(const std::string &key, + const ValueType &value); + /*! + * \brief Write seperator of array, before writing next element. + * User can proceed to call writer->Write to write next item + */ + inline void WriteArraySeperator(); + /*! + * \brief Write value into array. + * \param value The value of to be written. + * \tparam ValueType The value type to be written. + */ + template<typename ValueType> + inline void WriteArrayItem(const ValueType &value); + /*! + * \brief Write value to json. + * \param value any STL or json readable that can be written. + * \tparam ValueType the data type to be write. + */ + template<typename ValueType> + inline void Write(const ValueType &value); + + private: + /*! \brief Output stream */ + std::ostream *os_; + /*! + * \brief record how many element processed in + * current array/object scope. + */ + std::vector<size_t> scope_counter_; + /*! \brief Record whether current is a multiline scope */ + std::vector<bool> scope_multi_line_; + /*! + * \brief Write seperating space and newlines + */ + inline void WriteSeperator(); +}; + +/*! + * \brief Helper class to read JSON into a class or struct object. + * \code + * struct Param { + * std::string name; + * int value; + * // define load function from JSON + * inline void Load(dmlc::JSONReader *reader) { + * dmlc::JSONStructReadHelper helper; + * helper.DeclareField("name", &name); + * helper.DeclareField("value", &value); + * helper.ReadAllFields(reader); + * } + * }; + * \endcode + */ +class JSONObjectReadHelper { + public: + /*! + * \brief Declare field of type T + * \param key the key of the of field. + * \param addr address of the data type. + * \tparam T the data type to be read, must be STL composition of JSON serializable. + */ + template<typename T> + inline void DeclareField(const std::string &key, T *addr) { + DeclareFieldInternal(key, addr, false); + } + /*! + * \brief Declare optional field of type T + * \param key the key of the of field. + * \param addr address of the data type. + * \tparam T the data type to be read, must be STL composition of JSON serializable. + */ + template<typename T> + inline void DeclareOptionalField(const std::string &key, T *addr) { + DeclareFieldInternal(key, addr, true); + } + /*! + * \brief Read in all the declared fields. + * \param reader the JSONReader to read the json. + */ + inline void ReadAllFields(JSONReader *reader); + + private: + /*! + * \brief Internal function to declare field. + * \param key the key of the of field. + * \param addr address of the data type. + * \param optional if set to true, no error will be reported if the key is not presented. + * \tparam T the data type to be read, must be STL composition of JSON serializable. + */ + template<typename T> + inline void DeclareFieldInternal(const std::string &key, T *addr, bool optional); + /*! + * \brief The internal reader function. + * \param reader The reader to read. + * \param addr The memory address to read. + */ + template<typename T> + inline static void ReaderFunction(JSONReader *reader, void *addr); + /*! \brief callback type to reader function */ + typedef void (*ReadFunction)(JSONReader *reader, void *addr); + /*! \brief internal data entry */ + struct Entry { + /*! \brief the reader function */ + ReadFunction func; + /*! \brief the address to read */ + void *addr; + /*! \brief whether it is optional */ + bool optional; + }; + /*! \brief the internal map of reader callbacks */ + std::map<std::string, Entry> map_; +}; + +#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \ + static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __ + +/*! + * \def DMLC_JSON_ENABLE_ANY + * \brief Macro to enable save/load JSON of dmlc:: whose actual type is Type. + * Any type will be saved as json array [KeyName, content] + * + * \param Type The type to be registered. + * \param KeyName The Type key assigned to the type, must be same during load. + */ +#define DMLC_JSON_ENABLE_ANY(Type, KeyName) \ + DMLC_STR_CONCAT(DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName), __COUNTER__) = \ + ::dmlc::json::AnyJSONManager::Global()->EnableType<Type>(#KeyName) \ + +//! \cond Doxygen_Suppress +namespace json { + +/*! + * \brief generic serialization handler + * \tparam T the type to be serialized + */ +template<typename T> +struct Handler; + +template<typename ValueType> +struct NumericHandler { + inline static void Write(JSONWriter *writer, const ValueType &value) { + writer->WriteNumber<ValueType>(value); + } + inline static void Read(JSONReader *reader, ValueType *value) { + reader->ReadNumber<ValueType>(value); + } +}; + +template<typename ContainerType> +struct ArrayHandler { + inline static void Write(JSONWriter *writer, const ContainerType &array) { + typedef typename ContainerType::value_type ElemType; + writer->BeginArray(array.size() > 10 || !dmlc::is_pod<ElemType>::value); + for (typename ContainerType::const_iterator it = array.begin(); + it != array.end(); ++it) { + writer->WriteArrayItem(*it); + } + writer->EndArray(); + } + inline static void Read(JSONReader *reader, ContainerType *array) { + typedef typename ContainerType::value_type ElemType; + array->clear(); + reader->BeginArray(); + while (reader->NextArrayItem()) { + ElemType value; + Handler<ElemType>::Read(reader, &value); + array->insert(array->end(), value); + } + } +}; + +template<typename ContainerType> +struct MapHandler{ + inline static void Write(JSONWriter *writer, const ContainerType &map) { + writer->BeginObject(map.size() > 1); + for (typename ContainerType::const_iterator it = map.begin(); it != map.end(); ++it) { + writer->WriteObjectKeyValue(it->first, it->second); + } + writer->EndObject(); + } + inline static void Read(JSONReader *reader, ContainerType *map) { + typedef typename ContainerType::mapped_type ElemType; + map->clear(); + reader->BeginObject(); + std::string key; + while (reader->NextObjectItem(&key)) { + ElemType value; + reader->Read(&value); + (*map)[key] = value; + } + } +}; + +template<typename T> +struct CommonJSONSerializer { + inline static void Write(JSONWriter *writer, const T &value) { + value.Save(writer); + } + inline static void Read(JSONReader *reader, T *value) { + value->Load(reader); + } +}; + +template<> +struct Handler<std::string> { + inline static void Write(JSONWriter *writer, const std::string &value) { + writer->WriteString(value); + } + inline static void Read(JSONReader *reader, std::string *str) { + reader->ReadString(str); + } +}; + +template<typename T> +struct Handler<std::vector<T> > : public ArrayHandler<std::vector<T> > { +}; + +template<typename K, typename V> +struct Handler<std::pair<K, V> > { + inline static void Write(JSONWriter *writer, const std::pair<K, V> &kv) { + writer->BeginArray(); + writer->WriteArrayItem(kv.first); + writer->WriteArrayItem(kv.second); + writer->EndArray(); + } + inline static void Read(JSONReader *reader, std::pair<K, V> *kv) { + reader->BeginArray(); + CHECK(reader->NextArrayItem()) + << "Expect array of length 2"; + Handler<K>::Read(reader, &(kv->first)); + CHECK(reader->NextArrayItem()) + << "Expect array of length 2"; + Handler<V>::Read(reader, &(kv->second)); + CHECK(!reader->NextArrayItem()) + << "Expect array of length 2"; + } +}; + +template<typename T> +struct Handler<std::list<T> > : public ArrayHandler<std::list<T> > { +}; + +template<typename V> +struct Handler<std::map<std::string, V> > : public MapHandler<std::map<std::string, V> > { +}; + +#if DMLC_USE_CXX11 +template<typename V> +struct Handler<std::unordered_map<std::string, V> > + : public MapHandler<std::unordered_map<std::string, V> > { +}; +#endif // DMLC_USE_CXX11 + +template<typename T> +struct Handler { + inline static void Write(JSONWriter *writer, const T &data) { + typedef typename dmlc::IfThenElseType<dmlc::is_arithmetic<T>::value, + NumericHandler<T>, + CommonJSONSerializer<T> >::Type THandler; + THandler::Write(writer, data); + } + inline static void Read(JSONReader *reader, T *data) { + typedef typename dmlc::IfThenElseType<dmlc::is_arithmetic<T>::value, + NumericHandler<T>, + CommonJSONSerializer<T> >::Type THandler; + THandler::Read(reader, data); + } +}; + +#if DMLC_USE_CXX11 +// Manager to store json serialization strategy. +class AnyJSONManager { + public: + template<typename T> + inline AnyJSONManager& EnableType(const std::string& type_name) { // NOLINT(*) + std::type_index tp = std::type_index(typeid(T)); + if (type_name_.count(tp) != 0) { + CHECK(type_name_.at(tp) == type_name) + << "Type has already been registered as another typename " << type_name_.at(tp); + return *this; + } + CHECK(type_map_.count(type_name) == 0) + << "Type name " << type_name << " already registered in registry"; + Entry e; + e.read = ReadAny<T>; + e.write = WriteAny<T>; + type_name_[tp] = type_name; + type_map_[type_name] = e; + return *this; + } + // return global singleton + inline static AnyJSONManager* Global() { + static AnyJSONManager inst; + return &inst; + } + + private: + AnyJSONManager() {} + + template<typename T> + inline static void WriteAny(JSONWriter *writer, const any &data) { + writer->Write(dmlc::get<T>(data)); + } + template<typename T> + inline static void ReadAny(JSONReader *reader, any* data) { + T temp; + reader->Read(&temp); + *data = std::move(temp); + } + // data entry to store vtable for any type + struct Entry { + void (*read)(JSONReader* reader, any *data); + void (*write)(JSONWriter* reader, const any& data); + }; + + template<typename T> + friend struct Handler; + + std::unordered_map<std::type_index, std::string> type_name_; + std::unordered_map<std::string, Entry> type_map_; +}; + +template<> +struct Handler<any> { + inline static void Write(JSONWriter *writer, const any &data) { + std::unordered_map<std::type_index, std::string>& + nmap = AnyJSONManager::Global()->type_name_; + std::type_index id = std::type_index(data.type()); + auto it = nmap.find(id); + CHECK(it != nmap.end() && it->first == id) + << "Type " << id.name() << " has not been registered via DMLC_JSON_ENABLE_ANY"; + std::string type_name = it->second; + AnyJSONManager::Entry e = AnyJSONManager::Global()->type_map_.at(type_name); + writer->BeginArray(false); + writer->WriteArrayItem(type_name); + writer->WriteArraySeperator(); + e.write(writer, data); + writer->EndArray(); + } + inline static void Read(JSONReader *reader, any *data) { + std::string type_name; + reader->BeginArray(); + CHECK(reader->NextArrayItem()) << "invalid any json format"; + Handler<std::string>::Read(reader, &type_name); + std::unordered_map<std::string, AnyJSONManager::Entry>& + tmap = AnyJSONManager::Global()->type_map_; + auto it = tmap.find(type_name); + CHECK(it != tmap.end() && it->first == type_name) + << "Typename " << type_name << " has not been registered via DMLC_JSON_ENABLE_ANY"; + AnyJSONManager::Entry e = it->second; + CHECK(reader->NextArrayItem()) << "invalid any json format"; + e.read(reader, data); + CHECK(!reader->NextArrayItem()) << "invalid any json format"; + } +}; +#endif // DMLC_USE_CXX11 + +} // namespace json + +// implementations of JSONReader/Writer +inline int JSONReader::NextNonSpace() { + int ch; + do { + ch = is_->get(); + if (ch == '\n') ++line_count_n_; + if (ch == '\r') ++line_count_r_; + } while (isspace(ch)); + return ch; +} + +inline int JSONReader::PeekNextNonSpace() { + int ch; + while (true) { + ch = is_->peek(); + if (ch == '\n') ++line_count_n_; + if (ch == '\r') ++line_count_r_; + if (!isspace(ch)) break; + is_->get(); + } + return ch; +} + +inline void JSONReader::ReadString(std::string *out_str) { + int ch = NextNonSpace(); + CHECK_EQ(ch, '\"') + << "Error at" << line_info() + << ", Expect \'\"\' but get \'" << static_cast<char>(ch) << '\''; + std::ostringstream os; + while (true) { + ch = is_->get(); + if (ch == '\\') { + char sch = static_cast<char>(is_->get()); + switch (sch) { + case 'r': os << "\r"; break; + case 'n': os << "\n"; break; + case '\\': os << "\\"; break; + case '\t': os << "\t"; break; + case '\"': os << "\""; break; + default: LOG(FATAL) << "unknown string escape \\" << sch; + } + } else { + if (ch == '\"') break; + os << static_cast<char>(ch); + } + if (ch == EOF || ch == '\r' || ch == '\n') { + LOG(FATAL) + << "Error at" << line_info() + << ", Expect \'\"\' but reach end of line "; + } + } + *out_str = os.str(); +} + +template<typename ValueType> +inline void JSONReader::ReadNumber(ValueType *out_value) { + *is_ >> *out_value; + CHECK(!is_->fail()) + << "Error at" << line_info() + << ", Expect number"; +} + +inline void JSONReader::BeginObject() { + int ch = NextNonSpace(); + CHECK_EQ(ch, '{') + << "Error at" << line_info() + << ", Expect \'{\' but get \'" << static_cast<char>(ch) << '\''; + scope_counter_.push_back(0); +} + +inline void JSONReader::BeginArray() { + int ch = NextNonSpace(); + CHECK_EQ(ch, '[') + << "Error at" << line_info() + << ", Expect \'{\' but get \'" << static_cast<char>(ch) << '\''; + scope_counter_.push_back(0); +} + +inline bool JSONReader::NextObjectItem(std::string *out_key) { + bool next = true; + if (scope_counter_.back() != 0) { + int ch = NextNonSpace(); + if (ch == EOF) { + next = false; + } else if (ch == '}') { + next = false; + } else { + CHECK_EQ(ch, ',') + << "Error at" << line_info() + << ", JSON object expect \'}\' or \',\' \'" << static_cast<char>(ch) << '\''; + } + } else { + int ch = PeekNextNonSpace(); + if (ch == '}') { + is_->get(); + next = false; + } + } + if (!next) { + scope_counter_.pop_back(); + return false; + } else { + scope_counter_.back() += 1; + ReadString(out_key); + int ch = NextNonSpace(); + CHECK_EQ(ch, ':') + << "Error at" << line_info() + << ", Expect \':\' but get \'" << static_cast<char>(ch) << '\''; + return true; + } +} + +inline bool JSONReader::NextArrayItem() { + bool next = true; + if (scope_counter_.back() != 0) { + int ch = NextNonSpace(); + if (ch == EOF) { + next = false; + } else if (ch == ']') { + next = false; + } else { + CHECK_EQ(ch, ',') + << "Error at" << line_info() + << ", JSON array expect \']\' or \',\'. Get \'" << static_cast<char>(ch) << "\' instead"; + } + } else { + int ch = PeekNextNonSpace(); + if (ch == ']') { + is_->get(); + next = false; + } + } + if (!next) { + scope_counter_.pop_back(); + return false; + } else { + scope_counter_.back() += 1; + return true; + } +} + +template<typename ValueType> +inline void JSONReader::Read(ValueType *out_value) { + json::Handler<ValueType>::Read(this, out_value); +} + +inline void JSONWriter::WriteNoEscape(const std::string &s) { + *os_ << '\"' << s << '\"'; +} + +inline void JSONWriter::WriteString(const std::string &s) { + std::ostream &os = *os_; + os << '\"'; + for (size_t i = 0; i < s.length(); ++i) { + char ch = s[i]; + switch (ch) { + case '\r': os << "\\r"; break; + case '\n': os << "\\n"; break; + case '\\': os << "\\\\"; break; + case '\t': os << "\\t"; break; + case '\"': os << "\\\""; break; + default: os << ch; + } + } + os << '\"'; +} + +template<typename ValueType> +inline void JSONWriter::WriteNumber(const ValueType &v) { + *os_ << v; +} + +inline void JSONWriter::BeginArray(bool multi_line) { + *os_ << '['; + scope_multi_line_.push_back(multi_line); + scope_counter_.push_back(0); +} + +inline void JSONWriter::EndArray() { + CHECK_NE(scope_multi_line_.size(), 0); + CHECK_NE(scope_counter_.size(), 0); + bool newline = scope_multi_line_.back(); + size_t nelem = scope_counter_.back(); + scope_multi_line_.pop_back(); + scope_counter_.pop_back(); + if (newline && nelem != 0) WriteSeperator(); + *os_ << ']'; +} + +inline void JSONWriter::BeginObject(bool multi_line) { + *os_ << "{"; + scope_multi_line_.push_back(multi_line); + scope_counter_.push_back(0); +} + +inline void JSONWriter::EndObject() { + CHECK_NE(scope_multi_line_.size(), 0); + CHECK_NE(scope_counter_.size(), 0); + bool newline = scope_multi_line_.back(); + size_t nelem = scope_counter_.back(); + scope_multi_line_.pop_back(); + scope_counter_.pop_back(); + if (newline && nelem != 0) WriteSeperator(); + *os_ << '}'; +} + +template<typename ValueType> +inline void JSONWriter::WriteObjectKeyValue(const std::string &key, + const ValueType &value) { + std::ostream &os = *os_; + if (scope_counter_.back() == 0) { + WriteSeperator(); + os << '\"' << key << "\": "; + } else { + os << ", "; + WriteSeperator(); + os << '\"' << key << "\": "; + } + scope_counter_.back() += 1; + json::Handler<ValueType>::Write(this, value); +} + +inline void JSONWriter::WriteArraySeperator() { + std::ostream &os = *os_; + if (scope_counter_.back() != 0) { + os << ", "; + } + scope_counter_.back() += 1; + WriteSeperator(); +} + +template<typename ValueType> +inline void JSONWriter::WriteArrayItem(const ValueType &value) { + this->WriteArraySeperator(); + json::Handler<ValueType>::Write(this, value); +} + +template<typename ValueType> +inline void JSONWriter::Write(const ValueType &value) { + size_t nscope = scope_multi_line_.size(); + json::Handler<ValueType>::Write(this, value); + CHECK_EQ(nscope, scope_multi_line_.size()) + << "Uneven scope, did you call EndArray/EndObject after each BeginObject/Array?"; +} + +inline void JSONWriter::WriteSeperator() { + if (scope_multi_line_.size() == 0 || scope_multi_line_.back()) { + *os_ << '\n' << std::string(scope_multi_line_.size() * 2, ' '); + } +} + +inline void JSONObjectReadHelper::ReadAllFields(JSONReader *reader) { + reader->BeginObject(); + std::map<std::string, int> visited; + std::string key; + while (reader->NextObjectItem(&key)) { + if (map_.count(key) != 0) { + Entry e = map_[key]; + (*e.func)(reader, e.addr); + visited[key] = 0; + } else { + std::ostringstream os; + os << "JSONReader: Unknown field " << key << ", candidates are: \n"; + for (std::map<std::string, Entry>::iterator + it = map_.begin(); it != map_.end(); ++it) { + os << '\"' <<it->first << "\"\n"; + } + LOG(FATAL) << os.str(); + } + } + if (visited.size() != map_.size()) { + for (std::map<std::string, Entry>::iterator + it = map_.begin(); it != map_.end(); ++it) { + if (it->second.optional) continue; + CHECK_NE(visited.count(it->first), 0) + << "JSONReader: Missing field \"" << it->first << "\"\n At " + << reader->line_info(); + } + } +} + +template<typename T> +inline void JSONObjectReadHelper::ReaderFunction(JSONReader *reader, void *addr) { + json::Handler<T>::Read(reader, static_cast<T*>(addr)); +} + +template<typename T> +inline void JSONObjectReadHelper:: +DeclareFieldInternal(const std::string &key, T *addr, bool optional) { + CHECK_EQ(map_.count(key), 0) + << "Adding duplicate field " << key; + Entry e; + e.func = ReaderFunction<T>; + e.addr = static_cast<void*>(addr); + e.optional = optional; + map_[key] = e; +} + +//! \endcond +} // namespace dmlc +#endif // DMLC_JSON_H_ diff --git a/nnvm/include/dmlc/logging.h b/nnvm/include/dmlc/logging.h new file mode 100644 index 0000000..afdc639 --- /dev/null +++ b/nnvm/include/dmlc/logging.h @@ -0,0 +1,262 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file logging.h + * \brief defines logging macros of dmlc + * allows use of GLOG, fall back to internal + * implementation when disabled + */ +#ifndef DMLC_LOGGING_H_ +#define DMLC_LOGGING_H_ +#include <cstdio> +#include <cstdlib> +#include <string> +#include <vector> +#include <stdexcept> +#include "./base.h" + +namespace dmlc { +/*! + * \brief exception class that will be thrown by + * default logger if DMLC_LOG_FATAL_THROW == 1 + */ +struct Error : public std::runtime_error { + /*! + * \brief constructor + * \param s the error message + */ + explicit Error(const std::string &s) : std::runtime_error(s) {} +}; +} // namespace dmlc + +#if DMLC_USE_GLOG +#include <glog/logging.h> + +namespace dmlc { +/*! + * \brief optionally redirect to google's init log + * \param argv0 The arguments. + */ +inline void InitLogging(const char* argv0) { + google::InitGoogleLogging(argv0); +} +} // namespace dmlc + +#else +// use a light version of glog +#include <assert.h> +#include <iostream> +#include <sstream> +#include <ctime> + +#if defined(_MSC_VER) +#pragma warning(disable : 4722) +#endif + +namespace dmlc { +inline void InitLogging(const char* argv0) { + // DO NOTHING +} + +// Always-on checking +#define CHECK(x) \ + if (!(x)) \ + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check " \ + "failed: " #x << ' ' +#define CHECK_LT(x, y) CHECK((x) < (y)) +#define CHECK_GT(x, y) CHECK((x) > (y)) +#define CHECK_LE(x, y) CHECK((x) <= (y)) +#define CHECK_GE(x, y) CHECK((x) >= (y)) +#define CHECK_EQ(x, y) CHECK((x) == (y)) +#define CHECK_NE(x, y) CHECK((x) != (y)) +#define CHECK_NOTNULL(x) \ + ((x) == NULL ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*) +// Debug-only checking. +#ifdef NDEBUG +#define DCHECK(x) \ + while (false) CHECK(x) +#define DCHECK_LT(x, y) \ + while (false) CHECK((x) < (y)) +#define DCHECK_GT(x, y) \ + while (false) CHECK((x) > (y)) +#define DCHECK_LE(x, y) \ + while (false) CHECK((x) <= (y)) +#define DCHECK_GE(x, y) \ + while (false) CHECK((x) >= (y)) +#define DCHECK_EQ(x, y) \ + while (false) CHECK((x) == (y)) +#define DCHECK_NE(x, y) \ + while (false) CHECK((x) != (y)) +#else +#define DCHECK(x) CHECK(x) +#define DCHECK_LT(x, y) CHECK((x) < (y)) +#define DCHECK_GT(x, y) CHECK((x) > (y)) +#define DCHECK_LE(x, y) CHECK((x) <= (y)) +#define DCHECK_GE(x, y) CHECK((x) >= (y)) +#define DCHECK_EQ(x, y) CHECK((x) == (y)) +#define DCHECK_NE(x, y) CHECK((x) != (y)) +#endif // NDEBUG + +#if DMLC_LOG_CUSTOMIZE +#define LOG_INFO dmlc::CustomLogMessage(__FILE__, __LINE__) +#else +#define LOG_INFO dmlc::LogMessage(__FILE__, __LINE__) +#endif +#define LOG_ERROR LOG_INFO +#define LOG_WARNING LOG_INFO +#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__) +#define LOG_QFATAL LOG_FATAL + +// Poor man version of VLOG +#define VLOG(x) LOG_INFO.stream() + +#define LOG(severity) LOG_##severity.stream() +#define LG LOG_INFO.stream() +#define LOG_IF(severity, condition) \ + !(condition) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) + +#ifdef NDEBUG +#define LOG_DFATAL LOG_ERROR +#define DFATAL ERROR +#define DLOG(severity) true ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) +#define DLOG_IF(severity, condition) \ + (true || !(condition)) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) +#else +#define LOG_DFATAL LOG_FATAL +#define DFATAL FATAL +#define DLOG(severity) LOG(severity) +#define DLOG_IF(severity, condition) LOG_IF(severity, condition) +#endif + +// Poor man version of LOG_EVERY_N +#define LOG_EVERY_N(severity, n) LOG(severity) + +class DateLogger { + public: + DateLogger() { +#if defined(_MSC_VER) + _tzset(); +#endif + } + const char* HumanDate() { +#if defined(_MSC_VER) + _strtime_s(buffer_, sizeof(buffer_)); +#else + time_t time_value = time(NULL); + struct tm *pnow; +#if !defined(_WIN32) + struct tm now; + pnow = localtime_r(&time_value, &now); +#else + pnow = localtime(&time_value); // NOLINT(*) +#endif + snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d", + pnow->tm_hour, pnow->tm_min, pnow->tm_sec); +#endif + return buffer_; + } + + private: + char buffer_[9]; +}; + +class LogMessage { + public: + LogMessage(const char* file, int line) + : +#ifdef __ANDROID__ + log_stream_(std::cout) +#else + log_stream_(std::cerr) +#endif + { + log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" + << line << ": "; + } + ~LogMessage() { log_stream_ << '\n'; } + std::ostream& stream() { return log_stream_; } + + protected: + std::ostream& log_stream_; + + private: + DateLogger pretty_date_; + LogMessage(const LogMessage&); + void operator=(const LogMessage&); +}; + +// customized logger that can allow user to define where to log the message. +class CustomLogMessage { + public: + CustomLogMessage(const char* file, int line) { + log_stream_ << "[" << DateLogger().HumanDate() << "] " << file << ":" + << line << ": "; + } + ~CustomLogMessage() { + Log(log_stream_.str()); + } + std::ostream& stream() { return log_stream_; } + /*! + * \brief customized logging of the message. + * This function won't be implemented by libdmlc + * \param msg The message to be logged. + */ + static void Log(const std::string& msg); + + private: + std::ostringstream log_stream_; +}; + +#if DMLC_LOG_FATAL_THROW == 0 +class LogMessageFatal : public LogMessage { + public: + LogMessageFatal(const char* file, int line) : LogMessage(file, line) {} + ~LogMessageFatal() { + log_stream_ << "\n"; + abort(); + } + + private: + LogMessageFatal(const LogMessageFatal&); + void operator=(const LogMessageFatal&); +}; +#else +class LogMessageFatal { + public: + LogMessageFatal(const char* file, int line) { + log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" + << line << ": "; + } + std::ostringstream &stream() { return log_stream_; } + ~LogMessageFatal() DMLC_THROW_EXCEPTION { + // throwing out of destructor is evil + // hopefully we can do it here + // also log the message before throw +#if DMLC_LOG_BEFORE_THROW + LOG(ERROR) << log_stream_.str(); +#endif + throw Error(log_stream_.str()); + } + + private: + std::ostringstream log_stream_; + DateLogger pretty_date_; + LogMessageFatal(const LogMessageFatal&); + void operator=(const LogMessageFatal&); +}; +#endif + +// This class is used to explicitly ignore values in the conditional +// logging macros. This avoids compiler warnings like "value computed +// is not used" and "statement has no effect". +class LogMessageVoidify { + public: + LogMessageVoidify() {} + // This has to be an operator with a precedence lower than << but + // higher than "?:". See its usage. + void operator&(std::ostream&) {} +}; + +} // namespace dmlc + +#endif +#endif // DMLC_LOGGING_H_ diff --git a/nnvm/include/dmlc/memory.h b/nnvm/include/dmlc/memory.h new file mode 100644 index 0000000..3a2b9b0 --- /dev/null +++ b/nnvm/include/dmlc/memory.h @@ -0,0 +1,261 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file memory.h + * \brief Additional memory hanlding utilities. + */ +#ifndef DMLC_MEMORY_H_ +#define DMLC_MEMORY_H_ + +#include <vector> +#include "./base.h" +#include "./logging.h" +#include "./thread_local.h" + +namespace dmlc { + +/*! + * \brief A memory pool that allocate memory of fixed size and alignment. + * \tparam size The size of each piece. + * \tparam align The alignment requirement of the memory. + */ +template<size_t size, size_t align> +class MemoryPool { + public: + /*! \brief constructor */ + MemoryPool() { + static_assert(align % alignof(LinkedList) == 0, + "alignment requirement failed."); + curr_page_.reset(new Page()); + } + /*! \brief allocate a new memory of size */ + inline void* allocate() { + if (head_ != nullptr) { + LinkedList* ret = head_; + head_ = head_->next; + return ret; + } else { + if (page_ptr_ < kPageSize) { + return &(curr_page_->data[page_ptr_++]); + } else { + allocated_.push_back(std::move(curr_page_)); + curr_page_.reset(new Page()); + page_ptr_ = 1; + return &(curr_page_->data[0]); + } + } + } + /*! + * \brief deallocate a piece of memory + * \param p The pointer to the memory to be de-allocated. + */ + inline void deallocate(void* p) { + LinkedList* ptr = static_cast<LinkedList*>(p); + ptr->next = head_; + head_ = ptr; + } + + private: + // page size of each member + static const int kPageSize = ((1 << 22) / size); + // page to be requested. + struct Page { + typename std::aligned_storage<size, align>::type data[kPageSize]; + }; + // internal linked list structure. + struct LinkedList { + LinkedList* next{nullptr}; + }; + // head of free list + LinkedList* head_{nullptr}; + // current free page + std::unique_ptr<Page> curr_page_; + // pointer to the current free page position. + size_t page_ptr_{0}; + // allocated pages. + std::vector<std::unique_ptr<Page> > allocated_; +}; + + +/*! + * \brief A thread local allocator that get memory from a threadlocal memory pool. + * This is suitable to allocate objects that do not cross thread. + * \tparam T the type of the data to be allocated. + */ +template<typename T> +class ThreadlocalAllocator { + public: + /*! \brief pointer type */ + typedef T* pointer; + /*! \brief const pointer type */ + typedef const T* const_ptr; + /*! \brief value type */ + typedef T value_type; + /*! \brief default constructor */ + ThreadlocalAllocator() {} + /*! + * \brief constructor from another allocator + * \param other another allocator + * \tparam U another type + */ + template<typename U> + ThreadlocalAllocator(const ThreadlocalAllocator<U>& other) {} + /*! + * \brief allocate memory + * \param n number of blocks + * \return an uninitialized memory of type T. + */ + inline T* allocate(size_t n) { + CHECK_EQ(n, 1); + typedef ThreadLocalStore<MemoryPool<sizeof(T), alignof(T)> > Store; + return static_cast<T*>(Store::Get()->allocate()); + } + /*! + * \brief deallocate memory + * \param p a memory to be returned. + * \param n number of blocks + */ + inline void deallocate(T* p, size_t n) { + CHECK_EQ(n, 1); + typedef ThreadLocalStore<MemoryPool<sizeof(T), alignof(T)> > Store; + Store::Get()->deallocate(p); + } +}; + + +/*! + * \brief a shared pointer like type that allocate object + * from a threadlocal object pool. This object is not thread-safe + * but can be faster than shared_ptr in certain usecases. + * \tparam T the data type. + */ +template<typename T> +struct ThreadlocalSharedPtr { + public: + /*! \brief default constructor */ + ThreadlocalSharedPtr() : block_(nullptr) {} + /*! + * \brief constructor from nullptr + * \param other the nullptr type + */ + ThreadlocalSharedPtr(std::nullptr_t other) : block_(nullptr) {} // NOLINT(*) + /*! + * \brief copy constructor + * \param other another pointer. + */ + ThreadlocalSharedPtr(const ThreadlocalSharedPtr<T>& other) + : block_(other.block_) { + IncRef(block_); + } + /*! + * \brief move constructor + * \param other another pointer. + */ + ThreadlocalSharedPtr(ThreadlocalSharedPtr<T>&& other) + : block_(other.block_) { + other.block_ = nullptr; + } + /*! + * \brief destructor + */ + ~ThreadlocalSharedPtr() { + DecRef(block_); + } + /*! + * \brief move assignment + * \param other another object to be assigned. + * \return self. + */ + inline ThreadlocalSharedPtr<T>& operator=(ThreadlocalSharedPtr<T>&& other) { + DecRef(block_); + block_ = other.block_; + other.block_ = nullptr; + return *this; + } + /*! + * \brief copy assignment + * \param other another object to be assigned. + * \return self. + */ + inline ThreadlocalSharedPtr<T> &operator=(const ThreadlocalSharedPtr<T>& other) { + DecRef(block_); + block_ = other.block_; + IncRef(block_); + return *this; + } + /*! \brief check if nullptr */ + inline bool operator==(std::nullptr_t other) const { + return block_ == nullptr; + } + /*! + * \return get the pointer content. + */ + inline T* get() const { + if (block_ == nullptr) return nullptr; + return reinterpret_cast<T*>(&(block_->data)); + } + /*! + * \brief reset the pointer to nullptr. + */ + inline void reset() { + DecRef(block_); + block_ = nullptr; + } + /*! \return if use_count == 1*/ + inline bool unique() const { + if (block_ == nullptr) return false; + return block_->use_count_ == 1; + } + /*! \return dereference pointer */ + inline T* operator*() const { + return reinterpret_cast<T*>(&(block_->data)); + } + /*! \return dereference pointer */ + inline T* operator->() const { + return reinterpret_cast<T*>(&(block_->data)); + } + /*! + * \brief create a new space from threadlocal storage and return it. + * \tparam Args the arguments. + * \param args The input argument + * \return the allocated pointer. + */ + template <typename... Args> + inline static ThreadlocalSharedPtr<T> Create(Args&&... args) { + ThreadlocalAllocator<RefBlock> arena; + ThreadlocalSharedPtr<T> p; + p.block_ = arena.allocate(1); + p.block_->use_count_ = 1; + new (&(p.block_->data)) T(std::forward<Args>(args)...); + return p; + } + + private: + // internal reference block + struct RefBlock { + typename std::aligned_storage<sizeof(T), alignof(T)>::type data; + unsigned use_count_; + }; + // decrease ref counter + inline static void DecRef(RefBlock* block) { + if (block != nullptr) { + if (--block->use_count_ == 0) { + ThreadlocalAllocator<RefBlock> arena; + T* dptr = reinterpret_cast<T*>(&(block->data)); + dptr->~T(); + arena.deallocate(block, 1); + } + } + } + // increase ref counter + inline static void IncRef(RefBlock* block) { + if (block != nullptr) { + ++block->use_count_; + } + } + // internal block + RefBlock *block_; +}; + +} // namespace dmlc + +#endif // DMLC_MEMORY_H_ diff --git a/nnvm/include/dmlc/parameter.h b/nnvm/include/dmlc/parameter.h new file mode 100644 index 0000000..4ff99f8 --- /dev/null +++ b/nnvm/include/dmlc/parameter.h @@ -0,0 +1,831 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file parameter.h + * \brief Provide lightweight util to do parameter setup and checking. + */ +#ifndef DMLC_PARAMETER_H_ +#define DMLC_PARAMETER_H_ + +#include <cstddef> +#include <cstdlib> +#include <sstream> +#include <limits> +#include <map> +#include <set> +#include <typeinfo> +#include <string> +#include <vector> +#include <algorithm> +#include <utility> +#include <iostream> +#include "./base.h" +#include "./json.h" +#include "./logging.h" +#include "./type_traits.h" + +namespace dmlc { +// this file is backward compatible with non-c++11 +/*! \brief Error throwed by parameter checking */ +struct ParamError : public dmlc::Error { + /*! + * \brief constructor + * \param msg error message + */ + explicit ParamError(const std::string &msg) + : dmlc::Error(msg) {} +}; + +/*! + * \brief Get environment variable with default. + * \param key the name of environment variable. + * \param default_value the default value of environment vriable. + * \return The value received + */ +template<typename ValueType> +inline ValueType GetEnv(const char *key, + ValueType default_value); + +/*! \brief internal namespace for parameter manangement */ +namespace parameter { +// forward declare ParamManager +class ParamManager; +// forward declare FieldAccessEntry +class FieldAccessEntry; +// forward declare FieldEntry +template<typename DType> +class FieldEntry; +// forward declare ParamManagerSingleton +template<typename PType> +struct ParamManagerSingleton; +} // namespace parameter +/*! + * \brief Information about a parameter field in string representations. + */ +struct ParamFieldInfo { + /*! \brief name of the field */ + std::string name; + /*! \brief type of the field in string format */ + std::string type; + /*! + * \brief detailed type information string + * This include the default value, enum constran and typename. + */ + std::string type_info_str; + /*! \brief detailed description of the type */ + std::string description; +}; + +/*! + * \brief Parameter is the base type every parameter struct should inheritate from + * The following code is a complete example to setup parameters. + * \code + * struct Param : public dmlc::Parameter<Param> { + * float learning_rate; + * int num_hidden; + * std::string name; + * // declare parameters in header file + * DMLC_DECLARE_PARAMETER(Param) { + * DMLC_DECLARE_FIELD(num_hidden).set_range(0, 1000); + * DMLC_DECLARE_FIELD(learning_rate).set_default(0.01f); + * DMLC_DECLARE_FIELD(name).set_default("hello"); + * } + * }; + * // register it in cc file + * DMLC_REGISTER_PARAMETER(Param); + * \endcode + * + * After that, the Param struct will get all the functions defined in Parameter. + * \tparam PType the type of parameter struct + * + * \sa DMLC_DECLARE_FIELD, DMLC_REGISTER_PARAMETER, DMLC_DECLARE_PARAMETER + */ +template<typename PType> +struct Parameter { + public: + /*! + * \brief initialize the parameter by keyword arguments. + * This function will initialize the parameter struct, check consistency + * and throw error if something wrong happens. + * + * \param kwargs map of keyword arguments, or vector of pairs + * \tparam Container container type + * \throw ParamError when something go wrong. + */ + template<typename Container> + inline void Init(const Container &kwargs) { + PType::__MANAGER__()->RunInit(static_cast<PType*>(this), + kwargs.begin(), kwargs.end(), NULL); + } + /*! + * \brief initialize the parameter by keyword arguments. + * This is same as Init, but allow unknown arguments. + * + * \param kwargs map of keyword arguments, or vector of pairs + * \tparam Container container type + * \throw ParamError when something go wrong. + * \return vector of pairs of unknown arguments. + */ + template<typename Container> + inline std::vector<std::pair<std::string, std::string> > + InitAllowUnknown(const Container &kwargs) { + std::vector<std::pair<std::string, std::string> > unknown; + PType::__MANAGER__()->RunInit(static_cast<PType*>(this), + kwargs.begin(), kwargs.end(), &unknown); + return unknown; + } + /*! + * \brief Return a dictionary representation of the parameters + * \return A dictionary that maps key -> value + */ + inline std::map<std::string, std::string> __DICT__() const { + std::vector<std::pair<std::string, std::string> > vec + = PType::__MANAGER__()->GetDict(this->head()); + return std::map<std::string, std::string>(vec.begin(), vec.end()); + } + /*! + * \brief Write the parameters in JSON format. + * \param writer JSONWriter used for writing. + */ + inline void Save(dmlc::JSONWriter *writer) const { + writer->Write(this->__DICT__()); + } + /*! + * \brief Load the parameters from JSON. + * \param reader JSONReader used for loading. + * \throw ParamError when something go wrong. + */ + inline void Load(dmlc::JSONReader *reader) { + std::map<std::string, std::string> kwargs; + reader->Read(&kwargs); + this->Init(kwargs); + } + /*! + * \brief Get the fields of the parameters. + * \return List of ParamFieldInfo of each field. + */ + inline static std::vector<ParamFieldInfo> __FIELDS__() { + return PType::__MANAGER__()->GetFieldInfo(); + } + /*! + * \brief Print docstring of the parameter + * \return the printed docstring + */ + inline static std::string __DOC__() { + std::ostringstream os; + PType::__MANAGER__()->PrintDocString(os); + return os.str(); + } + + protected: + /*! + * \brief internal function to allow declare of a parameter memember + * \param manager the parameter manager + * \param key the key name of the parameter + * \param ref the reference to the parameter in the struct. + */ + template<typename DType> + inline parameter::FieldEntry<DType>& DECLARE( + parameter::ParamManagerSingleton<PType> *manager, + const std::string &key, DType &ref) { // NOLINT(*) + parameter::FieldEntry<DType> *e = + new parameter::FieldEntry<DType>(); + e->Init(key, this->head(), ref); + manager->manager.AddEntry(key, e); + return *e; + } + + private: + /*! \return Get head pointer of child structure */ + inline PType *head() const { + return static_cast<PType*>(const_cast<Parameter<PType>*>(this)); + } +}; + +//! \cond Doxygen_Suppress +/*! + * \brief macro used to declare parameter + * + * Example: + * \code + * struct Param : public dmlc::Parameter<Param> { + * // declare parameters in header file + * DMLC_DECLARE_PARAMETER(Param) { + * // details of declarations + * } + * }; + * \endcode + * + * This macro need to be put in a source file so that registeration only happens once. + * Refer to example code in Parameter for details + * + * \param PType the name of parameter struct. + * \sa Parameter + */ +#define DMLC_DECLARE_PARAMETER(PType) \ + static ::dmlc::parameter::ParamManager *__MANAGER__(); \ + inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \ + +/*! + * \brief macro to declare fields + * \param FieldName the name of the field. + */ +#define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName) + +/*! + * \brief macro to declare alias of a fields + * \param FieldName the name of the field. + * \param AliasName the name of the alias, must be declared after the field is declared. + */ +#define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName) + +/*! + * \brief Macro used to register parameter. + * + * This macro need to be put in a source file so that registeration only happens once. + * Refer to example code in Parameter for details + * \param PType the type of parameter struct. + * \sa Parameter + */ +#define DMLC_REGISTER_PARAMETER(PType) \ + ::dmlc::parameter::ParamManager *PType::__MANAGER__() { \ + static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \ + return &inst.manager; \ + } \ + static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \ + (*PType::__MANAGER__()) \ + +//! \endcond +/*! + * \brief internal namespace for parameter manangement + * There is no need to use it directly in normal case + */ +namespace parameter { +/*! + * \brief FieldAccessEntry interface to help manage the parameters + * Each entry can be used to access one parameter in the Parameter struct. + * + * This is an internal interface used that is used to manage parameters + */ +class FieldAccessEntry { + public: + FieldAccessEntry() + : has_default_(false) {} + /*! \brief destructor */ + virtual ~FieldAccessEntry() {} + /*! + * \brief set the default value. + * \param head the pointer to the head of the struct + * \throw error if no default is presented + */ + virtual void SetDefault(void *head) const = 0; + /*! + * \brief set the parameter by string value + * \param head the pointer to the head of the struct + * \param value the value to be set + */ + virtual void Set(void *head, const std::string &value) const = 0; + // check if value is OK + virtual void Check(void *head) const {} + /*! + * \brief get the string representation of value. + * \param head the pointer to the head of the struct + */ + virtual std::string GetStringValue(void *head) const = 0; + /*! + * \brief Get field information + * \return the corresponding field information + */ + virtual ParamFieldInfo GetFieldInfo() const = 0; + + protected: + /*! \brief whether this parameter have default value */ + bool has_default_; + /*! \brief positional index of parameter in struct */ + size_t index_; + /*! \brief parameter key name */ + std::string key_; + /*! \brief parameter type */ + std::string type_; + /*! \brief description of the parameter */ + std::string description_; + /*! + * \brief print string representation of default value + * \parma os the stream to print the docstring to. + */ + virtual void PrintDefaultValueString(std::ostream &os) const = 0; // NOLINT(*) + // allow ParamManager to modify self + friend class ParamManager; +}; + +/*! + * \brief manager class to handle parameter setting for each type + * An manager will be created for each parameter types. + */ +class ParamManager { + public: + /*! \brief destructor */ + ~ParamManager() { + for (size_t i = 0; i < entry_.size(); ++i) { + delete entry_[i]; + } + } + /*! + * \brief find the access entry by parameter key + * \param key the key of the parameter. + * \return pointer to FieldAccessEntry, NULL if nothing is found. + */ + inline FieldAccessEntry *Find(const std::string &key) const { + std::map<std::string, FieldAccessEntry*>::const_iterator it = + entry_map_.find(key); + if (it == entry_map_.end()) return NULL; + return it->second; + } + /*! + * \brief set parameter by keyword arguments. + * \param head head to the parameter field. + * \param begin begin iterator of original kwargs + * \param end end iterator of original kwargs + * \param unknown_args optional, used to hold unknown arguments + * When it is specified, unknown arguments will be stored into here, instead of raise an error + * \tparam RandomAccessIterator iterator type + * \throw ParamError when there is unknown argument and unknown_args == NULL, or required argument is missing. + */ + template<typename RandomAccessIterator> + inline void RunInit(void *head, + RandomAccessIterator begin, + RandomAccessIterator end, + std::vector<std::pair<std::string, std::string> > *unknown_args) const { + std::set<FieldAccessEntry*> selected_args; + for (RandomAccessIterator it = begin; it != end; ++it) { + FieldAccessEntry *e = Find(it->first); + if (e != NULL) { + e->Set(head, it->second); + e->Check(head); + selected_args.insert(e); + } else { + if (unknown_args != NULL) { + unknown_args->push_back(*it); + } else { + std::ostringstream os; + os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n"; + os << "----------------\n"; + PrintDocString(os); + throw dmlc::ParamError(os.str()); + } + } + } + + for (std::map<std::string, FieldAccessEntry*>::const_iterator it = entry_map_.begin(); + it != entry_map_.end(); ++it) { + if (selected_args.count(it->second) == 0) { + it->second->SetDefault(head); + } + } + } + /*! + * \brief internal function to add entry to manager, + * The manager will take ownership of the entry. + * \param key the key to the parameters + * \param e the pointer to the new entry. + */ + inline void AddEntry(const std::string &key, FieldAccessEntry *e) { + e->index_ = entry_.size(); + // TODO(bing) better error message + if (entry_map_.count(key) != 0) { + LOG(FATAL) << "key " << key << " has already been registered in " << name_; + } + entry_.push_back(e); + entry_map_[key] = e; + } + /*! + * \brief internal function to add entry to manager, + * The manager will take ownership of the entry. + * \param key the key to the parameters + * \param e the pointer to the new entry. + */ + inline void AddAlias(const std::string& field, const std::string& alias) { + if (entry_map_.count(field) == 0) { + LOG(FATAL) << "key " << field << " has not been registered in " << name_; + } + if (entry_map_.count(alias) != 0) { + LOG(FATAL) << "Alias " << alias << " has already been registered in " << name_; + } + entry_map_[alias] = entry_map_[field]; + } + /*! + * \brief set the name of parameter manager + * \param name the name to set + */ + inline void set_name(const std::string &name) { + name_ = name; + } + /*! + * \brief get field information of each field. + * \return field information + */ + inline std::vector<ParamFieldInfo> GetFieldInfo() const { + std::vector<ParamFieldInfo> ret(entry_.size()); + for (size_t i = 0; i < entry_.size(); ++i) { + ret[i] = entry_[i]->GetFieldInfo(); + } + return ret; + } + /*! + * \brief Print readible docstring to ostream, add newline. + * \parma os the stream to print the docstring to. + */ + inline void PrintDocString(std::ostream &os) const { // NOLINT(*) + for (size_t i = 0; i < entry_.size(); ++i) { + ParamFieldInfo info = entry_[i]->GetFieldInfo(); + os << info.name << " : " << info.type_info_str << '\n'; + if (info.description.length() != 0) { + os << " " << info.description << '\n'; + } + } + } + /*! + * \brief Get internal parameters in vector of pairs. + * \param head the head of the struct. + * \param skip_default skip the values that equals default value. + * \return the parameter dictionary. + */ + inline std::vector<std::pair<std::string, std::string> > GetDict(void * head) const { + std::vector<std::pair<std::string, std::string> > ret; + for (std::map<std::string, FieldAccessEntry*>::const_iterator + it = entry_map_.begin(); it != entry_map_.end(); ++it) { + ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head))); + } + return ret; + } + + private: + /*! \brief parameter struct name */ + std::string name_; + /*! \brief positional list of entries */ + std::vector<FieldAccessEntry*> entry_; + /*! \brief map of key to entry */ + std::map<std::string, FieldAccessEntry*> entry_map_; +}; + +//! \cond Doxygen_Suppress + +// The following piece of code will be template heavy and less documented +// singleton parameter manager for certain type, used for initialization +template<typename PType> +struct ParamManagerSingleton { + ParamManager manager; + explicit ParamManagerSingleton(const std::string ¶m_name) { + PType param; + param.__DECLARE__(this); + manager.set_name(param_name); + } +}; + +// Base class of FieldEntry +// implement set_default +template<typename TEntry, typename DType> +class FieldEntryBase : public FieldAccessEntry { + public: + // entry type + typedef TEntry EntryType; + // implement set value + virtual void Set(void *head, const std::string &value) const { + std::istringstream is(value); + is >> this->Get(head); + if (!is.fail()) { + while (!is.eof()) { + int ch = is.get(); + if (ch == EOF) { + is.clear(); break; + } + if (!isspace(ch)) { + is.setstate(std::ios::failbit); break; + } + } + } + + if (is.fail()) { + std::ostringstream os; + os << "Invalid Parameter format for " << key_ + << " expect " << type_ << " but value=\'" << value<< '\''; + throw dmlc::ParamError(os.str()); + } + } + virtual std::string GetStringValue(void *head) const { + std::ostringstream os; + PrintValue(os, this->Get(head)); + return os.str(); + } + virtual ParamFieldInfo GetFieldInfo() const { + ParamFieldInfo info; + std::ostringstream os; + info.name = key_; + info.type = type_; + os << type_; + if (has_default_) { + os << ',' << " optional, default="; + PrintDefaultValueString(os); + } else { + os << ", required"; + } + info.type_info_str = os.str(); + info.description = description_; + return info; + } + // implement set head to default value + virtual void SetDefault(void *head) const { + if (!has_default_) { + std::ostringstream os; + os << "Required parameter " << key_ + << " of " << type_ << " is not presented"; + throw dmlc::ParamError(os.str()); + } else { + this->Get(head) = default_value_; + } + } + // return reference of self as derived type + inline TEntry &self() { + return *(static_cast<TEntry*>(this)); + } + // implement set_default + inline TEntry &set_default(const DType &default_value) { + default_value_ = default_value; + has_default_ = true; + // return self to allow chaining + return this->self(); + } + // implement describe + inline TEntry &describe(const std::string &description) { + description_ = description; + // return self to allow chaining + return this->self(); + } + // initialization function + inline void Init(const std::string &key, + void *head, DType &ref) { // NOLINT(*) + this->key_ = key; + if (this->type_.length() == 0) { + this->type_ = dmlc::type_name<DType>(); + } + this->offset_ = ((char*)&ref) - ((char*)head); // NOLINT(*) + } + + protected: + // print the value + virtual void PrintValue(std::ostream &os, DType value) const { // NOLINT(*) + os << value; + } + virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) + PrintValue(os, default_value_); + } + // get the internal representation of parameter + // for example if this entry corresponds field param.learning_rate + // then Get(¶m) will return reference to param.learning_rate + inline DType &Get(void *head) const { + return *(DType*)((char*)(head) + offset_); // NOLINT(*) + } + // internal offset of the field + ptrdiff_t offset_; + // default value of field + DType default_value_; +}; + +// parameter base for numeric types that have range +template<typename TEntry, typename DType> +class FieldEntryNumeric + : public FieldEntryBase<TEntry, DType> { + public: + FieldEntryNumeric() + : has_begin_(false), has_end_(false) {} + // implement set_range + virtual TEntry &set_range(DType begin, DType end) { + begin_ = begin; end_ = end; + has_begin_ = true; has_end_ = true; + return this->self(); + } + // implement set_range + virtual TEntry &set_lower_bound(DType begin) { + begin_ = begin; has_begin_ = true; + return this->self(); + } + // consistency check for numeric ranges + virtual void Check(void *head) const { + FieldEntryBase<TEntry, DType>::Check(head); + DType v = this->Get(head); + if (has_begin_ && has_end_) { + if (v < begin_ || v > end_) { + std::ostringstream os; + os << "value " << v << "for Parameter " << this->key_ + << " exceed bound [" << begin_ << ',' << end_ <<']'; + throw dmlc::ParamError(os.str()); + } + } else if (has_begin_ && v < begin_) { + std::ostringstream os; + os << "value " << v << "for Parameter " << this->key_ + << " should be greater equal to " << begin_; + throw dmlc::ParamError(os.str()); + } else if (has_end_ && v > end_) { + std::ostringstream os; + os << "value " << v << "for Parameter " << this->key_ + << " should be smaller equal to " << end_; + throw dmlc::ParamError(os.str()); + } + } + + protected: + // whether it have begin and end range + bool has_begin_, has_end_; + // data bound + DType begin_, end_; +}; + +/*! + * \brief FieldEntry defines parsing and checking behavior of DType. + * This class can be specialized to implement specific behavior of more settings. + * \tparam DType the data type of the entry. + */ +template<typename DType> +class FieldEntry : + public IfThenElseType<dmlc::is_arithmetic<DType>::value, + FieldEntryNumeric<FieldEntry<DType>, DType>, + FieldEntryBase<FieldEntry<DType>, DType> >::Type { +}; + +// specialize define for int(enum) +template<> +class FieldEntry<int> + : public FieldEntryNumeric<FieldEntry<int>, int> { + public: + // construct + FieldEntry<int>() : is_enum_(false) {} + // parent + typedef FieldEntryNumeric<FieldEntry<int>, int> Parent; + // override set + virtual void Set(void *head, const std::string &value) const { + if (is_enum_) { + std::map<std::string, int>::const_iterator it = enum_map_.find(value); + std::ostringstream os; + if (it == enum_map_.end()) { + os << "Invalid Input: \'" << value; + os << "\', valid values are: "; + PrintEnums(os); + throw dmlc::ParamError(os.str()); + } else { + os << it->second; + Parent::Set(head, os.str()); + } + } else { + Parent::Set(head, value); + } + } + virtual ParamFieldInfo GetFieldInfo() const { + if (is_enum_) { + ParamFieldInfo info; + std::ostringstream os; + info.name = key_; + info.type = type_; + PrintEnums(os); + if (has_default_) { + os << ',' << "optional, default="; + PrintDefaultValueString(os); + } else { + os << ", required"; + } + info.type_info_str = os.str(); + info.description = description_; + return info; + } else { + return Parent::GetFieldInfo(); + } + } + // add enum + inline FieldEntry<int> &add_enum(const std::string &key, int value) { + if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \ + enum_back_map_.count(value) != 0) { + std::ostringstream os; + os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n"; + os << "Enums: "; + for (std::map<std::string, int>::const_iterator it = enum_map_.begin(); + it != enum_map_.end(); ++it) { + os << "(" << it->first << ": " << it->second << "), "; + } + throw dmlc::ParamError(os.str()); + } + enum_map_[key] = value; + enum_back_map_[value] = key; + is_enum_ = true; + return this->self(); + } + + protected: + // enum flag + bool is_enum_; + // enum map + std::map<std::string, int> enum_map_; + // enum map + std::map<int, std::string> enum_back_map_; + // override print behavior + virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) + os << '\''; + PrintValue(os, default_value_); + os << '\''; + } + // override print default + virtual void PrintValue(std::ostream &os, int value) const { // NOLINT(*) + if (is_enum_) { + CHECK_NE(enum_back_map_.count(value), 0) + << "Value not found in enum declared"; + os << enum_back_map_.at(value); + } else { + os << value; + } + } + + + private: + inline void PrintEnums(std::ostream &os) const { // NOLINT(*) + os << '{'; + for (std::map<std::string, int>::const_iterator + it = enum_map_.begin(); it != enum_map_.end(); ++it) { + if (it != enum_map_.begin()) { + os << ", "; + } + os << "\'" << it->first << '\''; + } + os << '}'; + } +}; + +// specialize define for string +template<> +class FieldEntry<std::string> + : public FieldEntryBase<FieldEntry<std::string>, std::string> { + public: + // parent class + typedef FieldEntryBase<FieldEntry<std::string>, std::string> Parent; + // override set + virtual void Set(void *head, const std::string &value) const { + this->Get(head) = value; + } + // override print default + virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) + os << '\'' << default_value_ << '\''; + } +}; + +// specialize define for bool +template<> +class FieldEntry<bool> + : public FieldEntryBase<FieldEntry<bool>, bool> { + public: + // parent class + typedef FieldEntryBase<FieldEntry<bool>, bool> Parent; + // override set + virtual void Set(void *head, const std::string &value) const { + std::string lower_case; lower_case.resize(value.length()); + std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower); + bool &ref = this->Get(head); + if (lower_case == "true") { + ref = true; + } else if (lower_case == "false") { + ref = false; + } else if (lower_case == "1") { + ref = true; + } else if (lower_case == "0") { + ref = false; + } else { + std::ostringstream os; + os << "Invalid Parameter format for " << key_ + << " expect " << type_ << " but value=\'" << value<< '\''; + throw dmlc::ParamError(os.str()); + } + } + + protected: + // print default string + virtual void PrintValue(std::ostream &os, bool value) const { // NOLINT(*) + if (value) { + os << "True"; + } else { + os << "False"; + } + } +}; + +} // namespace parameter +//! \endcond + +// implement GetEnv +template<typename ValueType> +inline ValueType GetEnv(const char *key, + ValueType default_value) { + const char *val = getenv(key); + if (val == NULL) return default_value; + ValueType ret; + parameter::FieldEntry<ValueType> e; + e.Init(key, &ret, ret); + e.Set(&ret, val); + return ret; +} +} // namespace dmlc +#endif // DMLC_PARAMETER_H_ diff --git a/nnvm/include/dmlc/registry.h b/nnvm/include/dmlc/registry.h new file mode 100644 index 0000000..67fbc43 --- /dev/null +++ b/nnvm/include/dmlc/registry.h @@ -0,0 +1,277 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file registry.h + * \brief Registry utility that helps to build registry singletons. + */ +#ifndef DMLC_REGISTRY_H_ +#define DMLC_REGISTRY_H_ + +#include <map> +#include <string> +#include <vector> +#include "./base.h" +#include "./logging.h" +#include "./parameter.h" +#include "./type_traits.h" + +namespace dmlc { +/*! + * \brief Registry class. + * Registry can be used to register global singletons. + * The most commonly use case are factory functions. + * + * \tparam EntryType Type of Registry entries, + * EntryType need to name a name field. + */ +template<typename EntryType> +class Registry { + public: + /*! \return list of functions in the registry */ + inline static const std::vector<const EntryType*> &List() { + return Get()->entry_list_; + } + /*! + * \brief Find the entry with corresponding name. + * \param name name of the function + * \return the corresponding function, can be NULL + */ + inline static const EntryType *Find(const std::string &name) { + const std::map<std::string, EntryType*> &fmap = Get()->fmap_; + typename std::map<std::string, EntryType*>::const_iterator p = fmap.find(name); + if (p != fmap.end()) { + return p->second; + } else { + return NULL; + } + } + /*! + * \brief Internal function to register a name function under name. + * \param name name of the function + * \return ref to the registered entry, used to set properties + */ + inline EntryType &__REGISTER__(const std::string& name) { + CHECK_EQ(fmap_.count(name), 0) + << name << " already registered"; + EntryType *e = new EntryType(); + e->name = name; + fmap_[name] = e; + entry_list_.push_back(e); + return *e; + } + /*! + * \brief Internal function to either register or get registered entry + * \param name name of the function + * \return ref to the registered entry, used to set properties + */ + inline EntryType &__REGISTER_OR_GET__(const std::string& name) { + if (fmap_.count(name) == 0) { + return __REGISTER__(name); + } else { + return *fmap_.at(name); + } + } + /*! + * \brief get a singleton of the Registry. + * This function can be defined by DMLC_ENABLE_REGISTRY. + * \return get a singleton + */ + static Registry *Get(); + + private: + /*! \brief list of entry types */ + std::vector<const EntryType*> entry_list_; + /*! \brief map of name->function */ + std::map<std::string, EntryType*> fmap_; + /*! \brief constructor */ + Registry() {} + /*! \brief destructor */ + ~Registry() { + for (typename std::map<std::string, EntryType*>::iterator p = fmap_.begin(); + p != fmap_.end(); ++p) { + delete p->second; + } + } +}; + +/*! + * \brief Common base class for function registry. + * + * \code + * // This example demonstrates how to use Registry to create a factory of trees. + * struct TreeFactory : + * public FunctionRegEntryBase<TreeFactory, std::function<Tree*()> > { + * }; + * + * // in a independent cc file + * namespace dmlc { + * DMLC_REGISTRY_ENABLE(TreeFactory); + * } + * // register binary tree constructor into the registry. + * DMLC_REGISTRY_REGISTER(TreeFactory, TreeFactory, BinaryTree) + * .describe("Constructor of BinaryTree") + * .set_body([]() { return new BinaryTree(); }); + * \endcode + * + * \tparam EntryType The type of subclass that inheritate the base. + * \tparam FunctionType The function type this registry is registerd. + */ +template<typename EntryType, typename FunctionType> +class FunctionRegEntryBase { + public: + /*! \brief name of the entry */ + std::string name; + /*! \brief description of the entry */ + std::string description; + /*! \brief additional arguments to the factory function */ + std::vector<ParamFieldInfo> arguments; + /*! \brief Function body to create ProductType */ + FunctionType body; + /*! \brief Return type of the function */ + std::string return_type; + + /*! + * \brief Set the function body. + * \param body Function body to set. + * \return reference to self. + */ + inline EntryType &set_body(FunctionType body) { + this->body = body; + return this->self(); + } + /*! + * \brief Describe the function. + * \param description The description of the factory function. + * \return reference to self. + */ + inline EntryType &describe(const std::string &description) { + this->description = description; + return this->self(); + } + /*! + * \brief Add argument information to the function. + * \param name Name of the argument. + * \param type Type of the argument. + * \param description Description of the argument. + * \return reference to self. + */ + inline EntryType &add_argument(const std::string &name, + const std::string &type, + const std::string &description) { + ParamFieldInfo info; + info.name = name; + info.type = type; + info.type_info_str = info.type; + info.description = description; + arguments.push_back(info); + return this->self(); + } + /*! + * \brief Append list if arguments to the end. + * \param args Additional list of arguments. + * \return reference to self. + */ + inline EntryType &add_arguments(const std::vector<ParamFieldInfo> &args) { + arguments.insert(arguments.end(), args.begin(), args.end()); + return this->self(); + } + /*! + * \brief Set the return type. + * \param type Return type of the function, could be Symbol or Symbol[] + * \return reference to self. + */ + inline EntryType &set_return_type(const std::string &type) { + return_type = type; + return this->self(); + } + + protected: + /*! + * \return reference of self as derived type + */ + inline EntryType &self() { + return *(static_cast<EntryType*>(this)); + } +}; + +/*! + * \def DMLC_REGISTRY_ENABLE + * \brief Macro to enable the registry of EntryType. + * This macro must be used under namespace dmlc, and only used once in cc file. + * \param EntryType Type of registry entry + */ +#define DMLC_REGISTRY_ENABLE(EntryType) \ + template<> \ + Registry<EntryType > *Registry<EntryType >::Get() { \ + static Registry<EntryType > inst; \ + return &inst; \ + } \ + +/*! + * \brief Generic macro to register an EntryType + * There is a complete example in FactoryRegistryEntryBase. + * + * \param EntryType The type of registry entry. + * \param EntryTypeName The typename of EntryType, must do not contain namespace :: . + * \param Name The name to be registered. + * \sa FactoryRegistryEntryBase + */ +#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \ + static EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \ + ::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \ + +/*! + * \brief (Optional) Declare a file tag to current file that contains object registrations. + * + * This will declare a dummy function that will be called by register file to + * incur a link dependency. + * + * \param UniqueTag The unique tag used to represent. + * \sa DMLC_REGISTRY_LINK_TAG + */ +#define DMLC_REGISTRY_FILE_TAG(UniqueTag) \ + int __dmlc_registry_file_tag_ ## UniqueTag ## __() { return 0; } + +/*! + * \brief (Optional) Force link to all the objects registered in file tag. + * + * This macro must be used in the same file as DMLC_REGISTRY_ENABLE and + * in the same namespace as DMLC_REGISTRY_FILE_TAG + * + * DMLC_REGISTRY_FILE_TAG and DMLC_REGISTRY_LINK_TAG are optional macros for registration. + * They are used to encforce link of certain file into during static linking. + * + * This is mainly used to solve problem during statically link a library which contains backward registration. + * Specifically, this avoids the objects in these file tags to be ignored by compiler. + * + * For dynamic linking, this problem won't occur as everything is loaded by default. + * + * Use of this is optional as it will create an error when a file tag do not exist. + * An alternative solution is always ask user to enable --whole-archieve during static link. + * + * \begincode + * // in file objective_registry.cc + * DMLC_REGISTRY_ENABLE(MyObjective); + * DMLC_REGISTRY_LINK_TAG(regression_op); + * DMLC_REGISTRY_LINK_TAG(rank_op); + * + * // in file regression_op.cc + * // declare tag of this file. + * DMLC_REGISTRY_FILE_TAG(regression_op); + * DMLC_REGISTRY_REGISTER(MyObjective, logistic_reg, logistic_reg); + * // ... + * + * // in file rank_op.cc + * // declare tag of this file. + * DMLC_REGISTRY_FILE_TAG(rank_op); + * DMLC_REGISTRY_REGISTER(MyObjective, pairwiserank, pairwiserank); + * + * \endcode + * + * \param UniqueTag The unique tag used to represent. + * \sa DMLC_REGISTRY_ENABLE, DMLC_REGISTRY_FILE_TAG + */ +#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \ + int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \ + static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __(); +} // namespace dmlc +#endif // DMLC_REGISTRY_H_ diff --git a/nnvm/include/dmlc/thread_local.h b/nnvm/include/dmlc/thread_local.h new file mode 100644 index 0000000..d6596d6 --- /dev/null +++ b/nnvm/include/dmlc/thread_local.h @@ -0,0 +1,77 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file thread_local.h + * \brief Portable thread local storage. + */ +#ifndef DMLC_THREAD_LOCAL_H_ +#define DMLC_THREAD_LOCAL_H_ + +#include <mutex> +#include <memory> +#include <vector> + +namespace dmlc { + +// macro hanlding for threadlocal variables +#ifdef __GNUC__ + #define MX_TREAD_LOCAL __thread +#elif __STDC_VERSION__ >= 201112L + #define MX_TREAD_LOCAL _Thread_local +#elif defined(_MSC_VER) + #define MX_TREAD_LOCAL __declspec(thread) +#endif + +#ifndef MX_TREAD_LOCAL +#message("Warning: Threadlocal is not enabled"); +#endif + +/*! + * \brief A threadlocal store to store threadlocal variables. + * Will return a thread local singleton of type T + * \tparam T the type we like to store + */ +template<typename T> +class ThreadLocalStore { + public: + /*! \return get a thread local singleton */ + static T* Get() { + static MX_TREAD_LOCAL T* ptr = nullptr; + if (ptr == nullptr) { + ptr = new T(); + Singleton()->RegisterDelete(ptr); + } + return ptr; + } + + private: + /*! \brief constructor */ + ThreadLocalStore() {} + /*! \brief destructor */ + ~ThreadLocalStore() { + for (size_t i = 0; i < data_.size(); ++i) { + delete data_[i]; + } + } + /*! \return singleton of the store */ + static ThreadLocalStore<T> *Singleton() { + static ThreadLocalStore<T> inst; + return &inst; + } + /*! + * \brief register str for internal deletion + * \param str the string pointer + */ + void RegisterDelete(T *str) { + std::unique_lock<std::mutex> lock(mutex_); + data_.push_back(str); + lock.unlock(); + } + /*! \brief internal mutex */ + std::mutex mutex_; + /*!\brief internal data */ + std::vector<T*> data_; +}; + +} // namespace dmlc + +#endif // DMLC_THREAD_LOCAL_H_ diff --git a/nnvm/include/dmlc/timer.h b/nnvm/include/dmlc/timer.h new file mode 100644 index 0000000..c97059f --- /dev/null +++ b/nnvm/include/dmlc/timer.h @@ -0,0 +1,49 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file timer.h + * \brief cross platform timer for timing + * \author Tianqi Chen + */ +#ifndef DMLC_TIMER_H_ +#define DMLC_TIMER_H_ + +#include "base.h" + +#if DMLC_USE_CXX11 +#include <chrono> +#endif + +#include <time.h> +#ifdef __MACH__ +#include <mach/clock.h> +#include <mach/mach.h> +#endif +#include "./logging.h" + +namespace dmlc { +/*! + * \brief return time in seconds + */ +inline double GetTime(void) { + #if DMLC_USE_CXX11 + return std::chrono::duration<double>( + std::chrono::high_resolution_clock::now().time_since_epoch()).count(); + #elif defined __MACH__ + clock_serv_t cclock; + mach_timespec_t mts; + host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock); + CHECK(clock_get_time(cclock, &mts) == 0) << "failed to get time"; + mach_port_deallocate(mach_task_self(), cclock); + return static_cast<double>(mts.tv_sec) + static_cast<double>(mts.tv_nsec) * 1e-9; + #else + #if defined(__unix__) || defined(__linux__) + timespec ts; + CHECK(clock_gettime(CLOCK_REALTIME, &ts) == 0) << "failed to get time"; + return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9; + #else + return static_cast<double>(time(NULL)); + #endif + #endif +} +} // namespace dmlc +#endif // DMLC_TIMER_H_ diff --git a/nnvm/include/dmlc/type_traits.h b/nnvm/include/dmlc/type_traits.h new file mode 100644 index 0000000..73abfba --- /dev/null +++ b/nnvm/include/dmlc/type_traits.h @@ -0,0 +1,171 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file type_traits.h + * \brief type traits information header + */ +#ifndef DMLC_TYPE_TRAITS_H_ +#define DMLC_TYPE_TRAITS_H_ + +#include "./base.h" +#if DMLC_USE_CXX11 +#include <type_traits> +#endif +#include <string> + +namespace dmlc { +/*! + * \brief whether a type is pod type + * \tparam T the type to query + */ +template<typename T> +struct is_pod { +#if DMLC_USE_CXX11 + /*! \brief the value of the traits */ + static const bool value = std::is_pod<T>::value; +#else + /*! \brief the value of the traits */ + static const bool value = false; +#endif +}; + + +/*! + * \brief whether a type is integer type + * \tparam T the type to query + */ +template<typename T> +struct is_integral { +#if DMLC_USE_CXX11 + /*! \brief the value of the traits */ + static const bool value = std::is_integral<T>::value; +#else + /*! \brief the value of the traits */ + static const bool value = false; +#endif +}; + +/*! + * \brief whether a type is floating point type + * \tparam T the type to query + */ +template<typename T> +struct is_floating_point { +#if DMLC_USE_CXX11 + /*! \brief the value of the traits */ + static const bool value = std::is_floating_point<T>::value; +#else + /*! \brief the value of the traits */ + static const bool value = false; +#endif +}; + +/*! + * \brief whether a type is arithemetic type + * \tparam T the type to query + */ +template<typename T> +struct is_arithmetic { +#if DMLC_USE_CXX11 + /*! \brief the value of the traits */ + static const bool value = std::is_arithmetic<T>::value; +#else + /*! \brief the value of the traits */ + static const bool value = (dmlc::is_integral<T>::value || + dmlc::is_floating_point<T>::value); +#endif +}; + +/*! + * \brief the string representation of type name + * \tparam T the type to query + * \return a const string of typename. + */ +template<typename T> +inline const char* type_name() { + return ""; +} + +/*! + * \brief whether a type have save/load function + * \tparam T the type to query + */ +template<typename T> +struct has_saveload { + /*! \brief the value of the traits */ + static const bool value = false; +}; + +/*! + * \brief template to select type based on condition + * For example, IfThenElseType<true, int, float>::Type will give int + * \tparam cond the condition + * \tparam Then the typename to be returned if cond is true + * \tparam The typename to be returned if cond is false +*/ +template<bool cond, typename Then, typename Else> +struct IfThenElseType; + +/*! \brief macro to quickly declare traits information */ +#define DMLC_DECLARE_TRAITS(Trait, Type, Value) \ + template<> \ + struct Trait<Type> { \ + static const bool value = Value; \ + } + +/*! \brief macro to quickly declare traits information */ +#define DMLC_DECLARE_TYPE_NAME(Type, Name) \ + template<> \ + inline const char* type_name<Type>() { \ + return Name; \ + } + +//! \cond Doxygen_Suppress +// declare special traits when C++11 is not available +#if DMLC_USE_CXX11 == 0 +DMLC_DECLARE_TRAITS(is_pod, char, true); +DMLC_DECLARE_TRAITS(is_pod, int8_t, true); +DMLC_DECLARE_TRAITS(is_pod, int16_t, true); +DMLC_DECLARE_TRAITS(is_pod, int32_t, true); +DMLC_DECLARE_TRAITS(is_pod, int64_t, true); +DMLC_DECLARE_TRAITS(is_pod, uint8_t, true); +DMLC_DECLARE_TRAITS(is_pod, uint16_t, true); +DMLC_DECLARE_TRAITS(is_pod, uint32_t, true); +DMLC_DECLARE_TRAITS(is_pod, uint64_t, true); +DMLC_DECLARE_TRAITS(is_pod, float, true); +DMLC_DECLARE_TRAITS(is_pod, double, true); + +DMLC_DECLARE_TRAITS(is_integral, char, true); +DMLC_DECLARE_TRAITS(is_integral, int8_t, true); +DMLC_DECLARE_TRAITS(is_integral, int16_t, true); +DMLC_DECLARE_TRAITS(is_integral, int32_t, true); +DMLC_DECLARE_TRAITS(is_integral, int64_t, true); +DMLC_DECLARE_TRAITS(is_integral, uint8_t, true); +DMLC_DECLARE_TRAITS(is_integral, uint16_t, true); +DMLC_DECLARE_TRAITS(is_integral, uint32_t, true); +DMLC_DECLARE_TRAITS(is_integral, uint64_t, true); + +DMLC_DECLARE_TRAITS(is_floating_point, float, true); +DMLC_DECLARE_TRAITS(is_floating_point, double, true); + +#endif + +DMLC_DECLARE_TYPE_NAME(float, "float"); +DMLC_DECLARE_TYPE_NAME(double, "double"); +DMLC_DECLARE_TYPE_NAME(int, "int"); +DMLC_DECLARE_TYPE_NAME(uint32_t, "int (non-negative)"); +DMLC_DECLARE_TYPE_NAME(uint64_t, "long (non-negative)"); +DMLC_DECLARE_TYPE_NAME(std::string, "string"); +DMLC_DECLARE_TYPE_NAME(bool, "boolean"); + +template<typename Then, typename Else> +struct IfThenElseType<true, Then, Else> { + typedef Then Type; +}; + +template<typename Then, typename Else> +struct IfThenElseType<false, Then, Else> { + typedef Else Type; +}; +//! \endcond +} // namespace dmlc +#endif // DMLC_TYPE_TRAITS_H_ diff --git a/nnvm/include/nnvm/tuple.h b/nnvm/include/nnvm/tuple.h index 755f272..dbae458 100644 --- a/nnvm/include/nnvm/tuple.h +++ b/nnvm/include/nnvm/tuple.h @@ -9,6 +9,7 @@ #include <vector> #include <type_traits> #include <algorithm> +#include <utility> #include <iostream> #include "./base.h" diff --git a/nnvm/python/nnvm/ctypes/symbol.py b/nnvm/python/nnvm/ctypes/symbol.py index 3bd5e65..3f5cb4e 100644 --- a/nnvm/python/nnvm/ctypes/symbol.py +++ b/nnvm/python/nnvm/ctypes/symbol.py @@ -98,12 +98,12 @@ class SymbolBase(object): **kwargs The attributes to set """ - keys = _base.c_array(_ctypes.c_char_p, - [_base.c_str(key) for key in kwargs.keys()]) - vals = _base.c_array(_ctypes.c_char_p, - [_base.c_str(str(val)) for val in kwargs.values()]) - num_args = _base.nn_uint(len(kwargs)) - _check_call(_LIB.NNSymbolSetAttrs( + keys = c_array(ctypes.c_char_p, + [c_str(key) for key in kwargs.keys()]) + vals = c_array(ctypes.c_char_p, + [c_str(str(val)) for val in kwargs.values()]) + num_args = nn_uint(len(kwargs)) + check_call(_LIB.NNSymbolSetAttrs( self.handle, num_args, keys, vals)) diff --git a/nnvm/python/nnvm/libinfo.py b/nnvm/python/nnvm/libinfo.py index 6648470..e15d4d8 100644 --- a/nnvm/python/nnvm/libinfo.py +++ b/nnvm/python/nnvm/libinfo.py @@ -27,9 +27,9 @@ def find_lib_path(): elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None): dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")]) if os.name == 'nt': - dll_path = [os.path.join(p, 'libnnvm.dll') for p in dll_path] + dll_path = [os.path.join(p, 'libnnvm_example.dll') for p in dll_path] else: - dll_path = [os.path.join(p, 'libnnvm.so') for p in dll_path] + dll_path = [os.path.join(p, 'libnnvm_example.so') for p in dll_path] lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] if len(lib_path) == 0: raise RuntimeError('Cannot find the files.\n' + diff --git a/nnvm/python/setup.py b/nnvm/python/setup.py index 58b6e7e..74ada0c 100644 --- a/nnvm/python/setup.py +++ b/nnvm/python/setup.py @@ -1,29 +1,32 @@ import os import sys from distutils.core import setup -from Cython.Build import cythonize -from distutils.extension import Extension +def config_cython(): + try: + from Cython.Build import cythonize + from distutils.extension import Extension + if sys.version_info >= (3, 0): + subdir = "_cy3" + else: + subdir = "_cy2" + ret = [] + path = "nnvm/cython" -def config(): - if sys.version_info >= (3, 0): - subdir = "_cy3" - else: - subdir = "_cy2" - ret = [] - path = "nnvm/cython" - - for fn in os.listdir(path): - if not fn.endswith(".pyx"): - continue - ret.append(Extension( - "nnvm/%s/%s" % (subdir, fn[:-4]), - ["nnvm/cython/%s" % fn], - include_dirs=["../include/"], - language="c++")) - return ret + for fn in os.listdir(path): + if not fn.endswith(".pyx"): + continue + ret.append(Extension( + "nnvm/%s/%s" % (subdir, fn[:-4]), + ["nnvm/cython/%s" % fn], + include_dirs=["../include/"], + language="c++")) + return cythonize(ret) + except: + print("Cython is not installed, will compile without cython module") + return [] setup( name='nnvm', - ext_modules = cythonize(config()) + ext_modules = config_cython() ) diff --git a/nnvm/src/example/operator.cc b/nnvm/src/example/operator.cc deleted file mode 100644 index 3a80fef..0000000 --- a/nnvm/src/example/operator.cc +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) 2016 by Contributors -// This is an example on how we can register operator information to NNVM - -#include <nnvm/base.h> -#include <nnvm/op.h> -#include <nnvm/op_attr_types.h> -#include <nnvm/node.h> -#include <nnvm/graph_attr_types.h> -#include <utility> - -namespace myproject { - -using nnvm::FListInputNames; -using nnvm::FMutateInput; -using nnvm::FInferShape; -using nnvm::FInferType; -using nnvm::FInplaceOption; -using nnvm::NodeAttrs; -using nnvm::TShape; -using nnvm::array_view; - -// simply return the shape as same -inline bool SameShape(const NodeAttrs& attrs, - array_view<TShape*> ishape, - array_view<TShape*> oshape) { - if (ishape.size() == 0 || ishape[0]->ndim() == 0) return false; - for (TShape* pshape : oshape) { - *pshape = *ishape[0]; - } - for (TShape* pshape : ishape) { - *pshape = *ishape[0]; - } - return true; -} - -inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs) { - return {{0, 0}}; -} - -// simple demonstration of reshape. -NNVM_REGISTER_OP(reshape) -.describe("reshape source to target shape") -.set_num_inputs(1) -.set_attr_parser( - [](NodeAttrs* attrs) { - // parse attr parser to get target attribute - TShape target; - std::istringstream is(attrs->dict.at("target")); - CHECK(is >> target); - attrs->parsed = std::move(target); - }) -.attr<FInferShape>( - "FInferShape", [] (const NodeAttrs& attrs, - array_view<TShape*> ishape, - array_view<TShape*> oshape) { - // get parsed attribute - const TShape& target = nnvm::get<TShape>(attrs.parsed); - *oshape[0] = target; - if (ishape[0]->ndim() == 0) return false; - CHECK_EQ(ishape[0]->Size(), target.Size()) - << "Reshape op: source target shape mismatch"; - return true; - }) -.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0); - - -NNVM_REGISTER_OP(cast) -.describe("cast source type to target") -.set_num_inputs(1) -.set_attr_parser( - [](NodeAttrs* attrs) { - // parse attr parser to get target attribute - int dtype; - std::istringstream is(attrs->dict.at("dtype")); - CHECK(is >> dtype); - attrs->parsed = std::move(dtype); - }) -.attr<FInferShape>("FInferShape", SameShape) -.attr<FInferType>( - "FInferType", [](const NodeAttrs& attrs, - array_view<int*> itype, - array_view<int*> otype) { - *otype[0] = nnvm::get<int>(attrs.parsed); - return true; - }); - - -NNVM_REGISTER_OP(add) -.describe("add two data together") -.set_num_inputs(2) -.attr<FInferShape>("FInferShape", SameShape) -.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0); - -NNVM_REGISTER_OP(__add_symbol__) -.describe("Alias of add") -.set_num_inputs(2); - -NNVM_REGISTER_OP(exp) -.describe("take exponmential") -.set_num_inputs(1) -.attr("inplace_pair", std::make_pair(0, 0)) -.attr<FInferShape>("FInferShape", SameShape); - -NNVM_REGISTER_OP(cross_device_copy) -.describe("Copy data across device.") -.set_num_inputs(1) -.attr<FInferShape>("FInferShape", SameShape); - - -NNVM_REGISTER_OP(conv2d) -.describe("take conv of input") -.set_num_inputs(2) -.attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { - return std::vector<std::string>{"data", "weight"}; - }); - -NNVM_REGISTER_OP(add) -.attr<std::string>("nick_name", "plus"); - -NNVM_REGISTER_OP(assign) -.set_num_inputs(2) -.set_num_outputs(1) -.attr<FMutateInput>("FMutateInput", [](const NodeAttrs& attrs, uint32_t index) { - return index == 0; - }); - -} // namespace myproject diff --git a/nnvm/src/test_main.cc b/nnvm/src/test_main.cc deleted file mode 100644 index 2d82cfb..0000000 --- a/nnvm/src/test_main.cc +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) 2016 by Contributors -#include <nnvm/op.h> -#include <nnvm/graph.h> -#include <nnvm/tuple.h> -#include <nnvm/c_api.h> -#include <nnvm/graph_attr_types.h> -#include <nnvm/pass_functions.h> -#include <dmlc/timer.h> -#include <string> - -void test_speed() { - auto add = nnvm::Op::Get("add"); - double tstart = dmlc::GetTime(); - size_t rep = 1000; - size_t n = 1000; - std::unordered_map<std::string, const nnvm::Symbol*> tmp; - std::unordered_map<std::string, std::string> kwargs; - std::vector<const nnvm::Symbol*> vec{2}; - std::string name = "xx"; - for (size_t t = 0; t < rep; ++t) { - nnvm::Symbol s = nnvm::Symbol::CreateVariable("x"); - for (size_t i = 0; i < n; ++i) { - nnvm::Symbol nw = nnvm::Symbol::CreateFunctor(add, kwargs); - vec[0] = &s; - vec[1] =&s; - tmp.clear(); - nw.Compose(vec, tmp, name); - s = nw; - } - } - double tend = dmlc::GetTime(); - LOG(INFO) << "compose speed = " << n * rep / (tend - tstart) << " ops/sec"; -} - -void test_node_speed() { - using namespace nnvm; - auto add = nnvm::Op::Get("add"); - double tstart = dmlc::GetTime(); - size_t rep = 1000; - size_t n = 1000; - for (size_t t = 0; t < rep; ++t) { - nnvm::Symbol s = nnvm::Symbol::CreateVariable("x"); - for (size_t i = 0; i < n; ++i) { - auto xx = NodeEntry{Node::Create(), 0, 0}; - NodeEntry x = s.outputs[0]; - xx.node->op = add; - xx.node->inputs.emplace_back(x); - xx.node->inputs.emplace_back(x); - Symbol ss; - ss.outputs.push_back(xx); - s = ss; - } - } - double tend = dmlc::GetTime(); - LOG(INFO) << "test_node_speed speed = " << n * rep / (tend - tstart) << " ops/sec"; -} - -void test_api_speed() { - auto add = (void*)nnvm::Op::Get("add"); // NOLINT(*) - double tstart = dmlc::GetTime(); - size_t rep = 1000; - size_t n = 1000; - std::unordered_map<std::string, const nnvm::Symbol*> tmp; - std::vector<const nnvm::Symbol*> vec{2}; - std::string name = "xx"; - for (size_t t = 0; t < rep; ++t) { - SymbolHandle s; - NNSymbolCreateVariable("xx", &s); - for (size_t i = 0; i < n; ++i) { - SymbolHandle arg[2]; - SymbolHandle ss; - NNSymbolCreateAtomicSymbol(add, 0, nullptr, nullptr, &ss); - arg[0] = s; - arg[1] = s; - NNSymbolCompose(ss, "nn", 2, nullptr, arg); - s = ss; - } - } - double tend = dmlc::GetTime(); - LOG(INFO) << "API compose speed = " << n * rep / (tend - tstart) << " ops/sec"; -} - -int main() { - test_speed(); - test_node_speed(); - test_api_speed(); - return 0; -} -- libgit2 0.26.0