Commit 8a9e65c7 by 雾雨魔理沙 Committed by Wuwei Lin

Add gradient for log-softmax (#4069)

parent 776fd6bd
...@@ -305,6 +305,15 @@ def softmax_grad(orig, grad): ...@@ -305,6 +305,15 @@ def softmax_grad(orig, grad):
return [(grad - _sum(grad * orig, orig.attrs.axis, True)) * orig] return [(grad - _sum(grad * orig, orig.attrs.axis, True)) * orig]
@register_gradient("nn.log_softmax")
def log_softmax_grad(orig, grad):
"""Gradient of log_softmax"""
x = orig.args[0]
sm = _nn.softmax(x, axis=orig.attrs.axis)
grad = grad / sm
return softmax_grad(sm, grad)
@register_gradient("nn.bias_add") @register_gradient("nn.bias_add")
def bias_add_grad(orig, grad): def bias_add_grad(orig, grad):
"""Returns gradient of bias_add""" """Returns gradient of bias_add"""
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import numpy as np import numpy as np
import pytest
import tvm import tvm
from tvm import relay from tvm import relay
...@@ -100,7 +101,13 @@ def test_binary_op(): ...@@ -100,7 +101,13 @@ def test_binary_op():
def test_softmax_grad(): def test_softmax_grad():
data = relay.var("data", relay.TensorType((1, 16), "float64")) data = relay.var("data", relay.TensorType((1, 16), "float64"))
fwd_func = relay.Function([data], relay.nn.softmax(data)) fwd_func = relay.Function([data], relay.nn.softmax(data))
check_grad(fwd_func) check_grad(fwd_func, scale=1)
def test_log_softmax_grad():
data = relay.var("data", relay.TensorType((2, 16), "float64"))
fwd_func = relay.Function([data], relay.nn.log_softmax(data))
check_grad(fwd_func, scale=1)
def test_bias_add_grad(): def test_bias_add_grad():
...@@ -111,6 +118,4 @@ def test_bias_add_grad(): ...@@ -111,6 +118,4 @@ def test_bias_add_grad():
if __name__ == "__main__": if __name__ == "__main__":
test_unary_op() pytest.main([__file__])
test_binary_op()
test_bias_add_grad()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment