Commit 1a00cab9 by 雾雨魔理沙 Committed by Wuwei Lin

[Relay] add some check for the ad algorithm (#3585)

* do

* fix test
parent 313bc9de
......@@ -95,22 +95,37 @@ def divide_grad(orig, grad):
collapse_sum_like(- (grad * orig / y), y)]
@register_gradient("zeros")
def zeros_grad(orig, grad):
"""Returns []"""
return []
@register_gradient("ones")
def ones_grad(orig, grad):
"""Returns []"""
return []
@register_gradient("zeros_like")
def zeros_like_grad(orig, grad):
"""Returns [0]"""
return [orig]
@register_gradient("ones_like")
def ones_like_grad(orig, grad):
"""Returns [0]"""
return [zeros_like(orig.args[0])]
@register_gradient("collapse_sum_like")
def collapse_sum_like_grad(orig, grad):
"""Returns [broadcast_to_like(grad, x), 0]"""
x, y = orig.args
return [broadcast_to_like(grad, x), zeros_like(y)]
@register_gradient("abs")
def abs_grad(orig, grad):
"""Returns grad * (select(x < 0, -1, 1))."""
......@@ -119,6 +134,7 @@ def abs_grad(orig, grad):
ones = ones_like(x)
return [where(less(x, zeros), -ones * grad, ones * grad)]
@register_gradient("clip")
def clip_grad(orig, grad):
"""Returns grad * (select(x < min || max < x , 0, 1))."""
......
......@@ -333,6 +333,9 @@ Expr Gradient(const Expr& re, const Module& mod) {
auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function";
CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
for (const auto& p : f->params) {
CHECK(p->checked_type().as<TensorTypeNode>()) << "input parameters need to be tensor";
}
Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty());
Expr rev = ReverseAD(bp)(e);
......
......@@ -21,6 +21,7 @@ from tvm.relay.analysis import detect_feature
from tvm.relay.transform import gradient
from tvm.relay.feature import Feature
from tvm.relay.prelude import Prelude
from tvm.relay.testing import run_infer_type
def test_prelude():
p = Prelude()
......@@ -47,6 +48,7 @@ def test_ad():
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x + x)
func = run_infer_type(func)
mod = relay.Module.from_expr(gradient(func))
mod = relay.transform.InferType()(mod)
back_func = mod["main"]
......
......@@ -18,14 +18,7 @@ import numpy as np
import tvm
from tvm import relay
from tvm.relay.transform import gradient
from tvm.relay.testing import ctx_list
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = relay.transform.InferType()(mod)
return mod["main"]
from tvm.relay.testing import ctx_list, run_infer_type
def sigmoid(x):
one = np.ones_like(x)
......@@ -49,6 +42,7 @@ def test_unary_op():
data = np.random.rand(*shape).astype(dtype)
ref_grad = ref(data)
fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))
for target, ctx in ctx_list():
......@@ -81,6 +75,7 @@ def test_binary_op():
y_data = np.random.rand(*s).astype(t.dtype)
ref_grad0, ref_grad1 = ref(x_data, y_data)
fwd_func = relay.Function([x, y], z)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))
for target, ctx in ctx_list():
......
......@@ -18,13 +18,7 @@ import numpy as np
import tvm
from tvm import relay
from tvm.relay.transform import gradient
from tvm.relay.testing import ctx_list
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = relay.transform.InferType()(mod)
return mod["main"]
from tvm.relay.testing import ctx_list, run_infer_type
def test_clip():
ref = (lambda x: np.where(x > 10.0, np.zeros_like(x),
......@@ -35,6 +29,7 @@ def test_clip():
data = np.random.rand(10, 4).astype("float32") * 11.0
ref_grad = ref(data)
fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))
for target, ctx in ctx_list():
......
......@@ -35,6 +35,7 @@ def test_id():
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func, mode="first_order"))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
......@@ -50,6 +51,7 @@ def test_add():
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x + x)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
......@@ -66,6 +68,7 @@ def test_temp_add():
x = relay.var("x", t)
y = x + x
func = relay.Function([x], y + y)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
......@@ -81,6 +84,7 @@ def test_sub():
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x - x)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
......@@ -104,6 +108,7 @@ def test_broadcast_add():
x = relay.var("x", t1)
y = relay.var("y", t2)
func = relay.Function([x, y], x + y)
func = run_infer_type(func)
full_func = run_infer_type(gradient(func))
assert full_func.checked_type == relay.FuncType([t1, t2],
relay.TupleType([relay.TensorType(expected_forward.shape, dtype),
......@@ -131,6 +136,7 @@ def test_broadcast_subtract():
x = relay.var("x", t1)
y = relay.var("y", t2)
func = relay.Function([x, y], x - y)
func = run_infer_type(func)
full_func = run_infer_type(gradient(func))
assert full_func.checked_type == relay.FuncType([t1, t2],
relay.TupleType([relay.TensorType(expected_forward.shape, dtype),
......@@ -156,6 +162,7 @@ def test_tuple():
relay.TupleGetItem(tup, 0) +
relay.TupleGetItem(tup, 1) -
relay.TupleGetItem(tup, 2)))
func = run_infer_type(func)
back_func = run_infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t, t, t], relay.TupleType([t, relay.TupleType([t, t, t])]))
x_nd = rand(dtype, *shape)
......@@ -184,8 +191,8 @@ def test_pow():
double = relay.Function([x], x + x)
i = relay.var("i", t)
func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
func = gradient(func, mod=mod)
mod["main"] = func
mod["main"] = gradient(mod["main"], mod=mod)
m = transform.InferType()(mod)
back_func = m["main"]
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
......@@ -207,6 +214,7 @@ def test_ref():
body = relay.Let(u, relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)), body)
body = relay.Let(r, relay.RefCreate(x), body)
func = relay.Function([x], body)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
x_nd = rand(dtype, *shape)
......@@ -222,6 +230,7 @@ def test_square_second_order():
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x * x)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func))
y = relay.var("y", t)
back_func_adjusted = relay.Function([y], relay.TupleGetItem(relay.TupleGetItem(back_func(y), 1), 0))
......@@ -242,6 +251,7 @@ def test_if():
net = relay.If(cond, x, y)
net = relay.log(net)
func = relay.Function(free_vars(net), net)
func = run_infer_type(func)
net = run_infer_type(func)
net = gradient(net, mode='higher_order')
net = run_infer_type(net)
......
......@@ -25,7 +25,7 @@ from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead,
from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match
from tvm.relay import GlobalVar, Call
from tvm.relay.transform import gradient
from tvm.relay.testing import add_nat_definitions, make_nat_expr
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
ctx = tvm.context("llvm", 0)
......@@ -54,7 +54,7 @@ def dcpe(expr, mod=None, grad=False):
passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)]
if grad:
expr = gradient(expr)
expr = gradient(run_infer_type(expr))
if mod:
assert isinstance(expr, Function)
mod["main"] = expr
......
......@@ -81,6 +81,7 @@ def test_cps_pe():
destroy_ref(F)
G = relay.Function([cond], relay.If(cond, one, two))
G = run_infer_type(G)
G = relay.transform.gradient(G)
destroy_ref(G)
......@@ -91,6 +92,7 @@ def test_cps_pe():
H = relay.If(cond, x, y)
H = relay.add(H, z)
H = relay.Function([cond,x,y,z], H)
H = run_infer_type(H)
H = relay.transform.gradient(H)
destroy_ref(H)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment