# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm

def test_const_saveload_json():
    # save load json
    x = tvm.const(1, "int32")
    y = tvm.const(10, "int32")
    z = x + y
    z = z + z
    json_str = tvm.save_json(z)
    zz = tvm.load_json(json_str)
    assert tvm.save_json(zz) == tvm.save_json(z)


def test_make_smap():
    # save load json
    x = tvm.const(1, "int32")
    y = tvm.const(10, "int32")
    z = tvm.expr.Add(x, y)
    smap = tvm.convert({"z": z, "x": x})
    json_str = tvm.save_json(tvm.convert([smap]))
    arr = tvm.load_json(json_str)
    assert len(arr) == 1
    assert arr[0]["z"].a == arr[0]["x"]


def test_make_node():
    x = tvm.make.node("IntImm", dtype="int32", value=10)
    assert isinstance(x, tvm.expr.IntImm)
    assert x.value == 10
    A = tvm.placeholder((10, ), name='A')
    AA = tvm.make.node("Tensor",
                       shape=A.shape,
                       dtype=A.dtype,
                       op=A.op,
                       value_index=A.value_index)
    assert AA.op == A.op
    assert AA.value_index == A.value_index


def test_make_attrs():
    try:
        x = tvm.make.node("attrs.TestAttrs", unknown_key=1, name="xx")
        assert False
    except tvm.TVMError as e:
        assert str(e).find("unknown_key") != -1

    try:
        x = tvm.make.node("attrs.TestAttrs", axis=100, name="xx")
        assert False
    except tvm.TVMError as e:
        assert str(e).find("upper bound") != -1

    x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4))
    assert x.name == "xx"
    assert x.padding[0].value == 3
    assert x.padding[1].value == 4
    assert x.axis == 10


    dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
    assert dattr.x.value == 1
    datrr = tvm.load_json(tvm.save_json(dattr))
    assert dattr.name.value == "xyz"



def test_make_sum():
    A = tvm.placeholder((2, 10), name='A')
    k = tvm.reduce_axis((0,10), "k")
    B = tvm.compute((2,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
    json_str = tvm.save_json(B)
    BB = tvm.load_json(json_str)
    assert B.op.body[0].combiner is not None
    assert BB.op.body[0].combiner is not None


def test_env_func():
    @tvm.register_func("test.env_func")
    def test(x):
        return x + 1

    f = tvm.get_global_func("test.env_func")
    x = tvm.get_env_func("test.env_func")
    assert x.name == "test.env_func"
    json_str = tvm.save_json([x])
    y = tvm.load_json(json_str)[0]
    assert y.name == x.name
    assert y(1) == 2
    assert y.func(1) == 2

    x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4), func=y)
    assert x.name == "xx"
    assert x.padding[0].value == 3
    assert x.padding[1].value == 4
    assert x.axis == 10
    x = tvm.load_json(tvm.save_json(x))
    assert isinstance(x.func, tvm.container.EnvFunc)
    assert x.func(10) == 11


if __name__ == "__main__":
    test_env_func()
    test_make_attrs()
    test_make_node()
    test_make_smap()
    test_const_saveload_json()
    test_make_sum()