Commit 7cc9240a by lixiaoquan Committed by Tianqi Chen

[Relay] Add foldr1 (#2928)

parent 162eab44
......@@ -142,6 +142,29 @@ class Prelude:
self.mod[self.foldr] = Function([f, bv, av],
Match(av, [nil_case, cons_case]), b, [a, b])
def define_list_foldr1(self):
"""Defines a right-way fold over a nonempty list.
foldr1(f, l) : fn<a>(fn(a, a) -> a, list[a]) -> a
foldr1(f, cons(a1, cons(a2, cons(..., cons(an, nil)))))
evalutes to f(a1, f(a2, f(..., f(an-1, an)))...)
"""
self.foldr1 = GlobalVar("foldr1")
a = TypeVar("a")
f = Var("f", FuncType([a, a], a))
av = Var("av", self.l(a))
x = Var("x")
y = Var("y")
z = Var("z")
one_case = Clause(PatternConstructor(self.cons,
[PatternVar(x), PatternConstructor(self.nil)]), x)
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]),
f(y, self.foldr1(f, z)))
self.mod[self.foldr1] = Function([f, av],
Match(av, [one_case, cons_case]), a, [a])
def define_list_concat(self):
"""Defines a function that concatenates two lists.
......@@ -471,6 +494,7 @@ class Prelude:
self.define_list_map()
self.define_list_foldl()
self.define_list_foldr()
self.define_list_foldr1()
self.define_list_concat()
self.define_list_filter()
self.define_list_zip()
......
......@@ -144,7 +144,7 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
for (const auto& con : op->constructors) {
if (!con->belong_to.same_as(op->header)) {
ReportFatalError(RELAY_ERROR(con << " has header " << con->belong_to
<< " but " << op << "has header " << op->header));
<< " but " << op << " has header " << op->header));
}
for (const Type& t : con->inputs) {
......
......@@ -31,6 +31,7 @@ length = p.length
map = p.map
foldl = p.foldl
foldr = p.foldr
foldr1 = p.foldr1
sum = p.sum
concat = p.concat
......@@ -228,6 +229,23 @@ def test_foldr():
assert count(same[0]) == 1 and count(same[1]) == 2 and count(same[2]) == 3
def test_foldr1():
a = relay.TypeVar("a")
lhs = mod[p.foldr1].checked_type
rhs = relay.FuncType([relay.FuncType([a, a], a), l(a)], a, [a])
assert lhs == rhs
x = relay.Var("x")
y = relay.Var("y")
f = relay.Function([x, y], add(x, y))
res = intrp.evaluate(foldr1(f,
cons(build_nat(1),
cons(build_nat(2),
cons(build_nat(3), nil())))))
assert count(res) == 6
def test_sum():
assert mod[sum].checked_type == relay.FuncType([l(nat())], nat())
res = intrp.evaluate(sum(cons(build_nat(1), cons(build_nat(2), nil()))))
......@@ -647,6 +665,7 @@ if __name__ == "__main__":
test_map()
test_foldl()
test_foldr()
test_foldr1()
test_concat()
test_filter()
test_zip()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment