Commit be9275c9 by Amy Wang Committed by Wuwei Lin

[Relay] Register abs gradient: grad * (select(x < 0, -1, 1)) (#3447)

parent 6577774d
......@@ -110,3 +110,11 @@ def collapse_sum_like_grad(orig, grad):
"""Returns [broadcast_to_like(grad, x), 0]"""
x, y = orig.args
return [broadcast_to_like(grad, x), zeros_like(y)]
@register_gradient("abs")
def abs_grad(orig, grad):
"""Returns grad * (select(x < 0, -1, 1))."""
x = orig.args[0]
zeros = zeros_like(x)
ones = ones_like(x)
return [where(less(x, zeros), -ones * grad, ones * grad)]
......@@ -53,6 +53,7 @@ def test_unary_op():
(tvm.relay.sigmoid, lambda x: sigmoid(x) * (1 - sigmoid(x))),
(tvm.relay.tanh, lambda x: 1 - np.tanh(x) * np.tanh(x)),
(tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)),
(tvm.relay.abs, lambda x: np.where(x < 0, -np.ones_like(x), np.ones_like(x))),
(relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x)))]:
check_single_op(opfunc, ref)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment