# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te

@tvm.target.generic_func
def mygeneric(data):
    # default generic function
    return data + 1

@mygeneric.register(["cuda", "gpu"])
def cuda_func(data):
    return data + 2

@mygeneric.register("rocm")
def rocm_func(data):
    return data + 3

@mygeneric.register("cpu")
def rocm_func(data):
    return data + 10


def test_target_dispatch():
    with tvm.target.cuda():
        assert mygeneric(1) == 3

    with tvm.target.rocm():
        assert mygeneric(1) == 4

    with tvm.target.create("cuda"):
        assert mygeneric(1) == 3

    with tvm.target.arm_cpu():
        assert mygeneric(1) == 11

    with tvm.target.create("metal"):
        assert mygeneric(1) == 3

    assert tvm.target.Target.current() is None


def test_target_string_parse():
    target = tvm.target.create("cuda -model=unknown -libs=cublas,cudnn")

    assert target.target_name == "cuda"
    assert target.options == ['-model=unknown', '-libs=cublas,cudnn']
    assert target.keys == ['cuda', 'gpu']
    assert target.libs == ['cublas', 'cudnn']
    assert str(target) == str(tvm.target.cuda(options="-libs=cublas,cudnn"))

    assert tvm.target.intel_graphics().device_name == "intel_graphics"
    assert tvm.target.mali().device_name == "mali"
    assert tvm.target.arm_cpu().device_name == "arm_cpu"

if __name__ == "__main__":
    test_target_dispatch()
    test_target_string_parse()