Commit a9313787 by larrywyang Committed by Tianqi Chen

[WIP] [NNVM] Fix softmax gradient (#1201)

[NNVM] Fix softmax gradient
parent 61dad72e
...@@ -366,22 +366,24 @@ NNVM_REGISTER_OP(softmax) ...@@ -366,22 +366,24 @@ NNVM_REGISTER_OP(softmax)
// [ ... ,-ynyn + yn] // [ ... ,-ynyn + yn]
// //
// grad_x = // grad_x =
// [-y1*(ograd1*y1 - 1 + ograd2*y2 + ..., -y2*(ograd1*y1 - 1 + ograd2*y2, ..., ...]] // [-y1*(ograd1*y1 - ograd1 + ograd2*y2 + ...),
// -y2*(ograd1*y1 - ograd2 + ograd2*y2 + ...),
// ...
// -yn*(ograd1*y1 - ogradn + ograd2*y2 + ...)]
// grad_x = ograd elemwise_mul output // grad_x = ograd elemwise_mul output
// grad_x = sum(grad_x, keepdim, axis) // grad_x = sum(grad_x, keepdim, axis)
// grad_x = grad_x broadcast_mul output // grad_x = grad_x broadcast_mul output
// grad_x = neg grad_x // grad_x = neg grad_x
// grad_x = grad_x + output // grad_x = grad_x + ograd elemwise_mul output
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(n->attrs.parsed); const SoftmaxParam& param = nnvm::get<SoftmaxParam>(n->attrs.parsed);
NodeEntry output = NodeEntry{n, 0, 0}; NodeEntry output = NodeEntry{n, 0, 0};
NodeEntry sub0 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub0", {ograds[0], output}); NodeEntry sub0 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub0", {ograds[0], output});
NodeEntry sub1 = MakeNode("sum", n->attrs.name + "_grad_sub1", {sub0}, NodeEntry sub1 = MakeNode("sum", n->attrs.name + "_grad_sub1", {sub0},
{{"axis", std::to_string(param.axis)}, {"keepdims", "true"}}); {{"axis", std::to_string(param.axis)}, {"keepdims", "true"}});
NodeEntry sub2 = MakeNode("broadcast_mul", n->attrs.name + "_grad_sub2", {sub1, output}); NodeEntry sub2 = MakeNode("broadcast_mul", n->attrs.name + "_grad_sub2", {sub1, output});
NodeEntry sub3 = MakeNode("negative", n->attrs.name + "_grad_sub3", {sub2});
return std::vector<NodeEntry> { return std::vector<NodeEntry> {
MakeNode("elemwise_add", n->attrs.name + "_grad", {sub3, output}) MakeNode("elemwise_sub", n->attrs.name + "_grad", {sub0, sub2})
}; };
}); });
...@@ -414,31 +416,33 @@ NNVM_REGISTER_OP(log_softmax) ...@@ -414,31 +416,33 @@ NNVM_REGISTER_OP(log_softmax)
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
// grad_x = grad_y dot jacobian of softmax // grad_x = grad_y dot jacobian of logsoftmax
// //
// jacobian of softmax // jacobian of logsoftmax
// [-y1 + 1, -y2, ... ] // [-y1 + 1, -y2, ... ]
// [ ... , -y2 + 1, ... ] // [ ... , -y2 + 1, ... ]
// [ ... ... ] // [ ... ... ]
// [ ... ,-yn + 1] // [ ... ,-yn + 1]
// //
// grad_x = // grad_x =
// [-(ograd1*y1 - 1 + ograd2*y2 + ..., -(ograd1*y1 - 1 + ograd2*y2, ..., ...]] // [ograd1 - exp(y1)*(ograd1 + ... + ogradn),
// ograd2 - exp(y2)*(ograd1 + ... + ogradn),
// grad_x = ograd elemwise_mul output // ...
// grad_x = sum(grad_x, keepdim, axis) // ogradn - exp(yn)*(ograd1 + ... + ogradn)]
// grad_x = sum(ograd, keepdim, axis)
// sigma = exp(output)
// grad_x = grad_x elemwise_mul sigma
// grad_x = neg grad_x // grad_x = neg grad_x
// grad_x = grad_x + ones_like(grad_x) // grad_x = grad_x + ograd
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(n->attrs.parsed); const SoftmaxParam& param = nnvm::get<SoftmaxParam>(n->attrs.parsed);
NodeEntry output = NodeEntry{n, 0, 0}; NodeEntry output = NodeEntry{n, 0, 0};
NodeEntry sub0 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub0", {ograds[0], output}); NodeEntry sub0 = MakeNode("sum", n->attrs.name + "_grad_sub0", {ograds[0]},
NodeEntry sub1 = MakeNode("sum", n->attrs.name + "_grad_sub1", {sub0},
{{"axis", std::to_string(param.axis)}, {"keepdims", "true"}}); {{"axis", std::to_string(param.axis)}, {"keepdims", "true"}});
NodeEntry sub2 = MakeNode("full_like", n->attrs.name + "_grad_sub2", {n->inputs[0]}, NodeEntry sub1 = MakeNode("exp", n->attrs.name + "_grad_sub1", {output});
{{"fill_value", "-1"}}); NodeEntry sub2 = MakeNode("broadcast_mul", n->attrs.name + "_grad_sub2", {sub0, sub1});
NodeEntry sub3 = MakeNode("broadcast_mul", n->attrs.name + "_grad_sub3", {sub1, sub2});
return std::vector<NodeEntry> { return std::vector<NodeEntry> {
MakeNode("elemwise_add", n->attrs.name + "_grad", {sub3, ograds[0]}) MakeNode("elemwise_sub", n->attrs.name + "_grad", {ograds[0], sub2})
}; };
}) })
.set_support_level(1); .set_support_level(1);
......
...@@ -217,7 +217,7 @@ def test_softmax(): ...@@ -217,7 +217,7 @@ def test_softmax():
dtype = "float32" dtype = "float32"
dshape = (10, 1000) dshape = (10, 1000)
inputs = [('x', dshape, x)] inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward), backward helper(y, inputs, dtype, forward, backward)
def test_log_softmax(): def test_log_softmax():
...@@ -229,7 +229,7 @@ def test_log_softmax(): ...@@ -229,7 +229,7 @@ def test_log_softmax():
def backward(head_grads, x): def backward(head_grads, x):
y = topi.testing.log_softmax_python(x) y = topi.testing.log_softmax_python(x)
grad = head_grads - np.sum(y * head_grads, axis=1, keepdims=True) grad = head_grads - np.exp(y) * np.sum(head_grads, axis=1, keepdims=True)
return [grad] return [grad]
dtype = "float32" dtype = "float32"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment