import numpy as np
import tvm
from tvm import relay


def test_fold_const():
    c_data = np.array([1, 2, 3]).astype("float32")
    def before():
        c = relay.const(c_data)
        x = relay.var("x")
        y = relay.add(c, c)
        y = relay.multiply(y, relay.const(2, "float32"))
        y = relay.add(x, y)
        z = relay.add(y, c)
        return relay.Function([x], z)

    def expected():
        x = relay.var("x")
        c_folded = (c_data + c_data) * 2
        y = relay.add(x, relay.const(c_folded))
        z = relay.add(y, relay.const(c_data))
        return relay.Function([x], z)

    def fail(x):
        raise RuntimeError()
    # the fold constant should work on any context.
    with tvm.build_config(add_lower_pass=[(0, fail)]):
        with tvm.target.create("cuda"):
            zz = relay.ir_pass.fold_constant(before())
    zexpected = expected()
    assert relay.ir_pass.alpha_equal(zz, zexpected)


def test_fold_let():
    c_data = np.array(1).astype("float32")
    def before():
        sb = relay.ScopeBuilder()
        x = relay.var("x")
        t1 = sb.let("t1", relay.const(c_data))
        t2 = sb.let("t2", relay.add(t1, t1))
        t3 = sb.let("t3", relay.add(t2, x))
        sb.ret(t3)
        return relay.Function([x], sb.get())

    def expected():
        sb = relay.ScopeBuilder()
        x = relay.var("x")
        c_folded = (c_data + c_data)
        t3 = sb.let("t3", relay.add(relay.const(c_folded), x))
        sb.ret(t3)
        return relay.Function([x], sb.get())

    zz = relay.ir_pass.fold_constant(before())
    zexpected = expected()
    assert relay.ir_pass.graph_equal(zz, zexpected)


def test_fold_tuple():
    c_data = np.array(1).astype("float32")
    def before():
        c = relay.const(c_data)
        x = relay.var("x")
        y = relay.Tuple([x, c])
        z = relay.add(y[1], c)
        z = relay.add(z, y[0])
        return relay.Function([x], z)

    def expected():
        c = relay.const(c_data + c_data)
        x = relay.var("x")
        z = relay.add(c, x)
        return relay.Function([x], z)

    zz = relay.ir_pass.fold_constant(before())
    zexpected = expected()
    assert relay.ir_pass.graph_equal(zz, zexpected)


def test_fold_concat():
    c_data = np.array([[1, 2, 3]]).astype("float32")

    def before():
        a = relay.const(c_data)
        b = relay.const(c_data)
        y = relay.concatenate((a, b), axis=0)
        return relay.Function([], y)

    def expected():
        y_data = np.concatenate((c_data, c_data), axis=0)
        y = relay.const(y_data)
        return relay.Function([], y)

    zz = relay.ir_pass.fold_constant(before())
    zexpected = expected()
    assert relay.ir_pass.graph_equal(zz, zexpected)


if __name__ == "__main__":
    test_fold_const()
    test_fold_let()
    test_fold_tuple()
    test_fold_concat()