# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay.backend.interpreter import ConstructorValue
from tvm.relay import create_executor
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr

mod = relay.Module()
p = Prelude(mod)
add_nat_definitions(p)

def count(e):
    return count_(p, e)

ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")

z = p.z
s = p.s
nat = p.nat
double = p.double
add = p.add

optional = p.optional
some = p.some
none = p.none

nil = p.nil
cons = p.cons
l = p.l
hd = p.hd
tl = p.tl
nth = p.nth
update = p.update
length = p.length
map = p.map
foldl = p.foldl
foldr = p.foldr
foldr1 = p.foldr1
sum = p.sum

concat = p.concat
filter = p.filter
zip = p.zip
rev = p.rev
unfoldl = p.unfoldl
unfoldr = p.unfoldr
map_accumr = p.map_accumr
map_accuml = p.map_accuml

tree = p.tree
rose = p.rose
tmap = p.tmap
size = p.size

compose = p.compose
iterate = p.iterate

# this is an example of creating the adt value in python side
def make_nat(n):
    if n != 0:
        return ConstructorValue(s, [make_nat(n - 1)], [])
    else:
        return ConstructorValue(z, [], [])

def make_nat_expr(n):
    assert n >= 0
    ret = z()
    while n > 0:
        ret = s(ret)
        n = n - 1
    return ret

def to_list(l):
    assert isinstance(l, ConstructorValue)
    val = l
    ret = []
    while True:
        if val.tag == p.cons.tag:
            ret.append(val.fields[0])
            val = val.fields[1]
        else:
            assert val.tag == p.nil.tag
            break
    return ret

def tree_to_dict(t):
    assert isinstance(t, ConstructorValue)
    ret = {}
    assert t.tag == p.rose.tag
    ret['member'] = t.fields[0]
    ret['children'] = []
    for subtree in to_list(t.fields[1]):
        l = tree_to_dict(subtree)
        ret['children'].append(l)
    return ret


# turns a scalar-valued relay tensor value into a python number
def get_scalar(tv):
    return tv.asnumpy().item()


def test_nat_value():
    assert count(make_nat_value(p, 10)) == 10
    assert count(intrp.evaluate(s(s(z())))) == 2


def test_nat_constructor():
    func = relay.Function([], z())
    test_z = relay.GlobalVar("test_z")
    mod[test_z] = func
    assert mod[test_z].body.checked_type == nat()
    test_sz = relay.GlobalVar("test_sz")
    func = relay.Function([], s(z()))
    mod[test_sz] = func
    assert mod[test_sz].body.checked_type == nat()


def test_double():
    assert mod[double].checked_type == relay.FuncType([nat()], nat())
    res = intrp.evaluate(double(s(z())))
    assert count(res) == 2


def test_add():
    assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
    res = intrp.evaluate(add(s(z()), s(z())))
    assert count(res) == 2


def test_list_constructor():
    test_consz = relay.GlobalVar("test_consz")
    func = relay.Function([], cons(z(), nil()))
    mod[test_consz] = func
    assert mod[test_consz].body.checked_type == l(nat())

def test_hd_tl():
    expected = list(range(10))
    l = nil()
    for i in reversed(expected):
        l = cons(make_nat_expr(i), l)

    got = []
    for i in range(len(expected)):
        got.append(count(intrp.evaluate(hd(l))))
        l = tl(l)

    assert got == expected

def test_nth():
    expected = list(range(10))
    l = nil()
    for i in reversed(expected):
        l = cons(relay.const(i), l)

    for i in range(len(expected)):
        item = intrp.evaluate(nth(l, relay.const(i)))
        assert get_scalar(item) == i


def test_update():
    expected = list(range(10))
    l = nil()
    # create zero initialized list
    for i in range(len(expected)):
        l = cons(make_nat_expr(0), l)

    # set value
    for i, v in enumerate(expected):
        l = update(l, relay.const(i), make_nat_expr(v))

    got = []
    for i in range(len(expected)):
        got.append(count(intrp.evaluate(nth(l, relay.const(i)))))

    assert got == expected

def test_length():
    a = relay.TypeVar("a")
    assert mod[length].checked_type == relay.FuncType([l(a)], relay.scalar_type('int32'), [a])
    res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil())))))
    assert get_scalar(res) == 3


def test_map():
    a = relay.TypeVar("a")
    b = relay.TypeVar("b")
    lhs = mod[map].checked_type
    rhs = relay.FuncType([relay.FuncType([a], b), l(a)], l(b), [a, b])
    assert lhs == rhs

    x = relay.Var("x")
    add_one = relay.Function([x], s(x))
    res = intrp.evaluate(map(add_one, cons(z(), cons(z(), nil()))))
    ones = to_list(res)
    assert len(ones) == 2
    assert count(ones[0]) == 1 and count(ones[1]) == 1


def test_foldl():
    a = relay.TypeVar("a")
    b = relay.TypeVar("b")
    lhs = mod[foldl].checked_type
    rhs = relay.FuncType([relay.FuncType([a, b], a), a, l(b)], a, [a, b])
    assert lhs == rhs

    x = relay.Var("x")
    y = relay.Var("y")
    rev_dup = relay.Function([y, x], cons(x, cons(x, y)))
    res = intrp.evaluate(foldl(rev_dup, nil(),
                               cons(make_nat_expr(1),
                                    cons(make_nat_expr(2),
                                         cons(make_nat_expr(3), nil())))))
    reversed = to_list(res)
    assert len(reversed) == 6
    assert count(reversed[0]) == 3 and count(reversed[1]) == 3
    assert count(reversed[2]) == 2 and count(reversed[3]) == 2
    assert count(reversed[4]) == 1 and count(reversed[5]) == 1


def test_foldr():
    a = relay.TypeVar("a")
    b = relay.TypeVar("b")
    lhs = mod[foldr].checked_type
    rhs = relay.FuncType([relay.FuncType([a, b], b), b, l(a)], b, [a, b])
    assert lhs == rhs

    x = relay.Var("x")
    y = relay.Var("y")
    identity = relay.Function([x, y], cons(x, y))
    res = intrp.evaluate(foldr(identity, nil(),
                               cons(make_nat_expr(1),
                                    cons(make_nat_expr(2),
                                         cons(make_nat_expr(3), nil())))))
    same = to_list(res)
    assert len(same) == 3
    assert count(same[0]) == 1 and count(same[1]) == 2 and count(same[2]) == 3


def test_foldr1():
    a = relay.TypeVar("a")
    lhs = mod[p.foldr1].checked_type
    rhs = relay.FuncType([relay.FuncType([a, a], a), l(a)], a, [a])
    assert lhs == rhs

    x = relay.Var("x")
    y = relay.Var("y")
    f = relay.Function([x, y], add(x, y))
    res = intrp.evaluate(foldr1(f,
                                cons(make_nat_expr(1),
                                    cons(make_nat_expr(2),
                                         cons(make_nat_expr(3), nil())))))

    assert count(res) == 6


def test_sum():
    assert mod[sum].checked_type == relay.FuncType([l(relay.scalar_type('int32'))], relay.scalar_type('int32'))
    res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2), nil()))))
    assert get_scalar(res) == 3


def test_concat():
    a = relay.TypeVar("a")
    assert mod[concat].checked_type == relay.FuncType([l(a), l(a)], l(a), [a])

    l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), nil()))
    l2 = cons(make_nat_expr(3), cons(make_nat_expr(4), nil()))
    res = intrp.evaluate(concat(l1, l2))

    catted = to_list(res)
    assert len(catted) == 4
    assert count(catted[0]) == 1
    assert count(catted[1]) == 2
    assert count(catted[2]) == 3
    assert count(catted[3]) == 4


def test_filter():
    a = relay.TypeVar("a")
    expected_type = relay.FuncType([
        relay.FuncType([a], relay.scalar_type("bool")), l(a)
    ], l(a), [a])
    assert mod[filter].checked_type == expected_type

    x = relay.Var("x", nat())
    greater_than_one = relay.Function(
        [x],
        relay.Match(x, [
            relay.Clause(
                relay.PatternConstructor(s, [
                    relay.PatternConstructor(
                        s, [relay.PatternWildcard()])
                ]),
                relay.const(True)),
            relay.Clause(relay.PatternWildcard(), relay.const(False))
        ]))
    res = intrp.evaluate(
        filter(greater_than_one,
               cons(make_nat_expr(1),
                    cons(make_nat_expr(1),
                         cons(make_nat_expr(3),
                              cons(make_nat_expr(1),
                                   cons(make_nat_expr(5),
                                        cons(make_nat_expr(1),
                                             nil()))))))))
    filtered = to_list(res)
    assert len(filtered) == 2
    assert count(filtered[0]) == 3
    assert count(filtered[1]) == 5


def test_zip():
    a = relay.TypeVar("a")
    b = relay.TypeVar("b")
    expected_type = relay.FuncType([l(a), l(b)],
                                   l(relay.TupleType([a, b])), [a, b])
    assert mod[zip].checked_type == expected_type

    l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
    l2 = cons(nil(),
              cons(cons(nil(), nil()),
                   cons(cons(nil(), cons(nil(), nil())),
                        nil())))

    res = intrp.evaluate(zip(l1, l2))
    zipped = to_list(res)
    assert len(zipped) == 3
    assert count(zipped[0][0]) == 1
    assert len(to_list(zipped[0][1])) == 0
    assert count(zipped[1][0]) == 2
    assert len(to_list(zipped[1][1])) == 1
    assert count(zipped[2][0]) == 3
    assert len(to_list(zipped[2][1])) == 2

    # test truncation
    l3 = cons(make_nat_expr(4), cons(make_nat_expr(5), nil()))
    shorter_res = intrp.evaluate(zip(l3, l2))
    truncated = to_list(shorter_res)
    assert len(truncated) == 2
    assert count(truncated[0][0]) == 4
    assert len(to_list(truncated[0][1])) == 0
    assert count(truncated[1][0]) == 5
    assert len(to_list(truncated[1][1])) == 1

    l4 = cons(nil(), nil())
    shortest_res = intrp.evaluate(zip(l3, l4))
    singleton = to_list(shortest_res)
    assert len(singleton) == 1
    assert count(singleton[0][0]) == 4
    assert len(to_list(singleton[0][1])) == 0


def test_rev():
    a = relay.TypeVar("a")
    assert mod[rev].checked_type == relay.FuncType([l(a)], l(a), [a])

    res = intrp.evaluate(rev(cons(make_nat_expr(1),
                                  cons(make_nat_expr(2),
                                       cons(make_nat_expr(3), nil())))))
    reversed = to_list(res)

    assert len(reversed) == 3
    assert count(reversed[0]) == 3
    assert count(reversed[1]) == 2
    assert count(reversed[2]) == 1


def test_unfoldr():
    a = relay.TypeVar("a")
    b = relay.TypeVar("b")
    expected_type = relay.FuncType([
        relay.FuncType([a], optional(relay.TupleType([a, b]))), a],
                                   l(b), [a, b])

    x = relay.Var("x", nat())
    n = relay.Var("n", nat())
    count_down = relay.Function(
        [x],
        relay.Match(x, [
            relay.Clause(relay.PatternConstructor(
                s, [relay.PatternVar(n)]),
                         some(relay.Tuple([n, x]))),
            relay.Clause(relay.PatternConstructor(z, []), none())
        ]))

    res = intrp.evaluate(unfoldr(count_down, make_nat_expr(3)))
    unfolded = to_list(res)

    assert len(unfolded) == 3
    assert count(unfolded[0]) == 3
    assert count(unfolded[1]) == 2
    assert count(unfolded[2]) == 1


def test_unfoldl():
    a = relay.TypeVar("a")
    b = relay.TypeVar("b")
    expected_type = relay.FuncType([
        relay.FuncType([a], optional(relay.TupleType([a, b]))), a],
                                   l(b), [a, b])

    x = relay.Var("x", nat())
    n = relay.Var("n", nat())
    count_down = relay.Function(
        [x],
        relay.Match(x, [
            relay.Clause(relay.PatternConstructor(
                s, [relay.PatternVar(n)]),
                         some(relay.Tuple([n, x]))),
            relay.Clause(relay.PatternConstructor(z, []), none())
        ]))

    res = intrp.evaluate(unfoldl(count_down, make_nat_expr(3)))
    unfolded = to_list(res)

    assert len(unfolded) == 3
    assert count(unfolded[0]) == 1
    assert count(unfolded[1]) == 2
    assert count(unfolded[2]) == 3


def test_map_accumr():
    a = relay.TypeVar("a")
    b = relay.TypeVar("b")
    c = relay.TypeVar("c")
    expected_type = relay.FuncType([
        relay.FuncType([a, b], relay.TupleType([a, c])),
        a, l(b)
    ], relay.TupleType([a, l(c)]), [a, b, c])
    assert mod[map_accumr].checked_type == expected_type

    acc = relay.Var("acc", nat())
    x = relay.Var("x", nat())
    add_acc_to_each = relay.Function([acc, x],
                                     relay.Tuple([add(x, acc),
                                                  add(x, acc)]))

    vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
    res = intrp.evaluate(map_accumr(add_acc_to_each, z(), vals))

    sum = count(res[0])
    new_vals = to_list(res[1])

    assert sum == 6
    assert len(new_vals) == 3
    assert count(new_vals[0]) == 6
    assert count(new_vals[1]) == 5
    assert count(new_vals[2]) == 3


def test_map_accuml():
    a = relay.TypeVar("a")
    b = relay.TypeVar("b")
    c = relay.TypeVar("c")
    expected_type = relay.FuncType([
        relay.FuncType([a, b], relay.TupleType([a, c])),
        a, l(b)
    ], relay.TupleType([a, l(c)]), [a, b, c])
    assert mod[map_accuml].checked_type == expected_type

    acc = relay.Var("acc", nat())
    x = relay.Var("x", nat())
    add_to_acc = relay.Function([acc, x],
                                relay.Tuple([add(x, acc), x]))

    vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
    res = intrp.evaluate(map_accuml(add_to_acc, z(), vals))

    sum = count(res[0])
    new_vals = to_list(res[1])

    assert sum == 6
    assert len(new_vals) == 3
    assert count(new_vals[0]) == 3
    assert count(new_vals[1]) == 2
    assert count(new_vals[2]) == 1


def test_optional_matching():
    x = relay.Var('x')
    y = relay.Var('y')
    v = relay.Var('v')
    condense = relay.Function(
        [x, y],
        relay.Match(x, [
            relay.Clause(relay.PatternConstructor(some, [relay.PatternVar(v)]), cons(v, y)),
            relay.Clause(relay.PatternConstructor(none), y)
        ]))

    res = intrp.evaluate(foldr(condense, nil(), cons(
        some(make_nat_expr(3)),
        cons(none(), cons(some(make_nat_expr(1)), nil())))))

    reduced = to_list(res)
    assert len(reduced) == 2
    assert count(reduced[0]) == 3
    assert count(reduced[1]) == 1


def test_tmap():
    a = relay.TypeVar("a")
    b = relay.TypeVar("b")
    lhs = mod[tmap].checked_type
    rhs = relay.FuncType([relay.FuncType([a], b), tree(a)], tree(b), [a, b])
    assert lhs == rhs

    x = relay.Var("x")
    add_one = relay.Function([x], s(x))
    res = intrp.evaluate(tmap(add_one,
                              rose(z(),
                                   cons(rose(z(), nil()),
                                        cons(rose(z(), nil()),
                                             nil())))))

    tree_dict = tree_to_dict(res)
    assert count(tree_dict['member']) == 1
    assert len(tree_dict['children']) == 2
    for subtree in tree_dict['children']:
        assert count(subtree['member']) == 1
        assert len(subtree['children']) == 0


def test_size():
    a = relay.TypeVar("a")
    lhs = mod[size].checked_type
    rhs = relay.FuncType([tree(a)], relay.scalar_type('int32'), [a])
    assert lhs == rhs

    root = rose(z(), cons(rose(z(), nil()),
                                  cons(rose(z(), nil()),
                                       nil())))
    t = rose(z(), cons(root, cons(root, cons(root, nil()))))
    res = intrp.evaluate(size(t))
    assert get_scalar(res) == 10


def test_wildcard_match_solo():
    x = relay.Var('x', nat())
    copy = relay.Function([x],
                          relay.Match(x, [relay.Clause(relay.PatternWildcard(), x)]),
                          nat())

    res = intrp.evaluate(copy(s(s(s(z())))))
    assert count(res) == 3


def test_wildcard_match_order():
    x = relay.Var('x', l(nat()))
    y = relay.Var('y')
    a = relay.Var('a')
    return_zero = relay.Function(
        [x],
        relay.Match(x, [
            relay.Clause(relay.PatternWildcard(), z()),
            relay.Clause(
                relay.PatternConstructor(
                    cons, [relay.PatternVar(y), relay.PatternVar(a)]),
                y),
            relay.Clause(relay.PatternConstructor(nil), s(z()))
        ]),
        nat())

    res = intrp.evaluate(return_zero(cons(s(z()), nil())))
    # wildcard pattern is evaluated first
    assert count(res) == 0


def test_nested_matches():
    a = relay.TypeVar('a')
    x = relay.Var('x')
    y = relay.Var('y')
    w = relay.Var('w')
    h = relay.Var('h')
    t = relay.Var('t')
    flatten = relay.GlobalVar('flatten')

    # flatten could be written using a fold, but this way has nested matches
    inner_match = relay.Match(
        y, [
            relay.Clause(relay.PatternConstructor(nil), flatten(w)),
            relay.Clause(relay.PatternConstructor(
                cons, [relay.PatternVar(h), relay.PatternVar(t)]),
                cons(h, flatten(cons(t, w))))
        ])

    mod[flatten] = relay.Function(
        [x],
        relay.Match(x, [
            relay.Clause(relay.PatternConstructor(nil), nil()),
            relay.Clause(relay.PatternConstructor(
                cons, [relay.PatternVar(y), relay.PatternVar(w)]),
                         inner_match)
        ]), l(a), [a])

    first_list = cons(make_nat_expr(1), cons(make_nat_expr(2),
                                         cons(make_nat_expr(3), nil())))
    second_list = cons(make_nat_expr(4), cons(make_nat_expr(5),
                                          cons(make_nat_expr(6), nil())))
    final_list = cons(first_list, cons(second_list, nil()))

    res = intrp.evaluate(flatten(final_list))

    flat = to_list(res)
    assert len(flat) == 6
    for i in range(6):
        assert count(flat[i]) == i + 1


def test_match_full_var():
    x = relay.Var('x')
    v = relay.Var('v')
    id_func = relay.Function([x],
                             relay.Match(x,
                                         [relay.Clause(relay.PatternVar(v),
                                                       v)]))

    res1 = intrp.evaluate(id_func(nil()))
    res2 = intrp.evaluate(id_func(cons(z(), cons(z(), nil()))))

    empty = to_list(res1)
    assert len(empty) == 0

    zeroes = to_list(res2)
    assert len(zeroes) == 2
    assert count(zeroes[0]) == 0
    assert count(zeroes[1]) == 0


def test_nested_pattern_match():
    x = relay.Var('x', l(nat()))
    h1 = relay.Var('h1')
    h2 = relay.Var('h2')
    t = relay.Var('t')
    match = relay.Match(
        x,
        [relay.Clause(
            relay.PatternConstructor(
                cons,
                [relay.PatternVar(h1),
                 relay.PatternConstructor(
                    cons,
                     [relay.PatternVar(h2), relay.PatternVar(t)])]),
            h2),
         relay.Clause(relay.PatternWildcard(), z())
        ])
    get_second = relay.Function([x], match)

    res = intrp.evaluate(get_second(cons(s(z()),
                                         cons(s(s(z())),
                                              nil()))))

    assert count(res) == 2


def test_compose():
    n = relay.Var('n')
    inc = relay.Function([n], s(n))
    x = relay.Var('x')
    res = intrp.evaluate(relay.Call(compose(inc, double), [s(s(z()))]))
    assert count(res) == 5


def test_iterate():
    expr = relay.Call(iterate(double, relay.const(2)), [make_nat_expr(3)])
    res = intrp.evaluate(relay.Function([], expr)())
    assert count(res) == 12


if __name__ == "__main__":
    test_nat_constructor()
    test_double()
    test_add()
    test_list_constructor()
    test_length()
    test_map()
    test_foldl()
    test_foldr()
    test_foldr1()
    test_concat()
    test_filter()
    test_zip()
    test_rev()
    test_unfoldl()
    test_unfoldr()
    test_map_accumr()
    test_map_accuml()
    test_sum()
    test_tmap()
    test_size()
    test_compose()
    test_iterate()