Commit 97be70a0 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] add more function to prelude (#2660)

parent 4ba30478
......@@ -250,12 +250,15 @@ class Interpreter(Executor):
The optimized expression.
"""
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=self.mod)
wrapped_expr = expr if isinstance(expr, Function) else Function([], expr)
if self.mod:
self.mod[self.mod.entry_func] = wrapped_expr
ck_expr = ir_pass.infer_type(wrapped_expr, mod=self.mod)
simp_expr = ir_pass.simplify_inference(ck_expr)
ck_simp = ir_pass.infer_type(simp_expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(ck_simp)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused
return ck_fused if isinstance(expr, Function) else Call(ck_fused, [])
def _make_executor(self, expr):
def _interp_wrapper(*args, **kwargs):
......
......@@ -340,7 +340,10 @@ class Prelude:
Match(t, [rose_case]), self.tree(b), [a, b])
def define_tree_size(self):
"""Defines a function that computes the size of a tree as a nat."""
"""Defines a function that computes the size of a tree as a nat.
Signature: fn<a>(t : tree[a]) -> nat
"""
self.size = GlobalVar("size")
a = TypeVar("a")
t = Var("t", self.tree(a))
......@@ -351,6 +354,56 @@ class Prelude:
self.mod[self.size] = Function([t],
Match(t, [rose_case]), self.nat(), [a])
def define_id(self):
"""Defines a function that return it's argument.
Signature: fn<a>(x : a) -> a
"""
self.id = GlobalVar("id")
a = TypeVar("a")
x = Var("x", a)
self.mod[self.id] = Function([x], x, a, [a])
def define_compose(self):
"""Defines a function that compose two function.
Signature: fn<a, b, c>(f : fn(b) -> c, g : fn(a) -> b) -> fn(a) -> c
"""
self.compose = GlobalVar("compose")
a = TypeVar("a")
b = TypeVar("b")
c = TypeVar("c")
f = Var("f", FuncType([b], c))
g = Var("g", FuncType([a], b))
x = Var("x")
self.mod[self.compose] = Function([f, g],
Function([x], f(g(x))),
FuncType([a], c),
[a, b, c])
def define_iterate(self):
"""Define a function that take a number n, a function f,
and return a closure that apply f n time on it's argument.
Signature: fn<a>(n : nat, f : fn(a) -> a) -> fn(a) -> a
"""
self.iterate = GlobalVar("iterate")
a = TypeVar("a")
f = Var("f", FuncType([a], a))
x = Var("x", self.nat())
y = Var("y", self.nat())
z = Var("z")
z_case = Clause(PatternConstructor(self.z), Function([z], z))
# todo: fix typechecker so Function([z], z) can be replaced by self.id
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
self.compose(f, self.iterate(f, y)))
self.mod[self.iterate] = Function([f, x],
Match(x, [z_case, s_case]),
FuncType([a], a),
[a])
def __init__(self, mod):
self.mod = mod
self.define_list_adt()
......@@ -377,3 +430,7 @@ class Prelude:
self.define_tree_adt()
self.define_tree_map()
self.define_tree_size()
self.define_id()
self.define_compose()
self.define_iterate()
......@@ -83,6 +83,7 @@ void ModuleNode::Add(const GlobalVar& var,
CHECK(AlphaEqual(type, old_type))
<< "Module#update changes type, not possible in this mode.";
}
var->checked_type_ = type;
AddUnchecked(var, checked_func);
}
......
......@@ -400,11 +400,8 @@ Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef&) {
}
void TypeSolver::ReportError(const Error& err, const NodeRef& location) {
this->err_reporter_->ReportAt(
this->current_func,
location,
err);
}
err_reporter_->ReportAt(current_func, location, err);
}
// Add type constraint to the solver.
void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) {
......
......@@ -43,6 +43,9 @@ rose = p.rose
tmap = p.tmap
size = p.size
compose = p.compose
iterate = p.iterate
# this is an example of using the adt value in python side
def count(n):
assert isinstance(n, ConstructorValue)
......@@ -93,6 +96,7 @@ def tree_to_dict(t):
def test_nat_value():
assert count(make_nat(10)) == 10
assert count(intrp.evaluate(s(s(z())))) == 2
def test_nat_constructor():
......@@ -577,6 +581,17 @@ def test_nested_pattern_match():
assert count(res) == 2
def test_compose():
n = relay.Var('n')
inc = relay.Function([n], s(n))
x = relay.Var('x')
res = intrp.evaluate(relay.Call(compose(inc, double), [s(s(z()))]))
assert count(res) == 5
def test_iterate():
expr = relay.Call(iterate(double, build_nat(2)), [build_nat(3)])
res = intrp.evaluate(relay.Function([], expr)())
assert count(res) == 12
if __name__ == "__main__":
test_nat_constructor()
......@@ -598,3 +613,5 @@ if __name__ == "__main__":
test_sum()
test_tmap()
test_size()
test_compose()
test_iterate()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment