Commit 69312744 by sgrechanik-h Committed by Tianqi Chen

[NNVM] Fix grads for sum and expand_like (#1455)

parent ddd249f2
......@@ -15,6 +15,10 @@ reg.register_schedule("expand_dims", _fschedule_broadcast)
@reg.register_compute("expand_like")
def compute_expand_like(attrs, inputs, _):
"""Compute definition of expand_like"""
if len(inputs[0].shape) == len(inputs[1].shape):
# If the number of dimensions is not changed then it is just a broadcasting
return topi.broadcast_to(inputs[0], inputs[1].shape)
exclude = attrs.get_bool("exclude")
axis = attrs.get_int_tuple("axis")
if exclude:
......
......@@ -170,12 +170,18 @@ Example::
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
const ReduceParam& param = nnvm::get<ReduceParam>(n->attrs.parsed);
std::ostringstream axis; axis << param.axis;
bool exclude = param.exclude;
TShape p_axis = param.axis;
if (!param.exclude && param.axis.ndim() == 0) {
exclude = true;
p_axis = TShape();
}
std::ostringstream axis; axis << p_axis;
return std::vector<NodeEntry>{
MakeNode("expand_like", n->attrs.name + "_grad",
{ograds[0], n->inputs[0]},
{{"axis", axis.str()},
{"exclude", std::to_string(param.exclude)}})
{"exclude", std::to_string(exclude)}})
};
});
......
......@@ -251,7 +251,8 @@ will return a new array with shape ``(2,1,1,1,1,1,3,4)``.
NNVM_REGISTER_OP(expand_like)
.describe(R"code(Expand an input array with the shape of second array.
This operation can always be composed of unsqueezing and expanding dims.
This operation can be thought of as a composition of expand_dims and broadcast_to.
If the dimensions are already expanded then it just broadcasts.
Examples::
input = [ 12. 19. 27.]
input.shape = (3,)
......@@ -282,11 +283,23 @@ Examples::
std::ostringstream axis;
axis << param.axis;
return std::vector<NodeEntry>{
MakeNode("sum", n->attrs.name + "_grad",
if (param.axis.ndim() == 0 && !param.exclude) {
// Special case needed because sum interprets axis=[] differently
return std::vector<NodeEntry>{
ograds[0],
MakeNode("zeros_like", n->attrs.name + "_zero_grad", {n->inputs[1]})
};
}
auto sum_node =
MakeNode("sum", n->attrs.name + "_sum_grad",
{ograds[0]},
{{"axis", axis.str()},
{"exclude", std::to_string(param.exclude)}}),
{"exclude", std::to_string(param.exclude)}});
return std::vector<NodeEntry>{
MakeNode("reshape_like", n->attrs.name + "_grad",
{sum_node, n->inputs[0]}),
MakeNode("zeros_like", n->attrs.name + "_zero_grad", {n->inputs[1]})
};
})
......
......@@ -378,6 +378,13 @@ def verify_expand_like(in_shape, out_shape, axis, exclude):
def forward(x, y):
odim = len(out_shape)
if len(x.shape) == len(y.shape):
return np.broadcast_to(x, y.shape)
if x.shape == (1,) and len(y.shape) == odim:
x = np.reshape(x, ())
real_axis = [i if i >= 0 else i + odim for i in axis]
real_axis = sorted(real_axis)
if exclude:
......@@ -391,11 +398,17 @@ def verify_expand_like(in_shape, out_shape, axis, exclude):
def backward(head_grads, x, y):
odim = len(out_shape)
keepdims = len(x.shape) == len(y.shape)
if x.shape == (1,) and len(y.shape) == odim:
x = np.reshape(x, ())
real_axis = [i if i >= 0 else i + odim for i in axis]
real_axis = sorted(real_axis)
if exclude:
real_axis = list(set(range(odim)) - set(real_axis))
return [np.sum(head_grads, axis=tuple(real_axis)),
return [np.sum(head_grads, axis=tuple(real_axis), keepdims=keepdims),
np.zeros_like(y)]
......@@ -410,6 +423,11 @@ def test_expand_like():
verify_expand_like((2,), (2, 3), [1], False)
verify_expand_like((3, 4), (3, 5, 4), [1], False)
verify_expand_like((5, 7), (5, 6, 7, 8), [0, 2], True)
verify_expand_like((2, 3), (2, 3), [], False)
verify_expand_like((1,), (2, 3), [0, 1], False)
verify_expand_like((1, 1), (2, 3), [0, 1], False)
verify_expand_like((2, 1), (2, 3), [1], False)
verify_expand_like((1, 3), (2, 3), [0], False)
def verify_elemwise_sum(num_args):
......
......@@ -65,15 +65,15 @@ def expand_like(a, shape_like, axis):
"""
odim = len(axis) + len(a.shape)
if odim != len(shape_like.shape):
if len(a.shape) == 1 and len(axis) == len(shape_like.shape):
# A special case: `a` is a scalar represented as a 1-dim tensor
return tvm.compute(shape_like.shape, lambda *idxs: a(0))
raise ValueError("shape inconsistent when expand_like ({}, {}, {})".format(
len(axis), len(a.shape), len(shape_like.shape)))
real_axis = topi.reduction._get_real_axis(len(shape_like.shape), axis)
real_axis = sorted(real_axis)
if not real_axis:
return a
def _compute(*idxs):
indices = []
axis_index = 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment