"""Test dispatcher.
The dispatcher can choose which template to use according
to the parameters of workload"""

from collections import namedtuple
from tvm import autotvm
from tvm.autotvm.task import dispatcher, DispatchContext

SimpleConfig = namedtuple('SimpleConfig', ('template_key', 'is_fallback'))

def test_dispatch():
    @dispatcher
    def my_dispatcher(a, b):
        return (a, b)

    @my_dispatcher.register("im2col")
    def _im2col(cfg, a, b):
        return a

    @my_dispatcher.register("spatial_pack")
    def _spatial_pack(cfg, a, b):
        return b

    class SimpleDispatcher(DispatchContext):
        def query(self, target, workload):
            a, b = workload
            tkey = "spatial_pack" if a + b > 2 else "im2col"
            cfg = SimpleConfig(tkey, False)
            return cfg

    with SimpleDispatcher():
        # this will call im2col
        assert my_dispatcher(1, 0) == 1

        # this will call spatial pack
        assert my_dispatcher(1, 100) == 100

def test_fallback():

    @autotvm.template
    def simple_template(a, b):
        cfg = autotvm.get_config()
        assert cfg.is_fallback

    simple_template(2, 3)


if __name__ == "__main__":
    test_dispatch()
    test_fallback()