Commit d24c7eed by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay] Ensure nested higher-order functions are treated correctly (#2676)

parent b5f46c42
...@@ -394,9 +394,7 @@ class Prelude: ...@@ -394,9 +394,7 @@ class Prelude:
f = Var("f", FuncType([a], a)) f = Var("f", FuncType([a], a))
x = Var("x", self.nat()) x = Var("x", self.nat())
y = Var("y", self.nat()) y = Var("y", self.nat())
z = Var("z") z_case = Clause(PatternConstructor(self.z), self.id)
z_case = Clause(PatternConstructor(self.z), Function([z], z))
# todo: fix typechecker so Function([z], z) can be replaced by self.id
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
self.compose(f, self.iterate(f, y))) self.compose(f, self.iterate(f, y)))
self.mod[self.iterate] = Function([f, x], self.mod[self.iterate] = Function([f, x],
......
...@@ -121,7 +121,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -121,7 +121,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) { Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) {
// TODO(tqchen, jroesch): propagate span to solver // TODO(tqchen, jroesch): propagate span to solver
try { try {
return solver_.Unify(t1, t2, expr); // instantiate higher-order func types when unifying because
// we only allow polymorphism at the top level
Type first = t1;
Type second = t2;
if (auto* ft1 = t1.as<FuncTypeNode>()) {
first = InstantiateFuncType(ft1);
}
if (auto* ft2 = t2.as<FuncTypeNode>()) {
second = InstantiateFuncType(ft2);
}
return solver_.Unify(first, second, expr);
} catch (const dmlc::Error &e) { } catch (const dmlc::Error &e) {
this->ReportFatalError( this->ReportFatalError(
expr, expr,
...@@ -351,6 +361,20 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -351,6 +361,20 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
return Downcast<FuncType>(inst_ty); return Downcast<FuncType>(inst_ty);
} }
// instantiates starting from incompletes
FuncType InstantiateFuncType(const FuncTypeNode* fn_ty) {
if (fn_ty->type_params.size() == 0) {
return GetRef<FuncType>(fn_ty);
}
Array<Type> type_args;
for (size_t i = 0; i < fn_ty->type_params.size(); i++) {
type_args.push_back(IncompleteTypeNode::make(Kind::kType));
}
return InstantiateFuncType(fn_ty, type_args);
}
void AddTypeArgs(const Expr& expr, Array<Type> type_args) { void AddTypeArgs(const Expr& expr, Array<Type> type_args) {
auto type_info = type_map_.find(expr); auto type_info = type_map_.find(expr);
if (type_info == type_map_.end()) { if (type_info == type_map_.end()) {
...@@ -464,6 +488,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -464,6 +488,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
arg_types.push_back(GetType(param)); arg_types.push_back(GetType(param));
} }
Type rtype = GetType(f->body); Type rtype = GetType(f->body);
if (auto* ft = rtype.as<FuncTypeNode>()) {
rtype = InstantiateFuncType(ft);
}
if (f->ret_type.defined()) { if (f->ret_type.defined()) {
rtype = this->Unify(f->ret_type, rtype, GetRef<Function>(f)); rtype = this->Unify(f->ret_type, rtype, GetRef<Function>(f));
} }
......
...@@ -133,6 +133,58 @@ def test_incomplete_call(): ...@@ -133,6 +133,58 @@ def test_incomplete_call():
assert ft.checked_type == relay.FuncType([tt, f_type], tt) assert ft.checked_type == relay.FuncType([tt, f_type], tt)
def test_higher_order_argument():
a = relay.TypeVar('a')
x = relay.Var('x', a)
id_func = relay.Function([x], x, a, [a])
b = relay.TypeVar('b')
f = relay.Var('f', relay.FuncType([b], b))
y = relay.Var('y', b)
ho_func = relay.Function([f, y], f(y), b, [b])
# id func should be an acceptable argument to the higher-order
# function even though id_func takes a type parameter
ho_call = ho_func(id_func, relay.const(0, 'int32'))
hc = relay.ir_pass.infer_type(ho_call)
expected = relay.scalar_type('int32')
assert hc.checked_type == expected
def test_higher_order_return():
a = relay.TypeVar('a')
x = relay.Var('x', a)
id_func = relay.Function([x], x, a, [a])
b = relay.TypeVar('b')
nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b])
ft = relay.ir_pass.infer_type(nested_id)
assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b])
def test_higher_order_nested():
a = relay.TypeVar('a')
x = relay.Var('x', a)
id_func = relay.Function([x], x, a, [a])
choice_t = relay.FuncType([], relay.scalar_type('bool'))
f = relay.Var('f', choice_t)
b = relay.TypeVar('b')
z = relay.Var('z')
top = relay.Function(
[f],
relay.If(f(), id_func, relay.Function([z], z)),
relay.FuncType([b], b),
[b])
expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b])
ft = relay.ir_pass.infer_type(top)
assert ft.checked_type == expected
def test_tuple(): def test_tuple():
tp = relay.TensorType((10,)) tp = relay.TensorType((10,))
x = relay.var("x", tp) x = relay.var("x", tp)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment