Commit 0555a03f by ziheng Committed by Tianqi Chen

[RELAY/OP] Gradient of relay level1 ops (#2633)

parent 76812dea
...@@ -51,6 +51,9 @@ class Expr(RelayNode): ...@@ -51,6 +51,9 @@ class Expr(RelayNode):
""" """
return _make.cast(self, dtype) return _make.cast(self, dtype)
def __neg__(self):
return _op_make.negative(self)
def __add__(self, other): def __add__(self, other):
if isinstance(other, Expr): if isinstance(other, Expr):
return _op_make.add(self, other) return _op_make.add(self, other)
......
...@@ -18,6 +18,7 @@ from . import op_attrs ...@@ -18,6 +18,7 @@ from . import op_attrs
# operator registry # operator registry
from . import _tensor from . import _tensor
from . import _tensor_grad
from . import _transform from . import _transform
from . import _reduce from . import _reduce
from ..expr import Expr from ..expr import Expr
......
...@@ -3,25 +3,7 @@ ...@@ -3,25 +3,7 @@
from __future__ import absolute_import from __future__ import absolute_import
import topi import topi
from .op import register_compute, register_schedule, register_pattern from .op import register_compute, register_schedule, register_pattern
from .op import register_gradient
from .op import schedule_injective, OpPattern from .op import schedule_injective, OpPattern
from .transform import collapse_sum_like
from .tensor import negative
def add_grad(orig, grad):
return [collapse_sum_like(grad, orig.args[0]), collapse_sum_like(grad, orig.args[1])]
register_gradient("add", add_grad)
def subtract_grad(orig, grad):
return [collapse_sum_like(grad, orig.args[0]),
collapse_sum_like(negative(grad), orig.args[1])]
register_gradient("subtract", subtract_grad)
schedule_broadcast = schedule_injective schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective schedule_elemwise = schedule_injective
......
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from ..expr import const
from .op import register_gradient
from .transform import collapse_sum_like, where
from .tensor import exp, negative, power, less
from .tensor import zeros_like, ones_like
@register_gradient("log")
def log_grad(orig, grad):
"""Returns [grad * (1 / x)]"""
x = orig.args[0]
return [grad * ones_like(x) / x]
@register_gradient("exp")
def exp_grad(orig, grad):
"""Returns [grad * exp(x)]"""
return [grad * exp(orig.args[0])]
@register_gradient("sqrt")
def sqrt_grad(orig, grad):
"""Returns [grad * 0.5 * (x ^ -0.5)]"""
a = const(0.5) # (TODO) type?
return [grad * a * power(orig.args[0], negative(a))]
@register_gradient("sigmoid")
def sigmoid_grad(orig, grad):
"""Returns [grad * sigmoid(x) * (1 - sigmoid(x))]."""
return [grad * orig * (ones_like(orig) - orig)]
@register_gradient("tanh")
def tanh_grad(orig, grad):
"""Returns grad * (1 - tanh(x) * tanh(x))."""
return [grad * ones_like(orig) - orig * orig]
@register_gradient("nn.relu")
def relu_grad(orig, grad):
"""Returns grad * (select(x < 0, 0, 1))."""
x = orig.args[0]
zeros = zeros_like(x)
ones = ones_like(x)
return [where(less(x, zeros), zeros, ones * grad)]
@register_gradient("add")
def add_grad(orig, grad):
"""Returns [grad, grad]"""
return [collapse_sum_like(grad, orig.args[0]),
collapse_sum_like(grad, orig.args[1])]
@register_gradient("subtract")
def subtract_grad(orig, grad):
"""Returns [grad, -grad]"""
return [collapse_sum_like(grad, orig.args[0]),
collapse_sum_like(negative(grad), orig.args[1])]
@register_gradient("multiply")
def multiply_grad(orig, grad):
"""Returns [grad * y, grad * x]"""
x, y = orig.args
return [collapse_sum_like(grad * y, x),
collapse_sum_like(grad * x, y)]
@register_gradient("divide")
def divide_grad(orig, grad):
"""Returns [grad / y, - grad * (x / y) / y]"""
x, y = orig.args
return [collapse_sum_like(grad / y, x),
collapse_sum_like(- (grad * orig / y), y)]
...@@ -168,7 +168,7 @@ def register_pattern(op_name, pattern, level=10): ...@@ -168,7 +168,7 @@ def register_pattern(op_name, pattern, level=10):
""" """
return register(op_name, "TOpPattern", pattern, level) return register(op_name, "TOpPattern", pattern, level)
def register_gradient(op_name, fgradient, level=10): def register_gradient(op_name, fgradient=None, level=10):
"""Register operator pattern for an op. """Register operator pattern for an op.
Parameters Parameters
......
import tvm
import numpy as np
from tvm import relay
from tvm.relay.ir_pass import gradient, infer_type
from tvm.relay.testing import ctx_list
def sigmoid(x):
one = np.ones_like(x)
return one / (one + np.exp(-x))
def relu(x):
x_copy = np.copy(x)
np.maximum(x_copy, 0, x_copy)
return x_copy
def test_unary_op():
def check_single_op(opfunc, ref):
shape = (10, 4)
dtype = 'float32'
tp = relay.TensorType(shape, dtype)
x = relay.var("x", tp)
y = opfunc(x)
if ref is not None:
data = np.random.rand(*shape).astype(dtype)
ref_grad = ref(data)
fwd_func = relay.Function([x], y)
bwd_func = infer_type(gradient(fwd_func))
for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
for opfunc, ref in [(tvm.relay.log, lambda x: 1 / x),
(tvm.relay.exp, np.exp),
(tvm.relay.sigmoid, lambda x: sigmoid(x) * (1 - sigmoid(x))),
(tvm.relay.tanh, lambda x: 1 - np.tanh(x) * np.tanh(x)),
(tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)),
(relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x)))]:
check_single_op(opfunc, ref)
def test_binary_op():
def inst(vars, sh):
return [vars.get(s, s) for s in sh]
def check_binary_op(opfunc, ref):
s = (5, 10, 5)
t = relay.TensorType((5, 10, 5))
x = relay.var("x", t)
y = relay.var("y", t)
z = opfunc(x, y)
x_data = np.random.rand(*s).astype(t.dtype)
y_data = np.random.rand(*s).astype(t.dtype)
ref_grad0, ref_grad1 = ref(x_data, y_data)
fwd_func = relay.Function([x, y], z)
bwd_func = infer_type(gradient(fwd_func))
for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad0, op_grad1) = intrp.evaluate(bwd_func)(x_data, y_data)
np.testing.assert_allclose(op_grad0.asnumpy(), ref_grad0, rtol=0.01)
np.testing.assert_allclose(op_grad1.asnumpy(), ref_grad1, rtol=0.01)
for opfunc, ref in [(relay.add, lambda x, y: [np.ones_like(x), np.ones_like(y)]),
(relay.subtract, lambda x, y: [np.ones_like(x), -np.ones_like(y)]),
(relay.multiply, lambda x, y: [y, x]),
(relay.divide, lambda x, y: [1 / y, - x / (y**2)])]:
check_binary_op(opfunc, ref)
if __name__ == "__main__":
test_unary_op()
test_binary_op()
...@@ -39,11 +39,11 @@ def test_unary_op(): ...@@ -39,11 +39,11 @@ def test_unary_op():
for opfunc, ref in [(tvm.relay.log, np.log), for opfunc, ref in [(tvm.relay.log, np.log),
(tvm.relay.exp, np.exp), (tvm.relay.exp, np.exp),
(tvm.relay.sqrt, np.sqrt), (tvm.relay.sqrt, np.sqrt),
(tvm.relay.sigmoid, sigmoid), (tvm.relay.sigmoid, sigmoid),
(tvm.relay.tanh, np.tanh), (tvm.relay.tanh, np.tanh),
(relay.nn.relu, relu)]: (relay.nn.relu, relu)]:
check_single_op(opfunc, ref) check_single_op(opfunc, ref)
...@@ -84,9 +84,9 @@ def test_binary_op(): ...@@ -84,9 +84,9 @@ def test_binary_op():
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
for opfunc, ref in [(relay.add, np.add), for opfunc, ref in [(relay.add, np.add),
(relay.subtract, np.subtract), (relay.subtract, np.subtract),
(relay.multiply, np.multiply), (relay.multiply, np.multiply),
(relay.divide, np.divide)]: (relay.divide, np.divide)]:
check_binary_op(opfunc, ref) check_binary_op(opfunc, ref)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment