Commit 98a67d9b by Tianqi Chen

Change function def to Node ref for more flexiblity (#27)

* Remove warning in g++5

* Change function def to Node ref for more flexiblity
parent 6ffeae97
......@@ -15,12 +15,13 @@ using nnvm::FMutateInputs;
using nnvm::FInferShape;
using nnvm::FInferType;
using nnvm::FInplaceOption;
using nnvm::Node;
using nnvm::NodeAttrs;
using nnvm::TShape;
using nnvm::array_view;
// simply return the shape as same
inline bool SameShape(const NodeAttrs& attrs,
inline bool SameShape(const Node& n,
std::vector<TShape> *ishape,
std::vector<TShape> *oshape) {
if (ishape->size() == 0 || (*ishape)[0].ndim() == 0) return false;
......@@ -33,7 +34,7 @@ inline bool SameShape(const NodeAttrs& attrs,
return true;
}
inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs) {
inline std::vector<std::pair<int, int> > InplaceIn0Out0(const Node& n) {
return {{0, 0}};
}
......@@ -50,11 +51,11 @@ NNVM_REGISTER_OP(reshape)
attrs->parsed = std::move(target);
})
.attr<FInferShape>(
"FInferShape", [] (const NodeAttrs& attrs,
"FInferShape", [] (const Node& n,
std::vector<TShape> *ishape,
std::vector<TShape> *oshape) {
// get parsed attribute
const TShape& target = nnvm::get<TShape>(attrs.parsed);
const TShape& target = nnvm::get<TShape>(n.attrs.parsed);
(*oshape)[0] = target;
if ((*ishape)[0].ndim() == 0) return false;
CHECK_EQ((*ishape)[0].Size(), target.Size())
......@@ -77,10 +78,10 @@ NNVM_REGISTER_OP(cast)
})
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInferType>(
"FInferType", [](const NodeAttrs& attrs,
"FInferType", [](const Node& n,
std::vector<int> *itype,
std::vector<int> *otype) {
(*otype)[0] = nnvm::get<int>(attrs.parsed);
(*otype)[0] = nnvm::get<int>(n.attrs.parsed);
return true;
});
......@@ -109,7 +110,7 @@ NNVM_REGISTER_OP(cross_device_copy)
NNVM_REGISTER_OP(conv2d)
.describe("take conv of input")
.set_num_inputs(2)
.attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
.attr<FListInputNames>("FListInputNames", [](const Node& n) {
return std::vector<std::string>{"data", "weight"};
});
......@@ -119,7 +120,7 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP(assign)
.set_num_inputs(2)
.set_num_outputs(1)
.attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) {
.attr<FMutateInputs>("FMutateInputs", [](const Node& n) {
return std::vector<uint32_t>{0};
});
......
......@@ -58,6 +58,11 @@
__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/*! \brief strict CXX11 support */
#ifndef DMLC_STRICT_CXX11
#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/// check if g++ is before 4.6
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6
......@@ -69,6 +74,7 @@
#endif
#endif
/*!
* \brief Enable std::thread related modules,
* Used to disable some module in mingw compile.
......@@ -82,6 +88,13 @@
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define DMLC_ATTRIBUTE_UNUSED
#endif
/*! \brief helper macro to generate string concat */
#define DMLC_STR_CONCAT_(__x, __y) __x##__y
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)
......
......@@ -25,7 +25,9 @@
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#if DMLC_STRICT_CXX11
#include "./any.h"
#endif // DMLC_STRICT_CXX11
#endif // DMLC_USE_CXX11
namespace dmlc {
......@@ -320,7 +322,8 @@ class JSONObjectReadHelper {
};
#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \
static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __
static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \
__make_AnyJSONType ## _ ## KeyName ## __
/*!
* \def DMLC_JSON_ENABLE_ANY
......@@ -475,7 +478,7 @@ struct Handler {
}
};
#if DMLC_USE_CXX11
#if DMLC_STRICT_CXX11
// Manager to store json serialization strategy.
class AnyJSONManager {
public:
......@@ -561,7 +564,7 @@ struct Handler<any> {
CHECK(!reader->NextArrayItem()) << "invalid any json format";
}
};
#endif // DMLC_USE_CXX11
#endif // DMLC_STRICT_CXX11
} // namespace json
......
......@@ -251,7 +251,8 @@ struct Parameter {
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
return &inst.manager; \
} \
static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \
static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
__make__ ## PType ## ParamManager__ = \
(*PType::__MANAGER__()) \
//! \endcond
......
......@@ -216,7 +216,7 @@ class FunctionRegEntryBase {
* \sa FactoryRegistryEntryBase
*/
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
static EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \
/*!
......@@ -272,6 +272,7 @@ class FunctionRegEntryBase {
*/
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __();
static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \
__dmlc_registry_file_tag_ ## UniqueTag ## __();
} // namespace dmlc
#endif // DMLC_REGISTRY_H_
......@@ -17,7 +17,6 @@ namespace nnvm {
// Forward declare node.
class Node;
/*!
* \brief we always used NodePtr for a reference pointer
* to the node, so this alias can be changed in case.
......@@ -48,8 +47,6 @@ struct NodeEntry {
struct NodeAttrs {
/*! \brief name of the node */
std::string name;
/*! \brief Vector representation of positional attributes */
std::vector<double> scalars;
/*! \brief The dictionary representation of attributes */
std::unordered_map<std::string, std::string> dict;
/*!
......@@ -108,7 +105,7 @@ inline uint32_t Node::num_outputs() const {
if (this->op->get_num_outputs == nullptr) {
return this->op->num_outputs;
} else {
return this->op->get_num_outputs(this->attrs);
return this->op->get_num_outputs(*this);
}
}
......@@ -117,7 +114,7 @@ inline uint32_t Node::num_inputs() const {
if (this->op->get_num_inputs == nullptr) {
return this->op->num_inputs;
} else {
return this->op->get_num_inputs(this->attrs);
return this->op->get_num_inputs(*this);
}
}
......
......@@ -102,16 +102,16 @@ class Op {
uint32_t num_outputs = 1;
/*!
* \brief get number of outputs given information about the node.
* \param attrs The attribute of the node
* \param n The node
* \return number of outputs.
*/
std::function<uint32_t(const NodeAttrs& attrs)> get_num_outputs = nullptr;
std::function<uint32_t(const Node& n)> get_num_outputs = nullptr;
/*!
* \brief get number of inputs given information about the node.
* \param attrs The attribute of the node
* \param n The node
* \return number of inputs
*/
std::function<uint32_t(const NodeAttrs& attrs)> get_num_inputs = nullptr;
std::function<uint32_t(const Node& n)> get_num_inputs = nullptr;
/*!
* \brief Attribute parser to parse the NodeAttrs information.
*
......@@ -136,11 +136,11 @@ class Op {
* attrs->parsed = std::move(param);
* }
* // The other function that can utilize the parsed result.
* TShape SumInferShape(const NodeAttrs& attrs,
* TShape SumInferShape(const NodePtr& ptr,
* const std::vector<TShape>& ishapes) {
* // we can use the parsed version of param
* // without repeatively parsing the parameter
* const SumParam& param = nnvm::get<SumParam>(attrs.parsed);
* const SumParam& param = nnvm::get<SumParam>(ptr->attrs.parsed);
* }
* \endcode
*/
......@@ -180,7 +180,7 @@ class Op {
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
inline Op& set_num_inputs(std::function<uint32_t (const Node& n)> fn); // NOLINT(*)
/*!
* \brief Set the num_outputs
* \param n The number of outputs to be set.
......@@ -192,7 +192,7 @@ class Op {
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
inline Op& set_num_outputs(std::function<uint32_t (const Node& n)> fn); // NOLINT(*)
/*!
* \brief Set the attr_parser function.
* \param fn The number of outputs to be set.
......@@ -279,10 +279,8 @@ class OpMap {
};
// internal macros to make
#define NNVM_STR_CONCAT_(__x, __y) __x##__y
#define NNVM_STR_CONCAT(__x, __y) NNVM_STR_CONCAT_(__x, __y)
#define NNVM_REGISTER_VAR_DEF(OpName) \
static ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
/*!
* \def NNVM_REGISTER_OP
......@@ -300,7 +298,7 @@ class OpMap {
* \endcode
*/
#define NNVM_REGISTER_OP(OpName) \
NNVM_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
// implementations of template functions after this.
......@@ -377,7 +375,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
return *this;
}
inline Op& Op::set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
inline Op& Op::set_num_inputs(std::function<uint32_t (const Node& n)> fn) { // NOLINT(*)
this->get_num_inputs = fn;
return *this;
}
......@@ -387,7 +385,7 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
return *this;
}
inline Op& Op::set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
inline Op& Op::set_num_outputs(std::function<uint32_t (const Node& n)> fn) { // NOLINT(*)
this->get_num_outputs = fn;
return *this;
}
......
......@@ -12,6 +12,7 @@
#include <functional>
#include "./base.h"
#include "./tuple.h"
#include "./node.h"
namespace nnvm {
......@@ -21,34 +22,34 @@ namespace nnvm {
/*!
* \brief Return list of input arguments names of each operator.
*
* \param attrs The attributes of the node.
* \param n The node.
* \return list of inputs
* \note Register under "FListInputNames", default return {"data"}.
*
* FListInputNames enables automatic variable creation for missing arguments.
*/
using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
using FListInputNames = std::function<std::vector<std::string> (const Node& n)>;
/*!
* \brief Return list of output arguments names of each operator.
*
* \param attrs The attributes of the node.
* \param n The node.
* \return list of inputs
* \note Register under "FListOutputNames", default return {"outputs"}.
*
* FListOutputNames customized naming for operator outputs.
*/
using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
using FListOutputNames = std::function<std::vector<std::string> (const Node& n)>;
/*!
* \brief Check whether operator will mutate k-th input.
* \param attrs The attributes of the node.
* \param n The node.
* \return list of input indices it mutates.
*
* \note Register under "FMutateInputs", default return false
* FMutateInputs enables mutation order handling correctly.
*/
using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attrs)>;
using FMutateInputs = std::function<std::vector<uint32_t> (const Node& n)>;
/*!
* \brief Inference function of certain type.
......@@ -56,9 +57,9 @@ using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attr
* \return whether all attributes are inferred.
*/
template<typename AttrType>
using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs)>;
using FInferNodeEntryAttr = std::function<bool (const Node& n,
std::vector<AttrType> *in_ptr,
std::vector<AttrType> *out_ptr)>;
/*!
* \brief Shape inference function.
* Update the shapes given the input shape information.
......@@ -96,7 +97,7 @@ using TIsBackwardOp = bool;
/*!
* \brief Get possible inplace options.
* This function enables optimization to reuse memory of inputs in output.
* \param attrs The attributes of the node
* \param n The node
* \param in_data The input data.
* \param out_data The output data.
* \return list of pair of that maps input->output,
......@@ -105,7 +106,20 @@ using TIsBackwardOp = bool;
* \note Register under "FInplaceOption", by default no inplace can happen.
*/
using FInplaceOption = std::function<
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;
std::vector<std::pair<int, int> > (const Node& n)>;
/*!
* \brief Get the gradient node of the op node
* This function generates the backward graph of the node
* \param nodeptr The node to take gradient
* \param out_grads Gradient of current node's outputs
* \return gradients of the inputs
*
* \note Register under "FGradient"
*/
using FGradient = std::function<std::vector<NodeEntry>(
const NodePtr& nodeptr,
const std::vector<NodeEntry>& out_grads)>;
} // namespace nnvm
......
......@@ -23,7 +23,7 @@ namespace nnvm {
* \param src The graph to be transformed.
* \return The generated graph.
*/
typedef std::function<Graph (Graph src)> PassFunction;
using PassFunction = std::function<Graph (Graph src)>;
/*!
* \brief Apply a series of pass transformations on g.
......
......@@ -11,9 +11,11 @@
#define NNVM_PASS_FUNCTIONS_H_
#include <string>
#include <vector>
#include <memory>
#include "./base.h"
#include "./pass.h"
#include "./node.h"
#include "./graph_attr_types.h"
namespace nnvm {
......@@ -109,6 +111,33 @@ inline Graph PlaceDevice(Graph graph,
return ApplyPass(std::move(graph), {"PlaceDevice"});
}
/*!
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
* \param graph source graph
* \param ys The entries we want to take gradient from.
* \param xs The input we want to
* \param aggregate_fun aggregation function applied to aggregate the inputs
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \return A new graph, whose outputs corresponds to inputs of xs.
*/
inline Graph Gradient(
Graph graph,
std::vector<NodeEntry> ys,
std::vector<NodeEntry> xs,
std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr,
std::function<int(const Node& node)> mirror_fun = nullptr) {
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
if (aggregate_fun != nullptr) {
graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
}
if (mirror_fun != nullptr) {
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
}
return ApplyPass(std::move(graph), {"Gradient"});
}
} // namespace pass
} // namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_
......@@ -68,7 +68,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
if (nodes_[nid].source->op != nullptr &&
fmutate_inputs.count(nodes_[nid].source->op)) {
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](nodes_[nid].source->attrs)) {
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](*(nodes_[nid].source))) {
mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
}
}
......
......@@ -38,7 +38,7 @@ inline void UpdateNodeVersion(Node *n) {
}
}
if (fmutate_inputs.count(n->op) != 0) {
for (uint32_t i : fmutate_inputs[n->op](n->attrs)) {
for (uint32_t i : fmutate_inputs[n->op](*n)) {
NodeEntry& e = n->inputs[i];
CHECK(e.node->is_variable())
<< "Mutation target can only be Variable";
......@@ -197,7 +197,7 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
if (node->is_variable()) {
vlist.push_back(node.get());
} else if (fmutate_inputs.count(node->op)) {
for (uint32_t i : fmutate_inputs[node->op](node->attrs)){
for (uint32_t i : fmutate_inputs[node->op](*node)){
mutable_set.insert(node->inputs[i].node.get());
}
}
......@@ -223,7 +223,7 @@ std::vector<std::string> Symbol::ListOutputNames() const {
std::string rname;
FListOutputNames fn = flist_ouputs.get(head.node->op, nullptr);
if (fn != nullptr) {
rname = fn(head.node->attrs)[head.index];
rname = fn(*head.node)[head.index];
} else {
rname = "output";
if (head.node->num_outputs() != 1) {
......@@ -279,7 +279,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
// switch to keyword argument matching
if (args.size() != n_req) {
FListInputNames fn = flist_inputs.get(n->op, nullptr);
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(*n);
if (arg_names.size() != n_req) {
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op->name;
}
......
......@@ -75,7 +75,7 @@ Graph InferAttr(Graph &&ret,
oshape[i] = rshape[idx.entry_id(nid, i)];
}
num_unknown +=
!(finfer_shape[inode.source->op](inode.source->attrs, &ishape, &oshape));
!(finfer_shape[inode.source->op](*inode.source, &ishape, &oshape));
for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
}
......
......@@ -44,7 +44,7 @@ Graph OrderMutation(const Graph& src) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs;
if (!n->is_variable() && fmutate_inputs.count(n->op)) {
mutate_inputs = fmutate_inputs[n->op](n->attrs);
mutate_inputs = fmutate_inputs[n->op](*n);
}
std::sort(mutate_inputs.begin(), mutate_inputs.end());
......@@ -102,7 +102,7 @@ Graph OrderMutation(const Graph& src) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs;
if (fmutate_inputs.count(kv.first->op)) {
mutate_inputs = fmutate_inputs[kv.first->op](kv.first->attrs);
mutate_inputs = fmutate_inputs[kv.first->op](*kv.first);
}
std::sort(mutate_inputs.begin(), mutate_inputs.end());
......
......@@ -169,7 +169,7 @@ Graph PlanMemory(Graph ret) {
if (inode.source->is_variable()) continue;
// check inplace option
if (finplace_option.count(inode.source->op) != 0) {
auto inplace_pairs = finplace_option[inode.source->op](inode.source->attrs);
auto inplace_pairs = finplace_option[inode.source->op](*inode.source);
for (auto& kv : inplace_pairs) {
uint32_t eid_out = idx.entry_id(nid, kv.second);
uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment