Commit 1a00cab9 by 雾雨魔理沙 Committed by Wuwei Lin

[Relay] add some check for the ad algorithm (#3585)

* do

* fix test
parent 313bc9de
...@@ -95,22 +95,37 @@ def divide_grad(orig, grad): ...@@ -95,22 +95,37 @@ def divide_grad(orig, grad):
collapse_sum_like(- (grad * orig / y), y)] collapse_sum_like(- (grad * orig / y), y)]
@register_gradient("zeros")
def zeros_grad(orig, grad):
"""Returns []"""
return []
@register_gradient("ones")
def ones_grad(orig, grad):
"""Returns []"""
return []
@register_gradient("zeros_like") @register_gradient("zeros_like")
def zeros_like_grad(orig, grad): def zeros_like_grad(orig, grad):
"""Returns [0]""" """Returns [0]"""
return [orig] return [orig]
@register_gradient("ones_like") @register_gradient("ones_like")
def ones_like_grad(orig, grad): def ones_like_grad(orig, grad):
"""Returns [0]""" """Returns [0]"""
return [zeros_like(orig.args[0])] return [zeros_like(orig.args[0])]
@register_gradient("collapse_sum_like") @register_gradient("collapse_sum_like")
def collapse_sum_like_grad(orig, grad): def collapse_sum_like_grad(orig, grad):
"""Returns [broadcast_to_like(grad, x), 0]""" """Returns [broadcast_to_like(grad, x), 0]"""
x, y = orig.args x, y = orig.args
return [broadcast_to_like(grad, x), zeros_like(y)] return [broadcast_to_like(grad, x), zeros_like(y)]
@register_gradient("abs") @register_gradient("abs")
def abs_grad(orig, grad): def abs_grad(orig, grad):
"""Returns grad * (select(x < 0, -1, 1)).""" """Returns grad * (select(x < 0, -1, 1))."""
...@@ -119,6 +134,7 @@ def abs_grad(orig, grad): ...@@ -119,6 +134,7 @@ def abs_grad(orig, grad):
ones = ones_like(x) ones = ones_like(x)
return [where(less(x, zeros), -ones * grad, ones * grad)] return [where(less(x, zeros), -ones * grad, ones * grad)]
@register_gradient("clip") @register_gradient("clip")
def clip_grad(orig, grad): def clip_grad(orig, grad):
"""Returns grad * (select(x < min || max < x , 0, 1)).""" """Returns grad * (select(x < min || max < x , 0, 1))."""
......
...@@ -333,6 +333,9 @@ Expr Gradient(const Expr& re, const Module& mod) { ...@@ -333,6 +333,9 @@ Expr Gradient(const Expr& re, const Module& mod) {
auto f = e.as<FunctionNode>(); auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function"; CHECK(f) << "input need to be a function";
CHECK(f->type_params.size() == 0) << "no polymorphism supported for now"; CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
for (const auto& p : f->params) {
CHECK(p->checked_type().as<TensorTypeNode>()) << "input parameters need to be tensor";
}
Expr body = LetList::With([&](LetList* ll) { Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty()); Var bp = ll->Push(BPEmpty());
Expr rev = ReverseAD(bp)(e); Expr rev = ReverseAD(bp)(e);
......
...@@ -21,6 +21,7 @@ from tvm.relay.analysis import detect_feature ...@@ -21,6 +21,7 @@ from tvm.relay.analysis import detect_feature
from tvm.relay.transform import gradient from tvm.relay.transform import gradient
from tvm.relay.feature import Feature from tvm.relay.feature import Feature
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.testing import run_infer_type
def test_prelude(): def test_prelude():
p = Prelude() p = Prelude()
...@@ -47,6 +48,7 @@ def test_ad(): ...@@ -47,6 +48,7 @@ def test_ad():
t = relay.TensorType(shape, dtype) t = relay.TensorType(shape, dtype)
x = relay.var("x", t) x = relay.var("x", t)
func = relay.Function([x], x + x) func = relay.Function([x], x + x)
func = run_infer_type(func)
mod = relay.Module.from_expr(gradient(func)) mod = relay.Module.from_expr(gradient(func))
mod = relay.transform.InferType()(mod) mod = relay.transform.InferType()(mod)
back_func = mod["main"] back_func = mod["main"]
......
...@@ -18,14 +18,7 @@ import numpy as np ...@@ -18,14 +18,7 @@ import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.transform import gradient from tvm.relay.transform import gradient
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list, run_infer_type
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = relay.transform.InferType()(mod)
return mod["main"]
def sigmoid(x): def sigmoid(x):
one = np.ones_like(x) one = np.ones_like(x)
...@@ -49,6 +42,7 @@ def test_unary_op(): ...@@ -49,6 +42,7 @@ def test_unary_op():
data = np.random.rand(*shape).astype(dtype) data = np.random.rand(*shape).astype(dtype)
ref_grad = ref(data) ref_grad = ref(data)
fwd_func = relay.Function([x], y) fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func)) bwd_func = run_infer_type(gradient(fwd_func))
for target, ctx in ctx_list(): for target, ctx in ctx_list():
...@@ -81,6 +75,7 @@ def test_binary_op(): ...@@ -81,6 +75,7 @@ def test_binary_op():
y_data = np.random.rand(*s).astype(t.dtype) y_data = np.random.rand(*s).astype(t.dtype)
ref_grad0, ref_grad1 = ref(x_data, y_data) ref_grad0, ref_grad1 = ref(x_data, y_data)
fwd_func = relay.Function([x, y], z) fwd_func = relay.Function([x, y], z)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func)) bwd_func = run_infer_type(gradient(fwd_func))
for target, ctx in ctx_list(): for target, ctx in ctx_list():
......
...@@ -18,13 +18,7 @@ import numpy as np ...@@ -18,13 +18,7 @@ import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.transform import gradient from tvm.relay.transform import gradient
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list, run_infer_type
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = relay.transform.InferType()(mod)
return mod["main"]
def test_clip(): def test_clip():
ref = (lambda x: np.where(x > 10.0, np.zeros_like(x), ref = (lambda x: np.where(x > 10.0, np.zeros_like(x),
...@@ -35,6 +29,7 @@ def test_clip(): ...@@ -35,6 +29,7 @@ def test_clip():
data = np.random.rand(10, 4).astype("float32") * 11.0 data = np.random.rand(10, 4).astype("float32") * 11.0
ref_grad = ref(data) ref_grad = ref(data)
fwd_func = relay.Function([x], y) fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func)) bwd_func = run_infer_type(gradient(fwd_func))
for target, ctx in ctx_list(): for target, ctx in ctx_list():
......
...@@ -35,6 +35,7 @@ def test_id(): ...@@ -35,6 +35,7 @@ def test_id():
t = relay.TensorType(shape, dtype) t = relay.TensorType(shape, dtype)
x = relay.var("x", t) x = relay.var("x", t)
func = relay.Function([x], x) func = relay.Function([x], x)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func, mode="first_order")) back_func = run_infer_type(gradient(func, mode="first_order"))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor() ex = create_executor()
...@@ -50,6 +51,7 @@ def test_add(): ...@@ -50,6 +51,7 @@ def test_add():
t = relay.TensorType(shape, dtype) t = relay.TensorType(shape, dtype)
x = relay.var("x", t) x = relay.var("x", t)
func = relay.Function([x], x + x) func = relay.Function([x], x + x)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func)) back_func = run_infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor() ex = create_executor()
...@@ -66,6 +68,7 @@ def test_temp_add(): ...@@ -66,6 +68,7 @@ def test_temp_add():
x = relay.var("x", t) x = relay.var("x", t)
y = x + x y = x + x
func = relay.Function([x], y + y) func = relay.Function([x], y + y)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func)) back_func = run_infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor() ex = create_executor()
...@@ -81,6 +84,7 @@ def test_sub(): ...@@ -81,6 +84,7 @@ def test_sub():
t = relay.TensorType(shape, dtype) t = relay.TensorType(shape, dtype)
x = relay.var("x", t) x = relay.var("x", t)
func = relay.Function([x], x - x) func = relay.Function([x], x - x)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func)) back_func = run_infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor() ex = create_executor()
...@@ -104,6 +108,7 @@ def test_broadcast_add(): ...@@ -104,6 +108,7 @@ def test_broadcast_add():
x = relay.var("x", t1) x = relay.var("x", t1)
y = relay.var("y", t2) y = relay.var("y", t2)
func = relay.Function([x, y], x + y) func = relay.Function([x, y], x + y)
func = run_infer_type(func)
full_func = run_infer_type(gradient(func)) full_func = run_infer_type(gradient(func))
assert full_func.checked_type == relay.FuncType([t1, t2], assert full_func.checked_type == relay.FuncType([t1, t2],
relay.TupleType([relay.TensorType(expected_forward.shape, dtype), relay.TupleType([relay.TensorType(expected_forward.shape, dtype),
...@@ -131,6 +136,7 @@ def test_broadcast_subtract(): ...@@ -131,6 +136,7 @@ def test_broadcast_subtract():
x = relay.var("x", t1) x = relay.var("x", t1)
y = relay.var("y", t2) y = relay.var("y", t2)
func = relay.Function([x, y], x - y) func = relay.Function([x, y], x - y)
func = run_infer_type(func)
full_func = run_infer_type(gradient(func)) full_func = run_infer_type(gradient(func))
assert full_func.checked_type == relay.FuncType([t1, t2], assert full_func.checked_type == relay.FuncType([t1, t2],
relay.TupleType([relay.TensorType(expected_forward.shape, dtype), relay.TupleType([relay.TensorType(expected_forward.shape, dtype),
...@@ -156,6 +162,7 @@ def test_tuple(): ...@@ -156,6 +162,7 @@ def test_tuple():
relay.TupleGetItem(tup, 0) + relay.TupleGetItem(tup, 0) +
relay.TupleGetItem(tup, 1) - relay.TupleGetItem(tup, 1) -
relay.TupleGetItem(tup, 2))) relay.TupleGetItem(tup, 2)))
func = run_infer_type(func)
back_func = run_infer_type(gradient(func)) back_func = run_infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t, t, t], relay.TupleType([t, relay.TupleType([t, t, t])])) assert back_func.checked_type == relay.FuncType([t, t, t], relay.TupleType([t, relay.TupleType([t, t, t])]))
x_nd = rand(dtype, *shape) x_nd = rand(dtype, *shape)
...@@ -184,8 +191,8 @@ def test_pow(): ...@@ -184,8 +191,8 @@ def test_pow():
double = relay.Function([x], x + x) double = relay.Function([x], x + x)
i = relay.var("i", t) i = relay.var("i", t)
func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
func = gradient(func, mod=mod)
mod["main"] = func mod["main"] = func
mod["main"] = gradient(mod["main"], mod=mod)
m = transform.InferType()(mod) m = transform.InferType()(mod)
back_func = m["main"] back_func = m["main"]
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
...@@ -207,6 +214,7 @@ def test_ref(): ...@@ -207,6 +214,7 @@ def test_ref():
body = relay.Let(u, relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)), body) body = relay.Let(u, relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)), body)
body = relay.Let(r, relay.RefCreate(x), body) body = relay.Let(r, relay.RefCreate(x), body)
func = relay.Function([x], body) func = relay.Function([x], body)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func)) back_func = run_infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
x_nd = rand(dtype, *shape) x_nd = rand(dtype, *shape)
...@@ -222,6 +230,7 @@ def test_square_second_order(): ...@@ -222,6 +230,7 @@ def test_square_second_order():
t = relay.TensorType(shape, dtype) t = relay.TensorType(shape, dtype)
x = relay.var("x", t) x = relay.var("x", t)
func = relay.Function([x], x * x) func = relay.Function([x], x * x)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func)) back_func = run_infer_type(gradient(func))
y = relay.var("y", t) y = relay.var("y", t)
back_func_adjusted = relay.Function([y], relay.TupleGetItem(relay.TupleGetItem(back_func(y), 1), 0)) back_func_adjusted = relay.Function([y], relay.TupleGetItem(relay.TupleGetItem(back_func(y), 1), 0))
...@@ -242,6 +251,7 @@ def test_if(): ...@@ -242,6 +251,7 @@ def test_if():
net = relay.If(cond, x, y) net = relay.If(cond, x, y)
net = relay.log(net) net = relay.log(net)
func = relay.Function(free_vars(net), net) func = relay.Function(free_vars(net), net)
func = run_infer_type(func)
net = run_infer_type(func) net = run_infer_type(func)
net = gradient(net, mode='higher_order') net = gradient(net, mode='higher_order')
net = run_infer_type(net) net = run_infer_type(net)
......
...@@ -25,7 +25,7 @@ from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, ...@@ -25,7 +25,7 @@ from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead,
from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match
from tvm.relay import GlobalVar, Call from tvm.relay import GlobalVar, Call
from tvm.relay.transform import gradient from tvm.relay.transform import gradient
from tvm.relay.testing import add_nat_definitions, make_nat_expr from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type
def check_eval(expr, expected_result, mod=None, rtol=1e-07): def check_eval(expr, expected_result, mod=None, rtol=1e-07):
ctx = tvm.context("llvm", 0) ctx = tvm.context("llvm", 0)
...@@ -54,7 +54,7 @@ def dcpe(expr, mod=None, grad=False): ...@@ -54,7 +54,7 @@ def dcpe(expr, mod=None, grad=False):
passes = [transform.PartialEvaluate(), passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)] transform.DeadCodeElimination(inline_once=True)]
if grad: if grad:
expr = gradient(expr) expr = gradient(run_infer_type(expr))
if mod: if mod:
assert isinstance(expr, Function) assert isinstance(expr, Function)
mod["main"] = expr mod["main"] = expr
......
...@@ -81,6 +81,7 @@ def test_cps_pe(): ...@@ -81,6 +81,7 @@ def test_cps_pe():
destroy_ref(F) destroy_ref(F)
G = relay.Function([cond], relay.If(cond, one, two)) G = relay.Function([cond], relay.If(cond, one, two))
G = run_infer_type(G)
G = relay.transform.gradient(G) G = relay.transform.gradient(G)
destroy_ref(G) destroy_ref(G)
...@@ -91,6 +92,7 @@ def test_cps_pe(): ...@@ -91,6 +92,7 @@ def test_cps_pe():
H = relay.If(cond, x, y) H = relay.If(cond, x, y)
H = relay.add(H, z) H = relay.add(H, z)
H = relay.Function([cond,x,y,z], H) H = relay.Function([cond,x,y,z], H)
H = run_infer_type(H)
H = relay.transform.gradient(H) H = relay.transform.gradient(H)
destroy_ref(H) destroy_ref(H)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment