Commit 8a9e65c7 by 雾雨魔理沙 Committed by Wuwei Lin

Add gradient for log-softmax (#4069)

parent 776fd6bd
......@@ -305,6 +305,15 @@ def softmax_grad(orig, grad):
return [(grad - _sum(grad * orig, orig.attrs.axis, True)) * orig]
@register_gradient("nn.log_softmax")
def log_softmax_grad(orig, grad):
"""Gradient of log_softmax"""
x = orig.args[0]
sm = _nn.softmax(x, axis=orig.attrs.axis)
grad = grad / sm
return softmax_grad(sm, grad)
@register_gradient("nn.bias_add")
def bias_add_grad(orig, grad):
"""Returns gradient of bias_add"""
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest
import tvm
from tvm import relay
......@@ -100,7 +101,13 @@ def test_binary_op():
def test_softmax_grad():
data = relay.var("data", relay.TensorType((1, 16), "float64"))
fwd_func = relay.Function([data], relay.nn.softmax(data))
check_grad(fwd_func)
check_grad(fwd_func, scale=1)
def test_log_softmax_grad():
data = relay.var("data", relay.TensorType((2, 16), "float64"))
fwd_func = relay.Function([data], relay.nn.log_softmax(data))
check_grad(fwd_func, scale=1)
def test_bias_add_grad():
......@@ -111,6 +118,4 @@ def test_bias_add_grad():
if __name__ == "__main__":
test_unary_op()
test_binary_op()
test_bias_add_grad()
pytest.main([__file__])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment