test_autotvm_dispatch_context.py 1.21 KB
Newer Older
1 2 3 4 5
"""Test dispatcher.
The dispatcher can choose which template to use according
to the parameters of workload"""

from collections import namedtuple
6
from tvm import autotvm
7 8
from tvm.autotvm.task import dispatcher, DispatchContext

9
SimpleConfig = namedtuple('SimpleConfig', ('template_key', 'is_fallback'))
10 11 12 13

def test_dispatch():
    @dispatcher
    def my_dispatcher(a, b):
14
        return (a, b)
15 16

    @my_dispatcher.register("im2col")
17 18 19 20 21 22
    def _im2col(cfg, a, b):
        return a

    @my_dispatcher.register("spatial_pack")
    def _spatial_pack(cfg, a, b):
        return b
23 24 25

    class SimpleDispatcher(DispatchContext):
        def query(self, target, workload):
26 27 28 29
            a, b = workload
            tkey = "spatial_pack" if a + b > 2 else "im2col"
            cfg = SimpleConfig(tkey, False)
            return cfg
30 31

    with SimpleDispatcher():
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
        # this will call im2col
        assert my_dispatcher(1, 0) == 1

        # this will call spatial pack
        assert my_dispatcher(1, 100) == 100

def test_fallback():

    @autotvm.template
    def simple_template(a, b):
        cfg = autotvm.get_config()
        assert cfg.is_fallback

    simple_template(2, 3)

47 48 49

if __name__ == "__main__":
    test_dispatch()
50
    test_fallback()