Commit 4fbb7c89 by Wei Chen Committed by Tianqi Chen

[RELAY] Add occurs check before unification (#2012)

parent a0c813b2
...@@ -61,6 +61,11 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) { ...@@ -61,6 +61,11 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) {
// - handle shape pattern matching // - handle shape pattern matching
TypeNode* lhs = GetTypeNode(dst); TypeNode* lhs = GetTypeNode(dst);
TypeNode* rhs = GetTypeNode(src); TypeNode* rhs = GetTypeNode(src);
// do occur check so we don't create self-referencing structure
if (lhs->FindRoot() == rhs->FindRoot()) {
return lhs->resolved_type;
}
if (lhs->resolved_type.as<IncompleteTypeNode>()) { if (lhs->resolved_type.as<IncompleteTypeNode>()) {
MergeFromTo(lhs, rhs); MergeFromTo(lhs, rhs);
return rhs->resolved_type; return rhs->resolved_type;
......
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/pass.h>
TEST(Relay, SelfReference) {
using namespace tvm;
auto type_a = relay::TypeVarNode::make("a", relay::TypeVarNode::kType);
auto type_b = relay::TypeVarNode::make("b", relay::TypeVarNode::kType);
auto x = relay::VarNode::make("x", type_a);
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, type_b, Array<relay::TypeVar>{});
auto fx = relay::CallNode::make(f, Array<relay::Expr>{ x });
auto type_fx = relay::InferType(fx, relay::EnvironmentNode::make(Map<relay::GlobalVar, relay::Function>{}));
CHECK_EQ(type_fx->checked_type(), type_a);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
...@@ -107,6 +107,22 @@ def test_type_args(): ...@@ -107,6 +107,22 @@ def test_type_args():
assert sh2[0].value == 1 assert sh2[0].value == 1
assert sh2[1].value == 10 assert sh2[1].value == 10
def test_self_reference():
"""
Program:
def f(x) {
return x;
}
"""
a = relay.TypeVar("a")
x = relay.var("x", a)
sb = relay.ScopeBuilder()
f = relay.Function([x], x)
fx = relay.Call(f, [x])
assert relay.ir_pass.infer_type(x).checked_type == a
assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a)
assert relay.ir_pass.infer_type(fx).checked_type == a
if __name__ == "__main__": if __name__ == "__main__":
test_free_expr() test_free_expr()
test_dual_op() test_dual_op()
...@@ -117,3 +133,4 @@ if __name__ == "__main__": ...@@ -117,3 +133,4 @@ if __name__ == "__main__":
test_tuple() test_tuple()
test_free_expr() test_free_expr()
test_type_args() test_type_args()
test_self_reference()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment