Commit 97be70a0 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] add more function to prelude (#2660)

parent 4ba30478
...@@ -250,12 +250,15 @@ class Interpreter(Executor): ...@@ -250,12 +250,15 @@ class Interpreter(Executor):
The optimized expression. The optimized expression.
""" """
# TODO: We need to move this optimization code into the optimizer/pass manager # TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=self.mod) wrapped_expr = expr if isinstance(expr, Function) else Function([], expr)
if self.mod:
self.mod[self.mod.entry_func] = wrapped_expr
ck_expr = ir_pass.infer_type(wrapped_expr, mod=self.mod)
simp_expr = ir_pass.simplify_inference(ck_expr) simp_expr = ir_pass.simplify_inference(ck_expr)
ck_simp = ir_pass.infer_type(simp_expr, mod=self.mod) ck_simp = ir_pass.infer_type(simp_expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(ck_simp) fused_expr = ir_pass.fuse_ops(ck_simp)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod) ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused return ck_fused if isinstance(expr, Function) else Call(ck_fused, [])
def _make_executor(self, expr): def _make_executor(self, expr):
def _interp_wrapper(*args, **kwargs): def _interp_wrapper(*args, **kwargs):
......
...@@ -340,7 +340,10 @@ class Prelude: ...@@ -340,7 +340,10 @@ class Prelude:
Match(t, [rose_case]), self.tree(b), [a, b]) Match(t, [rose_case]), self.tree(b), [a, b])
def define_tree_size(self): def define_tree_size(self):
"""Defines a function that computes the size of a tree as a nat.""" """Defines a function that computes the size of a tree as a nat.
Signature: fn<a>(t : tree[a]) -> nat
"""
self.size = GlobalVar("size") self.size = GlobalVar("size")
a = TypeVar("a") a = TypeVar("a")
t = Var("t", self.tree(a)) t = Var("t", self.tree(a))
...@@ -351,6 +354,56 @@ class Prelude: ...@@ -351,6 +354,56 @@ class Prelude:
self.mod[self.size] = Function([t], self.mod[self.size] = Function([t],
Match(t, [rose_case]), self.nat(), [a]) Match(t, [rose_case]), self.nat(), [a])
def define_id(self):
"""Defines a function that return it's argument.
Signature: fn<a>(x : a) -> a
"""
self.id = GlobalVar("id")
a = TypeVar("a")
x = Var("x", a)
self.mod[self.id] = Function([x], x, a, [a])
def define_compose(self):
"""Defines a function that compose two function.
Signature: fn<a, b, c>(f : fn(b) -> c, g : fn(a) -> b) -> fn(a) -> c
"""
self.compose = GlobalVar("compose")
a = TypeVar("a")
b = TypeVar("b")
c = TypeVar("c")
f = Var("f", FuncType([b], c))
g = Var("g", FuncType([a], b))
x = Var("x")
self.mod[self.compose] = Function([f, g],
Function([x], f(g(x))),
FuncType([a], c),
[a, b, c])
def define_iterate(self):
"""Define a function that take a number n, a function f,
and return a closure that apply f n time on it's argument.
Signature: fn<a>(n : nat, f : fn(a) -> a) -> fn(a) -> a
"""
self.iterate = GlobalVar("iterate")
a = TypeVar("a")
f = Var("f", FuncType([a], a))
x = Var("x", self.nat())
y = Var("y", self.nat())
z = Var("z")
z_case = Clause(PatternConstructor(self.z), Function([z], z))
# todo: fix typechecker so Function([z], z) can be replaced by self.id
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
self.compose(f, self.iterate(f, y)))
self.mod[self.iterate] = Function([f, x],
Match(x, [z_case, s_case]),
FuncType([a], a),
[a])
def __init__(self, mod): def __init__(self, mod):
self.mod = mod self.mod = mod
self.define_list_adt() self.define_list_adt()
...@@ -377,3 +430,7 @@ class Prelude: ...@@ -377,3 +430,7 @@ class Prelude:
self.define_tree_adt() self.define_tree_adt()
self.define_tree_map() self.define_tree_map()
self.define_tree_size() self.define_tree_size()
self.define_id()
self.define_compose()
self.define_iterate()
...@@ -83,6 +83,7 @@ void ModuleNode::Add(const GlobalVar& var, ...@@ -83,6 +83,7 @@ void ModuleNode::Add(const GlobalVar& var,
CHECK(AlphaEqual(type, old_type)) CHECK(AlphaEqual(type, old_type))
<< "Module#update changes type, not possible in this mode."; << "Module#update changes type, not possible in this mode.";
} }
var->checked_type_ = type;
AddUnchecked(var, checked_func); AddUnchecked(var, checked_func);
} }
......
...@@ -400,11 +400,8 @@ Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef&) { ...@@ -400,11 +400,8 @@ Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef&) {
} }
void TypeSolver::ReportError(const Error& err, const NodeRef& location) { void TypeSolver::ReportError(const Error& err, const NodeRef& location) {
this->err_reporter_->ReportAt( err_reporter_->ReportAt(current_func, location, err);
this->current_func, }
location,
err);
}
// Add type constraint to the solver. // Add type constraint to the solver.
void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) { void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) {
......
...@@ -43,6 +43,9 @@ rose = p.rose ...@@ -43,6 +43,9 @@ rose = p.rose
tmap = p.tmap tmap = p.tmap
size = p.size size = p.size
compose = p.compose
iterate = p.iterate
# this is an example of using the adt value in python side # this is an example of using the adt value in python side
def count(n): def count(n):
assert isinstance(n, ConstructorValue) assert isinstance(n, ConstructorValue)
...@@ -93,6 +96,7 @@ def tree_to_dict(t): ...@@ -93,6 +96,7 @@ def tree_to_dict(t):
def test_nat_value(): def test_nat_value():
assert count(make_nat(10)) == 10 assert count(make_nat(10)) == 10
assert count(intrp.evaluate(s(s(z())))) == 2
def test_nat_constructor(): def test_nat_constructor():
...@@ -577,6 +581,17 @@ def test_nested_pattern_match(): ...@@ -577,6 +581,17 @@ def test_nested_pattern_match():
assert count(res) == 2 assert count(res) == 2
def test_compose():
n = relay.Var('n')
inc = relay.Function([n], s(n))
x = relay.Var('x')
res = intrp.evaluate(relay.Call(compose(inc, double), [s(s(z()))]))
assert count(res) == 5
def test_iterate():
expr = relay.Call(iterate(double, build_nat(2)), [build_nat(3)])
res = intrp.evaluate(relay.Function([], expr)())
assert count(res) == 12
if __name__ == "__main__": if __name__ == "__main__":
test_nat_constructor() test_nat_constructor()
...@@ -598,3 +613,5 @@ if __name__ == "__main__": ...@@ -598,3 +613,5 @@ if __name__ == "__main__":
test_sum() test_sum()
test_tmap() test_tmap()
test_size() test_size()
test_compose()
test_iterate()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment