Unverified Commit 70e11d32 by Haichen Shen Committed by GitHub

[Autotvm] Fix autotvm customized template (#5034)

* init

* fix template

* tweak naming
parent 681df4fc
......@@ -42,7 +42,7 @@ from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
LocalBuilder, LocalRunner, RPCRunner
from .tuner import callback
from .task import get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, register_customized_task, \
register_topi_compute, register_topi_schedule, template, \
DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \
ApplyGraphBest as apply_graph_best
from .env import GLOBAL_SCOPE
......@@ -42,7 +42,7 @@ def get_infer_layout(task_name):
return topi.nn.depthwise_conv2d_infer_layout
raise ValueError("Cannot find infer layout for task %s" % task_name)
@autotvm.register_customized_task("layout_transform")
@autotvm.template("layout_transform")
def layout_transform(*args):
"""Autotvm layout transform template."""
cfg = get_config()
......
......@@ -22,8 +22,7 @@ This module defines the task data structure, as well as a collection(zoo)
of typical tasks of interest.
"""
from .task import Task, create, get_config, args_to_workload, \
register_customized_task
from .task import Task, create, get_config, args_to_workload, template
from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, \
......
......@@ -186,25 +186,35 @@ class Task(object):
TASK_TABLE = {}
class TopiTemplate(object):
"""Topi template that holds the topi compute and schedule function"""
class TaskTemplate(object):
"""
Task template is used to creates a tunable AutoTVM task.
It can be defined by a pair of compute and schedule function using
`_register_task_compute` and `_register_task_schedule`,
or by a customized task creation function that is more flexible using
`_register_customized_task`.
Note that when customized func is registered, compute and schedule function
will be ignored
"""
def __init__(self):
self.compute = None
self.schedule = None
self.customized_func = None
self.fcompute = None
self.fschedule = None
self.fcustomized = None
def __call__(self, *args, **kwargs):
args = deserialize_args(args)
if self.customized_func is None:
if self.fcustomized is None:
return self._default_func(*args, **kwargs)
assert callable(self.customized_func)
return self.customized_func(*args, **kwargs)
assert callable(self.fcustomized)
return self.fcustomized(*args, **kwargs)
def _default_func(self, *args, **kwargs):
assert callable(self.compute) and callable(self.schedule)
out = self.compute(*args, **kwargs)
assert callable(self.fcompute) and callable(self.fschedule)
out = self.fcompute(*args, **kwargs)
arg_bufs = [out] + self.get_inputs(out)
s = self.schedule([out])
s = self.fschedule([out])
return s, arg_bufs
def get_inputs(self, out):
......@@ -218,7 +228,7 @@ class TopiTemplate(object):
queue.extend(t.op.input_tensors)
return inputs
def register_task_compute(name, func=None):
def _register_task_compute(name, func=None):
"""Register compute function to autotvm task
Parameters
......@@ -237,17 +247,17 @@ def register_task_compute(name, func=None):
"""
def _do_reg(f):
if name not in TASK_TABLE:
TASK_TABLE[name] = TopiTemplate()
TASK_TABLE[name] = TaskTemplate()
tmpl = TASK_TABLE[name]
if tmpl.compute is not None:
if tmpl.fcompute is not None:
raise ValueError("Compute is already registered in autoTVM task %s" % name)
tmpl.compute = f
tmpl.fcompute = f
return f
if func:
return _do_reg(func)
return _do_reg
def register_task_schedule(name, func=None):
def _register_task_schedule(name, func=None):
"""Register schedule function to autotvm task
Parameters
......@@ -266,24 +276,19 @@ def register_task_schedule(name, func=None):
"""
def _do_reg(f):
if name not in TASK_TABLE:
TASK_TABLE[name] = TopiTemplate()
TASK_TABLE[name] = TaskTemplate()
tmpl = TASK_TABLE[name]
if tmpl.schedule is not None:
if tmpl.fschedule is not None:
raise ValueError("Schedule is already registered in autoTVM task %s" % name)
tmpl.schedule = f
tmpl.fschedule = f
return f
if func:
return _do_reg(func)
return _do_reg
def register_customized_task(name, func=None):
def _register_customized_task(name, func=None):
"""Register a customized function to AutoTVM task.
In most cases, you can just use register_topi_compute and register_topi_schedule
with the same task name to define an AutoTVM task. However, you can also
create a customized AutoTVM task that defines a tunable template or performs
extra layout transform before invoking compute/schedule function.
Parameters
----------
name: str
......@@ -297,6 +302,37 @@ def register_customized_task(name, func=None):
-------
decorator: callable
A decorator
"""
def _do_reg(f):
if name not in TASK_TABLE:
TASK_TABLE[name] = TaskTemplate()
tmpl = TASK_TABLE[name]
if tmpl.fcustomized is not None:
raise ValueError("Customized func is already registered in autoTVM task %s" % name)
tmpl.fcustomized = f
return f
if func:
return _do_reg(func)
return _do_reg
def template(task_name, func=None):
"""Decorate a function as a tunable schedule template.
Parameters
----------
task_name: str
The task name
func: None or callable
A callable template function.
If it is None, return a decorator.
If is callable, decorate this function.
Returns
-------
func: callable
The decorated function
Examples
--------
......@@ -304,7 +340,7 @@ def register_customized_task(name, func=None):
.. code-block:: python
@autotvm.register_customized_task("matmul")
@autotvm.template("matmul")
def matmul(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
......@@ -331,17 +367,22 @@ def register_customized_task(name, func=None):
return s, [A, B, C]
"""
def _do_reg(f):
if name not in TASK_TABLE:
TASK_TABLE[name] = TopiTemplate()
tmpl = TASK_TABLE[name]
if tmpl.customized_func is not None:
raise ValueError("Customized func is already registered in autoTVM task %s" % name)
tmpl.customized_func = f
return f
def _decorate(f):
def wrapper(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
workload = args_to_workload(args, task_name)
tgt = _target.Target.current()
cfg = DispatchContext.current.query(tgt, workload)
with ApplyConfig(cfg):
return f(*args, **kwargs)
_register_customized_task(task_name, f)
return wrapper
if func:
return _do_reg(func)
return _do_reg
return _decorate(func)
return _decorate
def create(task_name, args, target, target_host=None):
"""Create a tuning task and initialize its search space
......
......@@ -30,8 +30,8 @@ import tvm.te._ffi_api
from tvm import target as _target
from tvm.te import tensor
from .task import args_to_workload, DispatchContext, \
register_task_compute, register_task_schedule, serialize_args
from .task import args_to_workload, serialize_args, DispatchContext, \
_register_task_compute, _register_task_schedule
# Task extractor for relay program
......@@ -142,7 +142,7 @@ def register_topi_compute(task_name, func=None):
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
def _decorate(topi_compute):
@register_task_compute(task_name)
@_register_task_compute(task_name)
def wrapper(*args, **kwargs):
"""wrapper function for topi compute"""
assert not kwargs, "Do not support kwargs in template function call"
......@@ -212,7 +212,7 @@ def register_topi_schedule(task_name, func=None):
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
def _decorate(topi_schedule):
@register_task_schedule(task_name)
@_register_task_schedule(task_name)
def wrapper(outs, *args, **kwargs):
"""wrapper function for topi schedule"""
workload = get_workload(outs)
......
......@@ -26,7 +26,7 @@ from tvm import te
from tvm import autotvm
from tvm.autotvm.tuner import RandomTuner
@autotvm.register_customized_task("testing/conv2d_no_batching")
@autotvm.template("testing/conv2d_no_batching")
def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
"""An example template for testing"""
assert N == 1, "Only consider batch_size = 1 in this template"
......
......@@ -37,7 +37,7 @@ class DummyRunner(Runner):
def get_build_kwargs(self):
return {}
@autotvm.register_customized_task("testing/matmul")
@autotvm.template("testing/matmul")
def matmul(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
......@@ -64,7 +64,7 @@ def matmul(N, L, M, dtype):
return s, [A, B, C]
@autotvm.register_customized_task("testing/bad_matmul")
@autotvm.template("testing/bad_matmul")
def bad_matmul(N, L, M, dtype):
if 'bad_device' in tvm.target.Target.current().keys:
A = te.placeholder((N, L), name='A', dtype=dtype)
......
......@@ -22,7 +22,7 @@ from tvm import autotvm
def test_fallback():
@autotvm.register_customized_task("testing/dispatch/fallback")
@autotvm.template("testing/dispatch_fallback")
def simple_template(a, b):
cfg = autotvm.get_config()
assert cfg.is_fallback
......
......@@ -79,7 +79,7 @@ from tvm import autotvm
# can be very large (at the level of 10^9 for some input shapes)
#
@autotvm.register_customized_task("tutorial/conv2d_no_batching")
@autotvm.template("tutorial/conv2d_no_batching")
def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
assert N == 1, "Only consider batch_size = 1 in this template"
......
......@@ -103,7 +103,7 @@ def matmul_v0(N, L, M, dtype):
# In autotvm, we can define a tunable parameter, or a "knob" for such kind of value.
# Matmul V1: List candidate values
@autotvm.register_customized_task("tutorial/matmul_v1") # 1. use a decorator
@autotvm.template("tutorial/matmul_v1") # 1. use a decorator
def matmul_v1(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
......@@ -183,7 +183,7 @@ def matmul_v1(N, L, M, dtype):
# When the high level API cannot meet your requirement, you can always fall
# back to use low level API.
@autotvm.register_customized_task("tutorial/matmul")
@autotvm.template("tutorial/matmul")
def matmul(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
......
......@@ -95,7 +95,7 @@ def matmul_nn(A, B, L, dtype='float16', layout='NN'):
#
# We use AutoTVM to search for best configurations in this schedule.
@autotvm.register_customized_task("tutorial/test_gemm")
@autotvm.template("tutorial/auto_tensorcore/test_gemm")
def test_gemm(N, L, M, dtype, layout):
if (layout == "NN"):
shape_a = (N, L)
......@@ -265,7 +265,7 @@ elif dtype == 'int4' or dtype == 'int1':
assert(major == 7 and minor == 5 and layout == 'TN')
def tune_and_evaluate(M, N, L, dtype, layout):
task = autotvm.task.create("tutorial/test_gemm", args=(N, L, M, dtype, layout),
task = autotvm.task.create("tutorial/auto_tensorcore/test_gemm", args=(N, L, M, dtype, layout),
target='cuda')
print(task.config_space)
......
......@@ -310,7 +310,7 @@ def register_vta_tuning_tasks():
# init autotvm env to register VTA operator
TaskExtractEnv()
@autotvm.register_customized_task("conv2d_packed.vta")
@autotvm.template("conv2d_packed.vta")
def _topi_nn_conv2d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
A, W = args[:2]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment