Commit 5ed251a6 by SWu Committed by Jared Roesch

[Relay] Add grads (#3857)

* Add gradient implementations

* Add docstrings to fix lint errors
parent 360d26dd
......@@ -17,16 +17,25 @@
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from topi.util import get_const_tuple
from topi.nn.util import get_pad_tuple
from ..expr import const, Tuple, TupleGetItem
from topi.util import get_const_tuple
from ..expr import Tuple, TupleGetItem, const
from . import nn as _nn
from .op import register_gradient
from .reduce import sum as _sum
from .transform import collapse_sum_like, broadcast_to_like, where, transpose, reshape, tile, \
strided_slice
from .tensor import exp, negative, power, less, cos, sin
from .tensor import zeros_like, ones_like
from . import nn as _nn
from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like
from .transform import (
broadcast_to_like,
collapse_sum_like,
reshape,
reshape_like,
strided_slice,
tile,
transpose,
where,
)
@register_gradient("log")
......@@ -250,3 +259,59 @@ def conv2d_grad(orig, grad):
end=[None, None, filter_h, filter_w])
return [backward_data, backward_weight]
@register_gradient("nn.softmax")
def softmax_grad(orig, grad):
"""Gradient of softmax"""
return [(grad - _sum(grad * orig, orig.attrs.axis, True)) * orig]
@register_gradient("nn.bias_add")
def bias_grad(orig, grad):
"""Returns grad"""
data, bias = orig.args
return [collapse_sum_like(grad, data),
collapse_sum_like(grad, bias)]
@register_gradient("nn.dense")
def dense_grad(orig, grad):
"""Returns [grad' @ weight, data @ grad']"""
data, weight = orig.args
return [collapse_sum_like(transpose(grad) * weight, data),
collapse_sum_like(data * transpose(grad), weight)]
@register_gradient("nn.batch_flatten")
def batch_flatten_grad(orig, grad):
"""Returns grad reshaped to data dims"""
data = orig.args[0]
return [reshape_like(grad, data)]
@register_gradient("transpose")
def transpose_grad(orig, grad):
"""Returns grad transposed over the complement of original transpose axes"""
orig_axes = orig.attrs.axes
if orig_axes:
dims = len(orig_axes)
new_axes = [0] * dims
for i in range(dims):
new_axes[int(orig_axes[i])] = i
else:
new_axes = None
return [transpose(grad, axes=new_axes)]
@register_gradient("negative")
def negative_grad(orig, grad):
"""Returns -grad"""
return [-grad]
@register_gradient("sum")
def sum_grad(orig, grad):
"""Returns grad broadcasted to data dims"""
data = orig.args[0]
return [broadcast_to_like(grad, data)]
......@@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from tvm import relay
from tvm.relay.testing import check_grad, ctx_list, run_infer_type
from tvm.relay.transform import gradient
from tvm.relay.testing import ctx_list, run_infer_type
def sigmoid(x):
one = np.ones_like(x)
......@@ -30,6 +32,7 @@ def relu(x):
np.maximum(x_copy, 0, x_copy)
return x_copy
def test_unary_op():
def check_single_op(opfunc, ref):
shape = (10, 4)
......@@ -93,6 +96,20 @@ def test_binary_op():
check_binary_op(opfunc, ref)
def test_softmax_grad():
data = relay.var("data", relay.TensorType((1, 16), "float64"))
fwd_func = relay.Function([data], relay.nn.softmax(data))
check_grad(fwd_func)
def test_bias_add_grad():
data = relay.var("data", relay.TensorType((1, 16), "float32"))
bias = relay.var("bias", relay.TensorType((16,), "float32"))
fwd_func = relay.Function([data, bias], relay.nn.bias_add(data, bias))
check_grad(fwd_func)
if __name__ == "__main__":
test_unary_op()
test_binary_op()
test_bias_add_grad()
......@@ -15,13 +15,13 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
import topi
import topi.testing
import tvm
from tvm import relay
from tvm.relay.testing import check_grad, ctx_list, run_infer_type
from tvm.relay.transform import gradient
from tvm.relay.testing import ctx_list, check_grad
from tvm.relay.testing import run_infer_type
def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):
......@@ -129,7 +129,32 @@ def test_conv2d_grad():
verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1], mode='first_order')
def verify_dense_grad(d_shape, w_shape):
data = relay.var("data", relay.TensorType(d_shape, "float32"))
weight = relay.var("weight", relay.TensorType(w_shape, "float32"))
fwd_func = relay.Function([data, weight], relay.nn.dense(data, weight))
check_grad(fwd_func)
def test_dense_grad():
verify_dense_grad((1, 8), (16, 8))
verify_dense_grad((1, 4), (3, 4))
def verify_batch_flatten_grad(d_shape):
data = relay.var("data", relay.TensorType(d_shape, "float32"))
fwd_func = relay.Function([data], relay.nn.batch_flatten(data))
check_grad(fwd_func)
def test_batch_flatten_grad():
verify_batch_flatten_grad((1, 2, 3, 4))
verify_batch_flatten_grad((1, 8))
if __name__ == "__main__":
test_max_pool2d_grad()
test_avg_pool2d_grad()
test_conv2d_grad()
test_dense_grad()
test_batch_flatten_grad()
......@@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from tvm import relay
from tvm.relay.testing import check_grad, ctx_list, run_infer_type
from tvm.relay.transform import gradient
from tvm.relay.testing import ctx_list, run_infer_type
def test_clip():
ref = (lambda x: np.where(x > 10.0, np.zeros_like(x),
......@@ -38,5 +40,24 @@ def test_clip():
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
def verify_transpose_grad(d_shape, axes=None):
data = relay.var("data", relay.TensorType(d_shape, "float32"))
fwd_func = relay.Function([data], relay.transpose(data, axes=axes))
check_grad(fwd_func)
def test_transpose_grad():
verify_transpose_grad((1, 2, 3, 4))
verify_transpose_grad((1, 2, 3, 4), axes=(0, 2, 3, 1))
def test_negative_grad():
data = relay.var("data", relay.TensorType((10, 4), "float32"))
fwd_func = relay.Function([data], relay.negative(data))
check_grad(fwd_func)
if __name__ == "__main__":
test_clip()
test_transpose_grad()
test_negative_grad()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm import relay
from tvm.relay.testing import check_grad
def verify_sum_grad(d_shape, axis=None, keepdims=False, exclude=False):
data = relay.var("data", relay.TensorType(d_shape, "float32"))
fwd_func = relay.Function([data], relay.sum(data, axis=axis, keepdims=keepdims, exclude=exclude))
check_grad(fwd_func)
def test_sum_grad():
verify_sum_grad((4, 2))
verify_sum_grad((4, 2), axis=-1, keepdims=True)
verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)
if __name__ == "__main__":
test_sum_grad()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment