Commit 65016b65 by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay] Alpha equality tests for Relay exprs (#1871)

parent 1eedc945
...@@ -112,7 +112,7 @@ class Call(Expr): ...@@ -112,7 +112,7 @@ class Call(Expr):
class Let(Expr): class Let(Expr):
"""A variable bindings in Relay, see tvm/relay/expr.h for more details.""" """A variable bindings in Relay, see tvm/relay/expr.h for more details."""
def __init__(self, var, value, body, value_type): def __init__(self, var, value, body, value_type=None):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Let, var, value, body, value_type) _make.Let, var, value, body, value_type)
......
...@@ -268,10 +268,27 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -268,10 +268,27 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
return; return;
} }
if (func1->type_params.size() != func2->type_params.size()) {
equal = false;
return;
}
for (size_t i = 0U; i < func1->params.size(); i++) { for (size_t i = 0U; i < func1->params.size(); i++) {
this->VisitExpr(func1->params[i], func2->params[i]); this->VisitExpr(func1->params[i], func2->params[i]);
} }
for (size_t i = 0U; i < func1->type_params.size(); i++) {
equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]);
if (!equal) {
return;
}
}
equal = equal && AlphaEqual(func1->ret_type, func2->ret_type);
if (!equal) {
return;
}
this->VisitExpr(func1->body, func2->body); this->VisitExpr(func1->body, func2->body);
} else { } else {
equal = false; equal = false;
...@@ -287,10 +304,27 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -287,10 +304,27 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
return; return;
} }
if (op->type_args.size() != call->type_args.size()) {
equal = false;
return;
}
// checking attrs by pointer equality for now
equal = equal && (op->attrs == call->attrs);
if (!equal) {
return;
}
for (size_t i = 0U; i < op->args.size(); i++) { for (size_t i = 0U; i < op->args.size(); i++) {
this->VisitExpr(op->args[i], call->args[i]); this->VisitExpr(op->args[i], call->args[i]);
} }
for (size_t i = 0U; i < op->type_args.size(); i++) {
equal = equal && AlphaEqual(op->type_args[i], call->type_args[i]);
if (!equal) {
return;
}
}
} else { } else {
equal = false; equal = false;
} }
...@@ -301,6 +335,16 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -301,6 +335,16 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
eq_map.Set(op->var, let->var); eq_map.Set(op->var, let->var);
this->VisitExpr(op->value, let->value); this->VisitExpr(op->value, let->value);
this->VisitExpr(op->body, let->body); this->VisitExpr(op->body, let->body);
// value_type should match as well (including nulls)
if (op->value_type.defined() != let->value_type.defined()) {
equal = false;
return;
}
if (op->value_type.defined()) {
equal = equal && AlphaEqual(op->value_type, let->value_type);
}
} else { } else {
equal = false; equal = false;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment