Commit ccde31f1 by 黎明灰烬 Committed by Wuwei Lin

AutoTVM: selecting tuning templates when extracting task (#4338)

* AutoTVM: selecting tuning templates when extracting task

Make the procedure of trying new templates easier.

Test: tests/python/relay/test_autotvm_task_extraction.py

* Use dict to match key for topi ops

* fix lint issue

* be more pythonic :)
parent 0a9f7e9a
...@@ -54,7 +54,8 @@ def _lower(func, ...@@ -54,7 +54,8 @@ def _lower(func,
return grc.codegen(mod["main"]) return grc.codegen(mod["main"])
def extract_from_program(func, params, ops, target, target_host=None): def extract_from_program(func, params, ops, target, target_host=None,
template_keys=None):
""" Extract tuning tasks from a relay program. """ Extract tuning tasks from a relay program.
This function is the single program version of extract_from_multiple_program. This function is the single program version of extract_from_multiple_program.
...@@ -71,16 +72,21 @@ def extract_from_program(func, params, ops, target, target_host=None): ...@@ -71,16 +72,21 @@ def extract_from_program(func, params, ops, target, target_host=None):
The compilation target The compilation target
target_host: tvm.target.Target target_host: tvm.target.Target
The host compilation target The host compilation target
template_keys: dict of topi op to str
The tuning template keys map for schedules, default to None.
Example: {topi.nn.conv2d: 'direct'}
Returns Returns
------- -------
task: Array of autotvm.task.Task task: Array of autotvm.task.Task
collected tasks collected tasks
""" """
return extract_from_multiple_program([func], [params], ops, target, target_host) return extract_from_multiple_program([func], [params], ops, target, target_host,
template_keys=template_keys)
def extract_from_multiple_program(funcs, params, ops, target, target_host=None): def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
template_keys=None):
""" Extract tuning tasks from multiple relay programs. """ Extract tuning tasks from multiple relay programs.
This function collects tuning tasks by building a list of programs This function collects tuning tasks by building a list of programs
...@@ -98,6 +104,9 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None): ...@@ -98,6 +104,9 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
The compilation target The compilation target
target_host: tvm.target.Target target_host: tvm.target.Target
The host compilation target The host compilation target
template_keys: dict of topi op to str
The tuning template keys map for schedules, default to None.
Example: {topi.nn.conv2d: 'direct'}
Returns Returns
------- -------
...@@ -146,15 +155,26 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None): ...@@ -146,15 +155,26 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
logger.disabled = old_state logger.disabled = old_state
# convert *topi op to template key* map to *task name to template key* map
task_name_to_keys = {}
if template_keys is not None:
for op in template_keys.keys():
if op in env.topi_to_task:
task_name_to_keys[env.topi_to_task[op]] = template_keys[op]
else:
logger.warning("Invalid template key, fallback to direct")
task_name_to_keys[env.topi_to_task[op]] = 'direct'
# create tasks for target # create tasks for target
tasks = [] tasks = []
for task_name, args in env.get_tasks(): for task_name, args in env.get_tasks():
try: try:
key = task_name_to_keys[task_name] if task_name in task_name_to_keys else 'direct'
tsk = create(task_name, args, tsk = create(task_name, args,
target=target, target_host=target_host, target=target, target_host=target_host,
template_key='direct') template_key=key)
tasks.append(tsk) tasks.append(tsk)
except topi.InvalidShapeError: except topi.InvalidShapeError:
print("[Warning] Invalid shape during AutoTVM task creation") logger.warning("Invalid shape during AutoTVM task creation")
return tasks return tasks
...@@ -79,5 +79,51 @@ def test_task_extraction(): ...@@ -79,5 +79,51 @@ def test_task_extraction():
ops=(relay.op.nn.conv2d,)) ops=(relay.op.nn.conv2d,))
assert len(tasks) == 31 assert len(tasks) == 31
def test_template_key_provided():
"""test task extraction using non-'direct' template_key"""
target = 'llvm'
import topi
template_keys = {
# topi.nn.conv2d - is left blank to test fallback logic
topi.nn.dense: 'direct_nopack',
topi.nn.depthwise_conv2d_nchw: 'direct',
}
mod, params, _ = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense),
template_keys=template_keys)
for task in tasks:
if 'dense' in task.name:
assert task.config_space.template_key == 'direct_nopack'
else:
assert task.config_space.template_key == 'direct'
def test_template_key_empty():
"""test task extraction using empty template_key"""
target = 'llvm'
mod, params, _ = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense),
template_keys=None)
for task in tasks:
assert task.config_space.template_key == 'direct'
def test_template_key_default():
"""test task extraction without template_key"""
target = 'llvm'
mod, params, _ = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
for task in tasks:
assert task.config_space.template_key == 'direct'
if __name__ == '__main__': if __name__ == '__main__':
test_task_extraction() test_task_extraction()
test_template_key_provided()
test_template_key_empty()
test_template_key_default()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment