Commit 2d5a0720 by Wei Chen Committed by Tianqi Chen

[Relay] Add list update to prelude (#2866)

parent 46924406
...@@ -62,6 +62,25 @@ class Prelude: ...@@ -62,6 +62,25 @@ class Prelude:
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y)) s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y))
self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a]) self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a])
def define_list_update(self):
"""Defines a function to update the nth element of a list and return the updated list.
update(l, i, v) : list[a] -> nat -> a -> list[a]
"""
self.update = GlobalVar("update")
a = TypeVar("a")
l = Var("l", self.l(a))
n = Var("n", self.nat())
v = Var("v", a)
y = Var("y")
z_case = Clause(PatternConstructor(self.z), self.cons(v, self.tl(l)))
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
self.cons(self.hd(l), self.update(self.tl(l), y, v)))
self.mod[self.update] = Function([l, n, v], Match(n, [z_case, s_case]), self.l(a), [a])
def define_list_map(self): def define_list_map(self):
"""Defines a function for mapping a function over a list's """Defines a function for mapping a function over a list's
elements. That is, map(f, l) returns a new list where elements. That is, map(f, l) returns a new list where
...@@ -470,6 +489,7 @@ class Prelude: ...@@ -470,6 +489,7 @@ class Prelude:
self.define_nat_add() self.define_nat_add()
self.define_list_length() self.define_list_length()
self.define_list_nth() self.define_list_nth()
self.define_list_update()
self.define_list_sum() self.define_list_sum()
self.define_tree_adt() self.define_tree_adt()
......
...@@ -26,6 +26,7 @@ l = p.l ...@@ -26,6 +26,7 @@ l = p.l
hd = p.hd hd = p.hd
tl = p.tl tl = p.tl
nth = p.nth nth = p.nth
update = p.update
length = p.length length = p.length
map = p.map map = p.map
foldl = p.foldl foldl = p.foldl
...@@ -148,6 +149,23 @@ def test_nth(): ...@@ -148,6 +149,23 @@ def test_nth():
assert got == expected assert got == expected
def test_update():
expected = list(range(10))
l = nil()
# create zero initialized list
for i in range(len(expected)):
l = cons(build_nat(0), l)
# set value
for i, v in enumerate(expected):
l = update(l, build_nat(i), build_nat(v))
got = []
for i in range(len(expected)):
got.append(count(intrp.evaluate(nth(l, build_nat(i)))))
assert got == expected
def test_length(): def test_length():
a = relay.TypeVar("a") a = relay.TypeVar("a")
assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a]) assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment