Commit a0062582 by eqy Committed by Tianqi Chen

[RELAY][AUTOTVM] Extract tuning tasks from Relay programs (#2181)

parent 3cf910c8
......@@ -14,3 +14,4 @@ from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBe
from .topi_integration import register_topi_compute, register_topi_schedule
from .nnvm_integration import extract_from_graph, extract_from_multiple_graph
from .relay_integration import extract_from_program, extract_from_multiple_program
......@@ -7,208 +7,13 @@ import warnings
import logging
from ... import tensor, placeholder, create_schedule, target as _target
from ... import target as _target
from ..util import get_const_tuple
from .task import create, register
from .task import create
from .topi_integration import TaskExtractEnv
logger = logging.getLogger('autotvm')
def serialize_args(args):
"""serialize arguments of a topi function to a hashable tuple.
Parameters
----------
args: list of hashable or Tensor
"""
ret = []
for t in args:
if isinstance(t, tensor.Tensor):
ret.append(('TENSOR', get_const_tuple(t.shape), t.dtype))
else:
ret.append(t)
return tuple(ret)
def deserialize_args(args):
"""The inverse function of :code:`serialize_args`.
Parameters
----------
args: list of hashable or Tensor
"""
ret = []
for t in args:
if isinstance(t, tuple) and t[0] == 'TENSOR':
ret.append(placeholder(shape=t[1], dtype=t[2]))
else:
ret.append(t)
return ret
# Task extractor for nnvm graph
class TaskExtractEnv:
"""Global environment for extracting tuning tasks from nnvm graph"""
current = None
def __init__(self):
import topi
import nnvm
# NOTE: To add more symbols, you only need to change the following lists
# nnvm symbol -> topi compute
self.symbol2topi = {
nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw],
nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
nnvm.sym.dense: [topi.nn.dense],
}
# topi compute -> autotvm task name
self.topi_to_task = {
topi.nn.conv2d: "topi_nn_conv2d",
topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
topi.nn.dense: "topi_nn_dense",
}
self.topi_to_schedule = {
topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw,
topi.generic.schedule_conv2d_nhwc],
topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw,
topi.generic.schedule_depthwise_conv2d_nhwc],
topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
topi.nn.dense: [topi.generic.schedule_dense],
}
self._register_tracing()
self._register_topi_task()
self.task_collection = []
self.wanted_topi_funcs = list(self.topi_to_task.keys())
def _register_tracing(self):
"""Register tracing function to track the topi function call"""
# register topi compute for "tracing" target
for topi_compute in self.topi_to_task:
def _local_scope(compute_func):
"""start a scope to hold the local function in for loop"""
@compute_func.register("tracing", )
def _tracing_topi_compute(*args, **kwargs):
assert not kwargs, "Do not support extracting tuning tasks when" \
"kwargs is used in TOPI function call." \
"Please modify it to use only positional args."
if compute_func in self.wanted_topi_funcs: # record this call
key = (self.topi_to_task[compute_func], serialize_args(args))
if key not in self.task_collection:
self.task_collection.append(key)
return compute_func.fdefault(*args)
_local_scope(topi_compute)
# register topi schedule for "tracing" target
for topi_compute in self.topi_to_task:
for topi_schedule in self.topi_to_schedule[topi_compute]:
def _local_scope_(schedule_func):
"""start a scope to hold the local function in for loop"""
@schedule_func.register("tracing", )
def _tracing_topi_compute(outs):
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
return create_schedule([x.op for x in outs])
_local_scope_(topi_schedule)
def _register_topi_task(self):
"""register tuning wrapper for topi function"""
import topi
# Tuning wrapper for topi functions
@register("topi_nn_conv2d")
def _topi_nn_conv2d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
layout = args[-2]
assert layout == 'NCHW', "only support NCHW currently"
C = topi.nn.conv2d(*args, **kwargs)
s = topi.generic.schedule_conv2d_nchw([C])
return s, [A, W, C]
@register("topi_nn_depthwise_conv2d_nchw")
def _topi_nn_depthwise_conv2d_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.depthwise_conv2d_nchw(*args, **kwargs)
s = topi.generic.schedule_depthwise_conv2d_nchw([C])
return s, [A, W, C]
@register("topi_nn_group_conv2d_nchw")
def _topi_nn_group_conv2d_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.group_conv2d_nchw(*args, **kwargs)
s = topi.generic.schedule_group_conv2d_nchw([C])
return s, [A, W, C]
@register("topi_nn_conv2d_transpose_nchw")
def _topi_nn_conv2d_transpose_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.conv2d_transpose_nchw(*args, **kwargs)
s = topi.generic.schedule_conv2d_transpose_nchw([C])
return s, [A, W, C]
@register("topi_nn_dense")
def _topi_nn_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
data, weight, bias = args
C = topi.nn.dense(*args, **kwargs)
s = topi.generic.schedule_dense([C])
if bias is not None:
return s, [data, weight, bias, C]
return s, [data, weight, C]
def reset(self, wanted_topi_funcs):
"""Reset task collections
Parameters
----------
wanted_topi_funcs: List of function
The topi function to be extracted
"""
self.task_collection = []
self.wanted_topi_funcs = wanted_topi_funcs
def get_tasks(self):
"""Get collected tasks
Returns
-------
tasks: List of tuple(name, args)
A list of tasks extracted from the nnvm graph
"""
return self.task_collection
@staticmethod
def get():
"""Get the single instance of TaskExtractEnv
Returns
-------
env: TaskExtractEnv
The single instance of TaskExtractEnv
"""
if not TaskExtractEnv.current:
TaskExtractEnv.current = TaskExtractEnv()
return TaskExtractEnv.current
def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
""" Extract tuning tasks from a nnvm graph.
......@@ -237,13 +42,24 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
collected tasks
"""
import nnvm.compiler
import nnvm
import topi
env = TaskExtractEnv.get()
#NOTE: To add more symbols, you only need to change the following lists
#nnvm symbol -> topi compute
SYMBOL2TOPI = {
nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw],
nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
nnvm.sym.dense: [topi.nn.dense],
}
topi_funcs = []
for sym_name in symbols:
if sym_name in env.symbol2topi:
topi_funcs.extend(env.symbol2topi[sym_name])
if sym_name in SYMBOL2TOPI:
topi_funcs.extend(SYMBOL2TOPI[sym_name])
else:
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
......@@ -297,13 +113,24 @@ def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_
collected tasks
"""
import nnvm.compiler
import nnvm
import topi
env = TaskExtractEnv.get()
#NOTE: To add more symbols, you only need to change the following lists
#nnvm symbol -> topi compute
SYMBOL2TOPI = {
nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw],
nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
nnvm.sym.dense: [topi.nn.dense],
}
topi_funcs = []
for sym_name in symbols:
if sym_name in env.symbol2topi:
topi_funcs.extend(env.symbol2topi[sym_name])
if sym_name in SYMBOL2TOPI:
topi_funcs.extend(SYMBOL2TOPI[sym_name])
else:
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
......
# pylint: disable=unused-variable,invalid-name
"""
Decorator and utilities for the integration with TOPI and Relay
99.9% copy-paste of implementation by @MerryMercy
"""
import threading
import warnings
import logging
from ... import tensor, placeholder, target as _target
from .task import create
from .topi_integration import TaskExtractEnv
logger = logging.getLogger('autotvm')
def serialize_args(args):
"""serialize arguments of a topi function to a hashable tuple.
Parameters
----------
args: list of hashable or Tensor
"""
ret = []
for t in args:
if isinstance(t, tensor.Tensor):
ret.append(('TENSOR', get_const_tuple(t.shape), t.dtype))
else:
ret.append(t)
return tuple(ret)
def deserialize_args(args):
"""The inverse function of :code:`serialize_args`.
Parameters
----------
args: list of hashable or Tensor
"""
ret = []
for t in args:
if isinstance(t, tuple) and t[0] == 'TENSOR':
ret.append(placeholder(shape=t[1], dtype=t[2]))
else:
ret.append(t)
return ret
def extract_from_program(func, params, ops, target, target_host=None):
""" Extract tuning tasks from a relay program.
This function collects tuning tasks by building the program
with a "tracing" target and tracing all the calls to topi.
Parameters
----------
func: relay.expr.Function
The func to tune
params: dict of str to numpy array
The associated parameters of the program
ops: List of relay op
List of relay ops to be tuned
dtype: str or dict of str to str
The input types to the program
target: tvm.target.Target
The compilation target
target_host: tvm.target.Target
The host compilation target
Returns
-------
task: Array of autotvm.task.Task
collected tasks
"""
env = TaskExtractEnv.get()
import tvm.relay.op
from tvm import relay
import topi
# NOTE: To add more ops, you only need to change the following lists
# relay op -> topi compute
OP2TOPI = {
tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
}
topi_funcs = []
for op_name in ops:
if op_name in OP2TOPI:
topi_funcs.extend(OP2TOPI[op_name])
else:
warnings.warn("Op %s is not tunable, ignored" % op_name)
# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=relay.build, args=(func,
tracing_target,
target_host,
params))
build_thread.start()
build_thread.join()
logger.disabled = old_state
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
return tasks
def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
""" Extract tuning tasks from multiple relay programs.
This function is the multiple program version of extract_from_program
Parameters
----------
funcs: List of relay.expr.Function
The list of functions to tune
params: List of dict of str to numpy array
The associated parameters of the programs
ops: List of relay op
List of relay ops to be tuned
target: tvm.target.Target
The compilation target
target_host: tvm.target.Target
The host compilation target
Returns
-------
task: Array of autotvm.task.Task
collected tasks
"""
env = TaskExtractEnv.get()
import tvm.relay.op
from tvm import relay
import topi
# NOTE: To add more ops, you only need to change the following lists
# relay op -> topi compute
OP2TOPI = {
tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
}
topi_funcs = []
for op_name in ops:
if op_name in OP2TOPI:
topi_funcs.extend(OP2TOPI[op_name])
else:
warnings.warn("Op %s is not tunable, ignored" % op_name)
# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
for func, param in zip(funcs, params):
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=relay.build, args=(func,
tracing_target,
target_host,
params))
build_thread.start()
build_thread.join()
logger.disabled = old_state
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
return tasks
......@@ -11,16 +11,202 @@ tuple.
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
from ... import _api_internal, tensor
from .task import args_to_workload, dispatcher
from ... import _api_internal, tensor, placeholder, create_schedule
from .task import args_to_workload, dispatcher, register
from ..util import get_const_tuple
# A table that records all registered dispatcher for all targets
_REGISTED_DISPATHCER = {
}
def serialize_args(args):
"""serialize arguments of a topi function to a hashable tuple.
Parameters
----------
args: list of hashable or Tensor
"""
ret = []
for t in args:
if isinstance(t, tensor.Tensor):
ret.append(('TENSOR', get_const_tuple(t.shape), t.dtype))
else:
ret.append(t)
return tuple(ret)
def deserialize_args(args):
"""The inverse function of :code:`serialize_args`.
Parameters
----------
args: list of hashable or Tensor
"""
ret = []
for t in args:
if isinstance(t, tuple) and t[0] == 'TENSOR':
ret.append(placeholder(shape=t[1], dtype=t[2]))
else:
ret.append(t)
return ret
# Task extractor for nnvm graph, relay program
class TaskExtractEnv:
"""Global environment for extracting tuning tasks from nnvm graph"""
current = None
def __init__(self):
import topi
# topi compute -> autotvm task name
self.topi_to_task = {
topi.nn.conv2d: "topi_nn_conv2d",
topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
topi.nn.dense: "topi_nn_dense",
}
self.topi_to_schedule = {
topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw,
topi.generic.schedule_conv2d_nhwc],
topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw,
topi.generic.schedule_depthwise_conv2d_nhwc],
topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
topi.nn.dense: [topi.generic.schedule_dense],
}
self._register_tracing()
self._register_topi_task()
self.task_collection = []
self.wanted_topi_funcs = list(self.topi_to_task.keys())
def _register_tracing(self):
"""Register tracing function to track the topi function call"""
# register topi compute for "tracing" target
for topi_compute in self.topi_to_task:
def _local_scope(compute_func):
"""start a scope to hold the local function in for loop"""
@compute_func.register("tracing", )
def _tracing_topi_compute(*args, **kwargs):
assert not kwargs, "Do not support extracting tuning tasks when" \
"kwargs is used in TOPI function call." \
"Please modify it to use only positional args."
if compute_func in self.wanted_topi_funcs: # record this call
key = (self.topi_to_task[compute_func], serialize_args(args))
if key not in self.task_collection:
self.task_collection.append(key)
return compute_func.fdefault(*args)
_local_scope(topi_compute)
# register topi schedule for "tracing" target
for topi_compute in self.topi_to_task:
for topi_schedule in self.topi_to_schedule[topi_compute]:
def _local_scope_(schedule_func):
"""start a scope to hold the local function in for loop"""
@schedule_func.register("tracing", )
def _tracing_topi_compute(outs):
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
return create_schedule([x.op for x in outs])
_local_scope_(topi_schedule)
def _register_topi_task(self):
"""register tuning wrapper for topi function"""
import topi
# Tuning wrapper for topi functions
@register("topi_nn_conv2d")
def _topi_nn_conv2d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
layout = args[-2]
assert layout == 'NCHW', "only support NCHW currently"
C = topi.nn.conv2d(*args, **kwargs)
s = topi.generic.schedule_conv2d_nchw([C])
return s, [A, W, C]
@register("topi_nn_depthwise_conv2d_nchw")
def _topi_nn_depthwise_conv2d_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.depthwise_conv2d_nchw(*args, **kwargs)
s = topi.generic.schedule_depthwise_conv2d_nchw([C])
return s, [A, W, C]
@register("topi_nn_group_conv2d_nchw")
def _topi_nn_group_conv2d_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.group_conv2d_nchw(*args, **kwargs)
s = topi.generic.schedule_group_conv2d_nchw([C])
return s, [A, W, C]
@register("topi_nn_conv2d_transpose_nchw")
def _topi_nn_conv2d_transpose_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.conv2d_transpose_nchw(*args, **kwargs)
s = topi.generic.schedule_conv2d_transpose_nchw([C])
return s, [A, W, C]
@register("topi_nn_dense")
def _topi_nn_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
data, weight, bias = args
C = topi.nn.dense(*args, **kwargs)
s = topi.generic.schedule_dense([C])
if bias is not None:
return s, [data, weight, bias, C]
return s, [data, weight, C]
def reset(self, wanted_topi_funcs):
"""Reset task collections
Parameters
----------
wanted_topi_funcs: List of function
The topi function to be extracted
"""
self.task_collection = []
self.wanted_topi_funcs = wanted_topi_funcs
def get_tasks(self):
"""Get collected tasks
Returns
-------
tasks: List of tuple(name, args)
A list of tasks extracted from the nnvm graph
"""
return self.task_collection
@staticmethod
def get():
"""Get the single instance of TaskExtractEnv
Returns
-------
env: TaskExtractEnv
The single instance of TaskExtractEnv
"""
if not TaskExtractEnv.current:
TaskExtractEnv.current = TaskExtractEnv()
return TaskExtractEnv.current
def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
"""Register a tunable template for a topi compute function.
......
"""Test task extraction for autotvm"""
import tvm.relay.testing
from tvm import relay
from tvm import autotvm
def get_network(name, batch_size):
"""Get the symbol definition and random weight of a network"""
input_shape = (batch_size, 3, 224, 224)
if name == 'resnet-18':
net, params = relay.testing.resnet.get_workload(num_layers=18, batch_size=batch_size)
elif name == 'mobilenet':
net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size)
elif name == 'dcgan':
net, params = relay.testing.dcgan.get_workload(batch_size=batch_size)
input_shape = (batch_size, 100)
else:
raise ValueError("Unsupported network: " + name)
return net, params, input_shape
def test_task_extraction():
target = 'llvm'
net, params, input_shape = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target,
params=params,
ops=(relay.op.nn.conv2d,))
assert len(tasks) == 12
net, params, input_shape = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target,
params=params,
ops=(relay.op.nn.dense,))
assert len(tasks) == 1
net, params, input_shape = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
assert len(tasks) == 13
net, params, input_shape = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
assert len(tasks) == 20
net, params, input_shape = get_network('dcgan', batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target,
params=params,
ops=(relay.op.nn.conv2d_transpose,))
assert len(tasks) == 4
if __name__ == '__main__':
test_task_extraction()
......@@ -2,7 +2,7 @@
"""Conv2D schedule on x86"""
import tvm
from tvm import autotvm
from tvm.autotvm.task.nnvm_integration import deserialize_args
from tvm.autotvm.task.topi_integration import deserialize_args
from tvm.autotvm.task import get_config
from .. import generic, tag
from .. import nn
......
......@@ -4,7 +4,7 @@ import tvm
from tvm import autotvm
from tvm.autotvm.task import get_config
from tvm.autotvm.task.space import SplitEntity
from tvm.autotvm.task.nnvm_integration import deserialize_args
from tvm.autotvm.task.topi_integration import deserialize_args
from .. import generic, tag
from ..nn.pad import pad
from ..util import get_const_tuple
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment