Commit 669e9c15 by Yao Wang Committed by Tianqi Chen

Add gradient graph (#280)

* Add creating gradient symbol

* Fix lint

* Address comments

* Fix typo

* Address comment
parent 492a37c5
...@@ -38,8 +38,12 @@ This level enables fully connected multi-layer perceptron. ...@@ -38,8 +38,12 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.elemwise_sub nnvm.symbol.elemwise_sub
nnvm.symbol.elemwise_mul nnvm.symbol.elemwise_mul
nnvm.symbol.elemwise_div nnvm.symbol.elemwise_div
nnvm.symbol.fill nnvm.symbol.full
nnvm.symbol.fill_like nnvm.symbol.full_like
nnvm.symbol.ones
nnvm.symbol.ones_like
nnvm.symbol.zeros
nnvm.symbol.zeros_like
nnvm.symbol.flatten nnvm.symbol.flatten
nnvm.symbol.concatenate nnvm.symbol.concatenate
nnvm.symbol.expand_dims nnvm.symbol.expand_dims
...@@ -113,8 +117,12 @@ Detailed Definitions ...@@ -113,8 +117,12 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.elemwise_sub .. autofunction:: nnvm.symbol.elemwise_sub
.. autofunction:: nnvm.symbol.elemwise_mul .. autofunction:: nnvm.symbol.elemwise_mul
.. autofunction:: nnvm.symbol.elemwise_div .. autofunction:: nnvm.symbol.elemwise_div
.. autofunction:: nnvm.symbol.fill .. autofunction:: nnvm.symbol.full
.. autofunction:: nnvm.symbol.fill_like .. autofunction:: nnvm.symbol.full_like
.. autofunction:: nnvm.symbol.ones
.. autofunction:: nnvm.symbol.ones_like
.. autofunction:: nnvm.symbol.zeros
.. autofunction:: nnvm.symbol.zeros_like
.. autofunction:: nnvm.symbol.flatten .. autofunction:: nnvm.symbol.flatten
.. autofunction:: nnvm.symbol.concatenate .. autofunction:: nnvm.symbol.concatenate
.. autofunction:: nnvm.symbol.expand_dims .. autofunction:: nnvm.symbol.expand_dims
......
...@@ -13,6 +13,7 @@ from ._base import c_array, c_str, nn_uint, py_str, string_types ...@@ -13,6 +13,7 @@ from ._base import c_array, c_str, nn_uint, py_str, string_types
from ._base import GraphHandle, SymbolHandle from ._base import GraphHandle, SymbolHandle
from ._base import check_call from ._base import check_call
from .symbol import Variable, Symbol, Group as _Group from .symbol import Variable, Symbol, Group as _Group
from .symbol import ones_like
class GraphIndex(object): class GraphIndex(object):
"""Index for quickly accessing graph attributes. """Index for quickly accessing graph attributes.
...@@ -270,3 +271,38 @@ def create(symbol): ...@@ -270,3 +271,38 @@ def create(symbol):
check_call(_LIB.NNGraphCreate( check_call(_LIB.NNGraphCreate(
symbol.handle, ctypes.byref(ghandle))) symbol.handle, ctypes.byref(ghandle)))
return Graph(ghandle) return Graph(ghandle)
def gradients(ys, xs, grad_ys=None):
"""Create gradient symbol of ys respect to xs.
Parameters
----------
ys : Symbol or list of Symbol
Symbols from which the gradient is calculated.
xs : Symbol or list of Symbol
Symbols the gradient respect to.
For group symbol, gradients for all outputs will be calculated.
grad_ys : Symbol or list of Symbol
Head gradients for ys.
Returns
-------
ret : list of Symbol
Generated gradient symbol. For each xs,
all gradients from ys are merged into a single symbol.
"""
if isinstance(ys, list):
ys = _Group(ys)
g = create(ys)
g._set_symbol_list_attr('grad_ys', ys)
g._set_symbol_list_attr('grad_xs', xs)
ny = len(ys.list_output_names())
if grad_ys is None:
grad_ys = [ones_like(ys[i]) for i in range(ny)]
g._set_symbol_list_attr('grad_ys_out_grad', grad_ys)
sym = g.apply('Gradient').symbol
nx = len(_Group(xs).list_output_names()) \
if isinstance(xs, list) else len(xs.list_output_names())
ret = [sym[i] for i in range(nx)]
return ret
...@@ -20,11 +20,11 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) { ...@@ -20,11 +20,11 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
return std::move(v[0]); return std::move(v[0]);
} else if (v.size() == 0) { } else if (v.size() == 0) {
NodePtr zero_node = Node::Create(); NodePtr zero_node = Node::Create();
zero_node->attrs.op = Op::Get("__zero__"); zero_node->attrs.op = Op::Get("_zeros");
return NodeEntry{zero_node, 0, 0}; return NodeEntry{zero_node, 0, 0};
} else { } else {
NodePtr sum_node = Node::Create(); NodePtr sum_node = Node::Create();
sum_node->attrs.op = Op::Get("__ewise_sum__"); sum_node->attrs.op = Op::Get("elemwise_sum");
sum_node->inputs = std::move(v); sum_node->inputs = std::move(v);
return NodeEntry{sum_node, 0, 0}; return NodeEntry{sum_node, 0, 0};
} }
......
...@@ -197,7 +197,8 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_div) ...@@ -197,7 +197,8 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_div)
// grad_1 = - grad_y * n0 / n1^2 // grad_1 = - grad_y * n0 / n1^2
NodeEntry sub0 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_0", NodeEntry sub0 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_0",
{ograds[0], n->inputs[0]}); {ograds[0], n->inputs[0]});
NodeEntry sub1 = MakeNode("negative", n->attrs.name + "_grad_sub_1", {sub0}); NodeEntry sub1 = MakeNode("negative", n->attrs.name + "_grad_sub_1",
{sub0});
NodeEntry sub2 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_2", NodeEntry sub2 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_2",
{n->inputs[1], n->inputs[1]}); {n->inputs[1], n->inputs[1]});
return std::vector<NodeEntry>{ return std::vector<NodeEntry>{
...@@ -240,15 +241,27 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy) ...@@ -240,15 +241,27 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy)
DMLC_REGISTER_PARAMETER(InitOpParam); DMLC_REGISTER_PARAMETER(InitOpParam);
// fill // full
NNVM_REGISTER_INIT_OP(fill) NNVM_REGISTER_INIT_OP(full)
.describe(R"code(Fill array with scalar value .describe(R"code(Fill array with scalar value
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_support_level(1); .set_support_level(1);
// fill_like NNVM_REGISTER_INIT_OP(zeros)
NNVM_REGISTER_ELEMWISE_UNARY_OP(fill_like) .describe(R"code(Fill target with zeros
)code" NNVM_ADD_FILELINE)
.set_support_level(1);
NNVM_REGISTER_INIT_OP(ones)
.describe(R"code(Fill target with ones
)code" NNVM_ADD_FILELINE)
.set_support_level(1);
// full_like
NNVM_REGISTER_ELEMWISE_UNARY_OP(full_like)
.describe(R"code(Return an scalar value array with the same shape and type .describe(R"code(Return an scalar value array with the same shape and type
as the input array as the input array
...@@ -260,8 +273,38 @@ as the input array ...@@ -260,8 +273,38 @@ as the input array
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{ return std::vector<NodeEntry>{
MakeNode("fill_like", n->attrs.name + "_zero", MakeNode("zeros_like", n->attrs.name + "_grad",
{n->inputs[0]}, {{"value", "0"}}) {n->inputs[0]})
};
});
NNVM_REGISTER_ELEMWISE_UNARY_OP(zeros_like)
.describe(R"code(Return an array of zeros with the same shape and type
as the input array.
)code")
.add_argument("data", "Symbol", "The input")
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
MakeNode("zeros_like", n->attrs.name + "_grad",
{n->inputs[0]})
};
});
NNVM_REGISTER_ELEMWISE_UNARY_OP(ones_like)
.describe(R"code(Return an array of ones with the same shape and type
as the input array.
)code")
.add_argument("data", "Symbol", "The input")
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
MakeNode("zeros_like", n->attrs.name + "_grad",
{n->inputs[0]})
}; };
}); });
...@@ -353,8 +396,10 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rdiv_scalar__) ...@@ -353,8 +396,10 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rdiv_scalar__)
// y = scalar / n0 // y = scalar / n0
// grad_0 = - grad_y * scalar / n0^2 // grad_0 = - grad_y * scalar / n0^2
NodeEntry sub0 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_0", NodeEntry sub0 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_0",
{ograds[0]}, {{"scalar", n->attrs.dict["scalar"]}}); {ograds[0]},
NodeEntry sub1 = MakeNode("negative", n->attrs.name + "_grad_sub_1", {sub0}); {{"scalar", n->attrs.dict["scalar"]}});
NodeEntry sub1 = MakeNode("negative", n->attrs.name + "_grad_sub_1",
{sub0});
NodeEntry sub2 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_2", NodeEntry sub2 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_2",
{n->inputs[0], n->inputs[0]}); {n->inputs[0], n->inputs[0]});
return std::vector<NodeEntry>{ return std::vector<NodeEntry>{
...@@ -407,6 +452,63 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rpow_scalar__) ...@@ -407,6 +452,63 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rpow_scalar__)
}); });
struct ElementWiseSumParam : public dmlc::Parameter<ElementWiseSumParam> {
int num_args;
DMLC_DECLARE_PARAMETER(ElementWiseSumParam) {
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
.describe("Number of inputs to be summed.");
}
};
DMLC_REGISTER_PARAMETER(ElementWiseSumParam);
bool ElementWiseSumShape(const NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(out_attrs->size(), 1);
return ElemwiseAttr<TShape, shape_is_none, shape_assign, true, shape_string>(
attrs, in_attrs, out_attrs, TShape());
}
bool ElementWiseSumType(const NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(out_attrs->size(), 1);
return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
attrs, in_attrs, out_attrs, -1);
}
std::vector<NodeEntry> ElementWiseSumGrad(
const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
// identity constraints in the beginning for easier shape inference.
const Op* copy_op = Op::Get("identity");
CHECK_EQ(ograds.size(), 1);
std::vector<NodeEntry> ret;
NodeEntry n_out{n, 0, 0};
for (size_t i = 0; i < n->inputs.size(); i++) {
NodePtr id_node = Node::Create();
id_node->attrs.op = copy_op;
id_node->inputs = {ograds[0]};
ret.push_back(NodeEntry{id_node, 0, 0});
}
return ret;
}
NNVM_REGISTER_OP(elemwise_sum)
.describe(R"code(Adds all input arguments element-wise.
)code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ElementWiseSumParam>)
.set_num_inputs([](const NodeAttrs& attrs) {
uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
return ret;
})
.set_attr<nnvm::FInferShape>("FInferShape", ElementWiseSumShape)
.set_attr<nnvm::FInferType>("FInferType", ElementWiseSumType)
.set_attr<nnvm::FGradient>("FGradient", ElementWiseSumGrad)
.add_argument("args", "Symbol[]", "Positional input arguments");
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
...@@ -112,6 +112,23 @@ def test_print_graph_ir(): ...@@ -112,6 +112,23 @@ def test_print_graph_ir():
assert("y_bias" in ir1) assert("y_bias" in ir1)
assert("shape=" in ir2) assert("shape=" in ir2)
def test_gradient():
x = sym.Variable("x")
y = sym.Variable("y")
z1 = sym.elemwise_add(x, sym.sqrt(y))
z2 = sym.log(x)
gradient = graph.gradients([z1, z2], [x, y])
assert len(gradient) == 2
g1 = sym.Variable("g1")
g2 = sym.Variable("g2")
grad_ys = [g1, g2]
gradient = graph.gradients(sym.Group([z1, z2]),
sym.Group([x, y]), grad_ys=grad_ys)
g_graph = graph.create(sym.Group(gradient)).ir()
assert len(gradient) == 2
assert "g1" in g_graph
assert "g2" in g_graph
if __name__ == "__main__": if __name__ == "__main__":
test_print_graph_ir() test_print_graph_ir()
...@@ -123,3 +140,4 @@ if __name__ == "__main__": ...@@ -123,3 +140,4 @@ if __name__ == "__main__":
test_infer_type() test_infer_type()
test_plan_memory() test_plan_memory()
test_list_args() test_list_args()
test_gradient()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment