Commit ccde31f1 by 黎明灰烬 Committed by Wuwei Lin

AutoTVM: selecting tuning templates when extracting task (#4338)

* AutoTVM: selecting tuning templates when extracting task

Make the procedure of trying new templates easier.

Test: tests/python/relay/test_autotvm_task_extraction.py

* Use dict to match key for topi ops

* fix lint issue

* be more pythonic :)
parent 0a9f7e9a
......@@ -54,7 +54,8 @@ def _lower(func,
return grc.codegen(mod["main"])
def extract_from_program(func, params, ops, target, target_host=None):
def extract_from_program(func, params, ops, target, target_host=None,
template_keys=None):
""" Extract tuning tasks from a relay program.
This function is the single program version of extract_from_multiple_program.
......@@ -71,16 +72,21 @@ def extract_from_program(func, params, ops, target, target_host=None):
The compilation target
target_host: tvm.target.Target
The host compilation target
template_keys: dict of topi op to str
The tuning template keys map for schedules, default to None.
Example: {topi.nn.conv2d: 'direct'}
Returns
-------
task: Array of autotvm.task.Task
collected tasks
"""
return extract_from_multiple_program([func], [params], ops, target, target_host)
return extract_from_multiple_program([func], [params], ops, target, target_host,
template_keys=template_keys)
def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
template_keys=None):
""" Extract tuning tasks from multiple relay programs.
This function collects tuning tasks by building a list of programs
......@@ -98,6 +104,9 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
The compilation target
target_host: tvm.target.Target
The host compilation target
template_keys: dict of topi op to str
The tuning template keys map for schedules, default to None.
Example: {topi.nn.conv2d: 'direct'}
Returns
-------
......@@ -146,15 +155,26 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
logger.disabled = old_state
# convert *topi op to template key* map to *task name to template key* map
task_name_to_keys = {}
if template_keys is not None:
for op in template_keys.keys():
if op in env.topi_to_task:
task_name_to_keys[env.topi_to_task[op]] = template_keys[op]
else:
logger.warning("Invalid template key, fallback to direct")
task_name_to_keys[env.topi_to_task[op]] = 'direct'
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
try:
key = task_name_to_keys[task_name] if task_name in task_name_to_keys else 'direct'
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
template_key=key)
tasks.append(tsk)
except topi.InvalidShapeError:
print("[Warning] Invalid shape during AutoTVM task creation")
logger.warning("Invalid shape during AutoTVM task creation")
return tasks
......@@ -79,5 +79,51 @@ def test_task_extraction():
ops=(relay.op.nn.conv2d,))
assert len(tasks) == 31
def test_template_key_provided():
"""test task extraction using non-'direct' template_key"""
target = 'llvm'
import topi
template_keys = {
# topi.nn.conv2d - is left blank to test fallback logic
topi.nn.dense: 'direct_nopack',
topi.nn.depthwise_conv2d_nchw: 'direct',
}
mod, params, _ = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense),
template_keys=template_keys)
for task in tasks:
if 'dense' in task.name:
assert task.config_space.template_key == 'direct_nopack'
else:
assert task.config_space.template_key == 'direct'
def test_template_key_empty():
"""test task extraction using empty template_key"""
target = 'llvm'
mod, params, _ = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense),
template_keys=None)
for task in tasks:
assert task.config_space.template_key == 'direct'
def test_template_key_default():
"""test task extraction without template_key"""
target = 'llvm'
mod, params, _ = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
for task in tasks:
assert task.config_space.template_key == 'direct'
if __name__ == '__main__':
test_task_extraction()
test_template_key_provided()
test_template_key_empty()
test_template_key_default()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment