Commit 98a67d9b by Tianqi Chen

Change function def to Node ref for more flexiblity (#27)

* Remove warning in g++5

* Change function def to Node ref for more flexiblity
parent 6ffeae97
...@@ -15,12 +15,13 @@ using nnvm::FMutateInputs; ...@@ -15,12 +15,13 @@ using nnvm::FMutateInputs;
using nnvm::FInferShape; using nnvm::FInferShape;
using nnvm::FInferType; using nnvm::FInferType;
using nnvm::FInplaceOption; using nnvm::FInplaceOption;
using nnvm::Node;
using nnvm::NodeAttrs; using nnvm::NodeAttrs;
using nnvm::TShape; using nnvm::TShape;
using nnvm::array_view; using nnvm::array_view;
// simply return the shape as same // simply return the shape as same
inline bool SameShape(const NodeAttrs& attrs, inline bool SameShape(const Node& n,
std::vector<TShape> *ishape, std::vector<TShape> *ishape,
std::vector<TShape> *oshape) { std::vector<TShape> *oshape) {
if (ishape->size() == 0 || (*ishape)[0].ndim() == 0) return false; if (ishape->size() == 0 || (*ishape)[0].ndim() == 0) return false;
...@@ -33,7 +34,7 @@ inline bool SameShape(const NodeAttrs& attrs, ...@@ -33,7 +34,7 @@ inline bool SameShape(const NodeAttrs& attrs,
return true; return true;
} }
inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs) { inline std::vector<std::pair<int, int> > InplaceIn0Out0(const Node& n) {
return {{0, 0}}; return {{0, 0}};
} }
...@@ -50,11 +51,11 @@ NNVM_REGISTER_OP(reshape) ...@@ -50,11 +51,11 @@ NNVM_REGISTER_OP(reshape)
attrs->parsed = std::move(target); attrs->parsed = std::move(target);
}) })
.attr<FInferShape>( .attr<FInferShape>(
"FInferShape", [] (const NodeAttrs& attrs, "FInferShape", [] (const Node& n,
std::vector<TShape> *ishape, std::vector<TShape> *ishape,
std::vector<TShape> *oshape) { std::vector<TShape> *oshape) {
// get parsed attribute // get parsed attribute
const TShape& target = nnvm::get<TShape>(attrs.parsed); const TShape& target = nnvm::get<TShape>(n.attrs.parsed);
(*oshape)[0] = target; (*oshape)[0] = target;
if ((*ishape)[0].ndim() == 0) return false; if ((*ishape)[0].ndim() == 0) return false;
CHECK_EQ((*ishape)[0].Size(), target.Size()) CHECK_EQ((*ishape)[0].Size(), target.Size())
...@@ -77,10 +78,10 @@ NNVM_REGISTER_OP(cast) ...@@ -77,10 +78,10 @@ NNVM_REGISTER_OP(cast)
}) })
.attr<FInferShape>("FInferShape", SameShape) .attr<FInferShape>("FInferShape", SameShape)
.attr<FInferType>( .attr<FInferType>(
"FInferType", [](const NodeAttrs& attrs, "FInferType", [](const Node& n,
std::vector<int> *itype, std::vector<int> *itype,
std::vector<int> *otype) { std::vector<int> *otype) {
(*otype)[0] = nnvm::get<int>(attrs.parsed); (*otype)[0] = nnvm::get<int>(n.attrs.parsed);
return true; return true;
}); });
...@@ -109,7 +110,7 @@ NNVM_REGISTER_OP(cross_device_copy) ...@@ -109,7 +110,7 @@ NNVM_REGISTER_OP(cross_device_copy)
NNVM_REGISTER_OP(conv2d) NNVM_REGISTER_OP(conv2d)
.describe("take conv of input") .describe("take conv of input")
.set_num_inputs(2) .set_num_inputs(2)
.attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { .attr<FListInputNames>("FListInputNames", [](const Node& n) {
return std::vector<std::string>{"data", "weight"}; return std::vector<std::string>{"data", "weight"};
}); });
...@@ -119,7 +120,7 @@ NNVM_REGISTER_OP(add) ...@@ -119,7 +120,7 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP(assign) NNVM_REGISTER_OP(assign)
.set_num_inputs(2) .set_num_inputs(2)
.set_num_outputs(1) .set_num_outputs(1)
.attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) { .attr<FMutateInputs>("FMutateInputs", [](const Node& n) {
return std::vector<uint32_t>{0}; return std::vector<uint32_t>{0};
}); });
......
...@@ -58,6 +58,11 @@ ...@@ -58,6 +58,11 @@
__cplusplus >= 201103L || defined(_MSC_VER)) __cplusplus >= 201103L || defined(_MSC_VER))
#endif #endif
/*! \brief strict CXX11 support */
#ifndef DMLC_STRICT_CXX11
#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/// check if g++ is before 4.6 /// check if g++ is before 4.6
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__) #if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6 #if __GNUC__ == 4 && __GNUC_MINOR__ < 6
...@@ -69,6 +74,7 @@ ...@@ -69,6 +74,7 @@
#endif #endif
#endif #endif
/*! /*!
* \brief Enable std::thread related modules, * \brief Enable std::thread related modules,
* Used to disable some module in mingw compile. * Used to disable some module in mingw compile.
...@@ -82,6 +88,13 @@ ...@@ -82,6 +88,13 @@
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER)) #define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
#endif #endif
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define DMLC_ATTRIBUTE_UNUSED
#endif
/*! \brief helper macro to generate string concat */ /*! \brief helper macro to generate string concat */
#define DMLC_STR_CONCAT_(__x, __y) __x##__y #define DMLC_STR_CONCAT_(__x, __y) __x##__y
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y) #define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)
......
...@@ -25,7 +25,9 @@ ...@@ -25,7 +25,9 @@
#include <typeindex> #include <typeindex>
#include <typeinfo> #include <typeinfo>
#include <unordered_map> #include <unordered_map>
#if DMLC_STRICT_CXX11
#include "./any.h" #include "./any.h"
#endif // DMLC_STRICT_CXX11
#endif // DMLC_USE_CXX11 #endif // DMLC_USE_CXX11
namespace dmlc { namespace dmlc {
...@@ -320,7 +322,8 @@ class JSONObjectReadHelper { ...@@ -320,7 +322,8 @@ class JSONObjectReadHelper {
}; };
#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \ #define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \
static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __ static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \
__make_AnyJSONType ## _ ## KeyName ## __
/*! /*!
* \def DMLC_JSON_ENABLE_ANY * \def DMLC_JSON_ENABLE_ANY
...@@ -475,7 +478,7 @@ struct Handler { ...@@ -475,7 +478,7 @@ struct Handler {
} }
}; };
#if DMLC_USE_CXX11 #if DMLC_STRICT_CXX11
// Manager to store json serialization strategy. // Manager to store json serialization strategy.
class AnyJSONManager { class AnyJSONManager {
public: public:
...@@ -561,7 +564,7 @@ struct Handler<any> { ...@@ -561,7 +564,7 @@ struct Handler<any> {
CHECK(!reader->NextArrayItem()) << "invalid any json format"; CHECK(!reader->NextArrayItem()) << "invalid any json format";
} }
}; };
#endif // DMLC_USE_CXX11 #endif // DMLC_STRICT_CXX11
} // namespace json } // namespace json
......
...@@ -251,7 +251,8 @@ struct Parameter { ...@@ -251,7 +251,8 @@ struct Parameter {
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \ static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
return &inst.manager; \ return &inst.manager; \
} \ } \
static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \ static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
__make__ ## PType ## ParamManager__ = \
(*PType::__MANAGER__()) \ (*PType::__MANAGER__()) \
//! \endcond //! \endcond
......
...@@ -216,7 +216,7 @@ class FunctionRegEntryBase { ...@@ -216,7 +216,7 @@ class FunctionRegEntryBase {
* \sa FactoryRegistryEntryBase * \sa FactoryRegistryEntryBase
*/ */
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \ #define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
static EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \ static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \ ::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \
/*! /*!
...@@ -272,6 +272,7 @@ class FunctionRegEntryBase { ...@@ -272,6 +272,7 @@ class FunctionRegEntryBase {
*/ */
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \ #define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \ int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __(); static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \
__dmlc_registry_file_tag_ ## UniqueTag ## __();
} // namespace dmlc } // namespace dmlc
#endif // DMLC_REGISTRY_H_ #endif // DMLC_REGISTRY_H_
...@@ -17,7 +17,6 @@ namespace nnvm { ...@@ -17,7 +17,6 @@ namespace nnvm {
// Forward declare node. // Forward declare node.
class Node; class Node;
/*! /*!
* \brief we always used NodePtr for a reference pointer * \brief we always used NodePtr for a reference pointer
* to the node, so this alias can be changed in case. * to the node, so this alias can be changed in case.
...@@ -48,8 +47,6 @@ struct NodeEntry { ...@@ -48,8 +47,6 @@ struct NodeEntry {
struct NodeAttrs { struct NodeAttrs {
/*! \brief name of the node */ /*! \brief name of the node */
std::string name; std::string name;
/*! \brief Vector representation of positional attributes */
std::vector<double> scalars;
/*! \brief The dictionary representation of attributes */ /*! \brief The dictionary representation of attributes */
std::unordered_map<std::string, std::string> dict; std::unordered_map<std::string, std::string> dict;
/*! /*!
...@@ -108,7 +105,7 @@ inline uint32_t Node::num_outputs() const { ...@@ -108,7 +105,7 @@ inline uint32_t Node::num_outputs() const {
if (this->op->get_num_outputs == nullptr) { if (this->op->get_num_outputs == nullptr) {
return this->op->num_outputs; return this->op->num_outputs;
} else { } else {
return this->op->get_num_outputs(this->attrs); return this->op->get_num_outputs(*this);
} }
} }
...@@ -117,7 +114,7 @@ inline uint32_t Node::num_inputs() const { ...@@ -117,7 +114,7 @@ inline uint32_t Node::num_inputs() const {
if (this->op->get_num_inputs == nullptr) { if (this->op->get_num_inputs == nullptr) {
return this->op->num_inputs; return this->op->num_inputs;
} else { } else {
return this->op->get_num_inputs(this->attrs); return this->op->get_num_inputs(*this);
} }
} }
......
...@@ -102,16 +102,16 @@ class Op { ...@@ -102,16 +102,16 @@ class Op {
uint32_t num_outputs = 1; uint32_t num_outputs = 1;
/*! /*!
* \brief get number of outputs given information about the node. * \brief get number of outputs given information about the node.
* \param attrs The attribute of the node * \param n The node
* \return number of outputs. * \return number of outputs.
*/ */
std::function<uint32_t(const NodeAttrs& attrs)> get_num_outputs = nullptr; std::function<uint32_t(const Node& n)> get_num_outputs = nullptr;
/*! /*!
* \brief get number of inputs given information about the node. * \brief get number of inputs given information about the node.
* \param attrs The attribute of the node * \param n The node
* \return number of inputs * \return number of inputs
*/ */
std::function<uint32_t(const NodeAttrs& attrs)> get_num_inputs = nullptr; std::function<uint32_t(const Node& n)> get_num_inputs = nullptr;
/*! /*!
* \brief Attribute parser to parse the NodeAttrs information. * \brief Attribute parser to parse the NodeAttrs information.
* *
...@@ -136,11 +136,11 @@ class Op { ...@@ -136,11 +136,11 @@ class Op {
* attrs->parsed = std::move(param); * attrs->parsed = std::move(param);
* } * }
* // The other function that can utilize the parsed result. * // The other function that can utilize the parsed result.
* TShape SumInferShape(const NodeAttrs& attrs, * TShape SumInferShape(const NodePtr& ptr,
* const std::vector<TShape>& ishapes) { * const std::vector<TShape>& ishapes) {
* // we can use the parsed version of param * // we can use the parsed version of param
* // without repeatively parsing the parameter * // without repeatively parsing the parameter
* const SumParam& param = nnvm::get<SumParam>(attrs.parsed); * const SumParam& param = nnvm::get<SumParam>(ptr->attrs.parsed);
* } * }
* \endcode * \endcode
*/ */
...@@ -180,7 +180,7 @@ class Op { ...@@ -180,7 +180,7 @@ class Op {
* \param fn The function to be set. * \param fn The function to be set.
* \return reference to self. * \return reference to self.
*/ */
inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*) inline Op& set_num_inputs(std::function<uint32_t (const Node& n)> fn); // NOLINT(*)
/*! /*!
* \brief Set the num_outputs * \brief Set the num_outputs
* \param n The number of outputs to be set. * \param n The number of outputs to be set.
...@@ -192,7 +192,7 @@ class Op { ...@@ -192,7 +192,7 @@ class Op {
* \param fn The function to be set. * \param fn The function to be set.
* \return reference to self. * \return reference to self.
*/ */
inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*) inline Op& set_num_outputs(std::function<uint32_t (const Node& n)> fn); // NOLINT(*)
/*! /*!
* \brief Set the attr_parser function. * \brief Set the attr_parser function.
* \param fn The number of outputs to be set. * \param fn The number of outputs to be set.
...@@ -279,10 +279,8 @@ class OpMap { ...@@ -279,10 +279,8 @@ class OpMap {
}; };
// internal macros to make // internal macros to make
#define NNVM_STR_CONCAT_(__x, __y) __x##__y
#define NNVM_STR_CONCAT(__x, __y) NNVM_STR_CONCAT_(__x, __y)
#define NNVM_REGISTER_VAR_DEF(OpName) \ #define NNVM_REGISTER_VAR_DEF(OpName) \
static ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
/*! /*!
* \def NNVM_REGISTER_OP * \def NNVM_REGISTER_OP
...@@ -300,7 +298,7 @@ class OpMap { ...@@ -300,7 +298,7 @@ class OpMap {
* \endcode * \endcode
*/ */
#define NNVM_REGISTER_OP(OpName) \ #define NNVM_REGISTER_OP(OpName) \
NNVM_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName) ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
// implementations of template functions after this. // implementations of template functions after this.
...@@ -377,7 +375,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*) ...@@ -377,7 +375,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
return *this; return *this;
} }
inline Op& Op::set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*) inline Op& Op::set_num_inputs(std::function<uint32_t (const Node& n)> fn) { // NOLINT(*)
this->get_num_inputs = fn; this->get_num_inputs = fn;
return *this; return *this;
} }
...@@ -387,7 +385,7 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*) ...@@ -387,7 +385,7 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
return *this; return *this;
} }
inline Op& Op::set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*) inline Op& Op::set_num_outputs(std::function<uint32_t (const Node& n)> fn) { // NOLINT(*)
this->get_num_outputs = fn; this->get_num_outputs = fn;
return *this; return *this;
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <functional> #include <functional>
#include "./base.h" #include "./base.h"
#include "./tuple.h" #include "./tuple.h"
#include "./node.h"
namespace nnvm { namespace nnvm {
...@@ -21,34 +22,34 @@ namespace nnvm { ...@@ -21,34 +22,34 @@ namespace nnvm {
/*! /*!
* \brief Return list of input arguments names of each operator. * \brief Return list of input arguments names of each operator.
* *
* \param attrs The attributes of the node. * \param n The node.
* \return list of inputs * \return list of inputs
* \note Register under "FListInputNames", default return {"data"}. * \note Register under "FListInputNames", default return {"data"}.
* *
* FListInputNames enables automatic variable creation for missing arguments. * FListInputNames enables automatic variable creation for missing arguments.
*/ */
using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>; using FListInputNames = std::function<std::vector<std::string> (const Node& n)>;
/*! /*!
* \brief Return list of output arguments names of each operator. * \brief Return list of output arguments names of each operator.
* *
* \param attrs The attributes of the node. * \param n The node.
* \return list of inputs * \return list of inputs
* \note Register under "FListOutputNames", default return {"outputs"}. * \note Register under "FListOutputNames", default return {"outputs"}.
* *
* FListOutputNames customized naming for operator outputs. * FListOutputNames customized naming for operator outputs.
*/ */
using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>; using FListOutputNames = std::function<std::vector<std::string> (const Node& n)>;
/*! /*!
* \brief Check whether operator will mutate k-th input. * \brief Check whether operator will mutate k-th input.
* \param attrs The attributes of the node. * \param n The node.
* \return list of input indices it mutates. * \return list of input indices it mutates.
* *
* \note Register under "FMutateInputs", default return false * \note Register under "FMutateInputs", default return false
* FMutateInputs enables mutation order handling correctly. * FMutateInputs enables mutation order handling correctly.
*/ */
using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attrs)>; using FMutateInputs = std::function<std::vector<uint32_t> (const Node& n)>;
/*! /*!
* \brief Inference function of certain type. * \brief Inference function of certain type.
...@@ -56,9 +57,9 @@ using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attr ...@@ -56,9 +57,9 @@ using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attr
* \return whether all attributes are inferred. * \return whether all attributes are inferred.
*/ */
template<typename AttrType> template<typename AttrType>
using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs, using FInferNodeEntryAttr = std::function<bool (const Node& n,
std::vector<AttrType> *in_attrs, std::vector<AttrType> *in_ptr,
std::vector<AttrType> *out_attrs)>; std::vector<AttrType> *out_ptr)>;
/*! /*!
* \brief Shape inference function. * \brief Shape inference function.
* Update the shapes given the input shape information. * Update the shapes given the input shape information.
...@@ -96,7 +97,7 @@ using TIsBackwardOp = bool; ...@@ -96,7 +97,7 @@ using TIsBackwardOp = bool;
/*! /*!
* \brief Get possible inplace options. * \brief Get possible inplace options.
* This function enables optimization to reuse memory of inputs in output. * This function enables optimization to reuse memory of inputs in output.
* \param attrs The attributes of the node * \param n The node
* \param in_data The input data. * \param in_data The input data.
* \param out_data The output data. * \param out_data The output data.
* \return list of pair of that maps input->output, * \return list of pair of that maps input->output,
...@@ -105,7 +106,20 @@ using TIsBackwardOp = bool; ...@@ -105,7 +106,20 @@ using TIsBackwardOp = bool;
* \note Register under "FInplaceOption", by default no inplace can happen. * \note Register under "FInplaceOption", by default no inplace can happen.
*/ */
using FInplaceOption = std::function< using FInplaceOption = std::function<
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>; std::vector<std::pair<int, int> > (const Node& n)>;
/*!
* \brief Get the gradient node of the op node
* This function generates the backward graph of the node
* \param nodeptr The node to take gradient
* \param out_grads Gradient of current node's outputs
* \return gradients of the inputs
*
* \note Register under "FGradient"
*/
using FGradient = std::function<std::vector<NodeEntry>(
const NodePtr& nodeptr,
const std::vector<NodeEntry>& out_grads)>;
} // namespace nnvm } // namespace nnvm
......
...@@ -23,7 +23,7 @@ namespace nnvm { ...@@ -23,7 +23,7 @@ namespace nnvm {
* \param src The graph to be transformed. * \param src The graph to be transformed.
* \return The generated graph. * \return The generated graph.
*/ */
typedef std::function<Graph (Graph src)> PassFunction; using PassFunction = std::function<Graph (Graph src)>;
/*! /*!
* \brief Apply a series of pass transformations on g. * \brief Apply a series of pass transformations on g.
......
...@@ -11,9 +11,11 @@ ...@@ -11,9 +11,11 @@
#define NNVM_PASS_FUNCTIONS_H_ #define NNVM_PASS_FUNCTIONS_H_
#include <string> #include <string>
#include <vector>
#include <memory> #include <memory>
#include "./base.h" #include "./base.h"
#include "./pass.h" #include "./pass.h"
#include "./node.h"
#include "./graph_attr_types.h" #include "./graph_attr_types.h"
namespace nnvm { namespace nnvm {
...@@ -109,6 +111,33 @@ inline Graph PlaceDevice(Graph graph, ...@@ -109,6 +111,33 @@ inline Graph PlaceDevice(Graph graph,
return ApplyPass(std::move(graph), {"PlaceDevice"}); return ApplyPass(std::move(graph), {"PlaceDevice"});
} }
/*!
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
* \param graph source graph
* \param ys The entries we want to take gradient from.
* \param xs The input we want to
* \param aggregate_fun aggregation function applied to aggregate the inputs
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \return A new graph, whose outputs corresponds to inputs of xs.
*/
inline Graph Gradient(
Graph graph,
std::vector<NodeEntry> ys,
std::vector<NodeEntry> xs,
std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr,
std::function<int(const Node& node)> mirror_fun = nullptr) {
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
if (aggregate_fun != nullptr) {
graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
}
if (mirror_fun != nullptr) {
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
}
return ApplyPass(std::move(graph), {"Gradient"});
}
} // namespace pass } // namespace pass
} // namespace nnvm } // namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_ #endif // NNVM_PASS_FUNCTIONS_H_
...@@ -68,7 +68,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { ...@@ -68,7 +68,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]); iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
if (nodes_[nid].source->op != nullptr && if (nodes_[nid].source->op != nullptr &&
fmutate_inputs.count(nodes_[nid].source->op)) { fmutate_inputs.count(nodes_[nid].source->op)) {
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](nodes_[nid].source->attrs)) { for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](*(nodes_[nid].source))) {
mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id); mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
} }
} }
......
...@@ -38,7 +38,7 @@ inline void UpdateNodeVersion(Node *n) { ...@@ -38,7 +38,7 @@ inline void UpdateNodeVersion(Node *n) {
} }
} }
if (fmutate_inputs.count(n->op) != 0) { if (fmutate_inputs.count(n->op) != 0) {
for (uint32_t i : fmutate_inputs[n->op](n->attrs)) { for (uint32_t i : fmutate_inputs[n->op](*n)) {
NodeEntry& e = n->inputs[i]; NodeEntry& e = n->inputs[i];
CHECK(e.node->is_variable()) CHECK(e.node->is_variable())
<< "Mutation target can only be Variable"; << "Mutation target can only be Variable";
...@@ -197,7 +197,7 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const { ...@@ -197,7 +197,7 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
if (node->is_variable()) { if (node->is_variable()) {
vlist.push_back(node.get()); vlist.push_back(node.get());
} else if (fmutate_inputs.count(node->op)) { } else if (fmutate_inputs.count(node->op)) {
for (uint32_t i : fmutate_inputs[node->op](node->attrs)){ for (uint32_t i : fmutate_inputs[node->op](*node)){
mutable_set.insert(node->inputs[i].node.get()); mutable_set.insert(node->inputs[i].node.get());
} }
} }
...@@ -223,7 +223,7 @@ std::vector<std::string> Symbol::ListOutputNames() const { ...@@ -223,7 +223,7 @@ std::vector<std::string> Symbol::ListOutputNames() const {
std::string rname; std::string rname;
FListOutputNames fn = flist_ouputs.get(head.node->op, nullptr); FListOutputNames fn = flist_ouputs.get(head.node->op, nullptr);
if (fn != nullptr) { if (fn != nullptr) {
rname = fn(head.node->attrs)[head.index]; rname = fn(*head.node)[head.index];
} else { } else {
rname = "output"; rname = "output";
if (head.node->num_outputs() != 1) { if (head.node->num_outputs() != 1) {
...@@ -279,7 +279,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args, ...@@ -279,7 +279,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
// switch to keyword argument matching // switch to keyword argument matching
if (args.size() != n_req) { if (args.size() != n_req) {
FListInputNames fn = flist_inputs.get(n->op, nullptr); FListInputNames fn = flist_inputs.get(n->op, nullptr);
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs); auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(*n);
if (arg_names.size() != n_req) { if (arg_names.size() != n_req) {
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op->name; LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op->name;
} }
......
...@@ -75,7 +75,7 @@ Graph InferAttr(Graph &&ret, ...@@ -75,7 +75,7 @@ Graph InferAttr(Graph &&ret,
oshape[i] = rshape[idx.entry_id(nid, i)]; oshape[i] = rshape[idx.entry_id(nid, i)];
} }
num_unknown += num_unknown +=
!(finfer_shape[inode.source->op](inode.source->attrs, &ishape, &oshape)); !(finfer_shape[inode.source->op](*inode.source, &ishape, &oshape));
for (uint32_t i = 0; i < num_inputs; ++i) { for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
} }
......
...@@ -44,7 +44,7 @@ Graph OrderMutation(const Graph& src) { ...@@ -44,7 +44,7 @@ Graph OrderMutation(const Graph& src) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs; std::vector<uint32_t> mutate_inputs;
if (!n->is_variable() && fmutate_inputs.count(n->op)) { if (!n->is_variable() && fmutate_inputs.count(n->op)) {
mutate_inputs = fmutate_inputs[n->op](n->attrs); mutate_inputs = fmutate_inputs[n->op](*n);
} }
std::sort(mutate_inputs.begin(), mutate_inputs.end()); std::sort(mutate_inputs.begin(), mutate_inputs.end());
...@@ -102,7 +102,7 @@ Graph OrderMutation(const Graph& src) { ...@@ -102,7 +102,7 @@ Graph OrderMutation(const Graph& src) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs; std::vector<uint32_t> mutate_inputs;
if (fmutate_inputs.count(kv.first->op)) { if (fmutate_inputs.count(kv.first->op)) {
mutate_inputs = fmutate_inputs[kv.first->op](kv.first->attrs); mutate_inputs = fmutate_inputs[kv.first->op](*kv.first);
} }
std::sort(mutate_inputs.begin(), mutate_inputs.end()); std::sort(mutate_inputs.begin(), mutate_inputs.end());
......
...@@ -169,7 +169,7 @@ Graph PlanMemory(Graph ret) { ...@@ -169,7 +169,7 @@ Graph PlanMemory(Graph ret) {
if (inode.source->is_variable()) continue; if (inode.source->is_variable()) continue;
// check inplace option // check inplace option
if (finplace_option.count(inode.source->op) != 0) { if (finplace_option.count(inode.source->op) != 0) {
auto inplace_pairs = finplace_option[inode.source->op](inode.source->attrs); auto inplace_pairs = finplace_option[inode.source->op](*inode.source);
for (auto& kv : inplace_pairs) { for (auto& kv : inplace_pairs) {
uint32_t eid_out = idx.entry_id(nid, kv.second); uint32_t eid_out = idx.entry_id(nid, kv.second);
uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]); uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment