import tvm

@tvm.target.generic_func
def mygeneric(data):
    # default generic function
    return data + 1

@mygeneric.register(["cuda", "gpu"])
def cuda_func(data):
    return data + 2

@mygeneric.register("rocm")
def rocm_func(data):
    return data + 3

@mygeneric.register("cpu")
def rocm_func(data):
    return data + 10


def test_target_dispatch():
    with tvm.target.cuda():
        assert mygeneric(1) == 3

    with tvm.target.rocm():
        assert mygeneric(1) == 4

    with tvm.target.create("cuda"):
        assert mygeneric(1) == 3

    with tvm.target.arm_cpu():
        assert mygeneric(1) == 11

    with tvm.target.create("metal"):
        assert mygeneric(1) == 3

    assert tvm.target.current_target() is None


def test_target_string_parse():
    target = tvm.target.create("cuda -model=unknown -libs=cublas,cudnn")

    assert target.target_name == "cuda"
    assert target.options == ['-model=unknown', '-libs=cublas,cudnn']
    assert target.keys == ['cuda', 'gpu']
    assert target.libs == ['cublas', 'cudnn']
    assert str(target) == str(tvm.target.cuda(options="-libs=cublas,cudnn"))

    assert tvm.target.intel_graphics().device_name == "intel_graphics"
    assert tvm.target.mali().device_name == "mali"
    assert tvm.target.arm_cpu().device_name == "arm_cpu"

if __name__ == "__main__":
    test_target_dispatch()
    test_target_string_parse()