test_arith_simplify.py 4.46 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
import tvm

def csimplify(z):
    return tvm.ir_pass.CanonicalSimplify(
        tvm.make.Evaluate(z)).value

def test_simplify():
    x = tvm.var('n')
    z = x * 4  - x * 2
    zz = csimplify(z)
    assert zz.b.value == 2

    z = (x / 4) * 2  - (x / 4)
    zz = csimplify(z)
    assert zz.a == x and zz.b.value == 4

    z = (x % 4) * 3  + (x % 4)
    zz = csimplify(z)
    assert zz.b.value == 4
    zz = zz.a
    assert zz.a == x and zz.b.value == 4

23
    n = tvm.var('n')
24 25
    assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % (-1)), tvm.const(0, "int32"))
    assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % 1), tvm.const(0, "int32"))
26 27 28 29 30 31
    assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / 1), n)
    tvm.ir_pass.CanonicalSimplify(n / (-1))
    # This is not true in the current implementation
    #  assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / (-1)),
    #                           tvm.ir_pass.CanonicalSimplify(-n))

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
def test_simplify_div():
    x = tvm.var('x')
    assert tvm.ir_pass.CanonicalSimplify((16+48*x)/16 - (1 + (x*3))).value == 0
    # (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
    # (17+48*x)/16 != 1+3*x
    r = tvm.ir_pass.CanonicalSimplify((17+48*x)/16)
    assert r.b.value == 16
    assert tvm.ir_pass.CanonicalSimplify(r.a - (17 + 48*x)).value == 0
    # However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
    assert tvm.ir_pass.CanonicalSimplify((17+48*x)/16 - (1 + (x*3)), {x: tvm.Range(0,10)}).value == 0

    # Trying expressions that are not simplifiable for any values of the variables
    r = tvm.ir_pass.CanonicalSimplify((17+47*x)/16, {x: tvm.Range(0,10)})
    assert r.b.value == 16
    assert tvm.ir_pass.CanonicalSimplify(r.a - (17+47*x)).value == 0

    r = tvm.ir_pass.CanonicalSimplify((8*x - 17)/8, {x : tvm.Range(4,10)})
    assert tvm.ir_pass.CanonicalSimplify(r - (x-3)).value == 0
50

51 52 53 54 55 56 57 58 59 60 61 62
def test_simplify_mod():
    """Not yet working, mock design"""
    ib = tvm.ir_builder.create()
    n = tvm.var('n')
    j = tvm.var('j')
    A = ib.pointer("float32", name="A")
    with ib.for_range(0, 16, name="i") as i:
        A[i] = A[((n * 4 + j * 2) * 8 + i+1) % 16]
    body = ib.get()
    stmt = tvm.ir_pass.CanonicalSimplify(body)
    diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16)
    assert diff.value == 0
63
    # if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
64 65
    index = tvm.ir_pass.CanonicalSimplify(
        (j + n * 32) % 16, {j: tvm.Range(0, 6)})
66 67 68
    assert index != j
    index = tvm.ir_pass.CanonicalSimplify(
        (j + n * 32) % 16, {j: tvm.Range(0, 6), n: tvm.Range(0, 10)})
69 70
    assert index == j

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
def test_simplify_minmax():
    x = tvm.var('x')
    e1 = tvm.max(x, 1) - tvm.max(x, 1)
    e1s = tvm.ir_pass.CanonicalSimplify(e1)
    assert e1s.value == 0

    e2 = tvm.min(x, 1) - tvm.min(x, 1)
    e2s = tvm.ir_pass.CanonicalSimplify(e2)
    assert e2s.value == 0

def test_mul():
    x = tvm.var('x')
    e = x * x - x * x
    es = tvm.ir_pass.CanonicalSimplify(e)
    assert es.value == 0
86

87 88 89 90 91
def test_modular():
    rx = tvm.var("rx")
    ry = tvm.var("ry")
    y = tvm.var("y")
    x = tvm.var("x")
92 93 94 95 96
    i32_const = lambda x: tvm.const(x, "int32")
    vmap = {rx: tvm.Range(i32_const(0), i32_const(3)),
            ry: tvm.Range(i32_const(0), i32_const(3)),
            y: tvm.Range(i32_const(0), i32_const(2)),
            x: tvm.Range(i32_const(0), i32_const(14))}
97 98 99 100 101 102
    idx = ry * 16 + rx + y * 16 + x
    z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap)
    z2 = tvm.ir_pass.CanonicalSimplify(idx % 16, vmap)
    assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0
    assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0

103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
def test_const_propagation():
    x1 = tvm.const(4, "int32")
    x2 = x1 + 5
    assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9
    x3 = x2 / 3
    assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3
    x4 = x3 + 0.5
    assert isinstance(x4, tvm.expr.FloatImm) and x4.value == 3.5
    x5 = tvm.ceil(x4)
    assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4
    x6 = x5.astype('int')
    assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4
    y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int')
    assert isinstance(y, tvm.expr.IntImm) and y.value == 6


119
if __name__ == "__main__":
120
    test_simplify_div()
121
    test_simplify_mod()
122
    test_modular()
123
    test_simplify()
124
    test_mul()
125
    test_simplify_minmax()
126
    test_const_propagation()