Commit c0a5a9be by Wei Chen Committed by Tianqi Chen

[Relay] Add hd,tl,nth for list in Prelude (#2771)

parent 37414470
......@@ -18,6 +18,50 @@ class Prelude:
self.cons = Constructor("cons", [a, self.l(a)], self.l)
self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons])
def define_list_hd(self):
"""Defines a function to get the head of a list. Assume the list has at least one
element.
hd(l) : list[a] -> a
"""
self.hd = GlobalVar("hd")
a = TypeVar("a")
x = Var("x", self.l(a))
y = Var("y")
z = Var("z")
# Don't match nil() since it will break type checking
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y)
self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a])
def define_list_tl(self):
"""Defines a function to get the tail of a list.
tl(l) : list[a] -> list[a]
"""
self.tl = GlobalVar("tl")
a = TypeVar("a")
x = Var("x", self.l(a))
y = Var("y")
z = Var("z")
nil_case = Clause(PatternConstructor(self.nil, []), self.nil())
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), z)
self.mod[self.tl] = Function([x], Match(x, [nil_case, cons_case]), self.l(a), [a])
def define_list_nth(self):
"""Defines a function to get the nth element of a list.
nth(l) : list[a] -> a
"""
self.nth = GlobalVar("nth")
a = TypeVar("a")
x = Var("x", self.l(a))
n = Var("n", self.nat())
y = Var("y")
z_case = Clause(PatternConstructor(self.z), self.hd(x))
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y))
self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a])
def define_list_map(self):
"""Defines a function for mapping a function over a list's
elements. That is, map(f, l) returns a new list where
......@@ -405,6 +449,8 @@ class Prelude:
def __init__(self, mod):
self.mod = mod
self.define_list_adt()
self.define_list_hd()
self.define_list_tl()
self.define_list_map()
self.define_list_foldl()
self.define_list_foldr()
......@@ -423,6 +469,7 @@ class Prelude:
self.define_nat_double()
self.define_nat_add()
self.define_list_length()
self.define_list_nth()
self.define_list_sum()
self.define_tree_adt()
......
......@@ -23,6 +23,9 @@ none = p.none
nil = p.nil
cons = p.cons
l = p.l
hd = p.hd
tl = p.tl
nth = p.nth
length = p.length
map = p.map
foldl = p.foldl
......@@ -120,6 +123,30 @@ def test_list_constructor():
a = relay.TypeVar("a")
assert relay.ir_pass.infer_type(cons(z(), nil()), mod).checked_type == l(nat())
def test_hd_tl():
expected = list(range(10))
l = nil()
for i in reversed(expected):
l = cons(build_nat(i), l)
got = []
for i in range(len(expected)):
got.append(count(intrp.evaluate(hd(l))))
l = tl(l)
assert got == expected
def test_nth():
expected = list(range(10))
l = nil()
for i in reversed(expected):
l = cons(build_nat(i), l)
got = []
for i in range(len(expected)):
got.append(count(intrp.evaluate(nth(l, build_nat(i)))))
assert got == expected
def test_length():
a = relay.TypeVar("a")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment