test_pass_simplify.py 5.81 KB
Newer Older
1 2
import tvm
import numpy
3 4
from tvm import comm_reducer
from tvm.ir_pass import Simplify, CanonicalSimplify, Equal
5 6 7 8

def test_simplify():
    """Not yet working, mock design"""
    dtype = 'int64'
9 10 11 12
    n = tvm.var('n')
    Ab = tvm.decl_buffer((n, ), dtype)
    i = tvm.var('i')
    j = tvm.var('j')
13 14 15 16 17 18 19 20 21 22 23
    # for i in 0 to n-1:
    stmt = tvm.make.For(
        i, 2, n, 0, 0,
        tvm.make.For(j, 0, n, 0, 0,
                     tvm.make.IfThenElse(
                         tvm.make.LT(i + 2, n),
                         tvm.make.Store(Ab.data,
                                        tvm.make.Load(dtype, Ab.data, i + 4) + 1,
                                        (j + 1) * 4 - 4 * j + i),
                         None)))
    stmt = tvm.ir_pass.CanonicalSimplify(stmt)
24 25 26


def test_basic():
27
    m = tvm.var('m')
28 29 30
    ret = tvm.ir_pass.CanonicalSimplify(tvm.make.Evaluate(m-1))
    assert str(ret.value) == "(m - 1)"

31

32 33
def test_bound():
    m = tvm.var('m')
34
    vrange = tvm.convert({m: tvm.Range(tvm.const(0, "int32"), tvm.const(10, "int32"))})
35 36 37
    ret = tvm.ir_pass.Simplify(m % 10, vrange)
    assert ret == m

38 39
def test_canonical():
    x = tvm.var("x")
40
    z = tvm.const(3, "int32")
41 42 43 44 45 46
    ret = tvm.ir_pass.CanonicalSimplify(x / (z*z) - x / (z*z))
    assert(tvm.ir_pass.Equal(ret, 0))

    ret = tvm.ir_pass.CanonicalSimplify(x / (z+z) - x / (z+z))
    assert(tvm.ir_pass.Equal(ret, 0))

47 48 49 50 51 52 53 54 55 56
    #make sure terms are ordered based on their top operators (e.g., / always precedes %)
    ret1 = tvm.ir_pass.CanonicalSimplify(x % 3 + x / 3)
    ret2 = tvm.ir_pass.CanonicalSimplify(x / 3 + x % 3)
    assert(tvm.ir_pass.Equal(ret1, ret2))

    #when top operators match, compare string representation of terms
    ret1 = tvm.ir_pass.CanonicalSimplify(x % 4 + x % 3)
    ret2 = tvm.ir_pass.CanonicalSimplify(x % 3 + x % 4)
    assert (tvm.ir_pass.Equal(ret1, ret2))

57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136

def test_simplify_combiner():
    dummy = tvm.var('dummy')

    prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0))

    sum_or_prod = comm_reducer(lambda x, y: tvm.expr.Select(dummy < 0,
                                                            x + y, x*y),
                               lambda t0: tvm.expr.Select(dummy < 0,
                                                          tvm.const(0, t0), tvm.const(1, t0)))

    sum_and_prod = comm_reducer(lambda x, y: (x[0] + y[0],
                                              x[1]*y[1]),
                                lambda t0, t1: (tvm.const(0, t0),
                                                tvm.const(5, t0) - tvm.const(4, t0)))

    sum_and_prod2 = comm_reducer(lambda x, y: (x[0] + y[0],
                                               x[1]*y[1] + 0*x[0] + y[0] - y[0]),
                                 lambda t0, t1: (tvm.const(5, t0) - tvm.const(5, t0),
                                                 tvm.const(1, t1)))

    some_reducer1 = comm_reducer(lambda x, y: (x[0] + y[0],
                                               x[0] + y[0] + x[1] + y[1],
                                               x[0]*y[2] + y[0]*x[2],
                                               x[1] + y[2],
                                               4.0),
                                 lambda t0, t1, t2, t3, t4: (tvm.const(0, t0),
                                                             tvm.const(1, t1),
                                                             tvm.const(2, t2),
                                                             tvm.const(3, t3),
                                                             tvm.const(4, t4)))

    k = tvm.reduce_axis((0, 10), name="k")
    A = tvm.placeholder((10,), name='A')

    # Test that SimplifyCombiner makes use of vranges
    vrange = {dummy: tvm.Range(-10, -5)}
    assert Equal(Simplify(sum_or_prod(A[k], k), vrange), tvm.sum(A[k], k))
    vrange = {dummy: tvm.Range(5, 10)}
    assert Equal(Simplify(sum_or_prod(A[k], k), vrange), prod(A[k], k))

    assert Equal(Simplify(sum_and_prod((A[k], A[10-k]), k)[0]), tvm.sum(A[k], k))
    assert Equal(Simplify(sum_and_prod((A[k], A[10-k]), k)[1]), prod(A[10-k], k))

    assert Equal(Simplify(sum_and_prod2((A[k], A[10-k]), k)[0]), tvm.sum(A[k], k))
    assert Equal(Simplify(sum_and_prod2((A[k], A[10-k]), k)[1]), prod(A[10-k], k))

    reference_simplified_sources = [[A[0]],
                                    [A[0], A[1]],
                                    [A[0], A[2]],
                                    [A[0], A[1], A[2], A[3]],
                                    [A[4]]]
    for j in range(5):
        # Here we use the j-th component of the result, so only it and the components it
        # depends on are left.
        simplified = Simplify(some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j])

        # Check that the remaining components are the expected ones.
        for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]):
            assert Equal(lhs, rhs)

    # Test that components with side effects are not removed
    side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call.Intrinsic, None, 0)
    assert Equal(Simplify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0]),
                 sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
    assert Equal(Simplify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0]),
                 tvm.sum(side_effect(A[k]), k))


def test_simplify_reduce():
    k = tvm.reduce_axis((0, 10), name="k")
    j = tvm.reduce_axis((-5, 3), name="j")
    A = tvm.placeholder((10,), name='A')

    assert Equal(Simplify(tvm.sum(k/10, k)), tvm.sum(tvm.const(0, "int32"), k))
    assert Equal(Simplify(tvm.sum(A[3], [])), A[3])
    assert Equal(Simplify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j])),
                 tvm.sum(k + j, [k, j]))


137
if __name__ == "__main__":
138
    test_bound()
139
    test_basic()
140
    test_simplify()
141
    test_canonical()
142 143
    test_simplify_combiner()
    test_simplify_reduce()