Commit 0f4c151f by 雾雨魔理沙 Committed by Thierry Moreau

[Relay][Training] Add gradient for max. (#3915)

* save

* save
parent 83d2418a
...@@ -25,7 +25,7 @@ from ..expr import Tuple, TupleGetItem, const ...@@ -25,7 +25,7 @@ from ..expr import Tuple, TupleGetItem, const
from . import nn as _nn from . import nn as _nn
from .op import register_gradient from .op import register_gradient
from .reduce import sum as _sum from .reduce import sum as _sum
from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like, equal
from .transform import ( from .transform import (
broadcast_to_like, broadcast_to_like,
collapse_sum_like, collapse_sum_like,
...@@ -269,6 +269,18 @@ def conv2d_grad(orig, grad): ...@@ -269,6 +269,18 @@ def conv2d_grad(orig, grad):
return [backward_data, backward_weight] return [backward_data, backward_weight]
@register_gradient("max")
def max_grad(orig, grad):
"""Returns the gradient of max"""
# Only support axis=0, since broadcasting orig to x behaves incorrectly
x, axis = orig.args[0], orig.attrs.axis
assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0)
orig = broadcast_to_like(orig, x)
grad = broadcast_to_like(grad, x)
indicators = cast_like(equal(orig, x), grad)
return [indicators * grad]
@register_gradient("nn.softmax") @register_gradient("nn.softmax")
def softmax_grad(orig, grad): def softmax_grad(orig, grad):
"""Gradient of softmax""" """Gradient of softmax"""
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import pytest
from tvm import relay from tvm import relay
from tvm.relay.testing import check_grad from tvm.relay.testing import check_grad
...@@ -30,6 +31,16 @@ def test_sum_grad(): ...@@ -30,6 +31,16 @@ def test_sum_grad():
verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True) verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)
def test_max_grad():
s = (5, 10)
t = relay.TensorType(s)
x = relay.var("x", t)
axis = 0
z = relay.max(x, axis)
fwd_func = relay.Function([x], z)
check_grad(fwd_func, eps=1e-7, rtol=1)
if __name__ == "__main__": if __name__ == "__main__":
test_sum_grad() pytest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment