Commit 3597dfb6 by Sergei Grechanik Committed by Tianqi Chen

[NNVM] Fix gradients for broadcast_div (#1512)

parent a195f8d0
......@@ -337,8 +337,7 @@ Example::
});
NodeEntry dy = MakeNode("broadcast_div", n->attrs.name + "_drhs_div", {
NodeEntry{n, 0, 0},
MakeNode("__mul_scalar__", n->attrs.name + "_rhs_by_two",
{n->inputs[1]}, {{"scalar", "2"}})
MakeNode("negative", n->attrs.name + "_rhs_neg", {n->inputs[1]})
});
NodeEntry drhs = MakeNode("collapse_sum", n->attrs.name + "_drhs_sum", {
MakeNode("broadcast_mul", n->attrs.name + "_drhs_mul", { dy, ograds[0] }),
......
......@@ -268,7 +268,7 @@ def test_broadcast():
y = sym.broadcast_div(a, b)
def _backward_div(head_grads, a, b):
da = head_grads / b
db = _collapse(head_grads * a / (2 * b**2))
db = _collapse(- head_grads * a / b**2)
return da, db
helper(y, inputs, dtype, lambda a, b: a / b, _backward_div)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment