# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te

def check_throws(f):
    try:
        f()
    except tvm.error.TVMError:
        pass
    else:
        raise AssertionError("Should have raised an exception but didn't.")


def test_const_fold():
    def check(f, *args):
        x = f(*[tvm.tir.const(x, "int32") for x in args])
        y = f(*args)
        if not isinstance(x, (tvm.tir.IntImm,)) or x.value != int(y):
            raise ValueError("check error: %s vs %s " % (x, y))

    tmod = tvm.tir.truncmod
    check(lambda x, y: x + y, 3, 4)
    check(lambda x, y: x * y, 3, 12)
    check(lambda x, y: x * y - 10, 3, 12)
    check(lambda x, y: x - tmod(y, 10), 3, 12)
    check(lambda x, y: x // y + 10, 100, 12)
    check(lambda x, y: x & y + 10, 112, 128)
    check(lambda x, y: x > y, 112, 128)
    check(lambda x, y: x < y, 112, 128)
    check(lambda x, y: x <= y, 112, 128)
    check(lambda x, y: x >= y, 112, 128)
    check(lambda x, y: (x | y) ^ 10, 112, 128)


def test_const_fold2():
    x = te.var("x")
    tmod = tvm.tir.truncmod
    tdiv = tvm.tir.truncdiv
    assert (x + 0).same_as(x)
    assert (0 + x).same_as(x)
    assert (x - 0).same_as(x)
    assert tmod(x, 1).value == 0
    assert (x * 1).same_as(x)
    assert (1 * x).same_as(x)
    assert isinstance(tdiv(1, x), tvm.tir.Div)

def test_const_fold3():
    # Test that using ints with logic operations is forbidden
    x = te.var("x")
    for val in [0, 1]:
        for func in [tvm.tir.all, tvm.tir.any]:
            check_throws(lambda: func(tvm.tir.const(val, 'uint1'), x))
            check_throws(lambda: func(x, tvm.tir.const(val, 'uint1')))

    # Test const folding when both arguments are const
    for tvm_func, py_func in [(tvm.tir.all, lambda a, b: a and b), (tvm.tir.any, lambda a, b: a or b)]:
        for v1 in [0, 1]:
            for v2 in [0, 1]:
                assert tvm.ir.structural_equal(tvm_func(tvm.tir.const(v1, 'uint1'), tvm.tir.const(v2, 'uint1')),
                                         tvm.tir.const(py_func(v1, v2), 'uint1'))

    x = te.var("x", 'uint1')
    true = tvm.tir.const(1, 'uint1')
    false = tvm.tir.const(0, 'uint1')

    assert tvm.tir.all(x, true).same_as(x)
    assert tvm.tir.all(true, x).same_as(x)
    assert tvm.tir.any(x, false).same_as(x)
    assert tvm.tir.any(false, x).same_as(x)

    assert tvm.tir.all(x, false).same_as(false)
    assert tvm.tir.all(false, x).same_as(false)
    assert tvm.tir.any(x, true).same_as(true)
    assert tvm.tir.any(true, x).same_as(true)


def test_const_fold4():
    x1 = tvm.tir.const(4, "int32")
    x2 = x1 + 5
    tdiv = tvm.tir.truncdiv
    assert isinstance(x2, tvm.tir.IntImm) and x2.value == 9
    x3 = tdiv(x2, 3)
    assert isinstance(x3, tvm.tir.IntImm) and x3.value == 3
    x4 = x3 + 0.55
    assert isinstance(x4, tvm.tir.FloatImm) and abs(x4.value - 3.55) < 1e-6
    x5 = te.ceil(x4)
    assert isinstance(x5, tvm.tir.FloatImm) and x5.value == 4
    x6 = x5.astype('int')
    assert isinstance(x6, tvm.tir.IntImm) and x6.value == 4, "x6={}".format(x6)
    y = (te.round((tvm.tir.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int')
    assert isinstance(y, tvm.tir.IntImm) and y.value == 6


def test_binary_dtype_match():
    def verify_general_dtype_support(f, is_conditional=False):
        rules = [[('bool', 'int32'), 'int32'],
                 [('int32', 'float32'), 'float32'],
                 [('int32', 'int64'), 'int64'],
                 [('uint32', 'int32'), 'int32']]
        for (lhs_dtype, rhs_dtype), out_dtype in rules:
            lhs = te.var('lhs', dtype=lhs_dtype)
            rhs = te.var('rhs', dtype=rhs_dtype)
            out = f(lhs, rhs)
            if not is_conditional:
                assert out.dtype == out_dtype
            else:
                assert out.dtype == 'bool'
            if hasattr(out, 'a'):
                assert out.a.dtype == out_dtype
                assert out.b.dtype == out_dtype
            elif hasattr(out, 'args'):
                # CallOp
                assert out.args[0].dtype == out_dtype
                assert out.args[1].dtype == out_dtype
            else:
                raise ValueError('Unknown binary op format!')

    def verify_callop_float_only(f):
        for lhs_dtype in ['int32', 'float32', 'float64']:
            for rhs_dtype in ['int32', 'float32', 'float64']:
                lhs = te.var('lhs', dtype=lhs_dtype)
                rhs = te.var('rhs', dtype=rhs_dtype)
                if 'float' not in lhs_dtype and 'float' not in rhs_dtype:
                    check_throws(lambda: f(lhs, rhs))
                elif 'float' in lhs_dtype and 'float' in rhs_dtype and lhs_dtype != rhs_dtype:
                    check_throws(lambda: f(lhs, rhs))
                elif 'float' in lhs_dtype:
                    out = f(lhs, rhs)
                    assert out.dtype == lhs_dtype
                    assert out.args[0].dtype == lhs_dtype
                    assert out.args[1].dtype == lhs_dtype
                else:
                    out = f(lhs, rhs)
                    assert out.dtype == rhs_dtype
                    assert out.args[0].dtype == rhs_dtype
                    assert out.args[1].dtype == rhs_dtype

    verify_general_dtype_support(lambda a, b: a + b)
    verify_general_dtype_support(lambda a, b: a * b)
    verify_general_dtype_support(lambda a, b: a >= b, is_conditional=True)
    verify_general_dtype_support(lambda a, b: a <= b, is_conditional=True)
    verify_callop_float_only(lambda a, b: te.power(a, b))


def test_if_then_else():
    cases = [[(te.var('cond', dtype='bool'), 'bool', 'int32'), 'int32'],
             [(True, 'int32', 'float32'), 'float32'],
             [(False, 'int32', 'int64'), 'int64'],
             [(te.var('cond', dtype='bool'), 'uint32', 'int32'), 'int32'],
             [(te.var('cond', dtype='int32'), 'uint32', 'int32'), 'int32']]
    for (cond, lhs_dtype, rhs_dtype), out_dtype in cases:
        lhs = te.var('lhs', dtype=lhs_dtype)
        rhs = te.var('rhs', dtype=rhs_dtype)
        if cond is True or cond is False:
            out = tvm.tir.if_then_else(cond, lhs, rhs)
            out2 = tvm.tir.if_then_else(not cond, rhs, lhs)
            out3 = tvm.tir.if_then_else(not cond, lhs, rhs)
            assert tvm.ir.structural_equal(out, out2) == 1
            if cond:
                assert tvm.ir.structural_equal(out, lhs.astype(out_dtype)) == 1
                assert tvm.ir.structural_equal(out3, rhs.astype(out_dtype)) == 1
            else:
                assert tvm.ir.structural_equal(out, rhs.astype(out_dtype)) == 1
                assert tvm.ir.structural_equal(out3, lhs.astype(out_dtype)) == 1
        elif cond.dtype == 'bool':
            out = tvm.tir.if_then_else(cond, lhs, rhs)
            assert out.dtype == out_dtype
            assert out.args[1].dtype == out_dtype
            assert out.args[2].dtype == out_dtype
        elif cond.dtype != 'bool':
            check_throws(lambda: tvm.tir.if_then_else(cond, lhs, rhs))
        else:
            raise ValueError('Unknown combinations')


if __name__ == "__main__":
    test_const_fold()
    test_const_fold2()
    test_const_fold3()
    test_const_fold4()
    test_binary_dtype_match()
    test_if_then_else()