Commit 0f4c151f by 雾雨魔理沙 Committed by Thierry Moreau

[Relay][Training] Add gradient for max. (#3915)

* save

* save
parent 83d2418a
......@@ -25,7 +25,7 @@ from ..expr import Tuple, TupleGetItem, const
from . import nn as _nn
from .op import register_gradient
from .reduce import sum as _sum
from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like
from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like, equal
from .transform import (
broadcast_to_like,
collapse_sum_like,
......@@ -269,6 +269,18 @@ def conv2d_grad(orig, grad):
return [backward_data, backward_weight]
@register_gradient("max")
def max_grad(orig, grad):
"""Returns the gradient of max"""
# Only support axis=0, since broadcasting orig to x behaves incorrectly
x, axis = orig.args[0], orig.attrs.axis
assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0)
orig = broadcast_to_like(orig, x)
grad = broadcast_to_like(grad, x)
indicators = cast_like(equal(orig, x), grad)
return [indicators * grad]
@register_gradient("nn.softmax")
def softmax_grad(orig, grad):
"""Gradient of softmax"""
......
......@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from tvm import relay
from tvm.relay.testing import check_grad
......@@ -30,6 +31,16 @@ def test_sum_grad():
verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)
def test_max_grad():
s = (5, 10)
t = relay.TensorType(s)
x = relay.var("x", t)
axis = 0
z = relay.max(x, axis)
fwd_func = relay.Function([x], z)
check_grad(fwd_func, eps=1e-7, rtol=1)
if __name__ == "__main__":
test_sum_grad()
pytest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment