# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm

def test_dtype_bound():
    analyzer = tvm.arith.Analyzer()

    x = tvm.var("x", dtype="int64")
    bd = analyzer.const_int_bound(x)
    assert bd.min_value == bd.NEG_INF
    assert bd.max_value == bd.POS_INF

    x = tvm.var("x", dtype="int8")
    bd = analyzer.const_int_bound(x)
    assert bd.min_value == -128
    assert bd.max_value == 127

    x = tvm.var("x", dtype="uint8")
    bd = analyzer.const_int_bound(x)
    assert bd.min_value == 0
    assert bd.max_value == 255


def test_cast_bound():
    analyzer = tvm.arith.Analyzer()
    x = tvm.var("x", dtype="int8")
    bd = analyzer.const_int_bound((x % 3).astype("uint32"))
    assert bd.min_value == 0
    assert bd.max_value == 2

    bd = analyzer.const_int_bound(
        (x % 3).astype("float32").astype("int32"))
    assert bd.min_value == -2
    assert bd.max_value == 2


def test_add_sub_bound():
    analyzer = tvm.arith.Analyzer()
    x, y = tvm.var("x", "int64"), tvm.var("y", "int64")
    bd = analyzer.const_int_bound(x + y)
    assert bd.min_value == bd.NEG_INF
    assert bd.max_value == bd.POS_INF

    analyzer.update(x, tvm.arith.ConstIntBound(0, 4))
    analyzer.update(y, tvm.arith.ConstIntBound(1, 10))
    bd = analyzer.const_int_bound(x + y)
    assert bd.min_value == 1
    assert bd.max_value == 14

    bd = analyzer.const_int_bound(x - y)
    assert bd.min_value == -10
    assert bd.max_value == 3

    analyzer.update(x, tvm.arith.ConstIntBound(0, bd.POS_INF), override=True)
    bd = analyzer.const_int_bound(x - y)
    assert bd.min_value == -10
    assert bd.max_value == bd.POS_INF

    bd = analyzer.const_int_bound(1 - x)
    assert bd.min_value == bd.NEG_INF
    assert bd.max_value == 1


def test_mul_bound():
    analyzer = tvm.arith.Analyzer()
    x, y = tvm.var("x"), tvm.var("y")

    analyzer.update(x, tvm.arith.ConstIntBound(-2, 4))
    analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
    bd = analyzer.const_int_bound(x * y + 20)
    assert bd.min_value == 0
    assert bd.max_value == 60

    analyzer.update(x, tvm.arith.ConstIntBound(-3, 4), override=True)
    analyzer.update(y, tvm.arith.ConstIntBound(-8, 2), override=True)
    bd = analyzer.const_int_bound(x * y)
    assert bd.min_value == -32
    assert bd.max_value == 24

    analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True)
    analyzer.update(y, tvm.arith.ConstIntBound(-8, 2), override=True)
    bd = analyzer.const_int_bound(x * y)
    assert bd.min_value == bd.NEG_INF
    assert bd.max_value == bd.POS_INF


def test_div_bound():
    analyzer = tvm.arith.Analyzer()
    x, y = tvm.var("x"), tvm.var("y")

    analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
    analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
    bd = analyzer.const_int_bound(x / y)
    assert bd.min_value == -2

    analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
    analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True)
    bd = analyzer.const_int_bound(x / y)
    assert bd.min_value == -4
    assert bd.max_value == 9

    analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True)
    analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True)
    bd = analyzer.const_int_bound(x / y)
    assert bd.min_value == bd.NEG_INF
    assert bd.max_value == bd.POS_INF


def test_mod_bound():
    analyzer = tvm.arith.Analyzer()
    x, y = tvm.var("x"), tvm.var("y")

    analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
    analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
    bd = analyzer.const_int_bound(x % y)
    assert bd.min_value == -9
    assert bd.max_value == 4

    analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
    analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
    bd = analyzer.const_int_bound(x % y)
    assert bd.min_value == -9
    assert bd.max_value == 9

    analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True)
    analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
    bd = analyzer.const_int_bound(x % y)
    assert bd.min_value == 0
    assert bd.max_value == 9


def test_floordiv_bound():
    analyzer = tvm.arith.Analyzer()
    x, y = tvm.var("x"), tvm.var("y")
    fld = tvm.floordiv
    analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
    analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
    bd = analyzer.const_int_bound(fld(x, y))
    assert bd.min_value == -9 // 4

    analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
    analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True)
    bd = analyzer.const_int_bound(fld(x, y))
    assert bd.min_value == -4
    assert bd.max_value == 9

    analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True)
    analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True)
    bd = analyzer.const_int_bound(fld(x, y))
    assert bd.min_value == bd.NEG_INF
    assert bd.max_value == bd.POS_INF


def test_floormod_bound():
    analyzer = tvm.arith.Analyzer()
    x, y = tvm.var("x"), tvm.var("y")
    flm = tvm.floormod

    analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
    analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
    bd = analyzer.const_int_bound(flm(x, y))
    assert bd.min_value == 0
    assert bd.max_value == 9

    analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
    analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
    bd = analyzer.const_int_bound(flm(x, y))
    assert bd.min_value == 0
    assert bd.max_value == 9

    analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True)
    analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
    bd = analyzer.const_int_bound(flm(x, y))
    assert bd.min_value == 0
    assert bd.max_value == 9


def test_min_max_bound():
    analyzer = tvm.arith.Analyzer()
    x, y = tvm.var("x"), tvm.var("y")

    analyzer.update(x, tvm.arith.ConstIntBound(-9, 11))
    analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
    bd = analyzer.const_int_bound(tvm.min(x, y))
    assert bd.min_value == -9
    assert bd.max_value == 10

    analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
    analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
    bd = analyzer.const_int_bound(tvm.min(x, y))
    assert bd.min_value == bd.NEG_INF
    assert bd.max_value == 10

    bd = analyzer.const_int_bound(tvm.max(x, y))
    assert bd.min_value == 4
    assert bd.max_value == bd.POS_INF

    analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True)
    analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
    bd = analyzer.const_int_bound(tvm.max(x, y))
    assert bd.min_value == 4
    assert bd.max_value == bd.POS_INF


def test_select_bound():
    analyzer = tvm.arith.Analyzer()
    x, y = tvm.var("x"), tvm.var("y")

    analyzer.update(x, tvm.arith.ConstIntBound(-9, 11))
    analyzer.update(y, tvm.arith.ConstIntBound(4, 10))

    bd = analyzer.const_int_bound(
        tvm.expr.Select(x > 1, (y < 0).astype("int32"), y + 1))
    assert bd.min_value == 0
    assert bd.max_value == 11


def test_shift_and_bound():
    analyzer = tvm.arith.Analyzer()
    x, y = tvm.var("x"), tvm.var("y")

    analyzer.update(x, tvm.arith.ConstIntBound(-9, 11))
    analyzer.update(y, tvm.arith.ConstIntBound(2, 10))

    bd = analyzer.const_int_bound(x >> y)
    assert bd.min_value == -3
    assert bd.max_value == 2

    bd = analyzer.const_int_bound(x & y)
    assert bd.min_value == 0
    assert bd.max_value == 10

    analyzer.update(x, tvm.arith.ConstIntBound(10, 11), override=True)
    bd = analyzer.const_int_bound(x & y)
    assert bd.min_value == 0
    assert bd.max_value == 10


def test_mix_index_bound():
    analyzer = tvm.arith.Analyzer()
    x, y = tvm.var("x"), tvm.var("y")
    analyzer.update(x, tvm.arith.ConstIntBound(0, 24 - 1))
    analyzer.update(y, tvm.arith.ConstIntBound(0, 3 - 1))
    bd = analyzer.const_int_bound((x % 8) + (x / 8) * 8)
    assert bd.min_value == 0
    assert bd.max_value == 24 - 1

    bd = analyzer.const_int_bound(y + x * 3)
    assert bd.min_value == 0
    assert bd.max_value == 24 * 3 - 1

    bd = analyzer.const_int_bound((x % 7) + (x / 7) * 7)
    assert bd.min_value == 0
    assert bd.max_value == (23 // 7) * 7 + 6


if __name__ == "__main__":
    test_dtype_bound()
    test_cast_bound()
    test_add_sub_bound()
    test_mul_bound()
    test_div_bound()
    test_mod_bound()
    test_floordiv_bound()
    test_floormod_bound()
    test_min_max_bound()
    test_select_bound()
    test_shift_and_bound()
    test_mix_index_bound()