Unverified Commit 70e11d32 by Haichen Shen Committed by GitHub

[Autotvm] Fix autotvm customized template (#5034)

* init

* fix template

* tweak naming
parent 681df4fc
...@@ -42,7 +42,7 @@ from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo ...@@ -42,7 +42,7 @@ from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
LocalBuilder, LocalRunner, RPCRunner LocalBuilder, LocalRunner, RPCRunner
from .tuner import callback from .tuner import callback
from .task import get_config, create, ConfigSpace, ConfigEntity, \ from .task import get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, register_customized_task, \ register_topi_compute, register_topi_schedule, template, \
DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \ DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \
ApplyGraphBest as apply_graph_best ApplyGraphBest as apply_graph_best
from .env import GLOBAL_SCOPE from .env import GLOBAL_SCOPE
...@@ -42,7 +42,7 @@ def get_infer_layout(task_name): ...@@ -42,7 +42,7 @@ def get_infer_layout(task_name):
return topi.nn.depthwise_conv2d_infer_layout return topi.nn.depthwise_conv2d_infer_layout
raise ValueError("Cannot find infer layout for task %s" % task_name) raise ValueError("Cannot find infer layout for task %s" % task_name)
@autotvm.register_customized_task("layout_transform") @autotvm.template("layout_transform")
def layout_transform(*args): def layout_transform(*args):
"""Autotvm layout transform template.""" """Autotvm layout transform template."""
cfg = get_config() cfg = get_config()
......
...@@ -22,8 +22,7 @@ This module defines the task data structure, as well as a collection(zoo) ...@@ -22,8 +22,7 @@ This module defines the task data structure, as well as a collection(zoo)
of typical tasks of interest. of typical tasks of interest.
""" """
from .task import Task, create, get_config, args_to_workload, \ from .task import Task, create, get_config, args_to_workload, template
register_customized_task
from .space import ConfigSpace, ConfigEntity from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, \ from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, \
......
...@@ -186,25 +186,35 @@ class Task(object): ...@@ -186,25 +186,35 @@ class Task(object):
TASK_TABLE = {} TASK_TABLE = {}
class TopiTemplate(object): class TaskTemplate(object):
"""Topi template that holds the topi compute and schedule function""" """
Task template is used to creates a tunable AutoTVM task.
It can be defined by a pair of compute and schedule function using
`_register_task_compute` and `_register_task_schedule`,
or by a customized task creation function that is more flexible using
`_register_customized_task`.
Note that when customized func is registered, compute and schedule function
will be ignored
"""
def __init__(self): def __init__(self):
self.compute = None self.fcompute = None
self.schedule = None self.fschedule = None
self.customized_func = None self.fcustomized = None
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
args = deserialize_args(args) args = deserialize_args(args)
if self.customized_func is None: if self.fcustomized is None:
return self._default_func(*args, **kwargs) return self._default_func(*args, **kwargs)
assert callable(self.customized_func) assert callable(self.fcustomized)
return self.customized_func(*args, **kwargs) return self.fcustomized(*args, **kwargs)
def _default_func(self, *args, **kwargs): def _default_func(self, *args, **kwargs):
assert callable(self.compute) and callable(self.schedule) assert callable(self.fcompute) and callable(self.fschedule)
out = self.compute(*args, **kwargs) out = self.fcompute(*args, **kwargs)
arg_bufs = [out] + self.get_inputs(out) arg_bufs = [out] + self.get_inputs(out)
s = self.schedule([out]) s = self.fschedule([out])
return s, arg_bufs return s, arg_bufs
def get_inputs(self, out): def get_inputs(self, out):
...@@ -218,7 +228,7 @@ class TopiTemplate(object): ...@@ -218,7 +228,7 @@ class TopiTemplate(object):
queue.extend(t.op.input_tensors) queue.extend(t.op.input_tensors)
return inputs return inputs
def register_task_compute(name, func=None): def _register_task_compute(name, func=None):
"""Register compute function to autotvm task """Register compute function to autotvm task
Parameters Parameters
...@@ -237,17 +247,17 @@ def register_task_compute(name, func=None): ...@@ -237,17 +247,17 @@ def register_task_compute(name, func=None):
""" """
def _do_reg(f): def _do_reg(f):
if name not in TASK_TABLE: if name not in TASK_TABLE:
TASK_TABLE[name] = TopiTemplate() TASK_TABLE[name] = TaskTemplate()
tmpl = TASK_TABLE[name] tmpl = TASK_TABLE[name]
if tmpl.compute is not None: if tmpl.fcompute is not None:
raise ValueError("Compute is already registered in autoTVM task %s" % name) raise ValueError("Compute is already registered in autoTVM task %s" % name)
tmpl.compute = f tmpl.fcompute = f
return f return f
if func: if func:
return _do_reg(func) return _do_reg(func)
return _do_reg return _do_reg
def register_task_schedule(name, func=None): def _register_task_schedule(name, func=None):
"""Register schedule function to autotvm task """Register schedule function to autotvm task
Parameters Parameters
...@@ -266,24 +276,19 @@ def register_task_schedule(name, func=None): ...@@ -266,24 +276,19 @@ def register_task_schedule(name, func=None):
""" """
def _do_reg(f): def _do_reg(f):
if name not in TASK_TABLE: if name not in TASK_TABLE:
TASK_TABLE[name] = TopiTemplate() TASK_TABLE[name] = TaskTemplate()
tmpl = TASK_TABLE[name] tmpl = TASK_TABLE[name]
if tmpl.schedule is not None: if tmpl.fschedule is not None:
raise ValueError("Schedule is already registered in autoTVM task %s" % name) raise ValueError("Schedule is already registered in autoTVM task %s" % name)
tmpl.schedule = f tmpl.fschedule = f
return f return f
if func: if func:
return _do_reg(func) return _do_reg(func)
return _do_reg return _do_reg
def register_customized_task(name, func=None): def _register_customized_task(name, func=None):
"""Register a customized function to AutoTVM task. """Register a customized function to AutoTVM task.
In most cases, you can just use register_topi_compute and register_topi_schedule
with the same task name to define an AutoTVM task. However, you can also
create a customized AutoTVM task that defines a tunable template or performs
extra layout transform before invoking compute/schedule function.
Parameters Parameters
---------- ----------
name: str name: str
...@@ -297,6 +302,37 @@ def register_customized_task(name, func=None): ...@@ -297,6 +302,37 @@ def register_customized_task(name, func=None):
------- -------
decorator: callable decorator: callable
A decorator A decorator
"""
def _do_reg(f):
if name not in TASK_TABLE:
TASK_TABLE[name] = TaskTemplate()
tmpl = TASK_TABLE[name]
if tmpl.fcustomized is not None:
raise ValueError("Customized func is already registered in autoTVM task %s" % name)
tmpl.fcustomized = f
return f
if func:
return _do_reg(func)
return _do_reg
def template(task_name, func=None):
"""Decorate a function as a tunable schedule template.
Parameters
----------
task_name: str
The task name
func: None or callable
A callable template function.
If it is None, return a decorator.
If is callable, decorate this function.
Returns
-------
func: callable
The decorated function
Examples Examples
-------- --------
...@@ -304,7 +340,7 @@ def register_customized_task(name, func=None): ...@@ -304,7 +340,7 @@ def register_customized_task(name, func=None):
.. code-block:: python .. code-block:: python
@autotvm.register_customized_task("matmul") @autotvm.template("matmul")
def matmul(N, L, M, dtype): def matmul(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype) A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype) B = te.placeholder((L, M), name='B', dtype=dtype)
...@@ -331,17 +367,22 @@ def register_customized_task(name, func=None): ...@@ -331,17 +367,22 @@ def register_customized_task(name, func=None):
return s, [A, B, C] return s, [A, B, C]
""" """
def _do_reg(f): def _decorate(f):
if name not in TASK_TABLE: def wrapper(*args, **kwargs):
TASK_TABLE[name] = TopiTemplate() assert not kwargs, "Do not support kwargs in template function call"
tmpl = TASK_TABLE[name] workload = args_to_workload(args, task_name)
if tmpl.customized_func is not None: tgt = _target.Target.current()
raise ValueError("Customized func is already registered in autoTVM task %s" % name) cfg = DispatchContext.current.query(tgt, workload)
tmpl.customized_func = f with ApplyConfig(cfg):
return f return f(*args, **kwargs)
_register_customized_task(task_name, f)
return wrapper
if func: if func:
return _do_reg(func) return _decorate(func)
return _do_reg return _decorate
def create(task_name, args, target, target_host=None): def create(task_name, args, target, target_host=None):
"""Create a tuning task and initialize its search space """Create a tuning task and initialize its search space
......
...@@ -30,8 +30,8 @@ import tvm.te._ffi_api ...@@ -30,8 +30,8 @@ import tvm.te._ffi_api
from tvm import target as _target from tvm import target as _target
from tvm.te import tensor from tvm.te import tensor
from .task import args_to_workload, DispatchContext, \ from .task import args_to_workload, serialize_args, DispatchContext, \
register_task_compute, register_task_schedule, serialize_args _register_task_compute, _register_task_schedule
# Task extractor for relay program # Task extractor for relay program
...@@ -142,7 +142,7 @@ def register_topi_compute(task_name, func=None): ...@@ -142,7 +142,7 @@ def register_topi_compute(task_name, func=None):
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
""" """
def _decorate(topi_compute): def _decorate(topi_compute):
@register_task_compute(task_name) @_register_task_compute(task_name)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
"""wrapper function for topi compute""" """wrapper function for topi compute"""
assert not kwargs, "Do not support kwargs in template function call" assert not kwargs, "Do not support kwargs in template function call"
...@@ -212,7 +212,7 @@ def register_topi_schedule(task_name, func=None): ...@@ -212,7 +212,7 @@ def register_topi_schedule(task_name, func=None):
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
""" """
def _decorate(topi_schedule): def _decorate(topi_schedule):
@register_task_schedule(task_name) @_register_task_schedule(task_name)
def wrapper(outs, *args, **kwargs): def wrapper(outs, *args, **kwargs):
"""wrapper function for topi schedule""" """wrapper function for topi schedule"""
workload = get_workload(outs) workload = get_workload(outs)
......
...@@ -26,7 +26,7 @@ from tvm import te ...@@ -26,7 +26,7 @@ from tvm import te
from tvm import autotvm from tvm import autotvm
from tvm.autotvm.tuner import RandomTuner from tvm.autotvm.tuner import RandomTuner
@autotvm.register_customized_task("testing/conv2d_no_batching") @autotvm.template("testing/conv2d_no_batching")
def conv2d_no_batching(N, H, W, CI, CO, KH, KW): def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
"""An example template for testing""" """An example template for testing"""
assert N == 1, "Only consider batch_size = 1 in this template" assert N == 1, "Only consider batch_size = 1 in this template"
......
...@@ -37,7 +37,7 @@ class DummyRunner(Runner): ...@@ -37,7 +37,7 @@ class DummyRunner(Runner):
def get_build_kwargs(self): def get_build_kwargs(self):
return {} return {}
@autotvm.register_customized_task("testing/matmul") @autotvm.template("testing/matmul")
def matmul(N, L, M, dtype): def matmul(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype) A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype) B = te.placeholder((L, M), name='B', dtype=dtype)
...@@ -64,7 +64,7 @@ def matmul(N, L, M, dtype): ...@@ -64,7 +64,7 @@ def matmul(N, L, M, dtype):
return s, [A, B, C] return s, [A, B, C]
@autotvm.register_customized_task("testing/bad_matmul") @autotvm.template("testing/bad_matmul")
def bad_matmul(N, L, M, dtype): def bad_matmul(N, L, M, dtype):
if 'bad_device' in tvm.target.Target.current().keys: if 'bad_device' in tvm.target.Target.current().keys:
A = te.placeholder((N, L), name='A', dtype=dtype) A = te.placeholder((N, L), name='A', dtype=dtype)
......
...@@ -22,7 +22,7 @@ from tvm import autotvm ...@@ -22,7 +22,7 @@ from tvm import autotvm
def test_fallback(): def test_fallback():
@autotvm.register_customized_task("testing/dispatch/fallback") @autotvm.template("testing/dispatch_fallback")
def simple_template(a, b): def simple_template(a, b):
cfg = autotvm.get_config() cfg = autotvm.get_config()
assert cfg.is_fallback assert cfg.is_fallback
......
...@@ -79,7 +79,7 @@ from tvm import autotvm ...@@ -79,7 +79,7 @@ from tvm import autotvm
# can be very large (at the level of 10^9 for some input shapes) # can be very large (at the level of 10^9 for some input shapes)
# #
@autotvm.register_customized_task("tutorial/conv2d_no_batching") @autotvm.template("tutorial/conv2d_no_batching")
def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding): def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
assert N == 1, "Only consider batch_size = 1 in this template" assert N == 1, "Only consider batch_size = 1 in this template"
......
...@@ -103,7 +103,7 @@ def matmul_v0(N, L, M, dtype): ...@@ -103,7 +103,7 @@ def matmul_v0(N, L, M, dtype):
# In autotvm, we can define a tunable parameter, or a "knob" for such kind of value. # In autotvm, we can define a tunable parameter, or a "knob" for such kind of value.
# Matmul V1: List candidate values # Matmul V1: List candidate values
@autotvm.register_customized_task("tutorial/matmul_v1") # 1. use a decorator @autotvm.template("tutorial/matmul_v1") # 1. use a decorator
def matmul_v1(N, L, M, dtype): def matmul_v1(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype) A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype) B = te.placeholder((L, M), name='B', dtype=dtype)
...@@ -183,7 +183,7 @@ def matmul_v1(N, L, M, dtype): ...@@ -183,7 +183,7 @@ def matmul_v1(N, L, M, dtype):
# When the high level API cannot meet your requirement, you can always fall # When the high level API cannot meet your requirement, you can always fall
# back to use low level API. # back to use low level API.
@autotvm.register_customized_task("tutorial/matmul") @autotvm.template("tutorial/matmul")
def matmul(N, L, M, dtype): def matmul(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype) A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype) B = te.placeholder((L, M), name='B', dtype=dtype)
......
...@@ -95,7 +95,7 @@ def matmul_nn(A, B, L, dtype='float16', layout='NN'): ...@@ -95,7 +95,7 @@ def matmul_nn(A, B, L, dtype='float16', layout='NN'):
# #
# We use AutoTVM to search for best configurations in this schedule. # We use AutoTVM to search for best configurations in this schedule.
@autotvm.register_customized_task("tutorial/test_gemm") @autotvm.template("tutorial/auto_tensorcore/test_gemm")
def test_gemm(N, L, M, dtype, layout): def test_gemm(N, L, M, dtype, layout):
if (layout == "NN"): if (layout == "NN"):
shape_a = (N, L) shape_a = (N, L)
...@@ -265,7 +265,7 @@ elif dtype == 'int4' or dtype == 'int1': ...@@ -265,7 +265,7 @@ elif dtype == 'int4' or dtype == 'int1':
assert(major == 7 and minor == 5 and layout == 'TN') assert(major == 7 and minor == 5 and layout == 'TN')
def tune_and_evaluate(M, N, L, dtype, layout): def tune_and_evaluate(M, N, L, dtype, layout):
task = autotvm.task.create("tutorial/test_gemm", args=(N, L, M, dtype, layout), task = autotvm.task.create("tutorial/auto_tensorcore/test_gemm", args=(N, L, M, dtype, layout),
target='cuda') target='cuda')
print(task.config_space) print(task.config_space)
......
...@@ -310,7 +310,7 @@ def register_vta_tuning_tasks(): ...@@ -310,7 +310,7 @@ def register_vta_tuning_tasks():
# init autotvm env to register VTA operator # init autotvm env to register VTA operator
TaskExtractEnv() TaskExtractEnv()
@autotvm.register_customized_task("conv2d_packed.vta") @autotvm.template("conv2d_packed.vta")
def _topi_nn_conv2d(*args, **kwargs): def _topi_nn_conv2d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call" assert not kwargs, "Do not support kwargs in template function call"
A, W = args[:2] A, W = args[:2]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment