# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np


def test_const():
    x = tvm.const(1, "int32")
    print(x.dtype)
    assert x.dtype == tvm.int32
    assert isinstance(x, tvm.expr.IntImm)


def test_scalar_dtype_inference():
    for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
                 np.int8(1), np.int16(1), np.int32(1), np.int64(1),
                 np.float16(1), np.float32(1), np.float64(1)]:
        assert tvm.const(data).dtype == str(np.array(data).dtype)
    assert tvm.const(1).dtype == 'int32'
    assert tvm.const(1.0).dtype == 'float32'

    for data in [True, np.bool(1), np.uint8(1), np.uint16(1), np.uint32(1), np.uint64(1),
                 np.int8(1), np.int16(1), np.int32(1), np.int64(1),
                 np.float16(1), np.float32(1), np.float64(1)]:
        assert tvm.convert(data).dtype == str(np.array(data).dtype)
    assert tvm.convert(1).dtype == 'int32'
    assert tvm.convert(1.0).dtype == 'float32'


def test_make():
    x = tvm.const(1, "int32")
    y = tvm.var("x")
    z = x + y
    assert isinstance(tvm.max(x, y), tvm.expr.Max)
    assert isinstance(tvm.min(x, y), tvm.expr.Min)


def test_ir():
    x = tvm.const(1, "int32")
    y = tvm.make.IntImm('int32', 1)
    z = x + y
    stmt = tvm.make.Evaluate(z)
    assert isinstance(stmt, tvm.stmt.Evaluate)


def test_ir2():
    x = tvm.var("n")
    a = tvm.var("array", tvm.handle)
    st = tvm.make.Store(a, x + 1, 1)
    assert isinstance(st, tvm.stmt.Store)
    assert(st.buffer_var == a)


def test_let():
    x = tvm.var('x')
    y = tvm.var('y')
    stmt = tvm.make.LetStmt(
        x, 10, tvm.make.Evaluate(x + 1));


def test_cast():
    x = tvm.var('x', dtype="float32")
    y = x.astype("int32")
    z = x.astype("float32x4")
    assert isinstance(y, tvm.expr.Cast)
    assert isinstance(z, tvm.expr.Broadcast)
    assert z.lanes == 4


def test_attr():
    x = tvm.var('x')
    y = tvm.var('y')
    stmt = tvm.make.AttrStmt(
        y, "stride", 10, tvm.make.Evaluate(x + 1));
    assert stmt.node == y

    a = tvm.convert(1)
    assert a.value == 1
    try:
        a.no_field
        assert False
    except AttributeError:
        pass


def test_basic():
    a = tvm.var('a')
    b = tvm.var('b')
    c =  a + b
    assert str(c) == '(%s + %s)' % (a.name, b.name)


def test_stmt():
    x = tvm.make.Evaluate(0)
    tvm.make.For(tvm.var('i'), 0, 1,
                 tvm.stmt.For.Serial, 0,
                 x)


def test_dir():
    x = tvm.var('x')
    dir(x)


def test_dtype():
    x = tvm.var('x')
    assert x.dtype == 'int32'
    y = tvm.var('y')
    assert (x > y).dtype == 'bool'


def test_any():
    x = tvm.var('x')
    y = tvm.var('y')
    z = tvm.var('z')
    try:
        t = x or x
        assert False
    except ValueError:
        pass
    try:
        tvm.any()
        assert False
    except ValueError:
        pass
    assert str(tvm.any(x < y)) == '(%s < %s)' % (x.name, y.name)
    assert str(tvm.any(x < y, x > z)) == '((%s < %s) || (%s > %s))' % (
        x.name, y.name, x.name, z.name)
    assert str(tvm.any(x < y, y > z + 1, x < z * 2)) == \
        '(((%s < %s) || (%s > (%s + 1))) || (%s < (%s*2)))' % (
            x.name, y.name, y.name, z.name, x.name, z.name)


def test_all():
    x = tvm.var('x')
    y = tvm.var('y')
    z = tvm.var('z')
    try:
        t = x and x
        assert False
    except ValueError:
        pass
    try:
        tvm.all()
        assert False
    except ValueError:
        pass
    assert str(tvm.all(x < y)) == '(%s < %s)' % (x.name, y.name)
    assert str(tvm.all(x < y, x > z)) == '((%s < %s) && (%s > %s))' % (
        x.name, y.name, x.name, z.name)
    assert str(tvm.all(x < y, y > z + 1, x < z * 2)) == \
        '(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % (
            x.name, y.name, y.name, z.name, x.name, z.name)


def test_bitwise():
    x = tvm.var('x')
    y = tvm.var('y')
    assert str(x << y) == 'shift_left(x, y)'
    assert str(x >> y) == 'shift_right(x, y)'
    assert str(x & y) == 'bitwise_and(x, y)'
    assert str(x | y) == 'bitwise_or(x, y)'
    assert str(x ^ y) == 'bitwise_xor(x, y)'
    assert str(~x) == 'bitwise_not(x)'
    assert(tvm.const(1, "int8x2") >> 1).dtype == "int8x2"
    assert(x >> tvm.const(1, "int32x2")).dtype == "int32x2"
    assert(tvm.var("z", "int8x2") << tvm.const(1, "int8x2")).dtype == "int8x2"


def test_isnan():
    x = tvm.var('x', 'float32')
    assert str(tvm.isnan(x)) == 'isnan(x)'
    assert str(tvm.isnan(x).dtype) == 'bool'
    y = tvm.var('y', 'float16')
    assert str(tvm.isnan(y)) == 'isnan(float32(y))'
    z = tvm.var('z', 'int32')
    assert str(tvm.isnan(z)) == '(bool)0'
    k = tvm.var('k', 'int8x2')
    assert str(tvm.isnan(k).dtype) == 'uint1x2'


def test_equality():
    a = tvm.var('a')
    b = tvm.var('b')
    c = (a == b)
    assert not c
    d = (c != c)
    assert not d


def test_equality_string_imm():
    x = 'a'
    y = tvm.make.StringImm(x)
    x == y.value
    x == y


if __name__ == "__main__":
    test_cast()
    test_attr()
    test_const()
    test_scalar_dtype_inference()
    test_make()
    test_ir()
    test_basic()
    test_stmt()
    test_let()
    test_dir()
    test_dtype()
    test_any()
    test_all()
    test_bitwise()
    test_isnan()
    test_equality()
    test_equality_string_imm()