Commit 65016b65 by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay] Alpha equality tests for Relay exprs (#1871)

parent 1eedc945
...@@ -112,7 +112,7 @@ class Call(Expr): ...@@ -112,7 +112,7 @@ class Call(Expr):
class Let(Expr): class Let(Expr):
"""A variable bindings in Relay, see tvm/relay/expr.h for more details.""" """A variable bindings in Relay, see tvm/relay/expr.h for more details."""
def __init__(self, var, value, body, value_type): def __init__(self, var, value, body, value_type=None):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Let, var, value, body, value_type) _make.Let, var, value, body, value_type)
......
...@@ -268,10 +268,27 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -268,10 +268,27 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
return; return;
} }
if (func1->type_params.size() != func2->type_params.size()) {
equal = false;
return;
}
for (size_t i = 0U; i < func1->params.size(); i++) { for (size_t i = 0U; i < func1->params.size(); i++) {
this->VisitExpr(func1->params[i], func2->params[i]); this->VisitExpr(func1->params[i], func2->params[i]);
} }
for (size_t i = 0U; i < func1->type_params.size(); i++) {
equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]);
if (!equal) {
return;
}
}
equal = equal && AlphaEqual(func1->ret_type, func2->ret_type);
if (!equal) {
return;
}
this->VisitExpr(func1->body, func2->body); this->VisitExpr(func1->body, func2->body);
} else { } else {
equal = false; equal = false;
...@@ -287,10 +304,27 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -287,10 +304,27 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
return; return;
} }
if (op->type_args.size() != call->type_args.size()) {
equal = false;
return;
}
// checking attrs by pointer equality for now
equal = equal && (op->attrs == call->attrs);
if (!equal) {
return;
}
for (size_t i = 0U; i < op->args.size(); i++) { for (size_t i = 0U; i < op->args.size(); i++) {
this->VisitExpr(op->args[i], call->args[i]); this->VisitExpr(op->args[i], call->args[i]);
} }
for (size_t i = 0U; i < op->type_args.size(); i++) {
equal = equal && AlphaEqual(op->type_args[i], call->type_args[i]);
if (!equal) {
return;
}
}
} else { } else {
equal = false; equal = false;
} }
...@@ -301,6 +335,16 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -301,6 +335,16 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
eq_map.Set(op->var, let->var); eq_map.Set(op->var, let->var);
this->VisitExpr(op->value, let->value); this->VisitExpr(op->value, let->value);
this->VisitExpr(op->body, let->body); this->VisitExpr(op->body, let->body);
// value_type should match as well (including nulls)
if (op->value_type.defined() != let->value_type.defined()) {
equal = false;
return;
}
if (op->value_type.defined()) {
equal = equal && AlphaEqual(op->value_type, let->value_type);
}
} else { } else {
equal = false; equal = false;
} }
......
...@@ -14,12 +14,6 @@ def test_tensor_type_alpha_equal(): ...@@ -14,12 +14,6 @@ def test_tensor_type_alpha_equal():
t2 = relay.TensorType((), "float32") t2 = relay.TensorType((), "float32")
assert t1 == t2 assert t1 == t2
def test_constant_alpha_equal():
x = convert(1)
y = convert(2)
assert alpha_equal(x, x)
assert not alpha_equal(x, y)
assert alpha_equal(x, convert(1))
def test_incomplete_type_alpha_equal(): def test_incomplete_type_alpha_equal():
t1 = relay.IncompleteType(relay.Kind.Shape) t1 = relay.IncompleteType(relay.Kind.Shape)
...@@ -167,6 +161,79 @@ def test_type_relation_alpha_equal(): ...@@ -167,6 +161,79 @@ def test_type_relation_alpha_equal():
assert bigger != diff_num_inputs assert bigger != diff_num_inputs
def test_constant_alpha_equal():
x = convert(1)
y = convert(2)
assert alpha_equal(x, x)
assert not alpha_equal(x, y)
assert alpha_equal(x, convert(1))
def test_var_alpha_equal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
# normally only pointer equality
assert alpha_equal(v1, v1)
assert not alpha_equal(v1, v2)
# let node allows for setting the eq_map
l1 = relay.Let(v1, convert(1), v1, None)
l2 = relay.Let(v2, convert(1), v2, None)
l3 = relay.Let(v1, convert(1), v2, None)
assert alpha_equal(l1, l2)
assert not alpha_equal(l1, l3)
def test_global_var_alpha_equal():
v1 = relay.GlobalVar("v1")
v2 = relay.GlobalVar("v2")
# only pointer equality suffices (smoke test)
assert alpha_equal(v1, v1)
assert not alpha_equal(v1, v2)
def test_tuple_alpha_equal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
# unit value is a valid tuple
assert alpha_equal(relay.Tuple([]), relay.Tuple([]))
tup = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)])])
same = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)])])
assert alpha_equal(tup, same)
# use the eq_map
let_tup = relay.Let(v1, tup, v1, None)
let_mapped = relay.Let(v2, relay.Tuple([v2, convert(2), convert(3),
relay.Tuple([convert(4)])]),
v2, None)
assert alpha_equal(let_tup, let_mapped)
more_fields = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)]), v2])
assert not alpha_equal(tup, more_fields)
fewer_fields = relay.Tuple([v1, convert(2), convert(3)])
assert not alpha_equal(tup, fewer_fields)
different_end = relay.Tuple([v1, convert(2), convert(3),
relay.Tuple([convert(5)])])
assert not alpha_equal(tup, different_end)
different_start = relay.Tuple([v2, convert(2), convert(3),
relay.Tuple([convert(4)])])
assert not alpha_equal(tup, different_start)
longer_at_end = relay.Tuple([v1, convert(2), convert(3),
relay.Tuple([convert(4), convert(5)])])
assert not alpha_equal(tup, longer_at_end)
def test_tuple_get_item_alpha_equal(): def test_tuple_get_item_alpha_equal():
x = relay.Var('x') x = relay.Var('x')
y = relay.Var('y') y = relay.Var('y')
...@@ -174,6 +241,198 @@ def test_tuple_get_item_alpha_equal(): ...@@ -174,6 +241,198 @@ def test_tuple_get_item_alpha_equal():
assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2)) assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1)) assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
def test_param_alpha_equal():
# only checks equality of the types
v1 = relay.Var("v1")
v2 = relay.Var("v2")
p1 = relay.Param(v1, relay.TensorType((1, 2, 3), "float32"))
p2 = relay.Param(v2, relay.TensorType((1, 2, 3), "float32"))
assert alpha_equal(p1, p2)
p3 = relay.Param(v1, relay.TensorType((4, 5, 6), "int8"))
assert not alpha_equal(p1, p3)
p4 = relay.Param(v1, relay.TupleType([relay.TensorType((1, 2, 3),
"float32")]))
assert not alpha_equal(p1, p4)
def test_function_alpha_equal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
v3 = relay.Var("v3")
v4 = relay.Var("v4")
tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((4, 5, 6), "int8")
tt3 = relay.TupleType([tt1, tt2])
tp1 = relay.TypeParam("tp1", relay.Kind.Type)
tp2 = relay.TypeParam("tp2", relay.Kind.Type)
tp3 = relay.TypeParam("tp3", relay.Kind.Shape)
tp4 = relay.TypeParam("tp4", relay.Kind.Shape)
basic_args = [relay.Param(v3, tt1), relay.Param(v4, tt2)]
basic_tps = [tp1, tp2]
func = relay.Function([relay.Param(v1, tt1), relay.Param(v2, tt2)],
tt2, v2, basic_tps)
mapped = relay.Function(basic_args, tt2, v4, basic_tps)
assert alpha_equal(func, mapped)
fewer_params = relay.Function([relay.Param(v4, tt2)], tt2, v4, basic_tps)
assert not alpha_equal(func, fewer_params)
more_params = relay.Function([relay.Param(v3, tt1), relay.Param(v4, tt2),
relay.Param(v2, tt2)], tt2, v4, basic_tps)
assert not alpha_equal(func, more_params)
params_unordered = relay.Function([relay.Param(v3, tt2),
relay.Param(v4, tt1)],
tt1, v3, basic_tps)
assert not alpha_equal(func, params_unordered)
params_mismatch = relay.Function([relay.Param(v3, tt3),
relay.Param(v4, tt2)],
tt2, v4, basic_tps)
assert not alpha_equal(func, params_mismatch)
# also would not typecheck
ret_type_mismatch = relay.Function(basic_args, tt1, v4, basic_tps)
assert not alpha_equal(func, ret_type_mismatch)
# also mis-typed
different_body = relay.Function(basic_args, tt2, v3, basic_tps)
assert not alpha_equal(func, different_body)
fewer_type_params = relay.Function(basic_args, tt2, v4, [tp1])
assert not alpha_equal(func, fewer_type_params)
more_type_params = relay.Function(basic_args, tt2, v4, [tp1, tp2, tp3])
assert not alpha_equal(func, more_type_params)
type_params_unordered = relay.Function(basic_args, tt2, v4, [tp2, tp1])
assert not alpha_equal(func, type_params_unordered)
different_type_params = relay.Function(basic_args, tt2, v4, [tp3, tp4])
assert not alpha_equal(func, different_type_params)
# a well-typed example that also differs in body, ret type, and type params
tupled_example = relay.Function(basic_args, tt3, relay.Tuple([v3, v4]))
assert not alpha_equal(func, tupled_example)
def test_call_alpha_equal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
# attrs are compared only by pointer equality
attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((), "int8")
basic_args = [convert(1), convert(2), v2, relay.Tuple([])]
# manually writing out args to ensure that args does not rely on
# pointer equality
call = relay.Call(v1, [convert(1), convert(2), v2, relay.Tuple([])],
attr1, [tt1])
same = relay.Call(v1, basic_args, attr1, [tt1])
assert alpha_equal(call, same)
different_fn = relay.Call(v2, basic_args, attr1, [tt1])
assert not alpha_equal(call, different_fn)
fewer_args = relay.Call(v1, [convert(1), convert(2), v2], attr1, [tt1])
assert not alpha_equal(call, fewer_args)
reordered_args = relay.Call(v1, [convert(2), convert(1),
relay.Tuple([]), v2], attr1, [tt1])
assert not alpha_equal(call, reordered_args)
different_args = relay.Call(v1, [convert(1), convert(2), convert(3)],
attr1, [tt1])
assert not alpha_equal(call, different_args)
more_args = relay.Call(v1, [convert(1), convert(2), v2, relay.Tuple([]),
convert(3), convert(4)], attr1, [tt1])
assert not alpha_equal(call, more_args)
different_attrs = relay.Call(v1, basic_args, attr2, [tt1])
assert not alpha_equal(call, different_attrs)
no_type_args = relay.Call(v1, basic_args, attr1)
assert not alpha_equal(call, no_type_args)
more_type_args = relay.Call(v1, basic_args, attr1, [tt1, tt2])
assert not alpha_equal(call, more_type_args)
different_type_arg = relay.Call(v1, basic_args, attr1, [tt2])
assert not alpha_equal(call, different_type_arg)
def test_let_alpha_equal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
v3 = relay.Var("v3")
let = relay.Let(v1, convert(2), v1)
mapped = relay.Let(v2, convert(2), v2)
assert alpha_equal(let, mapped)
mismatched_var = relay.Let(v2, convert(2), v3)
assert not alpha_equal(let, mismatched_var)
different_value = relay.Let(v2, convert(3), v2)
assert not alpha_equal(let, different_value)
different_body = relay.Let(v2, convert(3), convert(12))
assert not alpha_equal(let, different_body)
# specified types must match
tt1 = relay.TensorType((), "float32")
tt2 = relay.TensorType((), "int8")
let_with_type = relay.Let(v1, convert(2), v1, tt1)
same_type = relay.Let(v1, convert(2), v1, tt1)
assert alpha_equal(let_with_type, same_type)
assert not alpha_equal(let, let_with_type)
different_type = relay.Let(v1, convert(2), v1, tt2)
assert not alpha_equal(let_with_type, different_type)
def test_if_alpha_equal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
if_sample = relay.If(v1, convert(1), relay.Tuple([convert(2), convert(3)]))
same = relay.If(v1, convert(1), relay.Tuple([convert(2), convert(3)]))
assert alpha_equal(if_sample, same)
different_cond = relay.If(v2, convert(1), relay.Tuple([convert(2), convert(3)]))
assert not alpha_equal(if_sample, different_cond)
different_true = relay.If(v1, convert(2), relay.Tuple([convert(2), convert(3)]))
assert not alpha_equal(if_sample, different_true)
different_false = relay.If(v1, convert(1), relay.Tuple([]))
assert not alpha_equal(if_sample, different_false)
def test_op_alpha_equal():
# only checks names
op1 = relay.op.get("add")
op2 = relay.op.get("add")
assert alpha_equal(op1, op2)
op3 = relay.op.get("take")
assert not alpha_equal(op1, op3)
if __name__ == "__main__": if __name__ == "__main__":
test_tensor_type_alpha_equal() test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal() test_incomplete_type_alpha_equal()
...@@ -182,4 +441,14 @@ if __name__ == "__main__": ...@@ -182,4 +441,14 @@ if __name__ == "__main__":
test_func_type_alpha_equal() test_func_type_alpha_equal()
test_tuple_type_alpha_equal() test_tuple_type_alpha_equal()
test_type_relation_alpha_equal() test_type_relation_alpha_equal()
test_constant_alpha_equal()
test_var_alpha_equal()
test_global_var_alpha_equal()
test_tuple_alpha_equal()
test_tuple_get_item_alpha_equal() test_tuple_get_item_alpha_equal()
test_param_alpha_equal()
test_function_alpha_equal()
test_call_alpha_equal()
test_let_alpha_equal()
test_if_alpha_equal()
test_op_alpha_equal()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment