Commit 46fa6eeb by Altan Haan Committed by Wuwei Lin

[Relay][Training] Add and fix gradients (#4126)

* add and fix gradients

* fix linter issues
parent 1c0e7435
...@@ -48,6 +48,9 @@ from .transform import ( ...@@ -48,6 +48,9 @@ from .transform import (
tile, tile,
transpose, transpose,
where, where,
repeat,
expand_dims,
full_like
) )
...@@ -198,6 +201,7 @@ def clip_grad(orig, grad): ...@@ -198,6 +201,7 @@ def clip_grad(orig, grad):
@register_gradient("nn.max_pool2d") @register_gradient("nn.max_pool2d")
def max_pool2d_grad(orig, grad): def max_pool2d_grad(orig, grad):
"""Returns the gradient of max_pool2d."""
attrs = orig.attrs attrs = orig.attrs
pool_grad = _nn.max_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size, pool_grad = _nn.max_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
strides=attrs.strides, padding=attrs.padding, strides=attrs.strides, padding=attrs.padding,
...@@ -207,6 +211,7 @@ def max_pool2d_grad(orig, grad): ...@@ -207,6 +211,7 @@ def max_pool2d_grad(orig, grad):
@register_gradient("nn.avg_pool2d") @register_gradient("nn.avg_pool2d")
def avg_pool2d_grad(orig, grad): def avg_pool2d_grad(orig, grad):
"""Returns the gradient of avg_pool2d."""
attrs = orig.attrs attrs = orig.attrs
pool_grad = _nn.avg_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size, pool_grad = _nn.avg_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
strides=attrs.strides, padding=attrs.padding, strides=attrs.strides, padding=attrs.padding,
...@@ -215,6 +220,26 @@ def avg_pool2d_grad(orig, grad): ...@@ -215,6 +220,26 @@ def avg_pool2d_grad(orig, grad):
return [pool_grad] return [pool_grad]
@register_gradient("nn.global_avg_pool2d")
def global_avg_pool2d_grad(orig, grad):
"""Returns the gradient of global_avg_pool2d."""
data = orig.args[0]
shape = data.checked_type.shape
layout = orig.attrs.layout
# we assume NCHW or NHWC layout for now, but easy to add more
assert layout in ["NCHW", "NHWC"]
if layout == "NCHW":
pool_size = shape[2], shape[3]
elif layout == "NHWC":
pool_size = shape[1], shape[2]
pool_grad = _nn.avg_pool2d_grad(grad, data, pool_size=pool_size,
strides=(1, 1), padding=(0, 0),
layout=layout)
return [pool_grad]
# not implemented, this is only for testing. # not implemented, this is only for testing.
@register_gradient("concatenate") @register_gradient("concatenate")
def concatenate_grad(orig, grad): def concatenate_grad(orig, grad):
...@@ -287,16 +312,53 @@ def conv2d_grad(orig, grad): ...@@ -287,16 +312,53 @@ def conv2d_grad(orig, grad):
return [backward_data, backward_weight] return [backward_data, backward_weight]
def _get_reduce_axis(call):
"""Helper function that returns the reduce axis of the call as plain python ints."""
x, axis = call.args[0], call.attrs.axis
shape = x.checked_type.concrete_shape
# should never exclude when axis is None
assert not (axis is None and call.attrs.exclude)
if axis is None:
return None
# convert to nonnegative integers and sort
axis = sorted([ax if ax >= 0 else len(shape) + ax for ax in map(int, axis)])
if call.attrs.exclude:
axis = [ax for ax in range(len(shape)) if ax not in axis]
return axis
def _unreduce_expand(x, axis):
"""Helper function that returns x expanded on the reduced dimensions in axis."""
# assume axis is sorted nonnegative ints
for ax in axis:
x = expand_dims(x, ax)
return x
@register_gradient("max") @register_gradient("max")
def max_grad(orig, grad): def max_grad(orig, grad):
"""Returns the gradient of max""" """Returns the gradient of max"""
# Only support axis=0, since broadcasting orig to x behaves incorrectly x, axis = orig.args[0], _get_reduce_axis(orig)
x, axis = orig.args[0], orig.attrs.axis shape = x.checked_type.concrete_shape
assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0)
orig = broadcast_to_like(orig, x) repeated = orig
grad = broadcast_to_like(grad, x) if axis is None:
indicators = cast_like(equal(orig, x), grad) repeated = full_like(x, repeated)
return [indicators * grad] else:
# expand dims (if necessary) and repeat along each axis
if not orig.attrs.keepdims:
repeated = _unreduce_expand(repeated, axis)
grad = _unreduce_expand(grad, axis)
for ax in axis:
repeated = repeat(repeated, shape[ax], ax)
indicators = cast_like(equal(repeated, x), grad)
num_selected = _sum(indicators, axis, keepdims=True)
# spread error across all max weights
return [indicators * grad / num_selected]
@register_gradient("nn.softmax") @register_gradient("nn.softmax")
...@@ -372,7 +434,11 @@ def negative_grad(orig, grad): ...@@ -372,7 +434,11 @@ def negative_grad(orig, grad):
@register_gradient("sum") @register_gradient("sum")
def sum_grad(orig, grad): def sum_grad(orig, grad):
"""Returns grad broadcasted to data dims""" """Returns grad broadcasted to data dims"""
data = orig.args[0] data, axis = orig.args[0], _get_reduce_axis(orig)
if not orig.attrs.keepdims:
if axis is None:
axis = list(range(len(data.checked_type.concrete_shape)))
grad = _unreduce_expand(grad, axis)
return [broadcast_to_like(grad, data)] return [broadcast_to_like(grad, data)]
......
...@@ -48,8 +48,7 @@ def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode): ...@@ -48,8 +48,7 @@ def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):
def test_max_pool2d_grad(): def test_max_pool2d_grad():
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False)
ceil_mode=False)
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False) verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False)
...@@ -75,7 +74,6 @@ def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, coun ...@@ -75,7 +74,6 @@ def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, coun
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data) op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
def test_avg_pool2d_grad(): def test_avg_pool2d_grad():
verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
ceil_mode=False, count_include_pad=True) ceil_mode=False, count_include_pad=True)
...@@ -83,6 +81,30 @@ def test_avg_pool2d_grad(): ...@@ -83,6 +81,30 @@ def test_avg_pool2d_grad():
ceil_mode=False, count_include_pad=False) ceil_mode=False, count_include_pad=False)
def verify_global_avg_pool2d_grad(x_shape):
x = relay.var("x", relay.TensorType(x_shape, "float32"))
y = tvm.relay.nn.global_avg_pool2d(x)
fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))
data = np.random.rand(*x_shape).astype("float32")
y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
out_grad = np.ones(shape=y_shape)
ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=(x_shape[2], x_shape[3]),
strides=(1, 1), padding=[0, 0, 0, 0], pool_type='avg',
ceil_mode=False)
for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
def test_global_avg_pool2d_grad():
verify_global_avg_pool2d_grad((1, 4, 16, 16))
verify_global_avg_pool2d_grad((1, 8, 8, 24))
def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'): def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'):
try: try:
import torch import torch
...@@ -155,6 +177,7 @@ def test_batch_flatten_grad(): ...@@ -155,6 +177,7 @@ def test_batch_flatten_grad():
if __name__ == "__main__": if __name__ == "__main__":
test_max_pool2d_grad() test_max_pool2d_grad()
test_avg_pool2d_grad() test_avg_pool2d_grad()
test_global_avg_pool2d_grad()
test_conv2d_grad() test_conv2d_grad()
test_dense_grad() test_dense_grad()
test_batch_flatten_grad() test_batch_flatten_grad()
...@@ -29,18 +29,21 @@ def test_sum_grad(): ...@@ -29,18 +29,21 @@ def test_sum_grad():
verify_sum_grad((4, 2)) verify_sum_grad((4, 2))
verify_sum_grad((4, 2), axis=-1, keepdims=True) verify_sum_grad((4, 2), axis=-1, keepdims=True)
verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True) verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)
verify_sum_grad((4, 2, 1), axis=1)
def test_max_grad(): def verify_max_grad(d_shape, axis=None, keepdims=False, exclude=False):
s = (10, 10) data = relay.var("data", relay.TensorType(d_shape, "float32"))
t = relay.TensorType(s) fwd_func = relay.Function([data], relay.max(data, axis=axis, keepdims=keepdims, exclude=exclude))
x = relay.var("x", t)
axis = 0
z = relay.max(x, axis)
fwd_func = relay.Function([x], z)
check_grad(fwd_func, scale=1e-3) check_grad(fwd_func, scale=1e-3)
def test_max_grad():
verify_max_grad((10, 10), axis=None)
verify_max_grad((10, 10), axis=-1)
verify_max_grad((6, 3, 2), axis=(1, 2), keepdims=True)
verify_max_grad((5, 4, 3), axis=(0, 2), exclude=True)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main() pytest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment