Commit b7beb1eb by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Allow fallback for template & Fix bugs in tuners (#1615)

* support fallback & fix bugs in tuners & clean topi test

* update task extraction

* update task extraction

* fix arm tutorial

* Update tune_nnvm_arm.py
parent 729224b1
...@@ -239,8 +239,9 @@ def build(graph, target=None, shape=None, dtype="float32", ...@@ -239,8 +239,9 @@ def build(graph, target=None, shape=None, dtype="float32",
raise ValueError("Target is not set in env or passed as argument.") raise ValueError("Target is not set in env or passed as argument.")
target = tvm.target.create(target) target = tvm.target.create(target)
# if not inside an autotvm config dispatch context, load pre-tuned parameters from TopHub # If current dispatch context is fallback context (the default root context),
if autotvm.task.DispatchContext.current is None: # then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(target) tophub_context = autotvm.tophub.context(target)
else: else:
tophub_context = autotvm.util.EmptyContext() tophub_context = autotvm.util.EmptyContext()
......
"""Test task extraction for autotvm"""
import nnvm.testing
import nnvm.compiler
from tvm import autotvm
def get_network(name, batch_size):
"""Get the symbol definition and random weight of a network"""
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)
if name == 'resnet-18':
net, params = nnvm.testing.resnet.get_workload(num_layers=18, batch_size=batch_size)
elif name == 'mobilenet':
net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size)
elif name == 'squeezenet v1.1':
net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1')
elif name == 'vgg-16':
net, params = nnvm.testing.vgg.get_workload(num_layers=16, batch_size=batch_size)
elif name == 'dcgan':
net, params = nnvm.testing.dcgan.get_workload(batch_size=batch_size)
input_shape = (batch_size, 100)
else:
raise ValueError("Unsupported network: " + name)
return net, params, input_shape, output_shape
def test_task_extraction():
target = 'llvm'
dtype = 'float32'
net, params, input_shape, out_shape = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_graph(net, target=target,
shape={'data': input_shape}, dtype=dtype,
symbols=(nnvm.sym.conv2d,))
assert len(tasks) == 12
net, params, input_shape, out_shape = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_graph(net, target=target,
shape={'data': input_shape}, dtype=dtype,
symbols=(nnvm.sym.dense,))
assert len(tasks) == 1
net, params, input_shape, out_shape = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_graph(net, target=target,
shape={'data': input_shape}, dtype=dtype,
symbols=(nnvm.sym.conv2d, nnvm.sym.dense))
assert len(tasks) == 13
net, params, input_shape, out_shape = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_graph(net, target=target,
shape={'data': input_shape}, dtype=dtype,
symbols=(nnvm.sym.conv2d, nnvm.sym.dense))
assert len(tasks) == 20
net, params, input_shape, out_shape = get_network('dcgan', batch_size=1)
tasks = autotvm.task.extract_from_graph(net, target=target,
shape={'data': input_shape}, dtype=dtype,
symbols=(nnvm.sym.conv2d_transpose,))
assert len(tasks) == 4
if __name__ == '__main__':
test_task_extraction()
...@@ -25,5 +25,6 @@ from . import tophub ...@@ -25,5 +25,6 @@ from . import tophub
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
from .tuner import callback from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \ from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
ApplyHistoryBest as apply_history_best register_topi_compute, register_topi_schedule, \
DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best
from .env import GLOBAL_SCOPE from .env import GLOBAL_SCOPE
...@@ -89,8 +89,9 @@ def measure_option(measure_func, ...@@ -89,8 +89,9 @@ def measure_option(measure_func,
callable: customized build function for other backends (e.g. VTA). callable: customized build function for other backends (e.g. VTA).
See measure/measure_methods.py::default_build_func for example. See measure/measure_methods.py::default_build_func for example.
check_correctness: bool check_correctness: bool, optional
Whether check correctness after measurement. This will use llvm cpu as reference. Whether check correctness after measurement. This will use llvm cpu target to generate
reference output.
replay_db : Database, optional replay_db : Database, optional
The database that we retrieve saved MeasureResult from. The database that we retrieve saved MeasureResult from.
......
...@@ -83,7 +83,7 @@ def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10): ...@@ -83,7 +83,7 @@ def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10):
The priority of this request, larger is more prior The priority of this request, larger is more prior
timeout: float, optional timeout: float, optional
The timeout of this check (units: seconds). The timeout of this check (units: seconds).
If time is out, a RuntimerError will be raised. If time is out, a RuntimeError will be raised.
""" """
def _check(): def _check():
remote = request_remote(device_key, tracker_addr, priority) remote = request_remote(device_key, tracker_addr, priority)
...@@ -281,11 +281,11 @@ def rpc(key, ...@@ -281,11 +281,11 @@ def rpc(key,
results: List of MeasureResult results: List of MeasureResult
The results for input_pack The results for input_pack
""" """
remote = request_remote(key, (host, port), priority, session_timeout) remote_args = (key, (host, port), priority, session_timeout)
res = _measure_common(input_pack, build_func, build_kwargs, number, repeat, res = _measure_common(input_pack, build_func, build_kwargs, number, repeat,
ref_input, ref_output, ref_input, ref_output,
remote) remote_args)
return res return res
fmeasure.pack_size = pack_size fmeasure.pack_size = pack_size
...@@ -294,7 +294,7 @@ def rpc(key, ...@@ -294,7 +294,7 @@ def rpc(key,
def _measure_common(input_pack, build_func, build_kwargs, number, repeat, def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
ref_input=None, ref_output=None, remote=None): ref_input=None, ref_output=None, remote_args=None):
"""Measure the time cost for a pack of inputs. """Measure the time cost for a pack of inputs.
(Note: A pack is a list of inputs which will be measured inside a same RPC session) (Note: A pack is a list of inputs which will be measured inside a same RPC session)
...@@ -318,8 +318,8 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat, ...@@ -318,8 +318,8 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
Reference input for checking correctness Reference input for checking correctness
ref_output: Array of np.ndarray, optional ref_output: Array of np.ndarray, optional
Reference output for checking correctness Reference output for checking correctness
remote: RPCSession, optional remote_args: Tuple, optional
The remote RPC session The arguments to request_remote. If is not None, will use remote rpc devices.
Returns Returns
------- -------
...@@ -327,7 +327,8 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat, ...@@ -327,7 +327,8 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
The list of results of measurement. The list of results of measurement.
""" """
res_pack = [] res_pack = []
tmp_dir = util.tempdir() if remote else None tmp_dir = util.tempdir() if remote_args else None
assert len(input_pack) == 1, "Only supports input_pack == 1 for now"
for inp in input_pack: for inp in input_pack:
tic = time.time() tic = time.time()
...@@ -360,31 +361,36 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat, ...@@ -360,31 +361,36 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
tstamp - tic, tstamp)) tstamp - tic, tstamp))
continue continue
# upload built module
if remote:
remote.upload(tmp_dir.relpath(filename))
func = remote.load_module(filename)
ctx = remote.context(str(inp.target), 0)
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat)
else:
ctx = context(str(inp.target), 0)
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat)
# measure time # measure time
errno = MeasureErrorNo.NO_ERROR errno = MeasureErrorNo.NO_ERROR
try: try:
# upload built module
if remote_args:
remote = request_remote(*remote_args)
remote.upload(tmp_dir.relpath(filename))
func = remote.load_module(filename)
ctx = remote.context(str(inp.target), 0)
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat)
else:
ctx = context(str(inp.target), 0)
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat)
# set input
if ref_input: if ref_input:
args = [nd.array(x, ctx=ctx) for x in ref_input] args = [nd.array(x, ctx=ctx) for x in ref_input]
else: else:
args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype, ctx=ctx) args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype, ctx=ctx)
for x in arg_bufs] for x in arg_bufs]
costs = time_f(*args).results costs = time_f(*args).results
if len(costs) > 2: # remove largest and smallest value to reduce variance if len(costs) > 2: # remove largest and smallest value to reduce variance
costs = list(costs) costs = list(costs)
costs.sort() costs.sort()
costs = tuple(costs[1:-1]) costs = tuple(costs[1:-1])
# check correctness of output
if ref_output: if ref_output:
for expected, real in zip(ref_output, args): for expected, real in zip(ref_output, args):
if not np.allclose(expected, real.asnumpy(), rtol=1e-4): if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
......
...@@ -9,7 +9,7 @@ of typical tasks of interest. ...@@ -9,7 +9,7 @@ of typical tasks of interest.
from .task import Task, create, register, template, get_config, args_to_workload from .task import Task, create, register, template, get_config, args_to_workload
from .space import ConfigSpace, ConfigEntity from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, dispatcher from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, FallbackContext, dispatcher
from .topi_integration import register_topi_compute, register_topi_schedule from .topi_integration import register_topi_compute, register_topi_schedule
from .nnvm_integration import extract_from_graph from .nnvm_integration import extract_from_graph
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
from tvm import target as _target from tvm import target as _target
from .space import ConfigSpace from .space import FallbackConfigEntity
logger = logging.getLogger('autotvm') logger = logging.getLogger('autotvm')
...@@ -34,9 +34,36 @@ class DispatchContext(object): ...@@ -34,9 +34,36 @@ class DispatchContext(object):
""" """
current = None current = None
def __init__(self):
self._old_ctx = DispatchContext.current
def query(self, target, workload): def query(self, target, workload):
""" """
Query the context to get the specific implementation. Query the context to get the specific config for a template.
If cannot find the result inside this context, this function will query it
from the upper contexts.
Parameters
----------
target: Target
The current target
workload : Workload
The current workload.
Returns
-------
cfg : ConfigSpace
The specific configuration.
"""
ret = self._query_inside(target, workload)
if ret is None:
ret = self._old_ctx.query(target, workload)
return ret
def _query_inside(self, target, workload):
"""
Query the context to get the specific config for a template.
This function only query config inside this context.
Parameters Parameters
---------- ----------
...@@ -117,17 +144,17 @@ def dispatcher(fworkload): ...@@ -117,17 +144,17 @@ def dispatcher(fworkload):
def dispatch_func(func, *args, **kwargs): def dispatch_func(func, *args, **kwargs):
"""The wrapped dispatch function""" """The wrapped dispatch function"""
tgt = _target.current_target() tgt = _target.current_target()
context = DispatchContext.current
if context is None:
raise RuntimeError("DispatchContext is not initialized")
workload = func(*args, **kwargs) workload = func(*args, **kwargs)
cfg = context.query(tgt, workload) cfg = DispatchContext.current.query(tgt, workload)
if cfg.template_key: if cfg.is_fallback and not cfg.template_key:
return dispatch_dict[cfg.template_key](cfg, *args, **kwargs) # first try 'direct' template
else: if 'direct' in dispatch_dict:
assert dispatch_dict, "No func registered for this dispatcher" return dispatch_dict['direct'](cfg, *args, **kwargs)
# otherwise pick a random template
for v in dispatch_dict.values(): for v in dispatch_dict.values():
return v(cfg, *args, **kwargs) return v(cfg, *args, **kwargs)
else:
return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)
fdecorate = decorate(fworkload, dispatch_func) fdecorate = decorate(fworkload, dispatch_func)
fdecorate.register = register fdecorate.register = register
...@@ -135,7 +162,7 @@ def dispatcher(fworkload): ...@@ -135,7 +162,7 @@ def dispatcher(fworkload):
class ApplyConfig(DispatchContext): class ApplyConfig(DispatchContext):
"""Apply a specific config entity during query. """Apply a deterministic config entity for all queries.
Parameters Parameters
---------- ----------
...@@ -147,7 +174,7 @@ class ApplyConfig(DispatchContext): ...@@ -147,7 +174,7 @@ class ApplyConfig(DispatchContext):
self._config = config self._config = config
self.workload = None self.workload = None
def query(self, target, workload): def _query_inside(self, target, workload):
"""Override query""" """Override query"""
self.workload = workload self.workload = workload
return self._config return self._config
...@@ -164,20 +191,12 @@ class ApplyHistoryBest(DispatchContext): ...@@ -164,20 +191,12 @@ class ApplyHistoryBest(DispatchContext):
If is str, then it should be the filename of a records log file. If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair. Each row of this file is an encoded record pair.
Otherwise, it is an iterator. Otherwise, it is an iterator.
default: ConfigEntity, optional
The default config to return when no history records
allow_fallback: bool
Whether allow to use a fallback configuration if cannot find
tuned result.
""" """
def __init__(self, records, default=None, allow_fallback=False): def __init__(self, records):
super(ApplyHistoryBest, self).__init__() super(ApplyHistoryBest, self).__init__()
self.best_by_targetkey = {} self.best_by_targetkey = {}
self.best_by_model = {} self.best_by_model = {}
self._default = default
self._allow_fallback = allow_fallback
self.fallback = {}
if records: if records:
self.load(records) self.load(records)
...@@ -234,7 +253,7 @@ class ApplyHistoryBest(DispatchContext): ...@@ -234,7 +253,7 @@ class ApplyHistoryBest(DispatchContext):
logger.debug("Finish loading %d records", counter) logger.debug("Finish loading %d records", counter)
def query(self, target, workload): def _query_inside(self, target, workload):
if target is None: if target is None:
raise RuntimeError("Need a target context to find the history best. " raise RuntimeError("Need a target context to find the history best. "
"Hint: If your target is llvm, use `with tvm.target.create('llvm'):`" "Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
...@@ -254,20 +273,50 @@ class ApplyHistoryBest(DispatchContext): ...@@ -254,20 +273,50 @@ class ApplyHistoryBest(DispatchContext):
if key in self.best_by_targetkey: if key in self.best_by_targetkey:
return self.best_by_targetkey[key][0].config return self.best_by_targetkey[key][0].config
if self._default: return None
return self._default
class FallbackContext(DispatchContext):
"""
A fallback dispatch context.
Any tunable template can be called under this context.
This is the root context.
"""
def __init__(self):
super(FallbackContext, self).__init__()
self.memory = {}
self.silent = False
def _query_inside(self, target, workload):
key = (str(target), workload)
if key in self.memory:
return self.memory[key]
if self._allow_fallback: if not self.silent:
key = (target, workload)
if key in self.fallback:
return self.fallback[key]
logger.warning( logger.warning(
"Cannot find config for target=%s, workload=%s. A fallback configuration " "Cannot find config for target=%s, workload=%s. A fallback configuration "
"is used, which may bring great performance regression.", target, workload) "is used, which may bring great performance regression.", target, workload)
cfg = ConfigSpace() cfg = FallbackConfigEntity()
self.fallback[key] = cfg
return cfg # cache this config
self.memory[key] = cfg
return cfg
def clear_cache(self, target, workload):
"""Clear fallback cache. Pass the same argument as _query_inside to this function
to clean the cache.
Parameters
----------
target: Target
The current target
workload : Workload
The current workload.
"""
key = (str(target), workload)
if key in self.memory:
del self.memory[key]
raise RuntimeError( DispatchContext.current = FallbackContext()
"Cannot find config for target=%s, workload=%s. You need to do tuning "
"for this workload to get the config." % (target, workload))
...@@ -7,11 +7,10 @@ import warnings ...@@ -7,11 +7,10 @@ import warnings
import logging import logging
from ... import tensor, placeholder, target as _target from ... import tensor, placeholder, create_schedule, target as _target
from ..util import get_const_tuple from ..util import get_const_tuple
from .task import create, register from .task import create, register
from .dispatcher import ApplyHistoryBest
logger = logging.getLogger('autotvm') logger = logging.getLogger('autotvm')
...@@ -56,40 +55,68 @@ class TaskExtractEnv: ...@@ -56,40 +55,68 @@ class TaskExtractEnv:
import topi import topi
import nnvm import nnvm
# NOTE: To add more symbols, you only need to change the following lists
# nnvm symbol -> topi compute
self.symbol2topi = { self.symbol2topi = {
nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw], nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw],
nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose], nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
nnvm.sym.dense: [topi.nn.dense],
} }
# topi compute -> autotvm task name
self.topi_to_task = { self.topi_to_task = {
topi.nn.conv2d: "topi_nn_conv2d", topi.nn.conv2d: "topi_nn_conv2d",
topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw", topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw", topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
topi.nn.dense: "topi_nn_dense",
} }
self._register_dummy() self.topi_to_schedule = {
topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw,
topi.generic.schedule_conv2d_nhwc],
topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw,
topi.generic.schedule_depthwise_conv2d_nhwc],
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
topi.nn.dense: [topi.generic.schedule_dense],
}
self._register_tracing()
self._register_topi_task() self._register_topi_task()
self.task_collection = [] self.task_collection = []
self.wanted_topi_funcs = list(self.topi_to_task.keys())
def _register_tracing(self):
"""Register tracing function to track the topi function call"""
# register topi compute for "tracing" target
for topi_compute in self.topi_to_task:
def _local_scope(compute_func):
"""start a scope to hold the local function in for loop"""
def _register_dummy(self): @compute_func.register("tracing", )
"""Register dummy function to track the topi function call""" def _tracing_topi_compute(*args, **kwargs):
for func in self.topi_to_task:
def _local_scope(local_func):
"""build a scope to holds the function"""
@local_func.register("dummy", )
def _dummy_func(*args, **kwargs):
assert not kwargs, "Do not support extracting tuning tasks when" \ assert not kwargs, "Do not support extracting tuning tasks when" \
"kwargs is used in TOPI function call." \ "kwargs is used in TOPI function call." \
"Please modify it to use only positional args." "Please modify it to use only positional args."
if (self.topi_to_task[local_func], serialize_args(args)) \ if compute_func in self.wanted_topi_funcs: # record this call
not in self.task_collection: key = (self.topi_to_task[compute_func], serialize_args(args))
self.task_collection.append((self.topi_to_task[local_func], if key not in self.task_collection:
serialize_args(args))) self.task_collection.append(key)
with _target.create("opencl"):
return local_func(*args) return compute_func.fdefault(*args)
_local_scope(topi_compute)
# register topi schedule for "tracing" target
for topi_compute in self.topi_to_task:
for topi_schedule in self.topi_to_schedule[topi_compute]:
def _local_scope_(schedule_func):
"""start a scope to hold the local function in for loop"""
_local_scope(func) @schedule_func.register("tracing", )
def _tracing_topi_compute(outs):
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
return create_schedule([x.op for x in outs])
_local_scope_(topi_schedule)
def _register_topi_task(self): def _register_topi_task(self):
"""register tuning wrapper for topi function""" """register tuning wrapper for topi function"""
...@@ -125,17 +152,47 @@ class TaskExtractEnv: ...@@ -125,17 +152,47 @@ class TaskExtractEnv:
s = topi.generic.schedule_conv2d_transpose_nchw([C]) s = topi.generic.schedule_conv2d_transpose_nchw([C])
return s, [A, W, C] return s, [A, W, C]
def reset(self): @register("topi_nn_dense")
"""Reset task collections""" def _topi_nn_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
data, weight, bias = args
C = topi.nn.dense(*args, **kwargs)
s = topi.generic.schedule_dense([C])
if bias is not None:
return s, [data, weight, bias, C]
return s, [data, weight, C]
def reset(self, wanted_topi_funcs):
"""Reset task collections
Parameters
----------
wanted_topi_funcs: List of function
The topi function to be extracted
"""
self.task_collection = [] self.task_collection = []
self.wanted_topi_funcs = wanted_topi_funcs
def get_tasks(self): def get_tasks(self):
"""Get collected tasks""" """Get collected tasks
Returns
-------
tasks: List of tuple(name, args)
A list of tasks extracted from the nnvm graph
"""
return self.task_collection return self.task_collection
@staticmethod @staticmethod
def get(): def get():
"""Get the single instance of TaskExtractEnv""" """Get the single instance of TaskExtractEnv
Returns
-------
env: TaskExtractEnv
The single instance of TaskExtractEnv
"""
if not TaskExtractEnv.current: if not TaskExtractEnv.current:
TaskExtractEnv.current = TaskExtractEnv() TaskExtractEnv.current = TaskExtractEnv()
return TaskExtractEnv.current return TaskExtractEnv.current
...@@ -144,8 +201,8 @@ class TaskExtractEnv: ...@@ -144,8 +201,8 @@ class TaskExtractEnv:
def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
""" Extract tuning tasks from a nnvm graph. """ Extract tuning tasks from a nnvm graph.
This function collects tunning tasks by building the graph This function collects tuning tasks by building the graph
with a "dummy" target and tracing all the calls to topi. with a "tracing" target and tracing all the calls to topi.
Parameters Parameters
---------- ----------
...@@ -158,7 +215,7 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): ...@@ -158,7 +215,7 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
target: tvm.target.Target target: tvm.target.Target
The compilation target The compilation target
symbols : Array of nnvm.symbol symbols : Array of nnvm.symbol
Array of nnvm symbols Array of nnvm symbols want to be tuned
target_host: tvm.target.Target target_host: tvm.target.Target
The host compilation target The host compilation target
...@@ -179,16 +236,16 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): ...@@ -179,16 +236,16 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
warnings.warn("Symbol %s is not tunable, ignored" % sym_name) warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
# run compiler to collect all TOPI calls during compilation # run compiler to collect all TOPI calls during compilation
env.reset() env.reset(topi_funcs)
# disable logger temporarily # disable logger temporarily
old_state = logger.disabled old_state = logger.disabled
logger.disabled = True logger.disabled = True
# use a dummy target to do a fake compile for collecting topi calls # use a "tracing" target to do a fake compile for collecting topi calls
dummy_target = _target.create("opencl -device=dummy") tracing_target = _target.create("llvm -device=tracing")
with ApplyHistoryBest([], allow_fallback=True): nnvm.compiler.engine.clear_cache()
nnvm.compiler.build(graph, target=dummy_target, shape=shape, dtype=dtype) nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)
logger.disabled = old_state logger.disabled = old_state
......
...@@ -567,15 +567,16 @@ class ConfigSpace(object): ...@@ -567,15 +567,16 @@ class ConfigSpace(object):
""" """
def __init__(self): def __init__(self):
# private dict to provide sugar # private dict to provide sugar
self.space_map = OrderedDict() # name -> space self.space_map = OrderedDict() # name -> space
self._collect = True self._collect = True
self._length = None self._length = None
self._entity_map = OrderedDict() self._entity_map = OrderedDict() # name -> entity
self._constraints = [] self._constraints = []
self.errors = [] self.errors = []
self.template_key = None self.template_key = None
self.code_hash = None self.code_hash = None
self.flop = 0 self.flop = 0
self.is_fallback = False
@staticmethod @staticmethod
def axis(var): def axis(var):
...@@ -607,6 +608,15 @@ class ConfigSpace(object): ...@@ -607,6 +608,15 @@ class ConfigSpace(object):
If is 'candidate', try listed candidate. If is 'candidate', try listed candidate.
kwargs: dict kwargs: dict
extra arguments for policy extra arguments for policy
see examples below for how to use filter
Examples
--------
>>> # use custom candidates
>>> cfg.define_split('tile_x', x, policy='candidate', candidate=[[1, 4, 4], [4, 1, 4]])
>>> # use a filter that only accepts the split scheme whose inner most tile is less then 4
>>> cfg.define_split('tile_y', y, policy='all', filter=lambda x: x.size[-1] <= 4)
""" """
axes = [axis] axes = [axis]
return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs) return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs)
...@@ -889,3 +899,45 @@ class ConfigEntity(ConfigSpace): ...@@ -889,3 +899,45 @@ class ConfigEntity(ConfigSpace):
def __repr__(self): def __repr__(self):
return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key, return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key,
self.code_hash, self.index) self.code_hash, self.index)
class FallbackConfigEntity(ConfigSpace):
"""The config entity created to support fallback"""
def __init__(self):
super(FallbackConfigEntity, self).__init__()
self.is_fallback = True
def fallback_split(self, name, constraints):
"""Fallback a split knob
Parameters
----------
name: str
name of the knob
constraints: List of int
The maximum tile size for every dimension. Value `-1` means no constraint.
Examples
--------
If you use cfg.define_split('tile_0', 128, num_outputs=3),
Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [4, 8, 4]
If you use cfg.define_split('tile_0', 49, num_outputs=3),
Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [7, 7, 1]
"""
space = self.space_map[name]
assert len(constraints) == space.num_outputs
indices = np.arange(space.num_outputs)
# '-1' means no constraint
constraints = [x if x != -1 else 1e10 for x in constraints]
for entity in reversed(space.entities):
if all([entity.size[i] <= constraints[i] for i in indices]):
self._entity_map[name] = entity
return
raise RuntimeError("Cannot find feasible fallback split entity for node: " + name)
def __repr__(self):
return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash)
...@@ -206,7 +206,7 @@ def args_to_workload(x): ...@@ -206,7 +206,7 @@ def args_to_workload(x):
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value return x.value
elif x is None: elif x is None:
return None return 0
else: else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use' raise RuntimeError('Do not support type "%s" in argument. Consider to use'
'primitive types only' % type(x)) 'primitive types only' % type(x))
......
...@@ -28,7 +28,7 @@ def _alias(name): ...@@ -28,7 +28,7 @@ def _alias(name):
return table.get(name, name) return table.get(name, name)
def context(target, extra_files=None, allow_fallback=False): def context(target, extra_files=None):
"""Return the dispatch context with pre-tuned parameters. """Return the dispatch context with pre-tuned parameters.
The corresponding downloaded *.log files under tophub root path will be loaded. The corresponding downloaded *.log files under tophub root path will be loaded.
Users can also add their own files in argument `extra_files`. Users can also add their own files in argument `extra_files`.
...@@ -39,12 +39,9 @@ def context(target, extra_files=None, allow_fallback=False): ...@@ -39,12 +39,9 @@ def context(target, extra_files=None, allow_fallback=False):
The compilation target The compilation target
extra_files: list of str, optional extra_files: list of str, optional
Extra log files to load Extra log files to load
allow_fallback: bool
Whether allow to use a fallback configuration if cannot find
tuned result.
""" """
rootpath = AUTOTVM_TOPHUB_ROOT_PATH rootpath = AUTOTVM_TOPHUB_ROOT_PATH
best_context = ApplyHistoryBest([], allow_fallback=allow_fallback) best_context = ApplyHistoryBest([])
if isinstance(target, str): if isinstance(target, str):
target = _target.create(target) target = _target.create(target)
......
...@@ -86,13 +86,9 @@ class GATuner(Tuner): ...@@ -86,13 +86,9 @@ class GATuner(Tuner):
# cross over # cross over
indices = np.arange(len(genes)) indices = np.arange(len(genes))
max_score = np.max(scores) scores += 1e-8
if max_score < 1e-8: scores /= np.max(scores)
probs = np.empty_like(scores) probs = scores / np.sum(scores)
probs[:] = 1.0 / len(scores)
else:
scores /= max_score
probs = scores / np.sum(scores)
tmp_genes = [] tmp_genes = []
for _ in range(self.pop_size): for _ in range(self.pop_size):
p1, p2 = np.random.choice(indices, size=2, replace=False, p=probs) p1, p2 = np.random.choice(indices, size=2, replace=False, p=probs)
......
...@@ -8,7 +8,7 @@ import gc ...@@ -8,7 +8,7 @@ import gc
import numpy as np import numpy as np
from .tuner import Tuner from .tuner import Tuner
from ..env import GLOBAL_SCOPE
class FeatureCache(object): class FeatureCache(object):
"""Feature cache manager for cache sharing between different cost models""" """Feature cache manager for cache sharing between different cost models"""
...@@ -119,11 +119,9 @@ class CostModel(object): ...@@ -119,11 +119,9 @@ class CostModel(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def clone_new(self): def spawn_base_model(self):
"""Clone a new model with the same parameters. """Clone a base model with the same parameters.
This function will only copy hyperparameters of the tuner, not all the trained model The base model is used to fit history data in transfer learning.
This is used for deriving a base model conveniently
Returns Returns
------- -------
...@@ -221,7 +219,9 @@ class ModelBasedTuner(Tuner): ...@@ -221,7 +219,9 @@ class ModelBasedTuner(Tuner):
break break
self.trial_pt += 1 self.trial_pt += 1
if self.trial_pt >= len(self.trials): # trial list is empty, choose randomly if self.trial_pt >= len(self.trials) - int(0.05 * self.plan_size):
# if the trial list is empty or
# the tuner is doing the last 5% trials (e-greedy), choose randomly
index = np.random.randint(len(self.space)) index = np.random.randint(len(self.space))
while index in self.visited: while index in self.visited:
index = np.random.randint(len(self.space)) index = np.random.randint(len(self.space))
...@@ -264,18 +264,16 @@ class ModelBasedTuner(Tuner): ...@@ -264,18 +264,16 @@ class ModelBasedTuner(Tuner):
self.train_ct += 1 self.train_ct += 1
def load_history(self, data_set): def load_history(self, data_set):
# filter data, only pick the data with a same task # set in_tuning as True to make the feature extraction consistent
data = [] GLOBAL_SCOPE.in_tuning = True
for inp, res in data_set:
if inp.task.name == self.task.name and \
inp.config.template_key == self.task.config_space.template_key:
data.append((inp, res))
if not data:
return
# fit base model # fit base model
base_model = self.cost_model.clone_new() base_model = self.cost_model.spawn_base_model()
base_model.fit_log(data, self.plan_size) success = base_model.fit_log(data_set, self.plan_size)
if not success:
GLOBAL_SCOPE.in_tuning = False
return
# use base model to select initial points # use base model to select initial points
if not self.trials: if not self.trials:
...@@ -285,6 +283,7 @@ class ModelBasedTuner(Tuner): ...@@ -285,6 +283,7 @@ class ModelBasedTuner(Tuner):
self.trial_pt = 0 self.trial_pt = 0
self.cost_model.load_basemodel(base_model) self.cost_model.load_basemodel(base_model)
GLOBAL_SCOPE.in_tuning = False
def has_next(self): def has_next(self):
return len(self.visited) < len(self.space) return len(self.visited) < len(self.space)
......
...@@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer): ...@@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
new_scores = model.predict(new_points) new_scores = model.predict(new_points)
ac_prob = np.exp((new_scores - scores) / t) ac_prob = np.exp((new_scores - scores) / (t + 1e-2))
ac_index = np.random.random(len(ac_prob)) < ac_prob ac_index = np.random.random(len(ac_prob)) < ac_prob
points[ac_index] = new_points[ac_index] points[ac_index] = new_points[ac_index]
......
...@@ -31,6 +31,10 @@ class Tuner(object): ...@@ -31,6 +31,10 @@ class Tuner(object):
self.best_measure_pair = None self.best_measure_pair = None
self.best_iter = 0 self.best_iter = 0
# time to leave
self.ttl = None
self.n_trial = None
def has_next(self): def has_next(self):
"""Whether has next untried config in the space """Whether has next untried config in the space
...@@ -76,7 +80,7 @@ class Tuner(object): ...@@ -76,7 +80,7 @@ class Tuner(object):
measure_option: dict measure_option: dict
The options for how to measure generated code. The options for how to measure generated code.
You should use the return value ot autotvm.measure_option for this argument. You should use the return value ot autotvm.measure_option for this argument.
early_stopping: int early_stopping: int, optional
Early stop the tuning when not finding better configs in this number of trials Early stop the tuning when not finding better configs in this number of trials
callbacks: List of callable callbacks: List of callable
A list of callback functions. The signature of callback function is A list of callback functions. The signature of callback function is
...@@ -87,6 +91,8 @@ class Tuner(object): ...@@ -87,6 +91,8 @@ class Tuner(object):
measure_batch = create_measure_batch(self.task, measure_option) measure_batch = create_measure_batch(self.task, measure_option)
n_parallel = getattr(measure_batch, 'n_parallel', 1) n_parallel = getattr(measure_batch, 'n_parallel', 1)
early_stopping = early_stopping or 1e9 early_stopping = early_stopping or 1e9
self.n_trial = n_trial
old_level = logger.level old_level = logger.level
GLOBAL_SCOPE.in_tuning = True GLOBAL_SCOPE.in_tuning = True
...@@ -127,11 +133,12 @@ class Tuner(object): ...@@ -127,11 +133,12 @@ class Tuner(object):
for callback in callbacks: for callback in callbacks:
callback(self, inputs, results) callback(self, inputs, results)
if i > self.best_iter + early_stopping: self.ttl = min(early_stopping + self.best_iter, n_trial) - i
if i >= self.best_iter + early_stopping:
logger.debug("Early stopped. Best iter: %d.", self.best_iter) logger.debug("Early stopped. Best iter: %d.", self.best_iter)
break break
if error_ct > 50: if error_ct > 150:
logger.warning("Too many errors happen in the tuning. Now is in debug mode") logger.warning("Too many errors happen in the tuning. Now is in debug mode")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
else: else:
......
...@@ -31,8 +31,12 @@ class XGBoostCostModel(CostModel): ...@@ -31,8 +31,12 @@ class XGBoostCostModel(CostModel):
If is 'curve', use sampled curve feature (relation feature). If is 'curve', use sampled curve feature (relation feature).
Note on choosing feature type: Note on choosing feature type:
For single task tuning, 'itervar' and 'knob' is good. For single task tuning, 'itervar' and 'knob' are good.
'itervar' is more accurate but 'knob' is much faster. 'itervar' is more accurate but 'knob' is much faster.
There are some constraints on 'itervar', if you meet
problems with feature extraction when using 'itervar',
you can swith to 'knob'.
For cross-shape tuning (e.g. many convolutions with different shapes), For cross-shape tuning (e.g. many convolutions with different shapes),
'itervar' and 'curve' has better transferability, 'itervar' and 'curve' has better transferability,
'knob' is faster. 'knob' is faster.
...@@ -46,8 +50,11 @@ class XGBoostCostModel(CostModel): ...@@ -46,8 +50,11 @@ class XGBoostCostModel(CostModel):
The number of threads. The number of threads.
log_interval: int, optional log_interval: int, optional
If is not none, the cost model will print training log every `log_interval` iterations. If is not none, the cost model will print training log every `log_interval` iterations.
upper_model: XGBoostCostModel, optional
The upper model used in transfer learning
""" """
def __init__(self, task, feature_type, loss_type, num_threads=None, log_interval=25): def __init__(self, task, feature_type, loss_type, num_threads=4, log_interval=25,
upper_model=None):
super(XGBoostCostModel, self).__init__() super(XGBoostCostModel, self).__init__()
if xgb is None: if xgb is None:
...@@ -109,35 +116,51 @@ class XGBoostCostModel(CostModel): ...@@ -109,35 +116,51 @@ class XGBoostCostModel(CostModel):
else: else:
raise RuntimeError("Invalid feature type " + feature_type) raise RuntimeError("Invalid feature type " + feature_type)
self.feature_cache = FeatureCache() if upper_model: # share a same feature cache with upper model
self.feature_cache = upper_model.feature_cache
else:
self.feature_cache = FeatureCache()
self.upper_model = upper_model
self.feature_extra_ct = 0 self.feature_extra_ct = 0
self.pool = None self.pool = None
self.base_model = None self.base_model = None
self.upper_model = None
self._sample_size = 0 self._sample_size = 0
self._reset_pool(self.space, self.target, self.task)
self._reset_pool() def _reset_pool(self, space, target, task):
"""reset processing pool for feature extraction"""
if self.upper_model: # base model will reuse upper model's pool,
self.upper_model._reset_pool(space, target, task)
return
self._close_pool()
def _reset_pool(self):
# reset processing pool for feature extraction
if self.pool:
self.pool.terminate()
self.pool.join()
del self.pool
# use global variable to pass common arguments # use global variable to pass common arguments
global _extract_space, _extract_target, _extract_task global _extract_space, _extract_target, _extract_task
_extract_space = self.space _extract_space = space
_extract_target = self.target _extract_target = target
_extract_task = self.task _extract_task = task
self.pool = multiprocessing.Pool(self.num_threads) self.pool = multiprocessing.Pool(self.num_threads)
def _close_pool(self):
if self.pool:
self.pool.terminate()
self.pool.join()
self.pool = None
def _get_pool(self):
if self.upper_model:
return self.upper_model._get_pool()
return self.pool
def _base_model_discount(self): def _base_model_discount(self):
return 1.0 / (2 ** (self._sample_size / 50.0)) return 1.0 / (2 ** (self._sample_size / 64.0))
def fit(self, xs, ys, plan_size): def fit(self, xs, ys, plan_size):
tic = time.time() tic = time.time()
self._reset_pool() self._reset_pool(self.space, self.target, self.task)
x_train = self._get_feature(xs) x_train = self._get_feature(xs)
y_train = np.array(ys) y_train = np.array(ys)
...@@ -150,8 +173,12 @@ class XGBoostCostModel(CostModel): ...@@ -150,8 +173,12 @@ class XGBoostCostModel(CostModel):
self._sample_size = len(x_train) self._sample_size = len(x_train)
if self.base_model: if self.base_model:
dtrain.set_base_margin(self._base_model_discount() * discount = self._base_model_discount()
self.base_model.predict(xs, output_margin=True)) if discount < 0.05: # discard base model
self.base_model.upper_model = None
self.base_model = None
else:
dtrain.set_base_margin(discount * self.base_model.predict(xs, output_margin=True))
self.bst = xgb.train(self.xgb_params, dtrain, self.bst = xgb.train(self.xgb_params, dtrain,
num_boost_round=8000, num_boost_round=8000,
...@@ -172,11 +199,19 @@ class XGBoostCostModel(CostModel): ...@@ -172,11 +199,19 @@ class XGBoostCostModel(CostModel):
def fit_log(self, records, plan_size): def fit_log(self, records, plan_size):
tic = time.time() tic = time.time()
self._reset_pool()
args = list(records) # filter data, only pick the data with a same task
logger.debug("XGB load %d entries from history log file", len(args)) data = []
for inp, res in records:
if inp.task.name == self.task.name and \
inp.config.template_key == self.task.config_space.template_key:
data.append((inp, res))
logger.debug("XGB load %d entries from history log file", len(data))
# extract feature
self._reset_pool(self.space, self.target, self.task)
pool = self._get_pool()
if self.fea_type == 'itervar': if self.fea_type == 'itervar':
feature_extract_func = _extract_itervar_feature_log feature_extract_func = _extract_itervar_feature_log
elif self.fea_type == 'knob': elif self.fea_type == 'knob':
...@@ -185,10 +220,21 @@ class XGBoostCostModel(CostModel): ...@@ -185,10 +220,21 @@ class XGBoostCostModel(CostModel):
feature_extract_func = _extract_curve_feature_log feature_extract_func = _extract_curve_feature_log
else: else:
raise RuntimeError("Invalid feature type: " + self.fea_type) raise RuntimeError("Invalid feature type: " + self.fea_type)
res = self.pool.map(feature_extract_func, args) res = pool.map(feature_extract_func, data)
xs, ys = zip(*res)
xs, ys = np.array(xs), np.array(ys) # filter out feature with different shapes
fea_len = len(self._get_feature([0])[0])
xs, ys = [], []
for x, y in res:
if len(x) == fea_len:
xs.append(x)
ys.append(y)
if len(xs) < 500: # no enough samples
return False
xs, ys = np.array(xs), np.array(ys)
x_train = xs x_train = xs
y_train = ys y_train = ys
y_max = np.max(y_train) y_max = np.max(y_train)
...@@ -212,6 +258,8 @@ class XGBoostCostModel(CostModel): ...@@ -212,6 +258,8 @@ class XGBoostCostModel(CostModel):
logger.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs)) logger.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs))
return True
def predict(self, xs, output_margin=False): def predict(self, xs, output_margin=False):
feas = self._get_feature(xs) feas = self._get_feature(xs)
dtest = xgb.DMatrix(feas) dtest = xgb.DMatrix(feas)
...@@ -224,20 +272,12 @@ class XGBoostCostModel(CostModel): ...@@ -224,20 +272,12 @@ class XGBoostCostModel(CostModel):
def load_basemodel(self, base_model): def load_basemodel(self, base_model):
self.base_model = base_model self.base_model = base_model
if isinstance(base_model, XGBoostCostModel): self.base_model._close_pool()
# share feature cache self.base_model.upper_model = self
base_model.feature_cache = self.feature_cache
def spawn_base_model(self):
# close thread pool
if base_model.pool:
base_model.pool.terminate()
base_model.pool.join()
del base_model.pool
self.base_model.upper_model = self
def clone_new(self):
return XGBoostCostModel(self.task, self.fea_type, self.loss_type, return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
self.num_threads, self.log_interval) self.num_threads, self.log_interval, self)
def _get_feature(self, indexes): def _get_feature(self, indexes):
"""get features for indexes, run extraction if we do not have cache for them""" """get features for indexes, run extraction if we do not have cache for them"""
...@@ -251,7 +291,7 @@ class XGBoostCostModel(CostModel): ...@@ -251,7 +291,7 @@ class XGBoostCostModel(CostModel):
need_extract = [x for x in indexes if x not in fea_cache] need_extract = [x for x in indexes if x not in fea_cache]
if need_extract: if need_extract:
pool = self.pool if self.upper_model is None else self.upper_model.pool pool = self._get_pool()
feas = pool.map(self.feature_extract_func, need_extract) feas = pool.map(self.feature_extract_func, need_extract)
for i, fea in zip(need_extract, feas): for i, fea in zip(need_extract, feas):
fea_cache[i] = fea fea_cache[i] = fea
...@@ -261,6 +301,9 @@ class XGBoostCostModel(CostModel): ...@@ -261,6 +301,9 @@ class XGBoostCostModel(CostModel):
ret[i, :] = fea_cache[ii] ret[i, :] = fea_cache[ii]
return ret return ret
def __del__(self):
self._close_pool()
_extract_space = None _extract_space = None
_extract_target = None _extract_target = None
......
...@@ -20,8 +20,12 @@ class XGBTuner(ModelBasedTuner): ...@@ -20,8 +20,12 @@ class XGBTuner(ModelBasedTuner):
If is 'curve', use sampled curve feature (relation feature). If is 'curve', use sampled curve feature (relation feature).
Note on choosing feature type: Note on choosing feature type:
For single task tuning, 'itervar' and 'knob' is good. For single task tuning, 'itervar' and 'knob' are good.
'itervar' is more accurate but 'knob' is much faster. 'itervar' is more accurate but 'knob' is much faster.
There are some constraints on 'itervar', if you meet
problems with feature extraction when using 'itervar',
you can swith to 'knob'.
For cross-shape tuning (e.g. many convolutions with different shapes), For cross-shape tuning (e.g. many convolutions with different shapes),
'itervar' and 'curve' has better transferability, 'itervar' and 'curve' has better transferability,
'knob' is faster. 'knob' is faster.
...@@ -32,8 +36,7 @@ class XGBTuner(ModelBasedTuner): ...@@ -32,8 +36,7 @@ class XGBTuner(ModelBasedTuner):
If is 'rank', use pairwise rank loss to train cost model. If is 'rank', use pairwise rank loss to train cost model.
The cost model predicts relative rank score. The cost model predicts relative rank score.
num_threads: int, optional num_threads: int, optional
The number of threads. The number of threads. optimizer: str or ModelOptimizer, optional
optimizer: str or ModelOptimizer, optional
If is 'sa', use a default simulated annealing optimizer. If is 'sa', use a default simulated annealing optimizer.
Otherwise it should be a ModelOptimizer object. Otherwise it should be a ModelOptimizer object.
diversity_filter_ratio: int or float, optional diversity_filter_ratio: int or float, optional
...@@ -45,7 +48,7 @@ class XGBTuner(ModelBasedTuner): ...@@ -45,7 +48,7 @@ class XGBTuner(ModelBasedTuner):
If is 0, output nothing. If is 0, output nothing.
Otherwise, output debug information every `verbose` iterations. Otherwise, output debug information every `verbose` iterations.
""" """
def __init__(self, task, plan_size=32, def __init__(self, task, plan_size=64,
feature_type='itervar', loss_type='rank', num_threads=None, feature_type='itervar', loss_type='rank', num_threads=None,
optimizer='sa', diversity_filter_ratio=None, log_interval=50): optimizer='sa', diversity_filter_ratio=None, log_interval=50):
cost_model = XGBoostCostModel(task, cost_model = XGBoostCostModel(task,
...@@ -62,3 +65,9 @@ class XGBTuner(ModelBasedTuner): ...@@ -62,3 +65,9 @@ class XGBTuner(ModelBasedTuner):
super(XGBTuner, self).__init__(task, cost_model, optimizer, super(XGBTuner, self).__init__(task, cost_model, optimizer,
plan_size, diversity_filter_ratio) plan_size, diversity_filter_ratio)
def tune(self, *args, **kwargs): # pylint: disable=arguments-differ
super(XGBTuner, self).tune(*args, **kwargs)
# manually close pool to avoid multiprocessing issues
self.cost_model._close_pool()
...@@ -8,8 +8,8 @@ from ..autotvm.tophub import list_packages, download_package ...@@ -8,8 +8,8 @@ from ..autotvm.tophub import list_packages, download_package
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--download", type=str, nargs='+', parser.add_argument("-d", "--download", type=str, nargs='+',
help="Target to download. Use 'all' to download for all targets") help="The targets to download. Use 'all' to download for all targets")
parser.add_argument("-l", "--list", action='store_true', help="List available packages") parser.add_argument("-l", "--list", action='store_true', help="List available packages")
args = parser.parse_args() args = parser.parse_args()
...@@ -21,8 +21,7 @@ if __name__ == '__main__': ...@@ -21,8 +21,7 @@ if __name__ == '__main__':
print("-" * 41) print("-" * 41)
for target, info in info: for target, info in info:
print("%-20s %-20s" % (target, "%.2f MB" % (info['size']/1000000))) print("%-20s %-20s" % (target, "%.2f MB" % (info['size']/1000000)))
elif args.download:
if args.download:
info = list_packages() info = list_packages()
all_targets = [x[0] for x in info] all_targets = [x[0] for x in info]
if 'all' in args.download: if 'all' in args.download:
...@@ -34,3 +33,5 @@ if __name__ == '__main__': ...@@ -34,3 +33,5 @@ if __name__ == '__main__':
if t not in all_targets: if t not in all_targets:
print("Warning : cannot find tuned parameters of " + t + ". (ignored)") print("Warning : cannot find tuned parameters of " + t + ". (ignored)")
download_package(t) download_package(t)
else:
parser.print_help()
...@@ -263,6 +263,7 @@ def override_native_generic_func(func_name): ...@@ -263,6 +263,7 @@ def override_native_generic_func(func_name):
"Keyword arguments cannot be used when invoking generic_func %s" % func_name) "Keyword arguments cannot be used when invoking generic_func %s" % func_name)
return generic_func_node(*args) return generic_func_node(*args)
fresult = decorate(fdefault, dispatch_func) fresult = decorate(fdefault, dispatch_func)
fresult.fdefault = fdefault
fresult.register = register fresult.register = register
return fresult return fresult
return fdecorate return fdecorate
......
...@@ -3,34 +3,48 @@ The dispatcher can choose which template to use according ...@@ -3,34 +3,48 @@ The dispatcher can choose which template to use according
to the parameters of workload""" to the parameters of workload"""
from collections import namedtuple from collections import namedtuple
from tvm import autotvm
from tvm.autotvm.task import dispatcher, DispatchContext from tvm.autotvm.task import dispatcher, DispatchContext
SimpleWorkload = namedtuple("SimpleWorkload", ["key"]) SimpleConfig = namedtuple('SimpleConfig', ('template_key', 'is_fallback'))
SimpleConfig = namedtuple("SimpleConfig", ["template_key"])
def test_dispatch(): def test_dispatch():
@dispatcher @dispatcher
def my_dispatcher(a, b): def my_dispatcher(a, b):
return SimpleWorkload(key=a + b) return (a, b)
@my_dispatcher.register("spatial_pack")
def _sp_pack_add(cfg, a, b):
return b + 100
@my_dispatcher.register("im2col") @my_dispatcher.register("im2col")
def _im2col_add(cfg, a, b): def _im2col(cfg, a, b):
return a + 1 return a
@my_dispatcher.register("spatial_pack")
def _spatial_pack(cfg, a, b):
return b
class SimpleDispatcher(DispatchContext): class SimpleDispatcher(DispatchContext):
def query(self, target, workload): def query(self, target, workload):
tkey = "spatial_pack" if workload.key > 2 else "im2col" a, b = workload
return SimpleConfig(tkey) tkey = "spatial_pack" if a + b > 2 else "im2col"
cfg = SimpleConfig(tkey, False)
return cfg
with SimpleDispatcher(): with SimpleDispatcher():
# im2col # this will call im2col
assert my_dispatcher(1, 0) == 2 assert my_dispatcher(1, 0) == 1
# spack
assert my_dispatcher(1, 100) == 200 # this will call spatial pack
assert my_dispatcher(1, 100) == 100
def test_fallback():
@autotvm.template
def simple_template(a, b):
cfg = autotvm.get_config()
assert cfg.is_fallback
simple_template(2, 3)
if __name__ == "__main__": if __name__ == "__main__":
test_dispatch() test_dispatch()
test_fallback()
"""Test space definition primitives""" """Test space definition primitives"""
import tvm import tvm
from tvm.autotvm.task.space import ConfigSpace from tvm.autotvm.task.space import ConfigSpace, FallbackConfigEntity
def gemm_func(cfg, N): def gemm_func(cfg, N):
A = tvm.placeholder((N, N), name='A') A = tvm.placeholder((N, N), name='A')
...@@ -26,5 +26,18 @@ def test_split(): ...@@ -26,5 +26,18 @@ def test_split():
assert len(cfg) == 64 assert len(cfg) == 64
assert len(cfg.space_map['tile_y']) == 8 assert len(cfg.space_map['tile_y']) == 8
# test fallback
cfg = FallbackConfigEntity()
cfg.define_split('tile_n', cfg.axis(128), num_outputs=3)
cfg.fallback_split('tile_n', [-1, 8, 4])
assert cfg['tile_n'].size == [4, 8, 4]
cfg = FallbackConfigEntity()
cfg.define_split('tile_n', cfg.axis(49), num_outputs=3)
cfg.fallback_split('tile_n', [-1, 8, 4])
assert cfg['tile_n'].size == [7, 7, 1]
if __name__ == '__main__': if __name__ == '__main__':
test_split() test_split()
...@@ -12,7 +12,7 @@ from test_autotvm_common import get_sample_task, get_sample_records ...@@ -12,7 +12,7 @@ from test_autotvm_common import get_sample_task, get_sample_records
def test_fit(): def test_fit():
task, target = get_sample_task() task, target = get_sample_task()
records = get_sample_records(n=100) records = get_sample_records(n=500)
base_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank') base_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
base_model.fit_log(records, plan_size=32) base_model.fit_log(records, plan_size=32)
...@@ -20,8 +20,8 @@ def test_fit(): ...@@ -20,8 +20,8 @@ def test_fit():
upper_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank') upper_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
upper_model.load_basemodel(base_model) upper_model.load_basemodel(base_model)
xs = np.arange(100) xs = np.arange(10)
ys = np.arange(100) ys = np.arange(10)
upper_model.fit(xs, ys, plan_size=32) upper_model.fit(xs, ys, plan_size=32)
......
...@@ -27,7 +27,14 @@ def _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype): ...@@ -27,7 +27,14 @@ def _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype):
@autotvm.task.dispatcher @autotvm.task.dispatcher
def conv2d_arm_cpu(data, kernel, strides, padding, layout, out_dtype): def conv2d_arm_cpu(data, kernel, strides, padding, layout, out_dtype):
"""TOPI compute callback. Mark this function as a dispatcher, so """TOPI compute callback. Mark this function as a dispatcher, so
this template can assign config according to workload""" this template can assign config according to workload
Returns
-------
workload: Tuple
Dispatcher will use this workload to query corresponding config.
Then use cfg.template_key to call a registered template.
"""
return _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype) return _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype)
@conv2d_arm_cpu.register(['direct']) @conv2d_arm_cpu.register(['direct'])
...@@ -70,8 +77,10 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs): ...@@ -70,8 +77,10 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile): def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile):
assert layout == "NCHW", "Only support NCHW" assert layout == "NCHW", "Only support NCHW"
out_dtype = out_dtype or data.dtype # create workload according to raw arguments
wkl = _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype)
out_dtype = out_dtype or data.dtype
N, CI, IH, IW = get_const_tuple(data.shape) N, CI, IH, IW = get_const_tuple(data.shape)
if len(kernel.shape) == 4: if len(kernel.shape) == 4:
pre_packed = False pre_packed = False
...@@ -113,6 +122,18 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n ...@@ -113,6 +122,18 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec') cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
# ==================================================================== # ====================================================================
if cfg.is_fallback:
if num_tile == 2:
cfg.fallback_split('tile_co', [-1, 8])
cfg.fallback_split('tile_oh', [-1, 2])
cfg.fallback_split('tile_ow', [-1, 8])
else:
cfg.fallback_split('tile_co', [-1, 16, 4])
cfg.fallback_split('tile_oh', [-1, 1, 1])
cfg.fallback_split('tile_ow', [-1, 1, 4])
cfg['ann_reduce'].anns = ['unroll', 'unroll']
cfg['ann_spatial'].anns = ['none', 'unroll', 'vec']
VC = cfg["tile_co"].size[-1] VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1] VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1] VW = cfg["tile_ow"].size[-1]
...@@ -145,8 +166,7 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n ...@@ -145,8 +166,7 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
output = tvm.compute(oshape, lambda n, co, h, w: output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC], conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
name='output_unpack', tag='spatial_conv2d_output', name='output_unpack', tag='spatial_conv2d_output',
attrs={'workload': _conv_arg_to_workload(data, kernel, strides, padding, attrs={'workload': wkl})
layout, out_dtype)})
return output return output
def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
...@@ -212,6 +232,10 @@ def decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype): ...@@ -212,6 +232,10 @@ def decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype):
return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size) return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size)
def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size): def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
# create workload according to raw arguments
wkl = _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout,
out_dtype, tile_size)
N, CI, IH, IW = get_const_tuple(data.shape) N, CI, IH, IW = get_const_tuple(data.shape)
if len(kernel.shape) == 4: if len(kernel.shape) == 4:
pre_computed = False pre_computed = False
...@@ -333,10 +357,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_ ...@@ -333,10 +357,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_
output = tvm.compute((N, K, H, W), lambda n, k, h, w: output = tvm.compute((N, K, H, W), lambda n, k, h, w:
Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m], Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m],
name='output', tag='winograd_conv2d_output', name='output', tag='winograd_conv2d_output',
attrs={'workload': _winograd_conv_arg_to_workload( attrs={'workload': wkl})
data, kernel, strides, padding, layout, out_dtype, tile_size)})
# we have to manually assign effective GFLOP for winogard # we have to manually assign effective GFLOP for winograd
cfg.add_flop(2 * N * K * H * W * KH * KW * C) cfg.add_flop(2 * N * K * H * W * KH * KW * C)
return output return output
...@@ -358,30 +381,29 @@ def _schedule_winograd(cfg, s, output, last): ...@@ -358,30 +381,29 @@ def _schedule_winograd(cfg, s, output, last):
kernel, G = U.op.input_tensors kernel, G = U.op.input_tensors
s[G].compute_inline() s[G].compute_inline()
eps, nu, k, c, kk, = s[U].op.axis eps, nu, k, c, kk, = s[U].op.axis
r_kh, r_kw = s[U].op.reduce_axis
s[U].reorder(k, c, eps, nu, r_kh, r_kw, kk)
s[U].unroll(eps)
s[U].unroll(nu)
s[U].unroll(r_kh)
s[U].unroll(r_kw)
s[U].vectorize(kk)
if autotvm.GLOBAL_SCOPE.in_tuning: if autotvm.GLOBAL_SCOPE.in_tuning:
# kernel transformation will be pre-computed during compilation, so we skip # kernel transformation will be pre-computed during compilation, so we skip
# this part to make tuning records correct # this part to make tuning records correct
s[U].pragma(k, 'debug_skip_region') s[U].pragma(eps, 'debug_skip_region')
else: else:
r_kh, r_kw = s[U].op.reduce_axis
s[U].reorder(k, c, eps, nu, r_kh, r_kw, kk)
for axis in [eps, nu, r_kh, r_kw]:
s[U].unroll(axis)
s[U].vectorize(kk)
s[U].parallel(k) s[U].parallel(k)
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
# transform image # transform image
DD = s.cache_read(d, 'global', [V]) DD = s.cache_read(d, 'global', [V])
s[B].compute_inline() s[B].compute_inline()
eps, nu, b, c, bb = s[V].op.axis eps, nu, b, c, bb = s[V].op.axis
r_eps, r_nu = s[V].op.reduce_axis r_eps, r_nu = s[V].op.reduce_axis
s[V].reorder(b, c, eps, nu, r_eps, r_nu, bb) s[V].reorder(b, c, eps, nu, r_eps, r_nu, bb)
s[V].unroll(eps) for axis in [eps, nu, r_eps, r_nu]:
s[V].unroll(nu) s[V].unroll(axis)
s[V].unroll(r_eps)
s[V].unroll(r_nu)
s[DD].compute_at(s[V], c) s[DD].compute_at(s[V], c)
s[V].vectorize(bb) s[V].vectorize(bb)
s[V].parallel(b) s[V].parallel(b)
...@@ -405,10 +427,8 @@ def _schedule_winograd(cfg, s, output, last): ...@@ -405,10 +427,8 @@ def _schedule_winograd(cfg, s, output, last):
s[A].compute_inline() s[A].compute_inline()
k, b, vh, vw = s[Y].op.axis k, b, vh, vw = s[Y].op.axis
r_eps, r_nu = s[Y].op.reduce_axis r_eps, r_nu = s[Y].op.reduce_axis
s[Y].unroll(vh) for axis in [vh, vw, r_eps, r_nu]:
s[Y].unroll(vw) s[Y].unroll(axis)
s[Y].unroll(r_eps)
s[Y].unroll(r_nu)
# output # output
n, co, h, w = s[last].op.axis n, co, h, w = s[last].op.axis
...@@ -444,6 +464,7 @@ def _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout, out_d ...@@ -444,6 +464,7 @@ def _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout, out_d
[data, raw_kernel, strides, padding, layout, out_dtype]) [data, raw_kernel, strides, padding, layout, out_dtype])
##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
@conv2d_winograd_without_weight_transform.register(['arm_cpu']) @conv2d_winograd_without_weight_transform.register(['arm_cpu'])
@autotvm.task.dispatcher @autotvm.task.dispatcher
def winograd_ww_config_dispatcher_(data, kernel, strides, padding, layout, out_dtype, tile_size): def winograd_ww_config_dispatcher_(data, kernel, strides, padding, layout, out_dtype, tile_size):
...@@ -472,6 +493,7 @@ def schedule_conv2d_winograd_without_weight_transform_(cfg, outs): ...@@ -472,6 +493,7 @@ def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
return s return s
##### REGISTER ALTER OP LAYOUT #####
@conv2d_alter_layout.register(["arm_cpu", "mali"]) @conv2d_alter_layout.register(["arm_cpu", "mali"])
def _alter_conv2d_layout(attrs, inputs, tinfos): def _alter_conv2d_layout(attrs, inputs, tinfos):
"""Alter op layout for pre-computing kernel transformation""" """Alter op layout for pre-computing kernel transformation"""
...@@ -493,18 +515,30 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -493,18 +515,30 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
# query config of this workload # query config of this workload
workload = _conv_arg_to_workload(tinfos[0], tinfos[1], strides, padding, workload = _conv_arg_to_workload(tinfos[0], tinfos[1], strides, padding,
layout, out_dtype) layout, out_dtype)
cfg = autotvm.task.DispatchContext.current.query(tvm.target.current_target(), workload) cfg = autotvm.DispatchContext.current.query(tvm.target.current_target(), workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None
context = autotvm.DispatchContext.current
while not isinstance(context, autotvm.FallbackContext):
context = context._old_ctx
context.clear_cache(tvm.target.current_target(), workload)
return None
if cfg.template_key == 'direct': # packing weight tensor if cfg.template_key == 'direct': # packing weight tensor
new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1]) new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])
return sym.conv2d(*copy_inputs, **new_attrs) return sym.conv2d(*copy_inputs, **new_attrs)
else: # pre-compute weight transformation in winograd else: # pre-compute weight transformation in winograd
tile_size = 4 if "-device=arm_cpu" in tvm.target.current_target().options:
tile_size = 4
VC = cfg['tile_k'].size[-1]
else:
from ..mali.conv2d import _pick_tile_size
tile_size = _pick_tile_size(tinfos[0], tinfos[1])
VC = cfg['tile_bna'].val
weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1], weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1],
tile_size=tile_size) tile_size=tile_size)
CO, CI, KH, KW = get_const_tuple(tinfos[1].shape) CO, CI, KH, KW = get_const_tuple(tinfos[1].shape)
VC = cfg['tile_k'].size[-1]
weight = sym.reshape(weight, weight = sym.reshape(weight,
shape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI)) shape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
weight = sym.transpose(weight, axes=[0, 1, 2, 4, 3]) weight = sym.transpose(weight, axes=[0, 1, 2, 4, 3])
......
...@@ -14,16 +14,21 @@ autotvm.task.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct', ...@@ -14,16 +14,21 @@ autotvm.task.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct',
# register customized schedule for arm cpu. # register customized schedule for arm cpu.
@autotvm.task.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu', 'direct') @autotvm.task.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu', 'direct')
def schedule_depthwise_conv2d_nchw_(cfg, outs): def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
"""Schedule depthwise conv2d """Schedule depthwise conv2d
Parameters Parameters
---------- ----------
cfg: ConfigEntity cfg: ConfigEntity
The configuration of this tempalte The configuration of this template
outs: Array of Tensor outs: Array of Tensor
The computation graph description of depthwise convolution2d The computation graph description of depthwise convolution2d
in the format of an array of tensors. in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for depthwise_conv2d nchw.
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
...@@ -38,6 +43,11 @@ def schedule_depthwise_conv2d_nchw_(cfg, outs): ...@@ -38,6 +43,11 @@ def schedule_depthwise_conv2d_nchw_(cfg, outs):
cfg.define_split('tile_h', h, num_outputs=2) cfg.define_split('tile_h', h, num_outputs=2)
cfg.define_split('tile_w', w, num_outputs=2) cfg.define_split('tile_w', w, num_outputs=2)
if cfg.is_fallback:
cfg.fallback_split('tile_c', [-1, 8])
cfg.fallback_split('tile_h', [-1, 2])
cfg.fallback_split('tile_w', [-1, 8])
# park data to vector form [n, c, h, w] -> [n, C, h, w, VC] # park data to vector form [n, c, h, w] -> [n, C, h, w, VC]
A0 = s.cache_read(data_pad, "global", C) A0 = s.cache_read(data_pad, "global", C)
_, c, h, w = s[A0].op.axis _, c, h, w = s[A0].op.axis
......
...@@ -29,7 +29,7 @@ def schedule_injective(outs): ...@@ -29,7 +29,7 @@ def schedule_injective(outs):
elif len(s[x].op.axis) >= 3: elif len(s[x].op.axis) >= 3:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1]) fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
s[x].parallel(fused) s[x].parallel(fused)
else: elif len(s[x].op.axis) >= 1:
s[x].parallel(s[x].op.axis[0]) s[x].parallel(s[x].op.axis[0])
return s return s
......
"""Common utility for topi test"""
def get_all_backend():
"""return all supported target
Returns
-------
targets: list
A list of all supported targets
"""
return ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx',
'llvm -device=arm_cpu']
import os
import numpy as np import numpy as np
import tvm import tvm
import topi import topi
import topi.testing import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple from topi.util import get_const_tuple
from tvm.contrib import util
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
def generate_quantized_np(shape, bits, out_dtype): def generate_quantized_np(shape, bits, out_dtype):
...@@ -16,23 +13,23 @@ def generate_quantized_np(shape, bits, out_dtype): ...@@ -16,23 +13,23 @@ def generate_quantized_np(shape, bits, out_dtype):
def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding, def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa): activation_bits, weight_bits, dorefa):
in_height = in_width = in_size in_height = in_width = in_size
input_type='uint32' input_type = 'uint32'
out_dtype='int32' out_dtype = 'int32'
with tvm.target.create('llvm'): with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_type, name='A') A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_type, name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_type, name='W') W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits,
out_dtype=out_dtype, layout="NCHW", dorefa=dorefa) out_dtype=out_dtype, layout="NCHW", dorefa=dorefa)
s = topi.generic.schedule_bitserial_conv2d_nchw([B]) s = topi.generic.schedule_bitserial_conv2d_nchw([B])
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_bitseral_conv2d_nchw")
def get_ref_data(): def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type) a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type)
w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type) w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type)
if dorefa: if dorefa:
w_ = np.copy(w_np).astype(out_dtype) w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']): for x in np.nditer(w_, op_flags=['readwrite']):
...@@ -61,16 +58,16 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, ...@@ -61,16 +58,16 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A') A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W') W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype, B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype,
layout="NHWC", dorefa=dorefa) layout="NHWC", dorefa=dorefa)
s = topi.generic.schedule_bitserial_conv2d_nhwc([B]) s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_bitseral_conv2d_nhwc")
def get_ref_data(): def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type) a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type)
w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type) w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type)
if dorefa: if dorefa:
w_ = np.copy(w_np).astype(out_dtype) w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']): for x in np.nditer(w_, op_flags=['readwrite']):
...@@ -109,4 +106,4 @@ def test_bitserial_conv2d(): ...@@ -109,4 +106,4 @@ def test_bitserial_conv2d():
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 2, False) verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 2, False)
if __name__ == "__main__": if __name__ == "__main__":
test_bitserial_conv2d() test_bitserial_conv2d()
\ No newline at end of file
...@@ -4,10 +4,6 @@ import numpy as np ...@@ -4,10 +4,6 @@ import numpy as np
import tvm import tvm
import topi import topi
import topi.testing import topi.testing
from topi.util import get_const_tuple
from tvm.contrib import util
target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'
def generate_quantized_np(shape, bits, out_dtype): def generate_quantized_np(shape, bits, out_dtype):
np.random.seed(0) np.random.seed(0)
...@@ -17,20 +13,19 @@ def generate_quantized_np(shape, bits, out_dtype): ...@@ -17,20 +13,19 @@ def generate_quantized_np(shape, bits, out_dtype):
# Verify that certain special instructions from the tensorize pass exist # Verify that certain special instructions from the tensorize pass exist
def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding, def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa): activation_bits, weight_bits, dorefa):
in_height = in_width = in_size in_height = in_width = in_size
input_type='uint32' input_type = 'uint32'
out_dtype='int32' out_dtype = 'int32'
with tvm.target.arm_cpu('rasp3b'): with tvm.target.arm_cpu('rasp3b'):
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A') A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W') W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype, B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype,
layout="NHWC", dorefa=dorefa) layout="NHWC", dorefa=dorefa)
s = topi.generic.schedule_bitserial_conv2d_nhwc([B]) s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
func = tvm.build(s, [A, W, B], tvm.target.arm_cpu('rasp3b'))
func = tvm.build(s, [A, W, B], target)
assembly = func.get_source('asm') assembly = func.get_source('asm')
matches = re.findall("vpadal", assembly) matches = re.findall("vpadal", assembly)
...@@ -47,7 +42,6 @@ def test_bitserial_conv2d(): ...@@ -47,7 +42,6 @@ def test_bitserial_conv2d():
stride = 1 stride = 1
pad = 1 pad = 1
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False) verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False) verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False)
......
...@@ -28,7 +28,7 @@ def verify_binary_dense(batch, in_dim, out_dim): ...@@ -28,7 +28,7 @@ def verify_binary_dense(batch, in_dim, out_dim):
a_np = (np.random.randint(2, size=(batch, in_dim)) * 2 - 1).astype(dtype) a_np = (np.random.randint(2, size=(batch, in_dim)) * 2 - 1).astype(dtype)
b_np = (np.random.randint(2, size=(out_dim, in_dim)) * 2 - 1).astype(dtype) b_np = (np.random.randint(2, size=(out_dim, in_dim)) * 2 - 1).astype(dtype)
c_np = np.dot(a_np, b_np.T) c_np = np.dot(a_np, b_np.T)
return (a_np, b_np, c_np) return a_np, b_np, c_np
a_np, b_np, c_np = get_ref_data() a_np, b_np, c_np = get_ref_data()
......
"""Test code for broadcasting operators.""" """Test code for broadcasting operators."""
import os from common import get_all_backend
import numpy as np import numpy as np
import tvm import tvm
import topi import topi
...@@ -8,6 +8,7 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast): ...@@ -8,6 +8,7 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
# Build the logic and compile the function # Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = fbcast(A, out_shape) B = fbcast(A, out_shape)
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not ctx.exist: if not ctx.exist:
...@@ -21,16 +22,11 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast): ...@@ -21,16 +22,11 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
out_npy = np.broadcast_to(data_npy, out_shape) out_npy = np.broadcast_to(data_npy, out_shape)
data_nd = tvm.nd.array(data_npy, ctx) data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx) out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
for _ in range(1): foo(data_nd, out_nd)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("vulkan") for target in get_all_backend():
check_device("opencl") check_device(target)
check_device("cuda")
check_device("metal")
check_device("rocm")
check_device("nvptx")
check_device("sdaccel") check_device("sdaccel")
...@@ -45,9 +41,10 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, ...@@ -45,9 +41,10 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
B = (tvm.var("B", dtype=dtype) if rhs_shape is None B = (tvm.var("B", dtype=dtype) if rhs_shape is None
else tvm.placeholder(shape=rhs_shape, name="B", dtype=dtype)) else tvm.placeholder(shape=rhs_shape, name="B", dtype=dtype))
C = ftopi(A, B) C = ftopi(A, B)
if (isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr)): if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
assert(isinstance(C, tvm.expr.Expr)) assert(isinstance(C, tvm.expr.Expr))
return return
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not ctx.exist: if not ctx.exist:
...@@ -82,12 +79,8 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, ...@@ -82,12 +79,8 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
foo(lhs_nd, rhs_nd, out_nd) foo(lhs_nd, rhs_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4) np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
check_device("opencl") for target in get_all_backend():
check_device("vulkan") check_device(target)
check_device("cuda")
check_device("metal")
check_device("rocm")
check_device("nvptx")
check_device("sdaccel") check_device("sdaccel")
def test_broadcast_to(): def test_broadcast_to():
......
...@@ -5,6 +5,7 @@ import topi ...@@ -5,6 +5,7 @@ import topi
from topi.util import get_const_tuple from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
from common import get_all_backend
def verify_clip(N, a_min, a_max, dtype): def verify_clip(N, a_min, a_max, dtype):
A = tvm.placeholder((N, N), dtype=dtype, name='A') A = tvm.placeholder((N, N), dtype=dtype, name='A')
...@@ -34,7 +35,7 @@ def verify_clip(N, a_min, a_max, dtype): ...@@ -34,7 +35,7 @@ def verify_clip(N, a_min, a_max, dtype):
f(a, b) f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm', 'opencl', 'sdaccel']: for device in get_all_backend():
check_device(device) check_device(device)
def test_clip(): def test_clip():
......
"""Example code to do conv2d."""
import os
import numpy as np
import tvm
from tvm import autotvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
def verify_conv2d(batch, in_size, in_channel, num_filter, kernel, stride, padding):
in_height = in_width = in_size
with tvm.target.arm_cpu():
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.conv2d(A, W, (stride, stride), (padding, padding), 'NCHW', 'float32')
s = topi.generic.schedule_conv2d_nchw([B])
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d.verify_conv2d")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
return a_np, w_np, b_np
a_np, w_np, b_np = get_ref_data()
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, W, B], "llvm")
func(a, w, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def test_conv2d():
with autotvm.tophub.context(tvm.target.arm_cpu('rasp3b'), allow_fallback=True):
verify_conv2d(1, 56, 64, 64, 3, 1, 1)
if __name__ == "__main__":
test_conv2d()
...@@ -43,14 +43,12 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -43,14 +43,12 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
w = tvm.nd.array(w_np, ctx) w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
with tvm.build_config(auto_unroll_max_step=128, func1 = tvm.build(s1, [A, W, B], device)
unroll_explicit=(device != "cuda")): func2 = tvm.build(s2, [A, W, C], device)
func1 = tvm.build(s1, [A, W, B], device) func1(a, w, b)
func2 = tvm.build(s2, [A, W, C], device) func2(a, w, c)
func1(a, w, b) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
func2(a, w, c) np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']: for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device) check_device(device)
......
"""Example code to do convolution.""" """Example code to do convolution."""
import os
import numpy as np import numpy as np
import tvm import tvm
from tvm import autotvm
import topi import topi
import topi.testing import topi.testing
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple from topi.util import get_const_tuple
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): from common import get_all_backend
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
in_height = in_width = in_size in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
bias = tvm.placeholder((num_filter, 1, 1), name='bias')
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d_nchw.verify_conv2d_nchw") @memoize("topi.tests.test_topi_conv2d_nchw.verify_conv2d_nchw")
def get_ref_data(): def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype) a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
c_np = np.maximum(b_np, 0) if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data() a_np, w_np, b_np, c_np = get_ref_data()
...@@ -38,66 +48,103 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -38,66 +48,103 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
dW = topi.nn.dilate(W, (1, 1, dilation, dilation)) dW = topi.nn.dilate(W, (1, 1, dilation, dilation))
B = topi.nn.conv2d(A, dW, stride, padding, layout='NCHW') C = topi.nn.conv2d(A, dW, stride, padding, layout='NCHW', out_dtype=dtype)
C = topi.nn.relu(B) if add_bias:
s1 = topi.generic.schedule_conv2d_nchw([B]) C = topi.add(C, bias)
s2 = topi.generic.schedule_conv2d_nchw([C]) if add_relu:
C = topi.nn.relu(C)
s = topi.generic.schedule_conv2d_nchw([C])
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx) w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
no_unroll_explicit = device in ["cuda", "nvptx", "rocm"] if add_bias:
with tvm.build_config(auto_unroll_max_step=1400, func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
unroll_explicit=not no_unroll_explicit): func(a, w, b, c)
func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) else:
func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func1(a, w, b) func(a, w, c)
func2(a, w, c) np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) for device in get_all_backend():
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device) check_device(device)
def test_conv2d_nchw(): def test_conv2d_nchw():
autotvm.DispatchContext.current.silent = True
# ResNet18 workloads # ResNet18 workloads
verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3)
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0) verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0)
verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1) verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1)
verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0) verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0)
verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1) verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1)
verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1) verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1)
verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0) verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0)
verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1) verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1)
verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1) verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1)
verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0) verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
# ResNet50 workloads
verify_conv2d_nchw(1, 64, 56, 256, 1, 1, 0) # bias, relu
verify_conv2d_nchw(1, 256, 56, 64, 1, 1, 0) verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_relu=True)
verify_conv2d_nchw(1, 256, 56, 128, 1, 2, 0) verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True)
verify_conv2d_nchw(1, 128, 28, 512, 1, 1, 0) verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
verify_conv2d_nchw(1, 256, 56, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 28, 128, 1, 1, 0)
verify_conv2d_nchw(1, 512, 28, 256, 1, 2, 0)
verify_conv2d_nchw(1, 256, 14, 1024, 1, 1, 0)
verify_conv2d_nchw(1, 512, 28, 1024, 1, 2, 0)
verify_conv2d_nchw(1, 1024, 14, 256, 1, 1, 0)
verify_conv2d_nchw(1, 1024, 14, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 7, 2048, 1, 2, 0)
verify_conv2d_nchw(1, 1024, 14, 2048, 1, 2, 0)
verify_conv2d_nchw(1, 2048, 7, 512, 1, 1, 0)
# Vgg16 workloads
verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1)
# Super resolution workloads
verify_conv2d_nchw(1, 1, 224, 64, 5, 1, 2)
verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1)
verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1)
verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1)
# dilation = 2 # dilation = 2
verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1, dilation=2) verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, dilation=2)
# weird workloads
verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=1)
verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=2)
# inception v3 workloads
verify_conv2d_nchw(1, 3, 299, 32, 3, 2, 0)
verify_conv2d_nchw(1, 32, 149, 32, 3, 1, 0)
verify_conv2d_nchw(1, 32, 147, 64, 3, 1, 1)
verify_conv2d_nchw(1, 64, 73, 80, 1, 1, 0)
verify_conv2d_nchw(1, 80, 73, 192, 3, 1, 0)
verify_conv2d_nchw(1, 192, 35, 64, 1, 1, 0)
verify_conv2d_nchw(1, 192, 35, 48, 1, 1, 0)
verify_conv2d_nchw(1, 48, 35, 64, 5, 1, 2)
verify_conv2d_nchw(1, 64, 35, 96, 3, 1, 1)
verify_conv2d_nchw(1, 96, 35, 96, 3, 1, 1)
verify_conv2d_nchw(1, 192, 35, 32, 1, 1, 0)
verify_conv2d_nchw(1, 256, 35, 64, 1, 1, 0)
verify_conv2d_nchw(1, 256, 35, 48, 1, 1, 0)
verify_conv2d_nchw(1, 288, 35, 64, 1, 1, 0)
verify_conv2d_nchw(1, 288, 35, 48, 1, 1, 0)
verify_conv2d_nchw(1, 288, 35, 384, 3, 2, 0)
# verify_conv2d_nchw(1, 96, 35, 96, 3, 2, 0)
# verify_conv2d_nchw(1, 768, 17, 192, 1, 1, 0)
# verify_conv2d_nchw(1, 768, 17, 128, 1, 1, 0)
# verify_conv2d_nchw(1, 128, 17, 128, 1, 1, 0)
# verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3)
# verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3)
# verify_conv2d_nchw(1, 128, 17, 192, 1, 1, 0)
# verify_conv2d_nchw(1, 768, 17, 160, 1, 1, 0)
# verify_conv2d_nchw(1, 160, 17, 160, 1, 1, 0)
# verify_conv2d_nchw(1, 160, 17, 192, 7, 1, 3)
# verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3)
# verify_conv2d_nchw(1, 160, 17, 192, 1, 1, 0)
# verify_conv2d_nchw(1, 192, 17, 192, 1, 1, 0)
# verify_conv2d_nchw(1, 192, 17, 192, 7, 1, 3)
# verify_conv2d_nchw(1, 192, 17, 320, 3, 2, 0)
# verify_conv2d_nchw(1, 192, 17, 192, 3, 2, 0)
verify_conv2d_nchw(1, 1280, 8, 320, 1, 1, 0)
verify_conv2d_nchw(1, 1280, 8, 384, 1, 1, 0)
verify_conv2d_nchw(1, 384, 8, 384, 1, 1, 0)
verify_conv2d_nchw(1, 384, 8, 384, 3, 1, 1)
verify_conv2d_nchw(1, 1280, 8, 448, 1, 1, 0)
verify_conv2d_nchw(1, 448, 8, 384, 3, 1, 1)
verify_conv2d_nchw(1, 1280, 8, 192, 1, 1, 0)
verify_conv2d_nchw(1, 2048, 8, 320, 1, 1, 0)
verify_conv2d_nchw(1, 2048, 8, 384, 1, 1, 0)
verify_conv2d_nchw(1, 2048, 8, 448, 1, 1, 0)
verify_conv2d_nchw(1, 2048, 8, 192, 1, 1, 0)
if __name__ == "__main__": if __name__ == "__main__":
test_conv2d_nchw() test_conv2d_nchw()
...@@ -6,14 +6,13 @@ import topi.testing ...@@ -6,14 +6,13 @@ import topi.testing
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple from topi.util import get_const_tuple
from common import get_all_backend
def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding): def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
in_height = in_width = in_size in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((in_channel, num_filter, kernel, kernel), name='W') W = tvm.placeholder((in_channel, num_filter, kernel, kernel), name='W')
B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], padding, A.dtype)
C = topi.nn.relu(B)
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
...@@ -36,22 +35,23 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, ...@@ -36,22 +35,23 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], [padding, padding], A.dtype)
C = topi.nn.relu(B)
s1 = topi.generic.schedule_conv2d_transpose_nchw([B]) s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
s2 = topi.generic.schedule_conv2d_transpose_nchw([C]) s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx) w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
with tvm.build_config(auto_unroll_max_step=128,
unroll_explicit=(device != "cuda")):
func1 = tvm.build(s1, [A, W, B], device)
func2 = tvm.build(s2, [A, W, C], device)
func1(a, w, b)
func2(a, w, c)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']: func1 = tvm.build(s1, [A, W, B], device)
func2 = tvm.build(s2, [A, W, C], device)
func1(a, w, b)
func2(a, w, c)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in get_all_backend():
check_device(device) check_device(device)
......
...@@ -6,13 +6,12 @@ import topi.testing ...@@ -6,13 +6,12 @@ import topi.testing
from topi.util import get_const_tuple from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
from common import get_all_backend
def verify_dense(batch, in_dim, out_dim, use_bias=True): def verify_dense(batch, in_dim, out_dim, use_bias=True):
A = tvm.placeholder((batch, in_dim), name='A') A = tvm.placeholder((batch, in_dim), name='A')
B = tvm.placeholder((out_dim, in_dim), name='B') B = tvm.placeholder((out_dim, in_dim), name='B')
C = tvm.placeholder((out_dim,), name='C') C = tvm.placeholder((out_dim,), name='C')
D = topi.nn.dense(A, B, C if use_bias else None)
D = topi.nn.relu(D)
dtype = A.dtype dtype = A.dtype
# use memoize to pickle the test data for next time use # use memoize to pickle the test data for next time use
...@@ -36,6 +35,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): ...@@ -36,6 +35,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
D = topi.nn.dense(A, B, C if use_bias else None)
D = topi.nn.relu(D)
s = topi.generic.schedule_dense(D) s = topi.generic.schedule_dense(D)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx) b = tvm.nd.array(b_np, ctx)
...@@ -45,13 +46,15 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): ...@@ -45,13 +46,15 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
f(a, b, c, d) f(a, b, c, d)
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5) np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']: for device in get_all_backend():
check_device(device) check_device(device)
def test_dense(): def test_dense():
verify_dense(1, 1024, 1000, use_bias=True) verify_dense(1, 1024, 1000, use_bias=True)
verify_dense(1, 1024, 1000, use_bias=False) verify_dense(1, 1024, 1000, use_bias=False)
verify_dense(2, 1024, 1000, use_bias=True)
if __name__ == "__main__": if __name__ == "__main__":
test_dense() test_dense()
...@@ -2,11 +2,10 @@ import tvm ...@@ -2,11 +2,10 @@ import tvm
import topi import topi
import topi.testing import topi.testing
import numpy as np import numpy as np
from scipy import signal
from topi.util import get_const_tuple from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nhwc
from common import get_all_backend
def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1): def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1):
in_width = in_height in_width = in_height
...@@ -18,10 +17,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -18,10 +17,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter') DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter')
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
# declare
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, DilatedFilter, stride=stride, padding=padding)
ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
...@@ -30,6 +25,10 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -30,6 +25,10 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
# declare
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, DilatedFilter, stride=stride, padding=padding)
ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)
# schedule # schedule
s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d) s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift) s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift)
...@@ -88,12 +87,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -88,12 +87,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
check_device("opencl") for device in get_all_backend():
check_device("cuda") check_device(device)
check_device("metal")
check_device("rocm")
check_device("vulkan")
check_device("nvptx")
def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding, dilation=1): def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding, dilation=1):
...@@ -107,11 +102,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -107,11 +102,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter') DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter')
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
# declare
DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, DilatedFilter, stride=[stride_h, stride_w], padding=padding)
ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)
# schedule
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
...@@ -121,6 +111,11 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -121,6 +111,11 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
# declare
DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, DilatedFilter, stride=[stride_h, stride_w], padding=padding)
ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)
# schedule
s1 = topi.generic.schedule_depthwise_conv2d_nhwc(DepthwiseConv2d) s1 = topi.generic.schedule_depthwise_conv2d_nhwc(DepthwiseConv2d)
s2 = topi.generic.schedule_depthwise_conv2d_nhwc(ScaleShift) s2 = topi.generic.schedule_depthwise_conv2d_nhwc(ScaleShift)
s3 = topi.generic.schedule_depthwise_conv2d_nhwc(Relu) s3 = topi.generic.schedule_depthwise_conv2d_nhwc(Relu)
...@@ -180,12 +175,9 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -180,12 +175,9 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
check_device("opencl") for device in get_all_backend():
check_device("cuda") check_device(device)
check_device("metal")
check_device("rocm")
check_device("vulkan")
check_device("nvptx")
def test_depthwise_conv2d(): def test_depthwise_conv2d():
print("testing nchw") print("testing nchw")
......
...@@ -312,7 +312,9 @@ def tune_and_evaluate(): ...@@ -312,7 +312,9 @@ def tune_and_evaluate():
# upload module to device # upload module to device
print("Upload...") print("Upload...")
remote = autotvm.measure.request_remote(device_key, timeout=10000) remote = autotvm.measure.request_remote(device_key,
tracker_addr=('localhost', 9190),
timeout=10000)
remote.upload(tmp.relpath(filename)) remote.upload(tmp.relpath(filename))
rlib = remote.load_module(filename) rlib = remote.load_module(filename)
...@@ -333,7 +335,6 @@ def tune_and_evaluate(): ...@@ -333,7 +335,6 @@ def tune_and_evaluate():
# We do not run the tuning in our webpage server since it takes too long. # We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run by yourself. # Uncomment the following line to run by yourself.
# tune_and_evaluate() # tune_and_evaluate()
###################################################################### ######################################################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment