Commit 69312744 by sgrechanik-h Committed by Tianqi Chen

[NNVM] Fix grads for sum and expand_like (#1455)

parent ddd249f2
...@@ -15,6 +15,10 @@ reg.register_schedule("expand_dims", _fschedule_broadcast) ...@@ -15,6 +15,10 @@ reg.register_schedule("expand_dims", _fschedule_broadcast)
@reg.register_compute("expand_like") @reg.register_compute("expand_like")
def compute_expand_like(attrs, inputs, _): def compute_expand_like(attrs, inputs, _):
"""Compute definition of expand_like""" """Compute definition of expand_like"""
if len(inputs[0].shape) == len(inputs[1].shape):
# If the number of dimensions is not changed then it is just a broadcasting
return topi.broadcast_to(inputs[0], inputs[1].shape)
exclude = attrs.get_bool("exclude") exclude = attrs.get_bool("exclude")
axis = attrs.get_int_tuple("axis") axis = attrs.get_int_tuple("axis")
if exclude: if exclude:
......
...@@ -170,12 +170,18 @@ Example:: ...@@ -170,12 +170,18 @@ Example::
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
const ReduceParam& param = nnvm::get<ReduceParam>(n->attrs.parsed); const ReduceParam& param = nnvm::get<ReduceParam>(n->attrs.parsed);
std::ostringstream axis; axis << param.axis; bool exclude = param.exclude;
TShape p_axis = param.axis;
if (!param.exclude && param.axis.ndim() == 0) {
exclude = true;
p_axis = TShape();
}
std::ostringstream axis; axis << p_axis;
return std::vector<NodeEntry>{ return std::vector<NodeEntry>{
MakeNode("expand_like", n->attrs.name + "_grad", MakeNode("expand_like", n->attrs.name + "_grad",
{ograds[0], n->inputs[0]}, {ograds[0], n->inputs[0]},
{{"axis", axis.str()}, {{"axis", axis.str()},
{"exclude", std::to_string(param.exclude)}}) {"exclude", std::to_string(exclude)}})
}; };
}); });
......
...@@ -251,7 +251,8 @@ will return a new array with shape ``(2,1,1,1,1,1,3,4)``. ...@@ -251,7 +251,8 @@ will return a new array with shape ``(2,1,1,1,1,1,3,4)``.
NNVM_REGISTER_OP(expand_like) NNVM_REGISTER_OP(expand_like)
.describe(R"code(Expand an input array with the shape of second array. .describe(R"code(Expand an input array with the shape of second array.
This operation can always be composed of unsqueezing and expanding dims. This operation can be thought of as a composition of expand_dims and broadcast_to.
If the dimensions are already expanded then it just broadcasts.
Examples:: Examples::
input = [ 12. 19. 27.] input = [ 12. 19. 27.]
input.shape = (3,) input.shape = (3,)
...@@ -282,11 +283,23 @@ Examples:: ...@@ -282,11 +283,23 @@ Examples::
std::ostringstream axis; std::ostringstream axis;
axis << param.axis; axis << param.axis;
return std::vector<NodeEntry>{ if (param.axis.ndim() == 0 && !param.exclude) {
MakeNode("sum", n->attrs.name + "_grad", // Special case needed because sum interprets axis=[] differently
return std::vector<NodeEntry>{
ograds[0],
MakeNode("zeros_like", n->attrs.name + "_zero_grad", {n->inputs[1]})
};
}
auto sum_node =
MakeNode("sum", n->attrs.name + "_sum_grad",
{ograds[0]}, {ograds[0]},
{{"axis", axis.str()}, {{"axis", axis.str()},
{"exclude", std::to_string(param.exclude)}}), {"exclude", std::to_string(param.exclude)}});
return std::vector<NodeEntry>{
MakeNode("reshape_like", n->attrs.name + "_grad",
{sum_node, n->inputs[0]}),
MakeNode("zeros_like", n->attrs.name + "_zero_grad", {n->inputs[1]}) MakeNode("zeros_like", n->attrs.name + "_zero_grad", {n->inputs[1]})
}; };
}) })
......
...@@ -378,6 +378,13 @@ def verify_expand_like(in_shape, out_shape, axis, exclude): ...@@ -378,6 +378,13 @@ def verify_expand_like(in_shape, out_shape, axis, exclude):
def forward(x, y): def forward(x, y):
odim = len(out_shape) odim = len(out_shape)
if len(x.shape) == len(y.shape):
return np.broadcast_to(x, y.shape)
if x.shape == (1,) and len(y.shape) == odim:
x = np.reshape(x, ())
real_axis = [i if i >= 0 else i + odim for i in axis] real_axis = [i if i >= 0 else i + odim for i in axis]
real_axis = sorted(real_axis) real_axis = sorted(real_axis)
if exclude: if exclude:
...@@ -391,11 +398,17 @@ def verify_expand_like(in_shape, out_shape, axis, exclude): ...@@ -391,11 +398,17 @@ def verify_expand_like(in_shape, out_shape, axis, exclude):
def backward(head_grads, x, y): def backward(head_grads, x, y):
odim = len(out_shape) odim = len(out_shape)
keepdims = len(x.shape) == len(y.shape)
if x.shape == (1,) and len(y.shape) == odim:
x = np.reshape(x, ())
real_axis = [i if i >= 0 else i + odim for i in axis] real_axis = [i if i >= 0 else i + odim for i in axis]
real_axis = sorted(real_axis) real_axis = sorted(real_axis)
if exclude: if exclude:
real_axis = list(set(range(odim)) - set(real_axis)) real_axis = list(set(range(odim)) - set(real_axis))
return [np.sum(head_grads, axis=tuple(real_axis)), return [np.sum(head_grads, axis=tuple(real_axis), keepdims=keepdims),
np.zeros_like(y)] np.zeros_like(y)]
...@@ -410,6 +423,11 @@ def test_expand_like(): ...@@ -410,6 +423,11 @@ def test_expand_like():
verify_expand_like((2,), (2, 3), [1], False) verify_expand_like((2,), (2, 3), [1], False)
verify_expand_like((3, 4), (3, 5, 4), [1], False) verify_expand_like((3, 4), (3, 5, 4), [1], False)
verify_expand_like((5, 7), (5, 6, 7, 8), [0, 2], True) verify_expand_like((5, 7), (5, 6, 7, 8), [0, 2], True)
verify_expand_like((2, 3), (2, 3), [], False)
verify_expand_like((1,), (2, 3), [0, 1], False)
verify_expand_like((1, 1), (2, 3), [0, 1], False)
verify_expand_like((2, 1), (2, 3), [1], False)
verify_expand_like((1, 3), (2, 3), [0], False)
def verify_elemwise_sum(num_args): def verify_elemwise_sum(num_args):
......
...@@ -65,15 +65,15 @@ def expand_like(a, shape_like, axis): ...@@ -65,15 +65,15 @@ def expand_like(a, shape_like, axis):
""" """
odim = len(axis) + len(a.shape) odim = len(axis) + len(a.shape)
if odim != len(shape_like.shape): if odim != len(shape_like.shape):
if len(a.shape) == 1 and len(axis) == len(shape_like.shape):
# A special case: `a` is a scalar represented as a 1-dim tensor
return tvm.compute(shape_like.shape, lambda *idxs: a(0))
raise ValueError("shape inconsistent when expand_like ({}, {}, {})".format( raise ValueError("shape inconsistent when expand_like ({}, {}, {})".format(
len(axis), len(a.shape), len(shape_like.shape))) len(axis), len(a.shape), len(shape_like.shape)))
real_axis = topi.reduction._get_real_axis(len(shape_like.shape), axis) real_axis = topi.reduction._get_real_axis(len(shape_like.shape), axis)
real_axis = sorted(real_axis) real_axis = sorted(real_axis)
if not real_axis:
return a
def _compute(*idxs): def _compute(*idxs):
indices = [] indices = []
axis_index = 0 axis_index = 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment