Commit 6ea74d41 by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Core part of auto-tuning module (#1312)

parent 7e7154f1
...@@ -96,6 +96,7 @@ assign_source_group("Include" ${GROUP_INCLUDE}) ...@@ -96,6 +96,7 @@ assign_source_group("Include" ${GROUP_INCLUDE})
file(GLOB COMPILER_SRCS file(GLOB COMPILER_SRCS
src/api/*.cc src/api/*.cc
src/arithmetic/*.cc src/arithmetic/*.cc
src/autotvm/*.cc
src/codegen/*.cc src/codegen/*.cc
src/codegen/stack_vm/*.cc src/codegen/stack_vm/*.cc
src/lang/*.cc src/lang/*.cc
......
tvm.autotvm
-----------
.. automodule:: tvm.autotvm
tvm.autotvm.measure
~~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.measure.measure
.. autoclass:: tvm.autotvm.measure.MeasureInput
:members:
.. autoclass:: tvm.autotvm.measure.MeasureResult
:members:
.. autofunction:: tvm.autotvm.measure.measure_option
.. autofunction:: tvm.autotvm.measure.create_measure_batch
tvm.autotvm.tuner
~~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.tuner
:members:
.. autoclass:: tvm.autotvm.tuner.Tuner
:members:
.. autoclass:: tvm.autotvm.tuner.RandomTuner
:members:
:inherited-members:
.. autoclass:: tvm.autotvm.tuner.GridSearchTuner
:members:
:inherited-members:
.. autoclass:: tvm.autotvm.tuner.GATuner
:members:
:inherited-members:
.. autoclass:: tvm.autotvm.tuner.XGBTuner
:members:
:inherited-members:
.. automodule:: tvm.autotvm.tuner.callback
:members:
tvm.autotvm.task
~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.task
:members:
.. automodule:: tvm.autotvm.task.task
:members:
.. automodule:: tvm.autotvm.task.space
:members:
tvm.autotvm.record
~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.record
:members:
...@@ -14,6 +14,7 @@ Python API ...@@ -14,6 +14,7 @@ Python API
ndarray ndarray
container container
function function
autotvm
graph_runtime graph_runtime
rpc rpc
bridge bridge
......
...@@ -191,6 +191,7 @@ gallery_dirs = ["tutorials", "vta/tutorials"] ...@@ -191,6 +191,7 @@ gallery_dirs = ["tutorials", "vta/tutorials"]
subsection_order = ExplicitOrder( subsection_order = ExplicitOrder(
['../tutorials/language', ['../tutorials/language',
'../tutorials/optimize', '../tutorials/optimize',
'../tutorials/autotvm',
'../tutorials/vta', '../tutorials/vta',
'../tutorials/topi', '../tutorials/topi',
'../tutorials/deployment', '../tutorials/deployment',
......
...@@ -488,7 +488,7 @@ bool VerifyMemory(LoweredFunc func, int device_type); ...@@ -488,7 +488,7 @@ bool VerifyMemory(LoweredFunc func, int device_type);
* *
* "max_local_memory_per_block": Total amount of local memory per block (in bytes). * "max_local_memory_per_block": Total amount of local memory per block (in bytes).
* "max_shared_memory_per_block": Total amount of shared memory per block (in bytes). * "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
* "max_thread_per_block": Maximum number of threads per block. * "max_threads_per_block": Maximum number of threads per block.
* "max_thread_x": Maximum length of threadIdx.x. * "max_thread_x": Maximum length of threadIdx.x.
* "max_thread_y": Maximum length of threadIdx.y. * "max_thread_y": Maximum length of threadIdx.y.
* "max_thread_z": Maximum length of threadIdx.z. * "max_thread_z": Maximum length of threadIdx.z.
......
"""The auto-tuning module of tvm
This module includes:
* Tuning space definition API
* Efficient auto-tuners
* Tuning result and database support
* Distributed measurement to scale up tuning
"""
from . import database
from . import feature
from . import measure
from . import record
from . import task
from . import tuner
from . import util
# some shortcuts
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity
from .record import ApplyHistoryBest as apply_history_best
# pylint: disable=consider-using-enumerate,invalid-name
"""
Database of MeasureInput/MeasureResult pair.
This can be used for replaying measurement.
"""
import os
from .record import encode, decode, measure_str_key
class Database(object):
"""
Base class for a record database object.
"""
def load(self, inp, get_all=False):
"""
Load a result based on an input's string key
Parameters
----------
inp: MeasureInput
to be translated into key for RedisDB
get_all: bool, optional
Whether the latest result (or all matching results) should be returned
Returns
-------
rec: MeasureResult if previously saved, otherwise None
"""
raise NotImplementedError()
def save(self, inp, res, extend=False):
"""
Save a result based on an input's string key
Parameters
----------
inp: MeasureInput
to be translated into key for RedisDB
res: MeasureResult
to associate with key
extend:
Whether to extend existing MeasureResults if they exist
"""
raise NotImplementedError()
def filter_inputs(db, measure_inputs, retry=False):
"""
Filter a measure_inputs batch based on saved db results
Parameters
----------
db: Database
database object
measure_inputs: Array of MeasureInput
measure_inputs as expected in measure_batch
retry: bool
whether to retry if the saved result is a failure
Returns
-------
partial_results: Array of MeasureResult
a full list of result, where None denotes no corresponding saved result
unsaved: Array of MeasureInput
a list that only contains unsaved inputs
"""
partial_results = list()
unsaved = list()
for inp in measure_inputs:
res = db.load(inp)
if res is None or (retry and res.error_no != 0):
unsaved.append(inp)
partial_results.append(None)
else:
partial_results.append(res)
return partial_results, unsaved
class RedisDatabase(Database):
"""
Redis version of record database
"""
REDIS_PROD = 15
REDIS_LOCA = 14
REDIS_TEST = 13 # for unit test
REDIS_NIGHT_TEMP = 12 # for nightly report (will be flushed after every workload)
MAGIC_SPLIT = "$"
def __init__(self, db_index=REDIS_PROD):
import redis
if db_index == RedisDatabase.REDIS_TEST:
host = 'localhost'
else:
host = os.environ.get('TVM_FLEET_HOST')
self.db = redis.StrictRedis(host=host, port=6379, db=db_index)
self.db_index = db_index
def set(self, key, value):
self.db.set(key, value)
def get(self, key):
return self.db.get(key)
def load(self, inp, get_all=False):
current = self.get(measure_str_key(inp))
if current is not None:
current = str(current)
records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
results = [rec[1] for rec in records]
if get_all:
return results
return max(results, key=lambda result: result.timestamp)
return current
def save(self, inp, res, extend=False):
current = self.get(measure_str_key(inp))
if not extend or current is None:
self.set(measure_str_key(inp),
RedisDatabase.MAGIC_SPLIT.join([encode(inp, res)]))
else:
current = current.split(RedisDatabase.MAGIC_SPLIT)
self.set(measure_str_key(inp),
RedisDatabase.MAGIC_SPLIT.join(current + [encode(inp, res)]))
def filter(self, func):
"""
Dump all of the records for a particular target
Parameters
----------
func: callable
The signature of the function is bool (MeasureInput, Array of MeasureResult)
Returns
-------
list of records (inp, result) matching the target
Examples
--------
get records for a target
>>> db.filter(lambda inp, resulst: "cuda" in inp.target.keys)
"""
matched_records = list()
# may consider filtering in iterator in the future
for key in self.db:
current = self.get(key)
try:
records = [decode(x) for x in current.spilt(RedisDatabase.MAGIC_SPLIT)]
except TypeError: # got a badly formatted/old format record
continue
inps, results = zip(*records)
inp = inps[0]
if not func(inp, results):
continue
result = max(results, key=lambda res: res.timestamp)
matched_records.append((inp, result))
return matched_records
def flush(self):
self.db.flushdb()
class DummyDatabase(RedisDatabase):
"""
A database based on python dictionary for testing.
"""
def __init__(self):
# pylint: disable=super-init-not-called
self.db = {}
def set(self, key, value):
self.db[key] = value
def get(self, key):
return self.db.get(key)
def flush(self):
self.db = {}
"""Global configuration/variable scope for autotvm"""
class AutotvmGlobalScope(object):
current = None
def __init__(self):
self._old = AutotvmGlobalScope.current
AutotvmGlobalScope.current = self
self.cuda_target_arch = None
GLOBAL_SCOPE = AutotvmGlobalScope()
# pylint: disable=invalid-name
"""Extract feature of iter vars
There are two types of feature
1) Itervar feature
This feature is extracted based on loop variables.
Different loop structures will result in different shapes of feature
2) Curve sample feature (relation feature)
This feature is extracted by sampling relation curve.
This feature is invariant of loop structure.
"""
import struct
import numpy as np
from tvm import schedule, ir_pass, build_module, get_global_func, target as _target
def ana_lower(sch, args,
binds=None,
simple_mode=True):
"""Do lower while keeping all axes in IR
i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads
"""
binds, _ = build_module.get_binds(args, binds)
sch = sch.normalize()
# Phase 0
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds, True)
stmt = ir_pass.StorageFlatten(stmt, binds, 64)
stmt = ir_pass.CanonicalSimplify(stmt)
assert simple_mode
return stmt
try:
_get_buffer_curve_sample_flatten = get_global_func(
"autotvm.feature.GetCurveSampleFeatureFlatten")
_get_itervar_feature = get_global_func("autotvm.feature.GetItervarFeature")
_get_itervar_feature_flatten = get_global_func("autotvm.feature.GetItervarFeatureFlatten")
except ValueError as e:
def raise_error(*args, **kwargs): # pylint: disable=unused-argument
raise RuntimeError("Cannot load autotvm c++ API")
_get_buffer_curve_sample_flatten = _get_itervar_feature = _get_itervar_feature_flatten = \
raise_error
def get_itervar_feature(sch, args, take_log=False):
"""get features of iter vars
Parameters
----------
sch: tvm.schedule.Schedule
args: Array of tvm.tensor.Tensor
the buffer args for lower
take_log: bool
whether take log of numerical statics
Returns
-------
features of every axis in the IR, see doc/features.md for detail
"""
stmt = ana_lower(sch, args, simple_mode=True)
feas = _get_itervar_feature(stmt, take_log)
# convert tvm node to python type
ret = []
for row in feas:
tmp = []
tmp.append([row[0][0].value, row[0][1]])
for item in row[1:]:
tmp.append([item[0].value] + [x.value for x in item[1:]])
ret.append(tmp)
return ret
def flatten_itervar_feature(fea):
"""flatten features into one-dimensional feature vectors
Parameters
----------
fea: list
return value of get_itervar_feature
Returns
-------
flatten_feature: np.ndarray
one-dimensional vector
"""
flatten = []
for axis in fea:
for pair in axis[1:]:
flatten.append(pair[1:])
return np.concatenate(flatten)
def get_itervar_feature_flatten(sch, args, take_log=True):
"""get flatten features of iter vars
this is equivalent to get_itervar_feature + flatten_itervar_feature, but much faster.
Parameters
----------
sch: tvm.schedule.Schedule
args: Array of tvm.tensor.Tensor
the buffer args for lower
take_log: bool
whether take log of numerical statics
Returns
-------
flatten_feature: np.ndarray
one-dimensional vector
"""
stmt = ana_lower(sch, args, simple_mode=True)
feas = _get_itervar_feature_flatten(stmt, take_log)
feas = struct.unpack('%df' % (len(feas)//4), feas)
return feas
def get_flatten_name(fea):
""" Get names of feature after flatten.
Parameters
----------
fea: list or str
return value of get_itervar_feature or a line of logfile
Returns
-------
feature_names: Array of str
"""
feature_name = {
"_attr_": ["length", "nest_level", "topdown", "bottomup"] +
["ann_%d" % i for i in range(20)],
"_arith_": ["add", "mul", "div"],
"buf_touch": ["stride", "mod", "count", "reuse", "T_count", "T_reuse"],
}
if isinstance(fea, str):
from .record import decode
# flatten line to feature
line = fea
inp, _ = decode(line)
target = _target.create(inp.target)
with target:
s, args = inp.template.instantiate(inp.config)
fea = get_itervar_feature(s, args)
names = []
ct = 0
for row in fea:
var_name = str(row[0][1])
for pair in row[1:]:
key = pair[0]
if key in feature_name:
name_list = feature_name[key]
else:
name_list = feature_name["buf_touch"]
for i in range(len((pair[1:]))):
names.append(".".join(["f%d" % ct, var_name, key, name_list[i]]))
ct += 1
return names
def get_buffer_curve_sample_flatten(sch, args, sample_n=30):
"""
Get flatten curve sample feature (relation feature)
Parameters
----------
sch: tvm.schedule.Schedule
args: Array of tvm.tensor.Tensor
the buffer args for lower
sample_n: int
number of sample points along one dimension
Returns
-------
flatten_feature: np.ndarray
one-dimensional vector
"""
stmt = ana_lower(sch, args, simple_mode=True)
feas = _get_buffer_curve_sample_flatten(stmt, sample_n, False)
feas = struct.unpack('%df' % (len(feas)//4), feas)
return feas
"""Distributed executor infrastructure to scale up the tuning"""
from .measure import MeasureInput, MeasureResult, MeasureErrorNo
from .measure import create_measure_batch, measure_option
from .measure_methods import request_remote
from .local_executor import LocalExecutor
from .executor import Future, Executor
""" Abstraction for asynchronous job execution """
class Executor(object):
"""
Base abstract executor interface for asynchronous job submission.
Allows submit asynchronous jobs and returns the Future object.
"""
# timeout for jobs that may hang
DEFAULT_TIMEOUT = 60
def submit(self, func, *args, **kwargs):
"""
Pass task (function, arguments) to the Executor.
Parameters
----------
func : callable
function to be run by a worker
args : list or tuple, optional
arguments passed to the function
kwargs : dict, optional
The keyword arguments
Returns
-------
future : Future
Future object wrapping the task which can be used to
collect the task's result.
"""
raise NotImplementedError()
class Future(object):
"""
Base class of the future object.
The implementations can return object of subclass of this.
This objects encapsulates the asynchronous execution of task
submitted to another thread, or another worker for execution.
Future objects store the state of tasks--can be polled for
result or a blocking call to retrieve the result can be used.
"""
def done(self):
"""
Return True if job was successfully cancelled or finished running.
"""
raise NotImplementedError()
def get(self, timeout=None):
"""
Get the result. This will block until the result is available.
Parameters
----------
timeout : int or float, optional
Maximum number of seconds to wait before it timeouts.
If not specified, it means we block until the result is available.
Returns
-------
result : Any
The result returned by the submitted function.
Raises
------
TimeoutError : if the result call timeouts.
"""
raise NotImplementedError()
class FutureError(RuntimeError):
"""Base error class of all future events"""
pass
# pylint:disable=redefined-builtin
class TimeoutError(FutureError):
"""Error raised when a task is timeout."""
pass
class ExecutionError(FutureError):
"""
Error raised when future execution crashes or failed.
"""
pass
"""Local based implementation of the executor using multiprocessing"""
import signal
from multiprocessing import Process, Queue
try:
from queue import Empty
except ImportError:
from Queue import Empty
import psutil
from . import executor
def kill_child_processes(parent_pid, sig=signal.SIGTERM):
"""kill all child processes recursively"""
try:
parent = psutil.Process(parent_pid)
except psutil.NoSuchProcess:
return
children = parent.children(recursive=True)
for process in children:
try:
process.send_signal(sig)
except psutil.NoSuchProcess:
return
def _execute_func(func, queue, args, kwargs):
"""execute function and return the result or exception to a queue"""
try:
res = func(*args, **kwargs)
except Exception as exc: # pylint: disable=broad-except
res = exc
queue.put(res)
def timeout_monitor(queue, timeout, func, args, kwargs):
"""A wrapper to support timeout of a function call"""
# start a new process for timeout (cannot use thread because we have c function)
p = Process(target=_execute_func, args=(func, queue, args, kwargs))
p.start()
p.join(timeout=timeout)
alive = p.is_alive()
kill_child_processes(p.pid)
p.terminate()
p.join()
if alive:
queue.put(executor.TimeoutError())
else:
if queue.empty():
queue.put(executor.ExecutionError("Fatal error in local executor"))
class LocalFuture(executor.Future):
"""Local wrapper for the future
Parameters
----------
process: multiprocessing.Process
process for running this task
queue: multiprocessing.Queue
queue for receiving the result of this task
"""
def __init__(self, process, queue):
self._done = False
self._process = process
self._queue = queue
def done(self):
self._done = self._done or not self._queue.empty()
return self._done
def get(self, timeout=None):
try:
res = self._queue.get(block=True, timeout=timeout)
except Empty:
raise executor.TimeoutError()
if self._process.is_alive():
kill_child_processes(self._process.pid)
self._process.terminate()
self._process.join()
self._queue.close()
self._queue.join_thread()
self._done = True
del self._queue
del self._process
return res
class LocalFutureNoFork(executor.Future):
"""Local wrapper for the future.
This is a none-fork version of LocalFuture.
Use this for the runtime that does not support fork (like cudnn)
"""
def __init__(self, result):
self._result = result
def done(self):
return True
def get(self, timeout=None):
return self._result
class LocalExecutor(executor.Executor):
"""Local executor that runs workers on the same machine with multiprocessing."""
def __init__(self, timeout=None):
self.timeout = timeout or executor.Executor.DEFAULT_TIMEOUT
def submit(self, func, *args, **kwargs):
"""
Note
----------
By default, the executor will fork a new process for a new job
But some runtime does not support fork (e.g. cuda runtime, cudnn).
In this circumstance, you should set 'fork_new_process' to False in kwargs
"""
fork_new_process = kwargs.pop('fork_new_process', True)
if not fork_new_process:
return LocalFutureNoFork(func(*args, **kwargs))
queue = Queue(1)
process = Process(target=timeout_monitor,
args=(queue, self.timeout, func, args, kwargs))
process.start()
return LocalFuture(process, queue)
"""Task is a tunable composition of template functions.
Tuner takes a tunable task and optimizes the joint configuration
space of all the template functions in the task.
This module defines the task data structure, as well as a collection(zoo)
of typical tasks of interest.
"""
from .task import Task, create, register, template, get_config
from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, dispatcher
"""
Decorator functions for hashing schedule code
code hashing is used to check the consistence of schedule code and the parameters loaded from log
"""
import inspect
import zlib
from tvm import schedule
def attach_code_hash(s):
"""Decorator for attaching a code hash to a schedule
Parameters
----------
s: Schedule
tvm.schedule.Schedule to attach the hash to
"""
def decorator(func):
def wrapper(*args, **kwargs):
func(*args, **kwargs)
raw_hash = zlib.crc32(''.join(inspect.getsourcelines(func)[0]).encode())
s.code_hash = hex(raw_hash)[2:]
return wrapper
return decorator
def attach_code_hash_to_arg(arg_idx=1):
"""Decorator for attaching a code hash to a schedule
Parameters
----------
arg_idx: int
index of the argument (expected to be a Schedule) to attach the code
hash to
"""
def decorator(func):
def wrapper(*args, **kwargs):
func(*args, **kwargs)
assert isinstance(args[arg_idx], schedule.Schedule)
raw_hash = zlib.crc32(''.join(inspect.getsourcelines(func)[0]).encode())
args[arg_idx].code_hash = hex(raw_hash)[2:]
return wrapper
return decorator
"""
Template dispatcher module.
A dispatcher is a function that can contains multiple behaviors.
Its specific behavior is can be controlled by DispatchContext.
DispatchContext is used in two ways, usually via different implementation
of the DispatchContext base class.
- During search, we can use it to pass the current proposal from tuner.
- During evaluation, we can use it to set pick the best policy.
"""
from __future__ import absolute_import as _abs
from decorator import decorate
from tvm import target as _target
class DispatchContext(object):
"""
Base class of dispatch context.
DispatchContext enables the target and workload
specific dispatch mechanism for templates.
"""
current = None
def query(self, target, workload):
"""
Query the context to get the specific implementation.
Parameters
----------
target: Target
The current target
workload : Workload
The current workload.
Returns
-------
cfg : ConfigSpace
The specific configuration.
"""
raise NotImplementedError()
def __enter__(self):
self._old_ctx = DispatchContext.current
DispatchContext.current = self
return self
def __exit__(self, ptype, value, trace):
DispatchContext.current = self._old_ctx
class ApplyConfig(DispatchContext):
"""Apply a specific config entity during query.
Parameters
----------
config : ConfigSpace or ConfigEntity
The specific configuration we care about.
"""
def __init__(self, config):
super(ApplyConfig, self).__init__()
self._config = config
self.workload = None
def query(self, target, workload):
"""Override query"""
self.workload = workload
return self._config
def dispatcher(fworkload):
"""Wrap a workload dispatcher function.
Parameters
----------
fworkload : function
The workload extraction function from arguments.
Returns
-------
fdispatcher : function
A wrapped dispatcher function, which will
dispatch based on DispatchContext and
the current workload.
"""
dispatch_dict = {}
func_name = fworkload.__name__
def register(key, func=None, override=False):
"""Register template function.
Parameters
----------
key : str or List of str
The template key to identify the template
under this dispatcher.
func : function
The function to be registered.
The first argument of the function is always
cfg returned by DispatchContext,
the rest arguments are the same as the fworkload.
override : bool
Whether override existing registration.
Returns
-------
The register function if necessary.
"""
if isinstance(key, str):
key = [key]
def _do_reg(myf):
for x in key:
if x in dispatch_dict and not override:
raise ValueError(
"Key %s is already registered for %s" % (x, func_name))
dispatch_dict[x] = myf
return myf
if func:
return _do_reg(func)
return _do_reg
def dispatch_func(func, *args, **kwargs):
"""The wrapped dispatch function"""
tgt = _target.current_target()
context = DispatchContext.current
if context is None:
raise RuntimeError("DispatchContext is not initialized")
workload = func(*args, **kwargs)
cfg = context.query(tgt, workload)
return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)
fdecorate = decorate(fworkload, dispatch_func)
fdecorate.register = register
return fdecorate
"""
A tuner takes a task as input. It proposes some promising :any:`ConfigEntity`
in the :any:`ConfigSpace` and measure them on the real hardware. Then it
proposed the next batch of :any:`ConfigEntity` according to the measure results.
This tuning loop is repeated.
"""
from . import callback
from .tuner import Tuner
from .gridsearch_tuner import GridSearchTuner, RandomTuner
from .ga_tuner import GATuner
from .xgboost_tuner import XGBTuner
# pylint: disable=consider-using-enumerate,invalid-name
"""Namespace of callback utilities of AutoTVM"""
import numpy as np
from .. import record
def log_to_file(file_out, protocol='json'):
"""Log the tuning records into file.
The rows of the log are stored in the format of autotvm.record.encode.
Parameters
----------
file_out : File or str
The file to log to.
protocol: str, optional
The log protocol. Can be 'json' or 'pickle'
Returns
-------
callback : callable
Callback function to do the logging.
"""
def _callback(_, inputs, results):
"""Callback implementation"""
if isinstance(file_out, str):
with open(file_out, "a") as f:
for inp, result in zip(inputs, results):
f.write(record.encode(inp, result, protocol) + "\n")
else:
for inp, result in zip(inputs, results):
file_out.write(record.encode(inp, result, protocol) + "\n")
return _callback
def save_tuner_state(prefix, save_every_sample=100):
"""Save the state of tuner
Parameters
----------
prefix : srt
prefix of the filename to store state
save_every_sample: int
save the state every x samples
Returns
-------
callback : function
Callback function to do the auto saving.
"""
def _callback(tuner, inputs, results):
for _, __ in zip(inputs, results):
try:
ct = len(tuner.visited)
except AttributeError:
ct = 0
if ct % save_every_sample == 0:
tuner.save_state(prefix + "_%d.state" % ct)
return _callback
def log_to_redis(host="localhost", port=6379, dbn=11):
"""Record the tuning record to a redis DB.
Parameters
----------
host: str, optional
Host address of redis db
port: int, optional
Port of redis db
dbn: int, optional
which redis db to use, default 11
"""
# import here so only depend on redis when necessary
import redis
red = redis.StrictRedis(host=host, port=port, db=dbn)
def _callback(_, inputs, results):
"""Callback implementation"""
for inp, result in zip(inputs, results):
red.set(inp, result)
return _callback
class Monitor(object):
"""A monitor to collect statistic during tuning"""
def __init__(self):
self.scores = []
self.timestamps = []
def __call__(self, tuner, inputs, results):
for inp, res in zip(inputs, results):
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)
self.scores.append(flops)
else:
self.scores.append(0)
self.timestamps.append(res.timestamp)
def reset(self):
self.scores = []
self.timestamps = []
def trial_scores(self):
"""get scores (currently is flops) of all trials"""
return np.array(self.scores)
def trial_timestamps(self):
"""get wall clock time stamp of all trials"""
return np.array(self.timestamps)
# pylint: disable=consider-using-enumerate,invalid-name,abstract-method
"""Tuner with genetic algorithm"""
import numpy as np
from .tuner import Tuner
from .model_based_tuner import knob2point, point2knob
class GATuner(Tuner):
"""Tuner with genetic algorithm.
This tuner does not have a cost model so it always run measurement on real machines.
This tuner expands the :code:`ConfigEntity` as gene.
Parameters
----------
pop_size: int
number of genes in one generation
elite_num: int
number of elite to keep
mutation_prob: float
probability of mutation of a knob in a gene
"""
def __init__(self, task, pop_size, elite_num=3, mutation_prob=0.1):
super(GATuner, self).__init__(task)
# algorithm configurations
self.pop_size = pop_size
self.elite_num = elite_num
self.mutation_prob = mutation_prob
assert elite_num <= pop_size, "The number of elites must be less than population size"
# space info
self.space = task.config_space
self.dims = [len(x) for x in self.space.space_map.values()]
self.visited = set([])
# current generation
self.genes = []
self.scores = []
self.elites = []
self.elite_scores = []
self.trial_pt = 0
# random initialization
self.pop_size = min(self.pop_size, len(self.space))
for _ in range(self.pop_size):
tmp_gene = point2knob(np.random.randint(len(self.space)), self.dims)
while knob2point(tmp_gene, self.dims) in self.visited:
tmp_gene = point2knob(np.random.randint(len(self.space)), self.dims)
self.genes.append(tmp_gene)
self.visited.add(knob2point(tmp_gene, self.dims))
def next_batch(self, batch_size):
ret = []
for _ in range(batch_size):
gene = self.genes[self.trial_pt % self.pop_size]
self.trial_pt += 1
ret.append(self.space.get(knob2point(gene, self.dims)))
return ret
def update(self, inputs, results):
for inp, res in zip(inputs, results):
if res.error_no == 0:
y = inp.task.flop / np.mean(res.costs)
self.scores.append(y)
else:
self.scores.append(0)
if len(self.scores) >= len(self.genes):
genes = self.genes + self.elites
scores = np.array(self.scores[:len(self.genes)] + self.elite_scores)
# reserve elite
self.elites, self.elite_scores = [], []
elite_indexes = np.argpartition(scores, -self.elite_num)[-self.elite_num:]
for ind in elite_indexes:
self.elites.append(genes[ind])
self.elite_scores.append(scores[ind])
# cross over
indices = np.arange(len(genes))
scores /= np.max(scores)
probs = scores / np.sum(scores)
tmp_genes = []
for _ in range(self.pop_size):
p1, p2 = np.random.choice(indices, size=2, replace=False, p=probs)
p1, p2 = genes[p1], genes[p2]
point = np.random.randint(len(self.dims))
tmp_gene = p1[:point] + p2[point:]
tmp_genes.append(tmp_gene)
# mutation
next_genes = []
for tmp_gene in tmp_genes:
for j, dim in enumerate(self.dims):
if np.random.random() < self.mutation_prob:
tmp_gene[j] = np.random.randint(dim)
if len(self.visited) < len(self.space):
while knob2point(tmp_gene, self.dims) in self.visited:
j = np.random.randint(len(self.dims))
tmp_gene[j] = np.random.randint(self.dims[j])
next_genes.append(tmp_gene)
self.visited.add(knob2point(tmp_gene, self.dims))
else:
break
self.genes = next_genes
self.trial_pt = 0
self.scores = []
def has_next(self):
return len(self.visited) - (len(self.genes) - self.trial_pt) < len(self.space)
# pylint: disable=abstract-method
"""Grid search tuner and random tuner"""
import numpy as np
from .tuner import Tuner
class GridSearchTuner(Tuner):
"""Enumerate the search space in a grid search order"""
def __init__(self, task):
super(GridSearchTuner, self).__init__(task)
self.counter = 0
def next_batch(self, batch_size):
ret = []
for _ in range(batch_size):
if self.counter >= len(self.task.config_space):
continue
index = self.counter
ret.append(self.task.config_space.get(index))
self.counter = self.counter + 1
return ret
def has_next(self):
return self.counter < len(self.task.config_space)
def __getstate__(self):
return {"counter": self.counter}
def __setstate__(self, state):
self.counter = state['counter']
class RandomTuner(Tuner):
"""Enumerate the search space in a random order"""
def __init__(self, task):
super(RandomTuner, self).__init__(task)
self.visited = set()
def next_batch(self, batch_size):
ret = []
counter = 0
while counter < batch_size:
if len(self.visited) >= len(self.task.config_space):
break
index = np.random.randint(len(self.task.config_space))
while index in self.visited:
index = np.random.randint(len(self.task.config_space))
ret.append(self.task.config_space.get(index))
self.visited.add(index)
counter += 1
return ret
def has_next(self):
return len(self.visited) < len(self.task.config_space)
def __getstate__(self):
return {"visited": self.counter}
def __setstate__(self, state):
self.counter = state['visited']
# pylint: disable=invalid-name
"""Metrics for evaluating tuning process"""
import numpy as np
from ..util import get_rank
def max_curve(trial_scores):
""" f(n) = max([s[i] fo i < n])
Parameters
----------
trial_scores: Array of float
the score of i th trial
Returns
-------
curve: Array of float
function values
"""
ret = np.empty(len(trial_scores))
keep = -1e9
for i, score in enumerate(trial_scores):
keep = max(keep, score)
ret[i] = keep
return ret
def mean_curve(trial_scores):
""" f(n) = mean([s[i] fo i < n])
Parameters
----------
trial_scores: Array of float
the score of i th trial
Returns
-------
curve: Array of float
function values
"""
ret = np.empty(len(trial_scores))
keep = 0
for i, score in enumerate(trial_scores):
keep += score
ret[i] = keep / (i+1)
return ret
def recall_curve(trial_ranks, top=None):
"""
if top is None, f(n) = sum([I(rank[i] < n) for i < n]) / n
if top is K, f(n) = sum([I(rank[i] < K) for i < n]) / K
Parameters
----------
trial_ranks: Array of int
the rank of i th trial in labels
top: int or None
top-n recall
Returns
-------
curve: Array of float
function values
"""
if not isinstance(trial_ranks, np.ndarray):
trial_ranks = np.array(trial_ranks)
ret = np.zeros(len(trial_ranks))
if top is None:
for i in range(len(trial_ranks)):
ret[i] = np.sum(trial_ranks[:i] <= i) / (i+1)
else:
for i in range(len(trial_ranks)):
ret[i] = 1.0 * np.sum(trial_ranks[:i] < top) / top
return ret
def cover_curve(trial_ranks):
"""
f(n) = max k s.t. {1,2,...,k} is a subset of {ranks[i] for i < n}
Parameters
----------
trial_ranks: Array of int
the rank of i th trial in labels
Returns
-------
curve: Array of float
function values
"""
ret = np.empty(len(trial_ranks))
keep = -1
cover = set()
for i, rank in enumerate(trial_ranks):
cover.add(rank)
while keep+1 in cover:
keep += 1
ret[i] = keep + 1
return ret / len(trial_ranks)
def average_recall(preds, labels, N):
"""evaluate average recall-n for predictions and labels"""
trials = np.argsort(preds)[::-1]
ranks = get_rank(labels[trials])
curve = recall_curve(ranks)
return np.sum(curve[:N]) / N
# pylint: disable=consider-using-enumerate
"""
Cost model optimizer based on simulated annealing
"""
import heapq
import logging
import time
import numpy as np
from ..util import sample_ints
from .model_based_tuner import ModelOptimizer, knob2point, point2knob
class SimulatedAnnealingOptimizer(ModelOptimizer):
"""parallel simulated annealing optimization algorithm
Parameters
----------
task: Task
The tuning task
n_iter: int
The number of iterations of simulated annealing
temp: float or Array of float
If is a single float, then use a constant temperature.
If is an Array, then perform linear cooling from temp[0] to temp[1]
early_stop: int, optional
Stop iteration if the optimal set do not change in `early_stop` rounds
verbose: int, optional
Print log every `verbose` iterations
"""
def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128,
early_stop=30, verbose=50):
super(SimulatedAnnealingOptimizer, self).__init__()
self.task = task
self.dims = [len(x) for x in self.task.config_space.space_map.values()]
self.n_iter = n_iter
self.temp = temp
self.persistent = persistent
self.parallel_size = parallel_size
self.early_stop = early_stop
self.verbose = verbose
self.points = None
def find_maximums(self, model, num, exclusive):
tic = time.time()
temp, n_iter, early_stop, verbose = self.temp, self.n_iter, self.early_stop, self.verbose
if self.persistent and self.points is not None:
points = self.points
else:
points = np.array(sample_ints(0, len(self.task.config_space), self.parallel_size))
scores = model.predict(points)
# build heap and insert initial points
heap_items = [(float('-inf'), -i) for i in range(num)]
heapq.heapify(heap_items)
in_heap = set(exclusive)
in_heap.update([-i for i in range(num)])
for s, p in zip(scores, points):
if s > heap_items[0][0] and p not in in_heap:
pop = heapq.heapreplace(heap_items, (s, p))
in_heap.remove(pop[1])
in_heap.add(p)
k = 0
k_last_modify = 0
if isinstance(temp, (tuple, list, np.ndarray)):
t = temp[0]
cool = 1.0 * (temp[0] - temp[1]) / (n_iter + 1)
else:
t = temp
cool = 0
while k < n_iter and k < k_last_modify + early_stop:
new_points = np.empty_like(points)
for i, p in enumerate(points):
new_points[i] = random_walk(p, self.dims)
new_scores = model.predict(new_points)
ac_prob = np.exp((new_scores - scores) / t)
ac_index = np.random.random(len(ac_prob)) < ac_prob
points[ac_index] = new_points[ac_index]
scores[ac_index] = new_scores[ac_index]
for s, p in zip(new_scores, new_points):
if s > heap_items[0][0] and p not in in_heap:
pop = heapq.heapreplace(heap_items, (s, p))
in_heap.remove(pop[1])
in_heap.add(p)
k_last_modify = k
k += 1
t -= cool
if verbose >= 1 and k % verbose == 0:
t_str = "%.2f" % t
logging.info("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
"elapsed: %.2f",
k, k_last_modify, heap_items[0][0],
np.max([v for v, _ in heap_items]), t_str,
time.time() - tic)
heap_items.sort(key=lambda item: -item[0])
if verbose:
logging.info("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic)
logging.info("SA Maximums: %s", heap_items)
if self.persistent:
self.points = points
return [x[1] for x in heap_items]
def random_walk(p, dims):
"""random walk as local transition
Parameters
----------
p: int
index of the ConfigEntity
dims: Array of int
sizes of each dimension
Returns
-------
new_p: int
new neighborhood index
"""
# transform to knob form
old = point2knob(p, dims)
new = list(old)
# mutate
while new == old:
from_i = np.random.randint(len(old))
to_v = np.random.randint(dims[from_i])
new[from_i] = to_v
# transform to index form
return knob2point(new, dims)
# pylint: disable=unused-argument, no-self-use, invalid-name
"""Base class of tuner"""
import logging
import numpy as np
from ..measure import MeasureInput
from ..measure import create_measure_batch
class Tuner(object):
"""Base class for tuners
Parameters
----------
task: autotvm.task.Task
Tuning Task
"""
def __init__(self, task, **kwargs):
self.param = kwargs
self.recorder = None
self.task = task
# keep the current best
self.best_config = None
self.best_flops = 0
self.best_measure_pair = None
def has_next(self):
"""Whether has next untried config in the space
Returns
-------
has_next: bool
"""
raise NotImplementedError()
def next_batch(self, batch_size):
"""get the next batch of configs to be measure on real hardware
Parameters
----------
batch_size: int
The size of the batch
Returns
-------
a batch of configs
"""
raise NotImplementedError()
def update(self, inputs, results):
"""Update parameters of the tuner according to measurement results
Parameters
----------
inputs: Array of autotvm.measure.MeasureInput
The input for measurement
results: Array of autotvm.measure.MeasureResult
result for measurement
"""
pass
def tune(self, n_trial, measure_option, verbose=1, callbacks=()):
"""Begin tuning
Parameters
----------
n_trial: int
Maximum number of configs to try (measure on real hardware)
measure_option: dict
The options for how to measure generated code.
You should use the return value ot autotvm.measure_option for this argument.
verbose: int
0: silent mode, no output
1: print every measurement result
callbacks: List of callable
A list of callback functions. The signature of callback function is
(Tuner, List of MeasureInput, List of MeasureResult)
with no return value. These callback functions will be called on
every measurement pair. See autotvm/tuner/callback.py for some examples.
"""
measure_batch = create_measure_batch(self.task, measure_option)
parallel_num = getattr(measure_batch, 'parallel_num', 1)
i = 0
while i < n_trial:
if not self.has_next():
break
configs = self.next_batch(min(parallel_num, n_trial - i))
inputs = [MeasureInput(self.task.target, self.task, config) for config in configs]
results = measure_batch(inputs)
# print info
if verbose >= 1:
for k, (inp, res) in enumerate(zip(inputs, results)):
config = inp.config
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)
else:
flops = 0
if flops > self.best_flops:
self.best_flops = flops
self.best_config = config
self.best_measure_pair = (inp, res)
logging.info("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9,
res, config)
i += len(results)
self.update(inputs, results)
for callback in callbacks:
callback(self, inputs, results)
del measure_batch
def reset(self):
"""reset the status of tuner"""
self.best_config = None
self.best_flops = 0
self.best_measure_pair = None
def load_history(self, data_set):
"""load history data for transfer learning
Parameters
----------
data_set: Array of (MeasureInput, MeasureResult) pair
Previous tuning records
"""
raise NotImplementedError()
"""Tuner that uses xgboost as cost model"""
from .model_based_tuner import ModelBasedTuner, ModelOptimizer
from .xgboost_cost_model import XGBoostCostModel
from .sa_model_optimizer import SimulatedAnnealingOptimizer
class XGBTuner(ModelBasedTuner):
"""Tuner that uses xgboost as cost model
Parameters
----------
task: Task
The tuning task
plan_size: int
The size of a plan. After `plan_size` trials, the tuner will refit a new cost model
and do planing for the next `plan_size` trials.
feature_type: str, optional
If is 'itervar', use features extracted from IterVar (loop variable).
If is 'knob', use flatten ConfigEntity directly.
If is 'curve', use sampled curve feature (relation feature).
Note on choosing feature type:
For single task tuning, 'itervar' and 'knob' is good.
'itervar' is more accurate but 'knob' is much faster.
For cross-shape tuning (e.g. many convolutions with different shapes),
'itervar' and 'curve' has better transferability,
'knob' is faster.
For cross-device or cross-operator tuning, you can use 'curve' only.
loss_type: str
If is 'reg', use regression loss to train cost model.
The cost model predicts the normalized flops.
If is 'rank', use pairwise rank loss to train cost model.
The cost model predicts relative rank score.
num_threads: int, optional
The number of threads.
optimizer: str or ModelOptimizer, optional
If is 'sa', use a default simulated annealing optimizer.
Otherwise it should be a ModelOptimizer object.
diversity_filter_ratio: int or float, optional
If is not None, the tuner will first select
top-(plan_size * diversity_filter_ratio) candidates according to the cost model
and then pick batch_size of them according to the diversity metric.
"""
def __init__(self, task, plan_size=32,
feature_type='itervar', loss_type='rank', num_threads=None,
optimizer='sa', diversity_filter_ratio=None):
cost_model = XGBoostCostModel(task,
feature_type=feature_type,
loss_type=loss_type,
num_threads=num_threads)
if optimizer == 'sa':
optimizer = SimulatedAnnealingOptimizer(task)
else:
assert isinstance(optimizer, ModelOptimizer), "Optimizer must be " \
"a supported name string" \
"or a ModelOptimizer object."
super(XGBTuner, self).__init__(task, cost_model, optimizer,
plan_size, diversity_filter_ratio)
# pylint: disable=invalid-name
"""Utilities"""
import logging
import multiprocessing
import time
import numpy as np
from .. import expr, ir_pass
def get_rank(values):
"""get rank of items
Parameters
----------
values: Array
Returns
-------
ranks: Array of int
the rank of this item in the input (the largest value ranks first)
"""
tmp = np.argsort(-values)
ranks = np.empty_like(tmp)
ranks[tmp] = np.arange(len(tmp))
return ranks
def sample_ints(low, high, m):
"""
Sample m different integer numbers from [low, high) without replacement
This function is an alternative of `np.random.choice` when (high - low) > 2 ^ 32, in
which case numpy does not work.
Parameters
----------
low: int
low point of sample range
high: int
high point of sample range
m: int
The number of sampled int
Returns
-------
ints: an array of size m
"""
vis = set()
assert m <= high - low
while len(vis) < m:
new = np.random.randint(low, high)
while new in vis:
new = np.random.randint(low, high)
vis.add(new)
return list(vis)
def pool_map(func, args, batch_size, verbose=False, pool=None):
"""A wrapper of multiprocessing.pool.Pool.map to support small-batch mapping
for large argument list. This can reduce memory usage
Parameters
----------
func: Func(arg) -> np.ndarray
mapping function
args: List
list of arguments
batch_size: int
batch size in mapping
verbose: bool, optional
whether print progress
pool: multiprocessing.Pool, optional
pool objection
Returns
-------
converted numpy array
"""
ret = None
tic = time.time()
local_pool = pool or multiprocessing.Pool()
if verbose:
logging.info("mapping begin")
for i in range(0, len(args), batch_size):
if verbose:
logging.info("mapping %d/%d elapsed %.2f", i, len(args),
time.time() - tic)
tmp = np.array(local_pool.map(func, args[i:i+batch_size]))
ret = tmp if ret is None else np.concatenate((ret, tmp))
if verbose:
logging.info("mapping done")
if not pool:
local_pool.close()
return ret
def get_func_name(func):
"""Get name of a function
Parameters
----------
func: Function
The function
Returns
-------
name: str
The name
"""
return func.func_name if hasattr(func, 'func_name') else func.__name__
def get_const_int(exp):
"""Verifies expr is integer and get the constant value.
Parameters
----------
exp : tvm.Expr or int
The input expression.
Returns
-------
out_value : int
The output.
"""
if isinstance(exp, int):
return exp
if not isinstance(exp, (expr.IntImm, expr.UIntImm)):
exp = ir_pass.Simplify(expr)
if not isinstance(exp, (expr.IntImm, expr.UIntImm)):
raise ValueError("Expect value to be constant int")
return exp.value
def get_const_tuple(in_tuple):
"""Verifies input tuple is IntImm, returns tuple of int.
Parameters
----------
in_tuple : tuple of Expr
The input.
Returns
-------
out_tuple : tuple of int
The output.
"""
return tuple(get_const_int(x) for x in in_tuple)
...@@ -229,8 +229,14 @@ class TrackerSession(object): ...@@ -229,8 +229,14 @@ class TrackerSession(object):
res += "----------------------------\n" res += "----------------------------\n"
res += "key\tfree\tpending\n" res += "key\tfree\tpending\n"
res += "----------------------------\n" res += "----------------------------\n"
for k, v in data["queue_info"].items(): queue_info = data['queue_info']
res += "%s\t%d\t%g\n" % (k, v["free"], v["pending"]) keys = list(queue_info.keys())
if keys:
keys.sort()
max_key_len = max([len(k) for k in keys])
for k in keys:
res += ("%%-%d" % max_key_len + "s\t%d\t%g\n") % \
(k, queue_info[k]["free"], queue_info[k]["pending"])
res += "----------------------------\n" res += "----------------------------\n"
return res return res
......
/*!
* Copyright (c) 2018 by Contributors
* \file feature_visitor.cc
* \brief Base class for feature extractor.
* These features are used for machine learning cost model
*/
#include "feature_visitor.h"
namespace tvm {
namespace autotvm {
// for loop
void FeatureVisitor::Visit_(const For *op) {
const auto *extent = op->extent.as<IntImm>();
int64_t loop_extent = -1;
if (extent != nullptr)
loop_extent = extent->value;
AnnotationType ann = kSerial;
switch (op->for_type) {
case ForType ::Parallel:
ann = kParallel;
break;
case ForType::Unrolled:
ann = kUnrolled;
break;
case ForType::Vectorized:
ann = kVectorized;
break;
case ForType::Serial:
ann = kSerial;
break;
}
if (EnterItervar_(op->loop_var, loop_extent, ann)) {
IRVisitor::Visit_(op);
ExitItervar_();
}
}
// parallel axis, virtual thread
void FeatureVisitor::Visit_(const AttrStmt *op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
VarExpr var = op->node.as<tvm::IterVarNode>()->var;
const auto *extent = op->value.as<IntImm>();
CHECK(extent);
std::string name = var.get()->name_hint;
AnnotationType ann = kParallel;
if (op->attr_key == attr::thread_extent) {
if (name == "blockIdx.x")
ann = kBlockX;
else if (name == "blockIdx.y")
ann = kBlockY;
else if (name == "blockIdx.z")
ann = kBlockZ;
else if (name == "threadIdx.x")
ann = kThreadX;
else if (name == "threadIdx.y")
ann = kThreadY;
else if (name == "threadIdx.z")
ann = kThreadZ;
else
LOG(FATAL) << "invalid thread itervar " + name;
} else {
ann = kVirtualThread;
}
if (EnterItervar_(var, extent->value, ann)) {
IRVisitor::Visit_(op);
ExitItervar_();
}
} else {
IRVisitor::Visit_(op);
}
}
// memory access
void FeatureVisitor::Visit_(const Load *op) {
EnterMem_(op->buffer_var, op->index);
IRVisitor::Visit_(op);
ExitMem_();
}
void FeatureVisitor::Visit_(const Store *op) {
EnterMem_(op->buffer_var, op->index);
IRVisitor::Visit_(op);
ExitMem_();
}
} // namespace autotvm
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file feature_visitor.h
* \brief Base class for feature extractor.
* These features are used for machine learning cost model
*/
#ifndef TVM_AUTOTVM_FEATURE_VISITOR_H_
#define TVM_AUTOTVM_FEATURE_VISITOR_H_
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <string>
namespace tvm {
namespace autotvm {
using namespace tvm::ir;
/*!
* \brief Type of for loop, used as one-hot encoding in features
*/
enum AnnotationType {
kBlockX, kBlockY, kBlockZ, kThreadX, kThreadY, kThreadZ,
kUnrolled, kVectorized, kParallel, kSerial, kVirtualThread,
kNum,
};
/*!
* \brief A base class for feature extractor, used for processing
* for loop and memory access in the IR
*/
class FeatureVisitor : public IRVisitor {
public:
// for loop
void Visit_(const For *op);
void Visit_(const AttrStmt *op);
// memory access
void Visit_(const Load *op);
void Visit_(const Store *op);
protected:
/*!
* \brief Enter a for loop node
* \param var The expression to be printed.
* \param length The output stream
* \param ann_type The type for the for loop
* \return skip Whether skip this node
*/
virtual bool EnterItervar_(tvm::VarExpr var, int64_t length, AnnotationType ann_type) = 0;
/*! \brief Exit a for loop subtree */
virtual void ExitItervar_() = 0;
/*!
* \brief Enter a memory access node
* \param buffer_var The buffer to access.
* \param index Index expression
*/
virtual void EnterMem_(tvm::VarExpr buffer_var, tvm::Expr index) = 0;
/*! \brief Exit a memory access node */
virtual void ExitMem_() = 0;
};
} // namespace autotvm
} // namespace tvm
#endif // TVM_AUTOTVM_FEATURE_VISITOR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file touch_extractor.h
* \brief Extract feature of touch pattern of axes in lowered IR
*/
#ifndef TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
#define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h>
#include <stack>
#include <vector>
#include <map>
#include <string>
#include <deque>
#include "feature_visitor.h"
namespace tvm {
namespace autotvm {
using TouchedBuffer = std::string;
// touch pattern buf[(stride * var) % mod) + other]
struct TouchPattern {
int64_t stride{0};
int64_t mod{-1}; // -1 for +inf
int64_t count{1};
int64_t reuse{1};
int64_t thread_count{0}; // count when move thread axis into innermost
int64_t thread_reuse{0}; // reuse ratio move thread axis into innermost
};
// all the feature of an iter var
struct ItervarFeature {
ItervarFeature(VarExpr var,
int64_t extent,
int nest,
AnnotationType ann_type,
int64_t topdown,
int counter)
: length(extent), nest_level(nest), ann(ann_type), topdown_product(topdown), order(counter) {}
ItervarFeature() {}
// Axis Attributes
int64_t length;
int nest_level;
AnnotationType ann; // one-hot axis type
int64_t topdown_product; // accumulative product of axis length, in top-down order
int64_t bottomup_product; // accumulative product of axis length, in bottom-up order
// bottomup_product = reuse * count for any touched buffer
int order; // used for soring axis
// Arithmetic feature
int add_ct{0};
int mul_ct{0};
int div_ct{0};
// Memory Touch Feature
std::unordered_map<TouchedBuffer, TouchPattern> touch_feature;
};
// extract iter vars and their touch pattern from ir
class TouchExtractor : public FeatureVisitor {
public:
void Analyze(Stmt stmt) {
this->Visit(stmt);
}
// arithmetic stats
void Visit_(const Add *op) {
if (op->type.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
IRVisitor::Visit_(op);
}
void Visit_(const Sub *op) {
if (op->type.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
IRVisitor::Visit_(op);
}
void Visit_(const Mul *op) {
if (op->type.is_float())
itervar_map[itervar_stack_.back()].mul_ct++;
IRVisitor::Visit_(op);
}
void Visit_(const Div *op) {
if (op->type.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
IRVisitor::Visit_(op);
}
void Visit_(const Mod *op) {
if (op->type.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
IRVisitor::Visit_(op);
}
std::unordered_map<VarExpr, ItervarFeature, tvm::ExprHash, tvm::ExprEqual> itervar_map;
private:
bool EnterItervar_(VarExpr var, int64_t length, AnnotationType ann_type);
void ExitItervar_();
void EnterMem_(VarExpr buffer_var, Expr index);
void ExitMem_();
int64_t topdown_product_{1};
std::map<std::string, size_t> buffer_counter_;
size_t itervar_counter_{0};
std::deque<VarExpr> itervar_stack_; // use deque instead of stack for indexing
std::deque<size_t> skip_stack_size_;
using IRVisitor::Visit_;
};
} // namespace autotvm
} // namespace tvm
#endif // TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
...@@ -73,7 +73,7 @@ Target CreateTarget(const std::string& target_name, ...@@ -73,7 +73,7 @@ Target CreateTarget(const std::string& target_name,
} else { } else {
t->device_type = kDLROCM; t->device_type = kDLROCM;
} }
t->keys_array.push_back(ir::StringImm::make("rocm")); t->keys_array.push_back(ir::StringImm::make(target_name));
t->keys_array.push_back(ir::StringImm::make("gpu")); t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 256; t->max_num_threads = 256;
if (t->device_name == "intel_graphics") { if (t->device_name == "intel_graphics") {
...@@ -195,11 +195,7 @@ Target Target::create(const std::string& target_str) { ...@@ -195,11 +195,7 @@ Target Target::create(const std::string& target_str) {
options.push_back(item); options.push_back(item);
} }
if (device_name == "rasp") { return CreateTarget(target_name, options);
return target::rasp(options);
} else {
return CreateTarget(target_name, options);
}
} }
/*! \brief Entry to hold the Target context stack. */ /*! \brief Entry to hold the Target context stack. */
......
...@@ -18,13 +18,13 @@ class GPUCodeVerifier : public IRVisitor { ...@@ -18,13 +18,13 @@ class GPUCodeVerifier : public IRVisitor {
bool Verify(tvm::Stmt stmt, bool Verify(tvm::Stmt stmt,
int64_t max_local_memory_per_block, int64_t max_local_memory_per_block,
int64_t max_shared_memory_per_block, int64_t max_shared_memory_per_block,
int64_t max_thread_per_block, int64_t max_threads_per_block,
int64_t max_thread_x, int64_t max_thread_x,
int64_t max_thread_y, int64_t max_thread_y,
int64_t max_thread_z) { int64_t max_thread_z) {
max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block); max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block); max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
max_thread_per_block_ = static_cast<size_t>(max_thread_per_block); max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
max_thread_x_ = static_cast<size_t>(max_thread_x); max_thread_x_ = static_cast<size_t>(max_thread_x);
max_thread_y_ = static_cast<size_t>(max_thread_y); max_thread_y_ = static_cast<size_t>(max_thread_y);
max_thread_z_ = static_cast<size_t>(max_thread_z); max_thread_z_ = static_cast<size_t>(max_thread_z);
...@@ -52,7 +52,7 @@ class GPUCodeVerifier : public IRVisitor { ...@@ -52,7 +52,7 @@ class GPUCodeVerifier : public IRVisitor {
if (nest_level_ == 0) { if (nest_level_ == 0) {
// exit a kernel, check the validity // exit a kernel, check the validity
valid_ &= thread_per_block_ <= max_thread_per_block_; valid_ &= thread_per_block_ <= max_threads_per_block_;
valid_ &= local_memory_per_block_ <= max_local_memory_per_block_; valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_; valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
...@@ -117,7 +117,7 @@ class GPUCodeVerifier : public IRVisitor { ...@@ -117,7 +117,7 @@ class GPUCodeVerifier : public IRVisitor {
size_t max_local_memory_per_block_; size_t max_local_memory_per_block_;
size_t max_shared_memory_per_block_; size_t max_shared_memory_per_block_;
size_t max_thread_per_block_; size_t max_threads_per_block_;
size_t max_thread_x_, max_thread_y_, max_thread_z_; size_t max_thread_x_, max_thread_y_, max_thread_z_;
bool valid_{true}; bool valid_{true};
...@@ -137,26 +137,34 @@ bool VerifyGPUCode(Stmt stmt, ...@@ -137,26 +137,34 @@ bool VerifyGPUCode(Stmt stmt,
Map<std::string, Expr> constraints) { Map<std::string, Expr> constraints) {
GPUCodeVerifier verifier; GPUCodeVerifier verifier;
auto get_int = [&constraints](std::string key, int64_t def) { int64_t max_local_memory_per_block = INT64_MAX;
auto iter = constraints.find(key); int64_t max_shared_memory_per_block = INT64_MAX;
if (iter != constraints.end()) { int64_t max_threads_per_block = INT64_MAX;
return ((*iter).second).as<IntImm>()->value; int64_t max_thread_x = INT64_MAX;
} else { int64_t max_thread_y = INT64_MAX;
return def; int64_t max_thread_z = INT64_MAX;
}
}; for (auto iter : constraints) {
if (iter.first == "max_local_memory_per_block")
int64_t max_local_memory_per_block = get_int("max_local_memory_per_block", INT64_MAX); max_local_memory_per_block = (iter.second).as<IntImm>()->value;
int64_t max_shared_memory_per_block = get_int("max_shared_memory_per_block", INT64_MAX); else if (iter.first == "max_shared_memory_per_block")
int64_t max_thread_per_block = get_int("max_thread_per_block", INT64_MAX); max_shared_memory_per_block = (iter.second).as<IntImm>()->value;
int64_t max_thread_x = get_int("max_thread_x", INT64_MAX); else if (iter.first == "max_threads_per_block")
int64_t max_thread_y = get_int("max_thread_y", INT64_MAX); max_threads_per_block = (iter.second).as<IntImm>()->value;
int64_t max_thread_z = get_int("max_thread_z", INT64_MAX); else if (iter.first == "max_thread_x")
max_thread_x = (iter.second).as<IntImm>()->value;
else if (iter.first == "max_thread_y")
max_thread_y = (iter.second).as<IntImm>()->value;
else if (iter.first == "max_thread_z")
max_thread_z = (iter.second).as<IntImm>()->value;
else
LOG(FATAL) << "Invalid check item: " << iter.first;
}
return verifier.Verify(stmt, return verifier.Verify(stmt,
max_local_memory_per_block, max_local_memory_per_block,
max_shared_memory_per_block, max_shared_memory_per_block,
max_thread_per_block, max_threads_per_block,
max_thread_x, max_thread_x,
max_thread_y, max_thread_y,
max_thread_z); max_thread_z);
......
"""
Test the tuner
"""
import logging
import time
import tvm
from tvm import autotvm
from tvm.autotvm.tuner import RandomTuner
@autotvm.template
def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
"""An example template for testing"""
assert N == 1, "Only consider batch_size = 1 in this template"
data = tvm.placeholder((N, CI, H, W), name='data')
kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
rc = tvm.reduce_axis((0, CI), name='rc')
ry = tvm.reduce_axis((0, KH), name='ry')
rx = tvm.reduce_axis((0, KW), name='rx')
conv = tvm.compute(
(N, CO, H - KH + 1, W - KW + 1),
lambda nn, ff, yy, xx: tvm.sum(
data[nn, rc, yy + ry, xx + rx] * kernel[ff, rc, ry, rx],
axis=[rc, ry, rx]), tag="conv2d_nchw")
s = tvm.create_schedule([conv.op])
output = conv
OL = s.cache_write(conv, 'local')
# create cache stage
AA = s.cache_read(data, 'shared', [OL])
WW = s.cache_read(kernel, 'shared', [OL])
AL = s.cache_read(AA, 'local', [OL])
WL = s.cache_read(WW, 'local', [OL])
# tile and bind spatial axes
n, f, y, x = s[output].op.axis
cfg = autotvm.get_config()
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
kernel_scope = n # this is the scope to attach global config inside this kernel
s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
s[output].bind(by, tvm.thread_axis("blockIdx.y"))
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
s[output].bind(vf, tvm.thread_axis("vthread"))
s[output].bind(vy, tvm.thread_axis("vthread"))
s[output].bind(vx, tvm.thread_axis("vthread"))
s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
s[OL].compute_at(s[output], tx)
# tile and bind reduction axes
n, f, y, x = s[OL].op.axis
rc, ry, rx = s[OL].op.reduce_axis
cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=3)
cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=3)
rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
ryo, rym, ryi = cfg['tile_rx'].apply(s, OL, ry)
rxo, rxm, rxi = cfg['tile_ry'].apply(s, OL, rx)
s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)
s[AA].compute_at(s[OL], rxo)
s[WW].compute_at(s[OL], rxo)
s[AL].compute_at(s[OL], rxm)
s[WL].compute_at(s[OL], rxm)
# cooperative fetching
for load in [AA, WW]:
n, f, y, x = s[load].op.axis
fused = s[load].fuse(n, f, y, x)
tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
# tune unroll
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
cfg.define_knob("unroll_explicit", [0, 1])
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
return s, [data, kernel, conv]
def get_sample_task(target=tvm.target.cuda(), target_host=None):
"""return a sample task for testing"""
task = autotvm.task.create(conv2d_no_batching,
args=(1, 7, 7, 512, 512, 3, 3),
target=target, target_host=target_host)
return task, target
def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()
def measure_batch(inputs):
from tvm.autotvm import MeasureResult
results = []
for inp in inputs:
tic = time.time()
# do nothing
time.sleep(0.001)
results.append(MeasureResult([time.time() - tic], 0,
time.time() - tic, time.time()))
return results
measure_option = autotvm.measure_option(mode='custom',
custom_measure_batch=measure_batch)
logging.info("%s", task.config_space)
# new tuner and recorder
for tuner_class in [autotvm.tuner.RandomTuner, autotvm.tuner.GridSearchTuner]:
tuner = tuner_class(task)
tuner.tune(n_trial=10, measure_option=measure_option)
def test_tuning_with_measure():
def check(target, target_host):
ctx = tvm.context(target, 0)
if not ctx.exist:
logging.info("Skip test because %s is not available" % target)
return
# init task
task, target = get_sample_task(target, target_host)
logging.info("%s", task.config_space)
measure_option = autotvm.measure_option(mode='local',
timeout=4,
number=2)
tuner = RandomTuner(task)
tuner.tune(n_trial=10, measure_option=measure_option)
check("cuda", None)
check("opencl", None)
if __name__ == "__main__":
# only print log when invoked from main
logging.basicConfig(level=logging.INFO)
test_task_tuner_without_measurement()
test_tuning_with_measure()
"""Common utilities for testing autotvm"""
import time
import tvm
from tvm import autotvm
from tvm.autotvm import MeasureInput, MeasureResult
@autotvm.template
def matmul(N, L, M, dtype):
A = tvm.placeholder((N, L), name='A', dtype=dtype)
B = tvm.placeholder((L, M), name='B', dtype=dtype)
k = tvm.reduce_axis((0, L), name='k')
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
s = tvm.create_schedule(C.op)
# schedule
y, x = s[C].op.axis
k = s[C].op.reduce_axis[0]
##### define space begin #####
cfg = autotvm.get_config()
cfg.define_split("tile_y", y, num_outputs=2)
cfg.define_split("tile_x", x, num_outputs=2)
##### define space end #####
# schedule according to config
yo, yi = cfg["tile_y"].apply(s, C, y)
xo, xi = cfg["tile_x"].apply(s, C, x)
s[C].reorder(yo, xo, k, yi, xi)
return s, [A, B, C]
def get_sample_task(n=128):
"""return a sample task for testing"""
target = tvm.target.create("llvm")
task = autotvm.task.create(matmul, args=(n, n, n, 'float32'), target=target)
return task, target
def get_sample_records(n):
"""get sample records for testing"""
tsk, target = get_sample_task()
inps, ress = [], []
for i in range(n):
inps.append(MeasureInput(target, tsk, tsk.config_space.get(i)))
ress.append(MeasureResult((i+1,), 0, i, time.time()))
return list(zip(inps, ress))
"""Test database"""
import copy
import logging
import time
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm import database
from tvm.autotvm.measure.measure_methods import HashMismatchError
from tvm.autotvm.record import encode, MeasureInput, MeasureResult
from test_autotvm_common import get_sample_task, get_sample_records
def test_save_load():
logging.info("test basic db load/save ...")
records = get_sample_records(3)
inp1, res1 = records[0]
inp2, res2 = records[1]
inp3, _ = records[2]
_db = database.DummyDatabase()
_db.flush()
_db.save(inp1, res1)
_db.save(inp2, res2)
load1 = _db.load(inp1)
load2 = _db.load(inp2)
load3 = _db.load(inp3)
assert load1 == res1
assert load2 == res2
assert load3 is None
assert load1 != load2
TRIAL_LIMIT = 2
def test_db_filter():
logging.info("test db filter ...")
# Pick a GPU target because there are more likely to be failures/invalid configs
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
batch_size = 2
measure_option = autotvm.measure_option(mode='local-nofork', timeout=2)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
ct = 0
all_inputs = list()
all_results = list()
batches = list()
tuner = autotvm.tuner.RandomTuner(task)
while ct < TRIAL_LIMIT:
inputs = list()
for i in range(batch_size):
cfg = tuner.next_batch(1)[0]
inputs.append((MeasureInput(target, task, cfg)))
all_inputs.append(inputs[-1])
batches.append(inputs)
results = measure_batch(inputs)
all_results += results
ct += 1
del measure_batch
db = database.DummyDatabase()
db.flush()
# First setting, memoize one input at a time, check that each is saved and replayed
measure_option = autotvm.measure_option(mode='local-nofork', timeout=2, replay_db=db)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
for i in range(len(all_inputs)+1):
db.flush()
for j in range(i):
db.save(all_inputs[j], all_results[j])
for k in range(len(batches)):
batch = batches[k]
batch_result = measure_batch(batch)
for l in range(batch_size):
all_idx = k*batch_size + l
assert batch_result[l] is not None
if all_idx < i:
assert encode(batch[l], batch_result[l]) == encode(batch[l], all_results[all_idx]), \
"(no retry) EXPECTED MATCH, GOT MISMATCH"
else:
assert encode(batch[l], batch_result[l]) != encode(batch[l], all_results[all_idx]), \
"(no retry) EXPECTED MISMATCH, GOT MATCH"
del measure_batch
def test_db_hash():
logging.info("test db hash check ...")
inp1, res1 = get_sample_records(1)[0]
inp2 = copy.deepcopy(inp1)
inp1.config.code_hash = 'cafecafe'
inp2.config.code_hash = 'dbffdbff'
res2l = list(tuple(res1))
# set timestamp
res2l[-1] = -1
res2 = MeasureResult(*res2l)
_db = database.DummyDatabase()
_db.flush()
_db.save(inp1, res1, extend=True)
_db.save(inp2, res2, extend=True)
load1 = _db.load(inp1)
load2 = _db.load(inp2)
assert load1 != load2
assert load1.timestamp != -1
assert load2.timestamp == -1
def test_db_latest_all():
logging.info("test db load w/ multiple results ...")
inp1, res1 = get_sample_records(1)[0]
lis1 = list(tuple(res1))
lis2 = list(tuple(res1))
lis3 = list(tuple(res1))
# set timestamp
lis1[-1] = 0.0
lis2[-1] = 1.1
lis3[-1] = 9999.9999
res1 = MeasureResult(*lis1)
res2 = MeasureResult(*lis2)
res3 = MeasureResult(*lis3)
_db = database.DummyDatabase()
_db.flush()
_db.save(inp1, res1, extend=True)
load1 = _db.load(inp1)
assert load1.timestamp == 0.0
_db.save(inp1, res2, extend=True)
load2 = _db.load(inp1)
assert load2.timestamp == 1.1
_db.save(inp1, res3, extend=True)
load3 = _db.load(inp1)
assert load3.timestamp == 9999.9999
load4 = _db.load(inp1, get_all=True)
assert encode(inp1, load4[0]) == encode(inp1, res1)
assert encode(inp1, load4[1]) == encode(inp1, res2)
assert encode(inp1, load4[2]) == encode(inp1, res3)
def test_db_save_replay():
logging.info("test db save (from measure_batch) and replay ...")
_db = database.DummyDatabase()
_db.flush()
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option(mode='local-nofork',
timeout=2,
replay_db=_db, save_to_replay_db=True)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
batch_size = 2
ct = 0
all_inputs = list()
all_results = list()
batches = list()
tuner = autotvm.tuner.RandomTuner(task)
while ct < TRIAL_LIMIT:
inputs = list()
for i in range(batch_size):
cfg = tuner.next_batch(1)[0]
inputs.append((MeasureInput(target, task, cfg)))
all_inputs.append(inputs[-1])
batches.append(inputs)
results = measure_batch(inputs)
all_results += results
ct += 1
assert len(_db.db.keys()) == batch_size * TRIAL_LIMIT, \
"%d vs %d" % (len(_db.db.keys()), batch_size * TRIAL_LIMIT)
all_results_2 = measure_batch(all_inputs)
all_results_3 = measure_batch(all_inputs)
for i in range(len(all_results)):
encr1 = encode(all_inputs[i], all_results[i])
encr2 = encode(all_inputs[i], all_results_2[i])
encr3 = encode(all_inputs[i], all_results_3[i])
assert encr1 == encr2, "EXPECTED MATCH WITH SAVE REPLAY (first replay), got MISMATCH"
assert encr2 == encr3, "EXPECTED MATCH WITH SAVE REPLAY (second replay), got MISMATCH"
del measure_batch
def test_check_hashmismatch():
logging.info("test hash mismatch check")
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option(mode='local-nofork')
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
inputs = list()
cfg = task.config_space.get(np.random.randint(len(task.config_space)))
# notvalidh is not a valid CRC32 hash (not hex)
cfg.code_hash = 'notvalidh'
inputs.append((MeasureInput(target, task, cfg)))
try:
results = measure_batch(inputs)
assert False, "HashMismatchError should be raised"
except HashMismatchError:
pass
del measure_batch
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
test_save_load()
test_db_filter()
test_db_hash()
test_db_latest_all()
test_db_save_replay()
test_check_hashmismatch()
"""Test dispatcher.
The dispatcher can choose which template to use according
to the parameters of workload"""
from collections import namedtuple
from tvm.autotvm.task import dispatcher, DispatchContext
SimpleWorkload = namedtuple("SimpleWorkload", ["key"])
SimpleConfig = namedtuple("SimpleConfig", ["template_key"])
def test_dispatch():
@dispatcher
def my_dispatcher(a, b):
return SimpleWorkload(key=a + b)
@my_dispatcher.register("spatial_pack")
def _sp_pack_add(cfg, a, b):
return b + 100
@my_dispatcher.register("im2col")
def _im2col_add(cfg, a, b):
return a + 1
class SimpleDispatcher(DispatchContext):
def query(self, target, workload):
tkey = "spatial_pack" if workload.key > 2 else "im2col"
return SimpleConfig(tkey)
with SimpleDispatcher():
# im2col
assert my_dispatcher(1, 0) == 2
# spack
assert my_dispatcher(1, 100) == 200
if __name__ == "__main__":
test_dispatch()
"""Test local executor"""
import time
from tvm.autotvm.measure import LocalExecutor, executor
def slow(n):
r = 0
for i in range(0, n+1):
r += i
return r
def fast(n):
return n*(n+1)//2
def test_local_measure_async():
ex = LocalExecutor()
f1 = ex.submit(slow, 9999999)
f2 = ex.submit(fast, 9999999)
t1 = 0
t2 = 0
while True:
if t1 == 0 and f1.done():
t1 = time.time()
if t2 == 0 and f2.done():
t2 = time.time()
if t1 != 0 and t2 != 0:
break
assert t2 < t1, "Expected fast async job to finish first!"
assert f1.get() == f2.get()
def timeout_job(n):
time.sleep(n * 1.5)
def test_timeout():
timeout = 0.5
ex = LocalExecutor(timeout=timeout)
f1 = ex.submit(timeout_job, timeout)
while not f1.done():
pass
res = f1.get()
assert isinstance(res, executor.TimeoutError)
if __name__ == "__main__":
test_local_measure_async()
test_timeout()
"""Test feature extraction"""
import numpy as np
import tvm
from tvm.autotvm import feature
def test_iter_feature_gemm():
N = 128
k = tvm.reduce_axis((0, N), 'k')
A = tvm.placeholder((N, N), name='A')
B = tvm.placeholder((N, N), name='B')
C = tvm.compute(
A.shape,
lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k),
name='C')
s = tvm.create_schedule(C.op)
feas = feature.get_itervar_feature(s, [A, B, C], take_log=False)
expected = [
{
'_attr_': [128, 1, 128, 2097152, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
'A_0': [128, -1, 16384, 128, 0, 0], 'B_0': [0, -1, 16384, 128, 0, 0],
'C_0': [128, -1, 16384, 128, 0, 0], 'C_1': [128, -1, 16384, 128, 0, 0],
},
{
'_attr_': [128, 2, 16384, 16384, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
'A_0': [0, -1, 128, 128, 0, 0], 'B_0': [1, -1, 16384, 1, 0, 0],
'C_0': [1, -1, 128, 128, 0, 0], 'C_1': [1, -1, 128, 128, 0, 0],
},
{
'_attr_': [128, 3, 2097152, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
'A_0': [1, -1, 128, 1, 0, 0], 'B_0': [128, -1, 128, 1, 0, 0],
'C_1': [0, -1, 1, 128, 0, 0], 'C_2': [0, -1, 1, 128, 0, 0],
}
]
for ans, row in zip(expected, feas):
for pair in row:
if pair[0] not in ans:
continue
assert ans[pair[0]] == pair[1:], "%s: %s vs %s" % (pair[0], ans[pair[0]], pair[1:])
def test_feature_shape():
"""test the dimensions of flatten feature are the same"""
N = 1024
n_sample = 100
def get_gemm_feature(target):
k = tvm.reduce_axis((0, N), 'k')
A = tvm.placeholder((N, N), name='A')
B = tvm.placeholder((N, N), name='B')
C = tvm.compute(A.shape, lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k),
name='C')
s = tvm.create_schedule(C.op)
y, x = s[C].op.axis
axes = list(s[C].tile(y, x, 8, 8)) + [k]
perm = np.random.permutation(5)
axes = [axes[x] for x in perm]
s[C].reorder(*axes)
if "gpu" in target.keys:
pick = []
# filter out reduction axis
for i in range(len(perm)):
if perm[i] != 4:
pick.append(axes[i])
s[C].bind(pick[0], tvm.thread_axis("blockIdx.x"))
s[C].bind(pick[1], tvm.thread_axis("vthread"))
s[C].bind(pick[2], tvm.thread_axis("threadIdx.y"))
with target:
feas = feature.get_itervar_feature(s, [A, B, C])
feas = feature.flatten_itervar_feature(feas)
return feas
targets = [
tvm.target.cuda(),
tvm.target.mali(),
tvm.target.rasp(),
]
for target in targets:
dim = len(get_gemm_feature(target))
for i in range(n_sample):
assert dim == len(get_gemm_feature(target)), "dimensions of feature do not match" \
" for different configurations"
if __name__ == "__main__":
test_iter_feature_gemm()
test_feature_shape()
"""Test flop calculation"""
import tvm
import numpy as np
from tvm.autotvm.task.task import compute_flop
def test_conv():
for i in range(5):
N, H, W, CO, CI, KH, KW = [np.random.randint(10, 32) for _ in range(7)]
D = tvm.placeholder((N, CI, H, W))
K = tvm.placeholder((CO, CI, KH, KW))
KH = min(H, KH)
KW = min(W, KW)
ci = tvm.reduce_axis((0, CI))
kh = tvm.reduce_axis((0, KH))
kw = tvm.reduce_axis((0, KW))
OH = (H - KH) + 1
OW = (W - KW) + 1
C = tvm.compute((N, CO, OH, OW), lambda n, co, h, w:
tvm.sum(D[n][ci][h][w] * K[co][ci][h][w], axis=[ci, kh, kw]))
s = tvm.create_schedule([C.op])
assert compute_flop(s) == 2 * N * CO * OH * OW * CI * KH * KW
def test_pack_gemm():
for i in range(5):
N, L, M = [np.random.randint(10, 128) * 4 for _ in range(3)]
A = tvm.placeholder((N, L))
B = tvm.placeholder((M, L))
k = tvm.reduce_axis((0, L))
bn = 4
A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
tvm.sum(A_pack[i, k, ii] * B_pack[j, k, jj], axis=[k]))
C = tvm.compute((N, M), lambda i, j: C_pack[i // bn][j // bn][i % bn][j % bn])
s = tvm.create_schedule([C.op])
assert compute_flop(s) == 2 * N * L * M
def test_outer_dot():
for i in range(5):
N, M = [np.random.randint(10, 128) * 4 for _ in range(2)]
A = tvm.placeholder((N,))
B = tvm.placeholder((M,))
C = tvm.compute((N, M), lambda i, j: A[i] * B[j])
s = tvm.create_schedule([C.op])
assert compute_flop(s) == N * M
def test_move():
"""No float number operation in simple move. So the estimator should raise an error """
N = 1024
A = tvm.placeholder((N,))
C = tvm.compute((N,), lambda i: A[i])
s = tvm.create_schedule([C.op])
try:
compute_flop(s)
assert False
except RuntimeError:
pass
if __name__ == '__main__':
test_conv()
test_pack_gemm()
test_outer_dot()
test_move()
"""test the correctness of dump and load of data log"""
import time
import tvm
from tvm.contrib import util
from tvm import autotvm
from tvm.autotvm.measure import MeasureInput, MeasureResult, MeasureErrorNo
from tvm.autotvm.record import encode, decode, ApplyHistoryBest, measure_str_key
from test_autotvm_common import get_sample_task
def test_load_dump():
task, target = get_sample_task()
inp = MeasureInput(target, task, task.config_space.get(0))
result = MeasureResult((2.0, 2.23, 0.23, 0.123, 0.234, 0.123), MeasureErrorNo.NO_ERROR,
2.3, time.time())
for protocol in ['json', 'pickle']:
row = encode(inp, result, protocol=protocol)
inp_2, result_2 = decode(row, protocol=protocol)
assert measure_str_key(inp) == measure_str_key(inp_2), \
"%s vs %s" % (measure_str_key(inp), measure_str_key(inp_2))
assert result.costs == result_2.costs
assert result.error_no == result_2.error_no
assert result.timestamp == result_2.timestamp
def test_file_io():
temp = util.tempdir()
file_path = temp.relpath("temp.log")
tsk, target = get_sample_task()
inputs = [MeasureInput(target, tsk, tsk.config_space.get(i)) for i in range(0, 10)]
results = [MeasureResult((i, ), 0, 0, 0) for i in range(0, 10)]
with open(file_path, "w") as fo:
cb = autotvm.callback.log_to_file(fo)
cb(None, inputs, results)
ref = zip(inputs, results)
for x, y in zip(ref, autotvm.record.load_from_file(file_path)):
assert x[1] == y[1]
def test_apply_history_best():
tsk, target = get_sample_task()
records = [
(MeasureInput(target, tsk, tsk.config_space.get(0)), MeasureResult((0.1,), 0, 2.3, 0)),
(MeasureInput(target, tsk, tsk.config_space.get(1)), MeasureResult((0.3,), 0, 2.3, 0)),
(MeasureInput(target, tsk, tsk.config_space.get(2)), MeasureResult((0.01,), 0, 2.3, 0)),
(MeasureInput(target, tsk, tsk.config_space.get(4)), MeasureResult((0.4,), 0, 2.3, 0))
]
hist_best = ApplyHistoryBest(records)
x = hist_best.query(target, tsk.workload)
assert str(x) == str(tsk.config_space.get(2))
if __name__ == "__main__":
test_load_dump()
test_apply_history_best()
test_file_io()
"""Test space definition primitives"""
import tvm
from tvm.autotvm.task.space import ConfigSpace
def gemm_func(cfg, N):
A = tvm.placeholder((N, N), name='A')
B = tvm.placeholder((N, N), name='B')
k = tvm.reduce_axis((0, N), name='k')
C = tvm.compute((N, N), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=[k]), name='C')
s = tvm.create_schedule([C.op])
y, x = s[C].op.axis
cfg.define_split('tile_y', cfg.axis(y), num_outputs=2)
cfg.define_split('tile_x', cfg.axis(x), num_outputs=2)
return s, [A, B, C]
def test_split():
cfg = ConfigSpace()
gemm_func(cfg, 128)
assert len(cfg) == 64
assert len(cfg.space_map['tile_y']) == 8
if __name__ == '__main__':
test_split()
...@@ -31,14 +31,14 @@ def test_shared_memory(): ...@@ -31,14 +31,14 @@ def test_shared_memory():
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=4 * M - 1, max_shared_memory_per_block=4 * M - 1,
max_thread_per_block=M))]}): max_threads_per_block=M))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert not valid[0] assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=4 * M, max_shared_memory_per_block=4 * M,
max_thread_per_block=M))]}): max_threads_per_block=M))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert valid[0] assert valid[0]
...@@ -66,14 +66,14 @@ def test_local_memory(): ...@@ -66,14 +66,14 @@ def test_local_memory():
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_local_memory_per_block=4 * M - 1, max_local_memory_per_block=4 * M - 1,
max_thread_per_block=1))]}): max_threads_per_block=1))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert not valid[0] assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_local_memory_per_block=4 * M, max_local_memory_per_block=4 * M,
max_thread_per_block=1))]}): max_threads_per_block=1))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert valid[0] assert valid[0]
...@@ -101,21 +101,21 @@ def test_num_thread(): ...@@ -101,21 +101,21 @@ def test_num_thread():
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=0, max_shared_memory_per_block=0,
max_thread_per_block=N - 1))]}): max_threads_per_block=N - 1))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert not valid[0] assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=0, max_shared_memory_per_block=0,
max_thread_per_block=N))]}): max_threads_per_block=N))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert valid[0] assert valid[0]
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=0, max_shared_memory_per_block=0,
max_thread_per_block=N, max_threads_per_block=N,
max_thread_y=M-1))]}): max_thread_y=M-1))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert not valid[0] assert not valid[0]
...@@ -123,7 +123,7 @@ def test_num_thread(): ...@@ -123,7 +123,7 @@ def test_num_thread():
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=0, max_shared_memory_per_block=0,
max_thread_per_block=N, max_threads_per_block=N,
max_thread_y=M))]}): max_thread_y=M))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert valid[0] assert valid[0]
...@@ -151,14 +151,14 @@ def test_multiple_kernels(): ...@@ -151,14 +151,14 @@ def test_multiple_kernels():
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=0, max_shared_memory_per_block=0,
max_thread_per_block=N - 1))]}): max_threads_per_block=N - 1))]}):
tvm.build(s, [A, C], target) tvm.build(s, [A, C], target)
assert not valid[0] assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=0, max_shared_memory_per_block=0,
max_thread_per_block=N))]}): max_threads_per_block=N))]}):
tvm.build(s, [A, C], target) tvm.build(s, [A, C], target)
assert valid[0] assert valid[0]
......
...@@ -8,7 +8,7 @@ make doc ...@@ -8,7 +8,7 @@ make doc
jsdoc web/tvm_runtime.js web/README.md || exit -1 jsdoc web/tvm_runtime.js web/README.md || exit -1
mv out docs/_build/html/jsdoc || exit -1 mv out docs/_build/html/jsdoc || exit -1
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
cd docs cd docs
PYTHONPATH=`pwd`/../python make html || exit -1 PYTHONPATH=`pwd`/../python make html || exit -1
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
export PYTHONPATH=python:apps/extension/python export PYTHONPATH=python:apps/extension/python
export LD_LIBRARY_PATH=build:${LD_LIBRARY_PATH} export LD_LIBRARY_PATH=build:${LD_LIBRARY_PATH}
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
# Test TVM # Test TVM
make cython || exit -1 make cython || exit -1
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
export PYTHONPATH=python:topi/python export PYTHONPATH=python:topi/python
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
TVM_FFI=ctypes python -m nose -v tests/python/unittest || exit -1 TVM_FFI=ctypes python -m nose -v tests/python/unittest || exit -1
TVM_FFI=ctypes python3 -m nose -v tests/python/unittest || exit -1 TVM_FFI=ctypes python3 -m nose -v tests/python/unittest || exit -1
......
Auto tuning
-------------
"""
How to get high performance convolution kernel on NVIDIA GPU by auto-tuning
=========================================================================
**Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_
This is an advanced tutorial for writing high performance tunable template for
NVIDIA GPU. By running auto-tuner on this template, we can outperform the
vendor provided library CuDNN in many cases.
"""
import logging
import sys
import tvm
import topi
from tvm import autotvm
######################################################################
# Step 1: Define the search space
# ---------------------------------
# There are plenty of useful schedule primitives in tvm. You can also find
# some tutorials that describe them in more details, such as
# (1). :doc:``Optimizing Conv2d on NVIDIA GPU <../optimize/opt_conv_cuda>`
# (2). `Optimizing DepthwiseConv on NVIDIA GPU <https://tvm.ai/2017/08/22/Optimize-Deep-Learning-GPU-Operators-with-TVM-A-Depthwise-Convolution-Example.html>`_
#
# However, their implementations are manually tuned for some special input
# shapes. In this section, we build a large enough space to cover
# the techniques used in these tutorials. Then we rely on the efficient auto-tuner
# to search through this space and pick some good configurations.
#
# If you are familiar with writing cuda schedule, you can find the following
# template is very general. Actually this template can be easily modified
# to tune other operators such as depthwise convolution and gemm.
# In order to fully understand this template, you should be familiar with
# the schedule primitives and auto tuning API. You can refer to the above
# tutorials and :doc:`autotvm tutorial <tune_simple_template>`
#
# It is worth noting that the search space for a conv2d operator
# can be very large (at the level of 10^9 for some input shapes)
#
@autotvm.template
def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
assert N == 1, "Only consider batch_size = 1 in this template"
data = tvm.placeholder((N, CI, H, W), name='data')
kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, 'float32')
s = tvm.create_schedule([conv.op])
# inline padding
pad_data = s[conv].op.input_tensors[0]
s[pad_data].compute_inline()
data, raw_data = pad_data, data
output = conv
OL = s.cache_write(conv, 'local')
# create cache stage
AA = s.cache_read(data, 'shared', [OL])
WW = s.cache_read(kernel, 'shared', [OL])
AL = s.cache_read(AA, 'local', [OL])
WL = s.cache_read(WW, 'local', [OL])
# tile and bind spatial axes
n, f, y, x = s[output].op.axis
cfg = autotvm.get_config()
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
kernel_scope = n # this is the scope to attach global config inside this kernel
s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
s[output].bind(by, tvm.thread_axis("blockIdx.y"))
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
s[output].bind(vf, tvm.thread_axis("vthread"))
s[output].bind(vy, tvm.thread_axis("vthread"))
s[output].bind(vx, tvm.thread_axis("vthread"))
s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
s[OL].compute_at(s[output], tx)
# tile and bind reduction axes
n, f, y, x = s[OL].op.axis
rc, ry, rx = s[OL].op.reduce_axis
cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=3)
cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=3)
rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
ryo, rym, ryi = cfg['tile_rx'].apply(s, OL, ry)
rxo, rxm, rxi = cfg['tile_ry'].apply(s, OL, rx)
s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)
s[AA].compute_at(s[OL], rxo)
s[WW].compute_at(s[OL], rxo)
s[AL].compute_at(s[OL], rxm)
s[WL].compute_at(s[OL], rxm)
# cooperative fetching
for load in [AA, WW]:
n, f, y, x = s[load].op.axis
fused = s[load].fuse(n, f, y, x)
tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
# tune unroll
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
cfg.define_knob("unroll_explicit", [0, 1])
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
return s, [raw_data, kernel, conv]
######################################################################
# Step 2: Search through the space
# ---------------------------------
# We pick the last layer on resnet as test case.
# Since our space is very large, :code:`XGBoostTuner` is most suitable
# for our case. Here we only do 20 trials for demonstration.
# In practice, making 1000 trials usually can find some good kernels
# for this template
# logging config (for printing tuning log to screen)
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
# the last layer in resnet
task = autotvm.task.create(conv2d_no_batching,
args=(1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)),
target='cuda')
print(task.config_space)
# use local gpu, measure 5 times for every config to reduce variance
# run 8 parallel threads for compilation
measure_option = autotvm.measure_option(mode='local',
number=10,
parallel_num=8,
timeout=20)
# begin tuning, log records to file `cache.tsv`
tuner = autotvm.tuner.XGBTuner(task)
tuner.tune(n_trial=20,
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('cache.tsv')])
# get best config from cache file
dispatch_context = autotvm.apply_history_best("cache.tsv")
best_config = dispatch_context.query(task.target, task.workload)
print("\nBest config:")
print(best_config)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment