Commit 669e9c15 by Yao Wang Committed by Tianqi Chen

Add gradient graph (#280)

* Add creating gradient symbol

* Fix lint

* Address comments

* Fix typo

* Address comment
parent 492a37c5
......@@ -38,8 +38,12 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.elemwise_sub
nnvm.symbol.elemwise_mul
nnvm.symbol.elemwise_div
nnvm.symbol.fill
nnvm.symbol.fill_like
nnvm.symbol.full
nnvm.symbol.full_like
nnvm.symbol.ones
nnvm.symbol.ones_like
nnvm.symbol.zeros
nnvm.symbol.zeros_like
nnvm.symbol.flatten
nnvm.symbol.concatenate
nnvm.symbol.expand_dims
......@@ -113,8 +117,12 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.elemwise_sub
.. autofunction:: nnvm.symbol.elemwise_mul
.. autofunction:: nnvm.symbol.elemwise_div
.. autofunction:: nnvm.symbol.fill
.. autofunction:: nnvm.symbol.fill_like
.. autofunction:: nnvm.symbol.full
.. autofunction:: nnvm.symbol.full_like
.. autofunction:: nnvm.symbol.ones
.. autofunction:: nnvm.symbol.ones_like
.. autofunction:: nnvm.symbol.zeros
.. autofunction:: nnvm.symbol.zeros_like
.. autofunction:: nnvm.symbol.flatten
.. autofunction:: nnvm.symbol.concatenate
.. autofunction:: nnvm.symbol.expand_dims
......
......@@ -13,6 +13,7 @@ from ._base import c_array, c_str, nn_uint, py_str, string_types
from ._base import GraphHandle, SymbolHandle
from ._base import check_call
from .symbol import Variable, Symbol, Group as _Group
from .symbol import ones_like
class GraphIndex(object):
"""Index for quickly accessing graph attributes.
......@@ -270,3 +271,38 @@ def create(symbol):
check_call(_LIB.NNGraphCreate(
symbol.handle, ctypes.byref(ghandle)))
return Graph(ghandle)
def gradients(ys, xs, grad_ys=None):
"""Create gradient symbol of ys respect to xs.
Parameters
----------
ys : Symbol or list of Symbol
Symbols from which the gradient is calculated.
xs : Symbol or list of Symbol
Symbols the gradient respect to.
For group symbol, gradients for all outputs will be calculated.
grad_ys : Symbol or list of Symbol
Head gradients for ys.
Returns
-------
ret : list of Symbol
Generated gradient symbol. For each xs,
all gradients from ys are merged into a single symbol.
"""
if isinstance(ys, list):
ys = _Group(ys)
g = create(ys)
g._set_symbol_list_attr('grad_ys', ys)
g._set_symbol_list_attr('grad_xs', xs)
ny = len(ys.list_output_names())
if grad_ys is None:
grad_ys = [ones_like(ys[i]) for i in range(ny)]
g._set_symbol_list_attr('grad_ys_out_grad', grad_ys)
sym = g.apply('Gradient').symbol
nx = len(_Group(xs).list_output_names()) \
if isinstance(xs, list) else len(xs.list_output_names())
ret = [sym[i] for i in range(nx)]
return ret
......@@ -20,11 +20,11 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
return std::move(v[0]);
} else if (v.size() == 0) {
NodePtr zero_node = Node::Create();
zero_node->attrs.op = Op::Get("__zero__");
zero_node->attrs.op = Op::Get("_zeros");
return NodeEntry{zero_node, 0, 0};
} else {
NodePtr sum_node = Node::Create();
sum_node->attrs.op = Op::Get("__ewise_sum__");
sum_node->attrs.op = Op::Get("elemwise_sum");
sum_node->inputs = std::move(v);
return NodeEntry{sum_node, 0, 0};
}
......
......@@ -197,7 +197,8 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_div)
// grad_1 = - grad_y * n0 / n1^2
NodeEntry sub0 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_0",
{ograds[0], n->inputs[0]});
NodeEntry sub1 = MakeNode("negative", n->attrs.name + "_grad_sub_1", {sub0});
NodeEntry sub1 = MakeNode("negative", n->attrs.name + "_grad_sub_1",
{sub0});
NodeEntry sub2 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_2",
{n->inputs[1], n->inputs[1]});
return std::vector<NodeEntry>{
......@@ -240,15 +241,27 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy)
DMLC_REGISTER_PARAMETER(InitOpParam);
// fill
NNVM_REGISTER_INIT_OP(fill)
// full
NNVM_REGISTER_INIT_OP(full)
.describe(R"code(Fill array with scalar value
)code" NNVM_ADD_FILELINE)
.set_support_level(1);
// fill_like
NNVM_REGISTER_ELEMWISE_UNARY_OP(fill_like)
NNVM_REGISTER_INIT_OP(zeros)
.describe(R"code(Fill target with zeros
)code" NNVM_ADD_FILELINE)
.set_support_level(1);
NNVM_REGISTER_INIT_OP(ones)
.describe(R"code(Fill target with ones
)code" NNVM_ADD_FILELINE)
.set_support_level(1);
// full_like
NNVM_REGISTER_ELEMWISE_UNARY_OP(full_like)
.describe(R"code(Return an scalar value array with the same shape and type
as the input array
......@@ -260,8 +273,38 @@ as the input array
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
MakeNode("fill_like", n->attrs.name + "_zero",
{n->inputs[0]}, {{"value", "0"}})
MakeNode("zeros_like", n->attrs.name + "_grad",
{n->inputs[0]})
};
});
NNVM_REGISTER_ELEMWISE_UNARY_OP(zeros_like)
.describe(R"code(Return an array of zeros with the same shape and type
as the input array.
)code")
.add_argument("data", "Symbol", "The input")
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
MakeNode("zeros_like", n->attrs.name + "_grad",
{n->inputs[0]})
};
});
NNVM_REGISTER_ELEMWISE_UNARY_OP(ones_like)
.describe(R"code(Return an array of ones with the same shape and type
as the input array.
)code")
.add_argument("data", "Symbol", "The input")
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
MakeNode("zeros_like", n->attrs.name + "_grad",
{n->inputs[0]})
};
});
......@@ -353,8 +396,10 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rdiv_scalar__)
// y = scalar / n0
// grad_0 = - grad_y * scalar / n0^2
NodeEntry sub0 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_0",
{ograds[0]}, {{"scalar", n->attrs.dict["scalar"]}});
NodeEntry sub1 = MakeNode("negative", n->attrs.name + "_grad_sub_1", {sub0});
{ograds[0]},
{{"scalar", n->attrs.dict["scalar"]}});
NodeEntry sub1 = MakeNode("negative", n->attrs.name + "_grad_sub_1",
{sub0});
NodeEntry sub2 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_2",
{n->inputs[0], n->inputs[0]});
return std::vector<NodeEntry>{
......@@ -407,6 +452,63 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rpow_scalar__)
});
struct ElementWiseSumParam : public dmlc::Parameter<ElementWiseSumParam> {
int num_args;
DMLC_DECLARE_PARAMETER(ElementWiseSumParam) {
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
.describe("Number of inputs to be summed.");
}
};
DMLC_REGISTER_PARAMETER(ElementWiseSumParam);
bool ElementWiseSumShape(const NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(out_attrs->size(), 1);
return ElemwiseAttr<TShape, shape_is_none, shape_assign, true, shape_string>(
attrs, in_attrs, out_attrs, TShape());
}
bool ElementWiseSumType(const NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(out_attrs->size(), 1);
return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
attrs, in_attrs, out_attrs, -1);
}
std::vector<NodeEntry> ElementWiseSumGrad(
const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
// identity constraints in the beginning for easier shape inference.
const Op* copy_op = Op::Get("identity");
CHECK_EQ(ograds.size(), 1);
std::vector<NodeEntry> ret;
NodeEntry n_out{n, 0, 0};
for (size_t i = 0; i < n->inputs.size(); i++) {
NodePtr id_node = Node::Create();
id_node->attrs.op = copy_op;
id_node->inputs = {ograds[0]};
ret.push_back(NodeEntry{id_node, 0, 0});
}
return ret;
}
NNVM_REGISTER_OP(elemwise_sum)
.describe(R"code(Adds all input arguments element-wise.
)code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ElementWiseSumParam>)
.set_num_inputs([](const NodeAttrs& attrs) {
uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
return ret;
})
.set_attr<nnvm::FInferShape>("FInferShape", ElementWiseSumShape)
.set_attr<nnvm::FInferType>("FInferType", ElementWiseSumType)
.set_attr<nnvm::FGradient>("FGradient", ElementWiseSumGrad)
.add_argument("args", "Symbol[]", "Positional input arguments");
} // namespace top
} // namespace nnvm
......@@ -112,6 +112,23 @@ def test_print_graph_ir():
assert("y_bias" in ir1)
assert("shape=" in ir2)
def test_gradient():
x = sym.Variable("x")
y = sym.Variable("y")
z1 = sym.elemwise_add(x, sym.sqrt(y))
z2 = sym.log(x)
gradient = graph.gradients([z1, z2], [x, y])
assert len(gradient) == 2
g1 = sym.Variable("g1")
g2 = sym.Variable("g2")
grad_ys = [g1, g2]
gradient = graph.gradients(sym.Group([z1, z2]),
sym.Group([x, y]), grad_ys=grad_ys)
g_graph = graph.create(sym.Group(gradient)).ir()
assert len(gradient) == 2
assert "g1" in g_graph
assert "g2" in g_graph
if __name__ == "__main__":
test_print_graph_ir()
......@@ -123,3 +140,4 @@ if __name__ == "__main__":
test_infer_type()
test_plan_memory()
test_list_args()
test_gradient()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment