Commit a21904a5 by Cody Hao Yu Committed by eqy

merge extract_from_program and extract_from_multiple_progam (#4173)

parent c3f02c4b
......@@ -52,8 +52,7 @@ def _build(func,
def extract_from_program(func, params, ops, target, target_host=None):
""" Extract tuning tasks from a relay program.
This function collects tuning tasks by building the program
with a "tracing" target and tracing all the calls to topi.
This function is the single program version of extract_from_multiple_program.
Parameters
----------
......@@ -73,66 +72,14 @@ def extract_from_program(func, params, ops, target, target_host=None):
task: Array of autotvm.task.Task
collected tasks
"""
import tvm.relay.op
from tvm import relay
import topi
env = TaskExtractEnv.get()
# NOTE: To add more ops, you only need to change the following lists
# relay op -> topi compute
OP2TOPI = {
tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
}
topi_funcs = []
for op_name in ops:
if op_name in OP2TOPI:
topi_funcs.extend(OP2TOPI[op_name])
else:
warnings.warn("Op %s is not tunable, ignored" % op_name)
# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
with env:
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
mod = relay.Module.from_expr(func)
build_thread = threading.Thread(target=_build,
args=(mod,
target,
target_host,
params))
build_thread.start()
build_thread.join()
logger.disabled = old_state
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
try:
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
warnings.warn("Invalid shape during AutoTVM task creation")
return tasks
return extract_from_multiple_program([func], [params], ops, target, target_host)
def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
""" Extract tuning tasks from multiple relay programs.
This function is the multiple program version of extract_from_program
This function collects tuning tasks by building a list of programs
with a "tracing" target and tracing all the calls to topi.
Parameters
----------
......@@ -152,19 +99,20 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
task: Array of autotvm.task.Task
collected tasks
"""
env = TaskExtractEnv.get()
import tvm.relay.op
from tvm import relay
import topi
env = TaskExtractEnv.get()
# NOTE: To add more ops, you only need to change the following lists
# relay op -> topi compute
OP2TOPI = {
tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw],
topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
tvm.relay.op.nn.contrib_deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
}
topi_funcs = []
......@@ -185,11 +133,8 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
mod = relay.Module.from_expr(func)
build_thread = threading.Thread(target=my_build,
args=(mod,
target,
target_host,
params))
build_thread = threading.Thread(target=_build,
args=(mod, target, target_host, param))
build_thread.start()
build_thread.join()
......
......@@ -37,36 +37,47 @@ def get_network(name, batch_size):
def test_task_extraction():
target = 'llvm'
mod_list = []
params_list = []
mod, params, input_shape = get_network('resnet-18', batch_size=1)
mod, params, _ = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.nn.conv2d,))
assert len(tasks) == 12
mod, params, input_shape = get_network('resnet-18', batch_size=1)
mod, params, _ = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.nn.dense,))
assert len(tasks) == 1
mod, params, input_shape = get_network('resnet-18', batch_size=1)
mod, params, _ = get_network('resnet-18', batch_size=1)
mod_list.append(mod)
params_list.append(params)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
assert len(tasks) == 13
mod, params, input_shape = get_network('mobilenet', batch_size=1)
mod, params, _ = get_network('mobilenet', batch_size=1)
mod_list.append(mod)
params_list.append(params)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
assert len(tasks) == 20
mod, params, input_shape = get_network('dcgan', batch_size=1)
mod, params, _ = get_network('dcgan', batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.nn.conv2d_transpose,))
assert len(tasks) == 4
tasks = autotvm.task.extract_from_multiple_program([m['main'] for m in mod_list], params_list,
target=target,
ops=(relay.op.nn.conv2d,))
assert len(tasks) == 31
if __name__ == '__main__':
test_task_extraction()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment