Commit ac070f83 by Tianqi Chen

[PASS] Add gradient pass (#28)

parent 803db5d1
......@@ -15,6 +15,10 @@ using nnvm::FMutateInputs;
using nnvm::FInferShape;
using nnvm::FInferType;
using nnvm::FInplaceOption;
using nnvm::Node;
using nnvm::NodePtr;
using nnvm::NodeEntry;
using nnvm::FGradient;
using nnvm::NodeAttrs;
using nnvm::TShape;
using nnvm::array_view;
......@@ -37,6 +41,17 @@ inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs)
return {{0, 0}};
}
// quick helper to make node
inline NodeEntry MakeNode(const char* op_name,
std::string node_name,
std::vector<NodeEntry> inputs) {
NodePtr p = Node::Create();
p->op = nnvm::Op::Get(op_name);
p->attrs.name = std::move(node_name);
p->inputs = std::move(inputs);
return NodeEntry{p, 0, 0};
}
// simple demonstration of reshape.
NNVM_REGISTER_OP(reshape)
.describe("reshape source to target shape")
......@@ -84,21 +99,67 @@ NNVM_REGISTER_OP(cast)
return true;
});
NNVM_REGISTER_OP(exp)
.describe("take exponential")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("mul", n->attrs.name + "_grad",
{ograds[0], NodeEntry{n, 0, 0}})
};
});
NNVM_REGISTER_OP(identity)
.describe("identity function")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{ograds[0]};
});
NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0);
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{ograds[0], ograds[0]};
});
NNVM_REGISTER_OP(__add_symbol__)
.describe("Alias of add")
.set_num_inputs(2);
NNVM_REGISTER_OP(mul)
.describe("multiply two data together")
.set_num_inputs(2)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
MakeNode("mul", n->attrs.name + "_grad_0",
{ograds[0], n->inputs[1]}),
MakeNode("mul", n->attrs.name + "_grad_1",
{ograds[0], n->inputs[0]})
};
});
NNVM_REGISTER_OP(exp)
.describe("take exponential")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape);
NNVM_REGISTER_OP(__ewise_sum__)
.describe("elementwise sum")
.set_num_inputs(nnvm::kVarg);
NNVM_REGISTER_OP(__zero__)
.describe("set output to zero")
.set_num_inputs(0);
NNVM_REGISTER_OP(__one__)
.describe("set output to one")
.set_num_inputs(0);
NNVM_REGISTER_OP(cross_device_copy)
.describe("Copy data across device.")
......
......@@ -58,6 +58,11 @@
__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/*! \brief strict CXX11 support */
#ifndef DMLC_STRICT_CXX11
#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/// check if g++ is before 4.6
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6
......@@ -69,6 +74,7 @@
#endif
#endif
/*!
* \brief Enable std::thread related modules,
* Used to disable some module in mingw compile.
......@@ -82,6 +88,13 @@
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define DMLC_ATTRIBUTE_UNUSED
#endif
/*! \brief helper macro to generate string concat */
#define DMLC_STR_CONCAT_(__x, __y) __x##__y
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)
......
......@@ -25,7 +25,9 @@
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#if DMLC_STRICT_CXX11
#include "./any.h"
#endif // DMLC_STRICT_CXX11
#endif // DMLC_USE_CXX11
namespace dmlc {
......@@ -320,7 +322,8 @@ class JSONObjectReadHelper {
};
#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \
static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __
static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \
__make_AnyJSONType ## _ ## KeyName ## __
/*!
* \def DMLC_JSON_ENABLE_ANY
......@@ -475,7 +478,7 @@ struct Handler {
}
};
#if DMLC_USE_CXX11
#if DMLC_STRICT_CXX11
// Manager to store json serialization strategy.
class AnyJSONManager {
public:
......@@ -561,7 +564,7 @@ struct Handler<any> {
CHECK(!reader->NextArrayItem()) << "invalid any json format";
}
};
#endif // DMLC_USE_CXX11
#endif // DMLC_STRICT_CXX11
} // namespace json
......
......@@ -251,7 +251,8 @@ struct Parameter {
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
return &inst.manager; \
} \
static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \
static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
__make__ ## PType ## ParamManager__ = \
(*PType::__MANAGER__()) \
//! \endcond
......
......@@ -216,7 +216,7 @@ class FunctionRegEntryBase {
* \sa FactoryRegistryEntryBase
*/
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
static EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \
/*!
......@@ -272,6 +272,7 @@ class FunctionRegEntryBase {
*/
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __();
static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \
__dmlc_registry_file_tag_ ## UniqueTag ## __();
} // namespace dmlc
#endif // DMLC_REGISTRY_H_
......@@ -260,6 +260,7 @@ NNVM_DLL int NNGraphFree(GraphHandle handle);
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
/*!
* \brief Get Set a attribute in json format.
* This feature allows pass graph attributes back and forth in reasonable speed.
......@@ -273,6 +274,7 @@ NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle,
const char* key,
const char* json_value);
/*!
* \brief Get a serialized attrirbute from graph.
* This feature allows pass graph attributes back and forth in reasonable speed.
......@@ -289,6 +291,21 @@ NNVM_DLL int NNGraphGetJSONAttr(SymbolHandle handle,
const char* key,
const char** json_out,
int *success);
/*!
* \brief Set a attribute whose type is std::vector<NodeEntry> in c++
* This feature allows pass List of symbolic variables for gradient request.
*
* \note This is beta feature only used for test purpos
*
* \param handle The graph handle.
* \param key The key to the attribute.
* \param list The symbol whose outputs represents the list of NodeEntry to be passed.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
const char* key,
SymbolHandle list);
/*!
* \brief Apply pass on the src graph.
* \param src The source graph handle.
......
......@@ -279,10 +279,8 @@ class OpMap {
};
// internal macros to make
#define NNVM_STR_CONCAT_(__x, __y) __x##__y
#define NNVM_STR_CONCAT(__x, __y) NNVM_STR_CONCAT_(__x, __y)
#define NNVM_REGISTER_VAR_DEF(OpName) \
static ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
/*!
* \def NNVM_REGISTER_OP
......@@ -300,7 +298,7 @@ class OpMap {
* \endcode
*/
#define NNVM_REGISTER_OP(OpName) \
NNVM_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
// implementations of template functions after this.
......
......@@ -11,6 +11,7 @@
#include <utility>
#include <functional>
#include "./base.h"
#include "./node.h"
#include "./tuple.h"
namespace nnvm {
......@@ -107,6 +108,19 @@ using TIsBackwardOp = bool;
using FInplaceOption = std::function<
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;
/*!
* \brief Get the gradient node of the op node
* This function generates the backward graph of the node
* \param nodeptr The node to take gradient
* \param out_grads Gradient of current node's outputs
* \return gradients of the inputs
*
* \note Register under "FGradient"
*/
using FGradient = std::function<std::vector<NodeEntry>(
const NodePtr& nodeptr,
const std::vector<NodeEntry>& out_grads)>;
} // namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_
......@@ -109,6 +109,37 @@ inline Graph PlaceDevice(Graph graph,
return ApplyPass(std::move(graph), {"PlaceDevice"});
}
/*!
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
* \param graph source graph
* \param ys The entries we want to take gradient from.
* \param xs The input to take gradient with respect to.
* \param ys_out_grad The symbol for additional gradient to be propagate back to y.
* \param aggregate_fun aggregation function applied to aggregate the inputs
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \return A new graph, whose outputs corresponds to inputs of xs.
*/
inline Graph Gradient(
Graph graph,
std::vector<NodeEntry> ys,
std::vector<NodeEntry> xs,
std::vector<NodeEntry> ys_out_grad,
std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr,
std::function<int(const Node& node)> mirror_fun = nullptr) {
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
graph.attrs["grad_ys_out_grad"] = std::make_shared<any>(std::move(ys_out_grad));
if (aggregate_fun != nullptr) {
graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
}
if (mirror_fun != nullptr) {
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
}
return ApplyPass(std::move(graph), {"Gradient"});
}
} // namespace pass
} // namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_
......@@ -10,7 +10,7 @@ from ._base import _LIB
from ._base import c_array, c_str, nn_uint, py_str, string_types
from ._base import GraphHandle, SymbolHandle
from ._base import check_call
from .symbol import Symbol
from .symbol import Symbol, Group as _Group
class Graph(object):
......@@ -56,8 +56,27 @@ class Graph(object):
else:
return None
def _set_symbol_list_attr(self, key, value):
"""Set the attribute of the graph.
Parameters
----------
key : string
The key of the attribute
value : value
The any type that can be dumped to json
type_name : string
The typename registered on c++ side.
"""
if isinstance(value, list):
value = _Group(value)
if not isinstance(value, Symbol):
raise ValueError("value need to be grouped symbol")
check_call(_LIB.NNGraphSetNodeEntryListAttr_(
self.handle, c_str(key), value.handle))
def _set_json_attr(self, key, value, type_name=None):
"""Set the attribute of the symbol.
"""Set the attribute of the graph.
Parameters
----------
......
......@@ -35,6 +35,17 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) {
API_END_HANDLE_ERROR(delete s);
}
int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
const char* key,
SymbolHandle list) {
API_BEGIN();
Symbol* s = static_cast<Symbol*>(list);
Graph* g = static_cast<Graph*>(handle);
g->attrs[std::string(key)]
= std::make_shared<any>(s->outputs);
API_END();
}
int NNGraphSetJSONAttr(GraphHandle handle,
const char* key,
const char* json_value) {
......
/*!
* Copyright (c) 2016 by Contributors
* \file gradients.cc
* \brief Passes that takes gradient of the graph
* This code code was modified based on mxnet codebase by Min Lin
*/
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <algorithm>
#include <functional>
namespace nnvm {
namespace pass {
namespace {
// default aggregate gradient function
// require operator __zero__ and __ewise_sum__ to be presented.
NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
if (v.size() == 1) {
return std::move(v[0]);
} else if (v.size() == 0) {
NodePtr zero_node = Node::Create();
zero_node->op = Op::Get("__zero__");
return NodeEntry{zero_node, 0, 0};
} else {
NodePtr sum_node = Node::Create();
sum_node->op = Op::Get("__ewise_sum__");
sum_node->inputs = std::move(v);
return NodeEntry{sum_node, 0, 0};
}
}
// helper entry
struct GradEntry {
NodeEntry sum{nullptr, 0, 0};
std::vector<NodeEntry> grads;
};
Graph Gradient(Graph src) {
using nnvm::FGradient;
using MirrorFun = std::function<int (const Node& node)>;
CHECK_NE(src.attrs.count("grad_ys"), 0)
<< "Gradient require grad_ys to be presented.";
CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0)
<< "Gradient require grad_ys_out_grad to be presented.";
CHECK_NE(src.attrs.count("grad_xs"), 0)
<< "Gradient require grad_xs to be presented.";
const std::vector<NodeEntry>& ys =
src.GetAttr<std::vector<NodeEntry> >("grad_ys");
const std::vector<NodeEntry>& ys_out_grad =
src.GetAttr<std::vector<NodeEntry> >("grad_ys_out_grad");
const std::vector<NodeEntry>& xs =
src.GetAttr<std::vector<NodeEntry> >("grad_xs");
using AggFun = std::function<NodeEntry (std::vector<NodeEntry>&& inputs)>;
AggFun agg_fun = DefaultAggregateGradient;
if (src.attrs.count("grad_aggregate_fun") != 0) {
agg_fun = src.GetAttr<AggFun>("grad_aggregate_fun");
}
MirrorFun mirror_fun = nullptr;
if (src.attrs.count("grad_mirror_fun") != 0) {
mirror_fun = src.GetAttr<MirrorFun>("grad_mirror_fun");
}
// topo sort
std::vector<NodePtr> topo_order;
std::unordered_map<Node*, std::vector<GradEntry> > output_grads;
DFSVisit(ys, [&](const NodePtr& node) {
if (output_grads.count(node.get()) == 0) {
output_grads[node.get()].resize(node->num_outputs());
}
topo_order.push_back(node);
});
CHECK_EQ(ys.size(), ys_out_grad.size());
for (size_t i = 0; i < ys.size(); ++i) {
output_grads[ys[i].node.get()][ys[i].index].grads = { ys_out_grad[i] };
}
// construct mirror reduece memory strategy if needed
std::unordered_map<Node*, NodePtr> mirror_map;
if (mirror_fun != nullptr) {
for (const NodePtr& n : topo_order) {
if (mirror_fun(*n)) {
NodePtr new_node = Node::Create();
*new_node = *n;
new_node->attrs.name += "_mirror";
for (auto& e : new_node->inputs) {
e.node = mirror_map.at(e.node.get());
}
for (auto& n : new_node->control_deps) {
n = mirror_map.at(n.get());
}
mirror_map[n.get()] = std::move(new_node);
} else {
mirror_map[n.get()] = n;
}
}
}
// traverse backward
static auto& grad_fun_map = Op::GetAttr<FGradient>("FGradient");
std::vector<NodeEntry> out_agg_grads;
for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) {
const NodePtr& ptr = *rit;
if (ptr->is_variable()) continue;
out_agg_grads.clear();
for (GradEntry& e : output_grads.at(ptr.get())) {
e.sum = agg_fun(std::move(e.grads));
out_agg_grads.push_back(e.sum);
}
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op]
(mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads);
auto git = input_grads.begin();
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git));
}
}
// take out the xs' grads
Graph ret;
ret.outputs.reserve(xs.size());
for (const NodeEntry& e : xs) {
GradEntry& entry = output_grads[e.node.get()][e.index];
// aggregate sum if there haven't been
if (entry.sum.node.get() == nullptr) {
entry.sum = agg_fun(std::move(entry.grads));
}
ret.outputs.emplace_back(std::move(entry.sum));
}
return ret;
}
// register pass
NNVM_REGISTER_PASS(Gradient)
.describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]")
.set_body(Gradient)
.set_change_graph(true)
.depend_graph_attr("grad_ys")
.depend_graph_attr("grad_xs")
.depend_graph_attr("grad_ys_out_grad");
} // namespace
} // namespace pass
} // namespace nnvm
import json
import nnvm.symbol as sym
import nnvm.graph as graph
def grad(ys, xs, ys_grads):
g = graph.create(ys)
g._set_symbol_list_attr('grad_ys', ys)
g._set_symbol_list_attr('grad_xs', xs)
g._set_symbol_list_attr('grad_ys_out_grad', ys_grads)
return g.apply('Gradient')
def test_graph_gradient():
x0 = sym.Variable('x0')
x1 = sym.Variable('x1')
yg = sym.Variable('yg')
y = sym.exp(sym.mul(x0, x1))
grad_graph = grad(y, [x0], yg)
print("Original graph")
print(y.debug_str())
print("Gradient graph")
print grad_graph.symbol.debug_str()
if __name__ == "__main__":
test_graph_gradient()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment