test_arith_modular_set.py 4.84 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
import tvm


def test_cast():
    analyzer = tvm.arith.Analyzer()
    x = tvm.var("x", dtype="int8")
    m = analyzer.modular_set((x * 3).astype("uint32"))
    assert m.coeff == 3
    assert m.base == 0
    m = analyzer.modular_set(
        (x * 3 + 1).astype("float32").astype("int32"))
    assert m.coeff == 3
    assert m.base == 1


def test_add_sub():
    analyzer = tvm.arith.Analyzer()
    x, y = tvm.var("x", "int64"), tvm.var("y", "int64")
    m = analyzer.modular_set(x * 6 + y * 4)
    assert m.coeff == 2
    assert m.base == 0

    analyzer.bind(y, x * 4 + 1)
    m = analyzer.modular_set(1 - y)
    assert m.coeff == 4
    assert m.base == 0


def test_mul():
    analyzer = tvm.arith.Analyzer()
    x, y = tvm.var("x"), tvm.var("y")
    m = analyzer.modular_set((x * 4 + 2) * (y * 6 + 1))
    assert m.coeff == 4
    assert m.base == 2


def test_div_shift():
    analyzer = tvm.arith.Analyzer()
    x, y = tvm.var("x"), tvm.var("y")
    # not sure if x is non-negative
57 58
    tdiv = tvm.truncdiv
    m = analyzer.modular_set(tdiv(x * 4 + 2, 2))
59 60 61 62 63 64
    assert m.coeff == 1
    assert m.base == 0
    # right shift always round down so it is fine
    m = analyzer.modular_set((x * 4 + 2) >> 1)
    assert m.coeff == 2
    assert m.base == 1
65 66 67 68
    fld = tvm.floordiv
    m = analyzer.modular_set(fld(x * 4 + 2, 2))
    assert m.coeff == 2
    assert m.base == 1
69 70
    # x is non-negative
    analyzer.update(x, tvm.arith.ConstIntBound(0, 100))
71
    m = analyzer.modular_set(tdiv(x * 4 + 2, 2))
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
    assert m.coeff == 2
    assert m.base == 1


def test_min_max_select():
    analyzer = tvm.arith.Analyzer()
    x, y = tvm.var("x"), tvm.var("y")
    m = analyzer.modular_set(tvm.min(x * 3, y * 9))
    assert m.coeff == 3
    assert m.base == 0

    m = analyzer.modular_set(tvm.max(x * 3 + 1, y * 9 + 4))
    assert m.coeff == 3
    assert m.base == 1

    m = analyzer.modular_set(tvm.expr.Select(x > 0, x * 3 + 1, y * 9 + 2))
    assert m.coeff == 1
    assert m.base == 0


def test_mix_index():
    a = tvm.var("a")
    b = tvm.var("b")
    analyzer = tvm.arith.Analyzer()
96
    tdiv = tvm.truncdiv
97 98 99 100 101 102 103 104
    m = analyzer.modular_set(a * 4 + b * 6 + 7)
    assert m.coeff == 2
    assert m.base == 1

    m = analyzer.modular_set((a * 4 + 1) * (b * 8 + 3))
    assert m.coeff == 4
    assert m.base == 3

105
    m = analyzer.modular_set(tdiv(a * 4 + 1, b * 8 + 3))
106 107 108
    assert m.coeff == 1
    assert m.base == 0

109
    m = analyzer.modular_set((a * 4 + 1) * tdiv(b * 8, 4))
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    assert m.coeff == 2
    assert m.base == 0

    m = analyzer.modular_set((a * 12 + 1) - (b * 3 * 7  + 2))
    assert m.coeff == 3
    assert m.base == 2

    m = analyzer.modular_set(a * 12 + tvm.min(b * 3 * 7, 2))
    assert m.coeff == 1
    assert m.base == 0


def test_constraint_scope():
    a = tvm.var("a")
    b = tvm.var("b")
    analyzer = tvm.arith.Analyzer()
126 127 128
    tmod = tvm.truncmod

    with analyzer.constraint_scope(tmod(b, 4) == 2):
129 130 131
        m = analyzer.modular_set(b + 1)
        assert m.coeff == 4
        assert m.base == 3
132
        with analyzer.constraint_scope(tmod(a, 2) == 1):
133 134 135 136 137 138 139 140 141 142 143
            m = analyzer.modular_set(b + a * 2)
            assert m.coeff == 4
            assert m.base == 0
        m = analyzer.modular_set(b + a * 2)
        assert m.coeff == 2
        assert m.base == 0

    m = analyzer.modular_set(b + 1)
    assert m.coeff == 1
    assert m.base == 0

144 145 146
def test_intersect():
    a = tvm.var("a")
    analyzer = tvm.arith.Analyzer()
147 148 149
    tmod = tvm.truncmod
    with analyzer.constraint_scope(tmod(a, 4) == 1):
        with analyzer.constraint_scope(tmod(a, 3) == 1):
150 151 152 153
            m = analyzer.modular_set(a)
            assert m.coeff == 12
            assert m.base == 1

154 155 156
    with analyzer.constraint_scope(tmod(a, 3) == 2):
        with analyzer.constraint_scope(tmod(a, 5) == 3):
            with analyzer.constraint_scope(tmod(a, 7) == 2):
157 158 159 160
                m = analyzer.modular_set(a)
                assert m.coeff == 105
                assert m.base == 23

161 162 163 164 165 166 167 168 169

if __name__ == "__main__":
    test_cast()
    test_add_sub()
    test_mul()
    test_div_shift()
    test_min_max_select()
    test_mix_index()
    test_constraint_scope()
170
    test_intersect()