Commit ac070f83 by Tianqi Chen

[PASS] Add gradient pass (#28)

parent 803db5d1
...@@ -15,6 +15,10 @@ using nnvm::FMutateInputs; ...@@ -15,6 +15,10 @@ using nnvm::FMutateInputs;
using nnvm::FInferShape; using nnvm::FInferShape;
using nnvm::FInferType; using nnvm::FInferType;
using nnvm::FInplaceOption; using nnvm::FInplaceOption;
using nnvm::Node;
using nnvm::NodePtr;
using nnvm::NodeEntry;
using nnvm::FGradient;
using nnvm::NodeAttrs; using nnvm::NodeAttrs;
using nnvm::TShape; using nnvm::TShape;
using nnvm::array_view; using nnvm::array_view;
...@@ -37,6 +41,17 @@ inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs) ...@@ -37,6 +41,17 @@ inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs)
return {{0, 0}}; return {{0, 0}};
} }
// quick helper to make node
inline NodeEntry MakeNode(const char* op_name,
std::string node_name,
std::vector<NodeEntry> inputs) {
NodePtr p = Node::Create();
p->op = nnvm::Op::Get(op_name);
p->attrs.name = std::move(node_name);
p->inputs = std::move(inputs);
return NodeEntry{p, 0, 0};
}
// simple demonstration of reshape. // simple demonstration of reshape.
NNVM_REGISTER_OP(reshape) NNVM_REGISTER_OP(reshape)
.describe("reshape source to target shape") .describe("reshape source to target shape")
...@@ -84,21 +99,67 @@ NNVM_REGISTER_OP(cast) ...@@ -84,21 +99,67 @@ NNVM_REGISTER_OP(cast)
return true; return true;
}); });
NNVM_REGISTER_OP(exp)
.describe("take exponential")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("mul", n->attrs.name + "_grad",
{ograds[0], NodeEntry{n, 0, 0}})
};
});
NNVM_REGISTER_OP(identity)
.describe("identity function")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{ograds[0]};
});
NNVM_REGISTER_OP(add) NNVM_REGISTER_OP(add)
.describe("add two data together") .describe("add two data together")
.set_num_inputs(2) .set_num_inputs(2)
.attr<FInferShape>("FInferShape", SameShape) .attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0); .attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{ograds[0], ograds[0]};
});
NNVM_REGISTER_OP(__add_symbol__) NNVM_REGISTER_OP(mul)
.describe("Alias of add") .describe("multiply two data together")
.set_num_inputs(2); .set_num_inputs(2)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
MakeNode("mul", n->attrs.name + "_grad_0",
{ograds[0], n->inputs[1]}),
MakeNode("mul", n->attrs.name + "_grad_1",
{ograds[0], n->inputs[0]})
};
});
NNVM_REGISTER_OP(exp) NNVM_REGISTER_OP(__ewise_sum__)
.describe("take exponential") .describe("elementwise sum")
.set_num_inputs(1) .set_num_inputs(nnvm::kVarg);
.attr<FInferShape>("FInferShape", SameShape);
NNVM_REGISTER_OP(__zero__)
.describe("set output to zero")
.set_num_inputs(0);
NNVM_REGISTER_OP(__one__)
.describe("set output to one")
.set_num_inputs(0);
NNVM_REGISTER_OP(cross_device_copy) NNVM_REGISTER_OP(cross_device_copy)
.describe("Copy data across device.") .describe("Copy data across device.")
......
...@@ -58,6 +58,11 @@ ...@@ -58,6 +58,11 @@
__cplusplus >= 201103L || defined(_MSC_VER)) __cplusplus >= 201103L || defined(_MSC_VER))
#endif #endif
/*! \brief strict CXX11 support */
#ifndef DMLC_STRICT_CXX11
#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/// check if g++ is before 4.6 /// check if g++ is before 4.6
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__) #if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6 #if __GNUC__ == 4 && __GNUC_MINOR__ < 6
...@@ -69,6 +74,7 @@ ...@@ -69,6 +74,7 @@
#endif #endif
#endif #endif
/*! /*!
* \brief Enable std::thread related modules, * \brief Enable std::thread related modules,
* Used to disable some module in mingw compile. * Used to disable some module in mingw compile.
...@@ -82,6 +88,13 @@ ...@@ -82,6 +88,13 @@
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER)) #define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
#endif #endif
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define DMLC_ATTRIBUTE_UNUSED
#endif
/*! \brief helper macro to generate string concat */ /*! \brief helper macro to generate string concat */
#define DMLC_STR_CONCAT_(__x, __y) __x##__y #define DMLC_STR_CONCAT_(__x, __y) __x##__y
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y) #define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)
......
...@@ -25,7 +25,9 @@ ...@@ -25,7 +25,9 @@
#include <typeindex> #include <typeindex>
#include <typeinfo> #include <typeinfo>
#include <unordered_map> #include <unordered_map>
#if DMLC_STRICT_CXX11
#include "./any.h" #include "./any.h"
#endif // DMLC_STRICT_CXX11
#endif // DMLC_USE_CXX11 #endif // DMLC_USE_CXX11
namespace dmlc { namespace dmlc {
...@@ -320,7 +322,8 @@ class JSONObjectReadHelper { ...@@ -320,7 +322,8 @@ class JSONObjectReadHelper {
}; };
#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \ #define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \
static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __ static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \
__make_AnyJSONType ## _ ## KeyName ## __
/*! /*!
* \def DMLC_JSON_ENABLE_ANY * \def DMLC_JSON_ENABLE_ANY
...@@ -475,7 +478,7 @@ struct Handler { ...@@ -475,7 +478,7 @@ struct Handler {
} }
}; };
#if DMLC_USE_CXX11 #if DMLC_STRICT_CXX11
// Manager to store json serialization strategy. // Manager to store json serialization strategy.
class AnyJSONManager { class AnyJSONManager {
public: public:
...@@ -561,7 +564,7 @@ struct Handler<any> { ...@@ -561,7 +564,7 @@ struct Handler<any> {
CHECK(!reader->NextArrayItem()) << "invalid any json format"; CHECK(!reader->NextArrayItem()) << "invalid any json format";
} }
}; };
#endif // DMLC_USE_CXX11 #endif // DMLC_STRICT_CXX11
} // namespace json } // namespace json
......
...@@ -251,7 +251,8 @@ struct Parameter { ...@@ -251,7 +251,8 @@ struct Parameter {
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \ static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
return &inst.manager; \ return &inst.manager; \
} \ } \
static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \ static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
__make__ ## PType ## ParamManager__ = \
(*PType::__MANAGER__()) \ (*PType::__MANAGER__()) \
//! \endcond //! \endcond
......
...@@ -216,7 +216,7 @@ class FunctionRegEntryBase { ...@@ -216,7 +216,7 @@ class FunctionRegEntryBase {
* \sa FactoryRegistryEntryBase * \sa FactoryRegistryEntryBase
*/ */
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \ #define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
static EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \ static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \ ::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \
/*! /*!
...@@ -272,6 +272,7 @@ class FunctionRegEntryBase { ...@@ -272,6 +272,7 @@ class FunctionRegEntryBase {
*/ */
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \ #define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \ int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __(); static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \
__dmlc_registry_file_tag_ ## UniqueTag ## __();
} // namespace dmlc } // namespace dmlc
#endif // DMLC_REGISTRY_H_ #endif // DMLC_REGISTRY_H_
...@@ -260,6 +260,7 @@ NNVM_DLL int NNGraphFree(GraphHandle handle); ...@@ -260,6 +260,7 @@ NNVM_DLL int NNGraphFree(GraphHandle handle);
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
/*! /*!
* \brief Get Set a attribute in json format. * \brief Get Set a attribute in json format.
* This feature allows pass graph attributes back and forth in reasonable speed. * This feature allows pass graph attributes back and forth in reasonable speed.
...@@ -273,6 +274,7 @@ NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); ...@@ -273,6 +274,7 @@ NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle,
const char* key, const char* key,
const char* json_value); const char* json_value);
/*! /*!
* \brief Get a serialized attrirbute from graph. * \brief Get a serialized attrirbute from graph.
* This feature allows pass graph attributes back and forth in reasonable speed. * This feature allows pass graph attributes back and forth in reasonable speed.
...@@ -289,6 +291,21 @@ NNVM_DLL int NNGraphGetJSONAttr(SymbolHandle handle, ...@@ -289,6 +291,21 @@ NNVM_DLL int NNGraphGetJSONAttr(SymbolHandle handle,
const char* key, const char* key,
const char** json_out, const char** json_out,
int *success); int *success);
/*!
* \brief Set a attribute whose type is std::vector<NodeEntry> in c++
* This feature allows pass List of symbolic variables for gradient request.
*
* \note This is beta feature only used for test purpos
*
* \param handle The graph handle.
* \param key The key to the attribute.
* \param list The symbol whose outputs represents the list of NodeEntry to be passed.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
const char* key,
SymbolHandle list);
/*! /*!
* \brief Apply pass on the src graph. * \brief Apply pass on the src graph.
* \param src The source graph handle. * \param src The source graph handle.
......
...@@ -279,10 +279,8 @@ class OpMap { ...@@ -279,10 +279,8 @@ class OpMap {
}; };
// internal macros to make // internal macros to make
#define NNVM_STR_CONCAT_(__x, __y) __x##__y
#define NNVM_STR_CONCAT(__x, __y) NNVM_STR_CONCAT_(__x, __y)
#define NNVM_REGISTER_VAR_DEF(OpName) \ #define NNVM_REGISTER_VAR_DEF(OpName) \
static ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
/*! /*!
* \def NNVM_REGISTER_OP * \def NNVM_REGISTER_OP
...@@ -300,7 +298,7 @@ class OpMap { ...@@ -300,7 +298,7 @@ class OpMap {
* \endcode * \endcode
*/ */
#define NNVM_REGISTER_OP(OpName) \ #define NNVM_REGISTER_OP(OpName) \
NNVM_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName) ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
// implementations of template functions after this. // implementations of template functions after this.
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <utility> #include <utility>
#include <functional> #include <functional>
#include "./base.h" #include "./base.h"
#include "./node.h"
#include "./tuple.h" #include "./tuple.h"
namespace nnvm { namespace nnvm {
...@@ -107,6 +108,19 @@ using TIsBackwardOp = bool; ...@@ -107,6 +108,19 @@ using TIsBackwardOp = bool;
using FInplaceOption = std::function< using FInplaceOption = std::function<
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>; std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;
/*!
* \brief Get the gradient node of the op node
* This function generates the backward graph of the node
* \param nodeptr The node to take gradient
* \param out_grads Gradient of current node's outputs
* \return gradients of the inputs
*
* \note Register under "FGradient"
*/
using FGradient = std::function<std::vector<NodeEntry>(
const NodePtr& nodeptr,
const std::vector<NodeEntry>& out_grads)>;
} // namespace nnvm } // namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_ #endif // NNVM_OP_ATTR_TYPES_H_
...@@ -109,6 +109,37 @@ inline Graph PlaceDevice(Graph graph, ...@@ -109,6 +109,37 @@ inline Graph PlaceDevice(Graph graph,
return ApplyPass(std::move(graph), {"PlaceDevice"}); return ApplyPass(std::move(graph), {"PlaceDevice"});
} }
/*!
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
* \param graph source graph
* \param ys The entries we want to take gradient from.
* \param xs The input to take gradient with respect to.
* \param ys_out_grad The symbol for additional gradient to be propagate back to y.
* \param aggregate_fun aggregation function applied to aggregate the inputs
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \return A new graph, whose outputs corresponds to inputs of xs.
*/
inline Graph Gradient(
Graph graph,
std::vector<NodeEntry> ys,
std::vector<NodeEntry> xs,
std::vector<NodeEntry> ys_out_grad,
std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr,
std::function<int(const Node& node)> mirror_fun = nullptr) {
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
graph.attrs["grad_ys_out_grad"] = std::make_shared<any>(std::move(ys_out_grad));
if (aggregate_fun != nullptr) {
graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
}
if (mirror_fun != nullptr) {
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
}
return ApplyPass(std::move(graph), {"Gradient"});
}
} // namespace pass } // namespace pass
} // namespace nnvm } // namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_ #endif // NNVM_PASS_FUNCTIONS_H_
...@@ -10,7 +10,7 @@ from ._base import _LIB ...@@ -10,7 +10,7 @@ from ._base import _LIB
from ._base import c_array, c_str, nn_uint, py_str, string_types from ._base import c_array, c_str, nn_uint, py_str, string_types
from ._base import GraphHandle, SymbolHandle from ._base import GraphHandle, SymbolHandle
from ._base import check_call from ._base import check_call
from .symbol import Symbol from .symbol import Symbol, Group as _Group
class Graph(object): class Graph(object):
...@@ -56,8 +56,27 @@ class Graph(object): ...@@ -56,8 +56,27 @@ class Graph(object):
else: else:
return None return None
def _set_symbol_list_attr(self, key, value):
"""Set the attribute of the graph.
Parameters
----------
key : string
The key of the attribute
value : value
The any type that can be dumped to json
type_name : string
The typename registered on c++ side.
"""
if isinstance(value, list):
value = _Group(value)
if not isinstance(value, Symbol):
raise ValueError("value need to be grouped symbol")
check_call(_LIB.NNGraphSetNodeEntryListAttr_(
self.handle, c_str(key), value.handle))
def _set_json_attr(self, key, value, type_name=None): def _set_json_attr(self, key, value, type_name=None):
"""Set the attribute of the symbol. """Set the attribute of the graph.
Parameters Parameters
---------- ----------
......
...@@ -35,6 +35,17 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) { ...@@ -35,6 +35,17 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) {
API_END_HANDLE_ERROR(delete s); API_END_HANDLE_ERROR(delete s);
} }
int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
const char* key,
SymbolHandle list) {
API_BEGIN();
Symbol* s = static_cast<Symbol*>(list);
Graph* g = static_cast<Graph*>(handle);
g->attrs[std::string(key)]
= std::make_shared<any>(s->outputs);
API_END();
}
int NNGraphSetJSONAttr(GraphHandle handle, int NNGraphSetJSONAttr(GraphHandle handle,
const char* key, const char* key,
const char* json_value) { const char* json_value) {
......
/*!
* Copyright (c) 2016 by Contributors
* \file gradients.cc
* \brief Passes that takes gradient of the graph
* This code code was modified based on mxnet codebase by Min Lin
*/
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <algorithm>
#include <functional>
namespace nnvm {
namespace pass {
namespace {
// default aggregate gradient function
// require operator __zero__ and __ewise_sum__ to be presented.
NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
if (v.size() == 1) {
return std::move(v[0]);
} else if (v.size() == 0) {
NodePtr zero_node = Node::Create();
zero_node->op = Op::Get("__zero__");
return NodeEntry{zero_node, 0, 0};
} else {
NodePtr sum_node = Node::Create();
sum_node->op = Op::Get("__ewise_sum__");
sum_node->inputs = std::move(v);
return NodeEntry{sum_node, 0, 0};
}
}
// helper entry
struct GradEntry {
NodeEntry sum{nullptr, 0, 0};
std::vector<NodeEntry> grads;
};
Graph Gradient(Graph src) {
using nnvm::FGradient;
using MirrorFun = std::function<int (const Node& node)>;
CHECK_NE(src.attrs.count("grad_ys"), 0)
<< "Gradient require grad_ys to be presented.";
CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0)
<< "Gradient require grad_ys_out_grad to be presented.";
CHECK_NE(src.attrs.count("grad_xs"), 0)
<< "Gradient require grad_xs to be presented.";
const std::vector<NodeEntry>& ys =
src.GetAttr<std::vector<NodeEntry> >("grad_ys");
const std::vector<NodeEntry>& ys_out_grad =
src.GetAttr<std::vector<NodeEntry> >("grad_ys_out_grad");
const std::vector<NodeEntry>& xs =
src.GetAttr<std::vector<NodeEntry> >("grad_xs");
using AggFun = std::function<NodeEntry (std::vector<NodeEntry>&& inputs)>;
AggFun agg_fun = DefaultAggregateGradient;
if (src.attrs.count("grad_aggregate_fun") != 0) {
agg_fun = src.GetAttr<AggFun>("grad_aggregate_fun");
}
MirrorFun mirror_fun = nullptr;
if (src.attrs.count("grad_mirror_fun") != 0) {
mirror_fun = src.GetAttr<MirrorFun>("grad_mirror_fun");
}
// topo sort
std::vector<NodePtr> topo_order;
std::unordered_map<Node*, std::vector<GradEntry> > output_grads;
DFSVisit(ys, [&](const NodePtr& node) {
if (output_grads.count(node.get()) == 0) {
output_grads[node.get()].resize(node->num_outputs());
}
topo_order.push_back(node);
});
CHECK_EQ(ys.size(), ys_out_grad.size());
for (size_t i = 0; i < ys.size(); ++i) {
output_grads[ys[i].node.get()][ys[i].index].grads = { ys_out_grad[i] };
}
// construct mirror reduece memory strategy if needed
std::unordered_map<Node*, NodePtr> mirror_map;
if (mirror_fun != nullptr) {
for (const NodePtr& n : topo_order) {
if (mirror_fun(*n)) {
NodePtr new_node = Node::Create();
*new_node = *n;
new_node->attrs.name += "_mirror";
for (auto& e : new_node->inputs) {
e.node = mirror_map.at(e.node.get());
}
for (auto& n : new_node->control_deps) {
n = mirror_map.at(n.get());
}
mirror_map[n.get()] = std::move(new_node);
} else {
mirror_map[n.get()] = n;
}
}
}
// traverse backward
static auto& grad_fun_map = Op::GetAttr<FGradient>("FGradient");
std::vector<NodeEntry> out_agg_grads;
for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) {
const NodePtr& ptr = *rit;
if (ptr->is_variable()) continue;
out_agg_grads.clear();
for (GradEntry& e : output_grads.at(ptr.get())) {
e.sum = agg_fun(std::move(e.grads));
out_agg_grads.push_back(e.sum);
}
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op]
(mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads);
auto git = input_grads.begin();
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git));
}
}
// take out the xs' grads
Graph ret;
ret.outputs.reserve(xs.size());
for (const NodeEntry& e : xs) {
GradEntry& entry = output_grads[e.node.get()][e.index];
// aggregate sum if there haven't been
if (entry.sum.node.get() == nullptr) {
entry.sum = agg_fun(std::move(entry.grads));
}
ret.outputs.emplace_back(std::move(entry.sum));
}
return ret;
}
// register pass
NNVM_REGISTER_PASS(Gradient)
.describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]")
.set_body(Gradient)
.set_change_graph(true)
.depend_graph_attr("grad_ys")
.depend_graph_attr("grad_xs")
.depend_graph_attr("grad_ys_out_grad");
} // namespace
} // namespace pass
} // namespace nnvm
import json
import nnvm.symbol as sym
import nnvm.graph as graph
def grad(ys, xs, ys_grads):
g = graph.create(ys)
g._set_symbol_list_attr('grad_ys', ys)
g._set_symbol_list_attr('grad_xs', xs)
g._set_symbol_list_attr('grad_ys_out_grad', ys_grads)
return g.apply('Gradient')
def test_graph_gradient():
x0 = sym.Variable('x0')
x1 = sym.Variable('x1')
yg = sym.Variable('yg')
y = sym.exp(sym.mul(x0, x1))
grad_graph = grad(y, [x0], yg)
print("Original graph")
print(y.debug_str())
print("Gradient graph")
print grad_graph.symbol.debug_str()
if __name__ == "__main__":
test_graph_gradient()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment