Commit 6ea74d41 by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Core part of auto-tuning module (#1312)

parent 7e7154f1
...@@ -96,6 +96,7 @@ assign_source_group("Include" ${GROUP_INCLUDE}) ...@@ -96,6 +96,7 @@ assign_source_group("Include" ${GROUP_INCLUDE})
file(GLOB COMPILER_SRCS file(GLOB COMPILER_SRCS
src/api/*.cc src/api/*.cc
src/arithmetic/*.cc src/arithmetic/*.cc
src/autotvm/*.cc
src/codegen/*.cc src/codegen/*.cc
src/codegen/stack_vm/*.cc src/codegen/stack_vm/*.cc
src/lang/*.cc src/lang/*.cc
......
tvm.autotvm
-----------
.. automodule:: tvm.autotvm
tvm.autotvm.measure
~~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.measure.measure
.. autoclass:: tvm.autotvm.measure.MeasureInput
:members:
.. autoclass:: tvm.autotvm.measure.MeasureResult
:members:
.. autofunction:: tvm.autotvm.measure.measure_option
.. autofunction:: tvm.autotvm.measure.create_measure_batch
tvm.autotvm.tuner
~~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.tuner
:members:
.. autoclass:: tvm.autotvm.tuner.Tuner
:members:
.. autoclass:: tvm.autotvm.tuner.RandomTuner
:members:
:inherited-members:
.. autoclass:: tvm.autotvm.tuner.GridSearchTuner
:members:
:inherited-members:
.. autoclass:: tvm.autotvm.tuner.GATuner
:members:
:inherited-members:
.. autoclass:: tvm.autotvm.tuner.XGBTuner
:members:
:inherited-members:
.. automodule:: tvm.autotvm.tuner.callback
:members:
tvm.autotvm.task
~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.task
:members:
.. automodule:: tvm.autotvm.task.task
:members:
.. automodule:: tvm.autotvm.task.space
:members:
tvm.autotvm.record
~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.record
:members:
...@@ -14,6 +14,7 @@ Python API ...@@ -14,6 +14,7 @@ Python API
ndarray ndarray
container container
function function
autotvm
graph_runtime graph_runtime
rpc rpc
bridge bridge
......
...@@ -191,6 +191,7 @@ gallery_dirs = ["tutorials", "vta/tutorials"] ...@@ -191,6 +191,7 @@ gallery_dirs = ["tutorials", "vta/tutorials"]
subsection_order = ExplicitOrder( subsection_order = ExplicitOrder(
['../tutorials/language', ['../tutorials/language',
'../tutorials/optimize', '../tutorials/optimize',
'../tutorials/autotvm',
'../tutorials/vta', '../tutorials/vta',
'../tutorials/topi', '../tutorials/topi',
'../tutorials/deployment', '../tutorials/deployment',
......
...@@ -488,7 +488,7 @@ bool VerifyMemory(LoweredFunc func, int device_type); ...@@ -488,7 +488,7 @@ bool VerifyMemory(LoweredFunc func, int device_type);
* *
* "max_local_memory_per_block": Total amount of local memory per block (in bytes). * "max_local_memory_per_block": Total amount of local memory per block (in bytes).
* "max_shared_memory_per_block": Total amount of shared memory per block (in bytes). * "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
* "max_thread_per_block": Maximum number of threads per block. * "max_threads_per_block": Maximum number of threads per block.
* "max_thread_x": Maximum length of threadIdx.x. * "max_thread_x": Maximum length of threadIdx.x.
* "max_thread_y": Maximum length of threadIdx.y. * "max_thread_y": Maximum length of threadIdx.y.
* "max_thread_z": Maximum length of threadIdx.z. * "max_thread_z": Maximum length of threadIdx.z.
......
"""The auto-tuning module of tvm
This module includes:
* Tuning space definition API
* Efficient auto-tuners
* Tuning result and database support
* Distributed measurement to scale up tuning
"""
from . import database
from . import feature
from . import measure
from . import record
from . import task
from . import tuner
from . import util
# some shortcuts
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity
from .record import ApplyHistoryBest as apply_history_best
# pylint: disable=consider-using-enumerate,invalid-name
"""
Database of MeasureInput/MeasureResult pair.
This can be used for replaying measurement.
"""
import os
from .record import encode, decode, measure_str_key
class Database(object):
"""
Base class for a record database object.
"""
def load(self, inp, get_all=False):
"""
Load a result based on an input's string key
Parameters
----------
inp: MeasureInput
to be translated into key for RedisDB
get_all: bool, optional
Whether the latest result (or all matching results) should be returned
Returns
-------
rec: MeasureResult if previously saved, otherwise None
"""
raise NotImplementedError()
def save(self, inp, res, extend=False):
"""
Save a result based on an input's string key
Parameters
----------
inp: MeasureInput
to be translated into key for RedisDB
res: MeasureResult
to associate with key
extend:
Whether to extend existing MeasureResults if they exist
"""
raise NotImplementedError()
def filter_inputs(db, measure_inputs, retry=False):
"""
Filter a measure_inputs batch based on saved db results
Parameters
----------
db: Database
database object
measure_inputs: Array of MeasureInput
measure_inputs as expected in measure_batch
retry: bool
whether to retry if the saved result is a failure
Returns
-------
partial_results: Array of MeasureResult
a full list of result, where None denotes no corresponding saved result
unsaved: Array of MeasureInput
a list that only contains unsaved inputs
"""
partial_results = list()
unsaved = list()
for inp in measure_inputs:
res = db.load(inp)
if res is None or (retry and res.error_no != 0):
unsaved.append(inp)
partial_results.append(None)
else:
partial_results.append(res)
return partial_results, unsaved
class RedisDatabase(Database):
"""
Redis version of record database
"""
REDIS_PROD = 15
REDIS_LOCA = 14
REDIS_TEST = 13 # for unit test
REDIS_NIGHT_TEMP = 12 # for nightly report (will be flushed after every workload)
MAGIC_SPLIT = "$"
def __init__(self, db_index=REDIS_PROD):
import redis
if db_index == RedisDatabase.REDIS_TEST:
host = 'localhost'
else:
host = os.environ.get('TVM_FLEET_HOST')
self.db = redis.StrictRedis(host=host, port=6379, db=db_index)
self.db_index = db_index
def set(self, key, value):
self.db.set(key, value)
def get(self, key):
return self.db.get(key)
def load(self, inp, get_all=False):
current = self.get(measure_str_key(inp))
if current is not None:
current = str(current)
records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
results = [rec[1] for rec in records]
if get_all:
return results
return max(results, key=lambda result: result.timestamp)
return current
def save(self, inp, res, extend=False):
current = self.get(measure_str_key(inp))
if not extend or current is None:
self.set(measure_str_key(inp),
RedisDatabase.MAGIC_SPLIT.join([encode(inp, res)]))
else:
current = current.split(RedisDatabase.MAGIC_SPLIT)
self.set(measure_str_key(inp),
RedisDatabase.MAGIC_SPLIT.join(current + [encode(inp, res)]))
def filter(self, func):
"""
Dump all of the records for a particular target
Parameters
----------
func: callable
The signature of the function is bool (MeasureInput, Array of MeasureResult)
Returns
-------
list of records (inp, result) matching the target
Examples
--------
get records for a target
>>> db.filter(lambda inp, resulst: "cuda" in inp.target.keys)
"""
matched_records = list()
# may consider filtering in iterator in the future
for key in self.db:
current = self.get(key)
try:
records = [decode(x) for x in current.spilt(RedisDatabase.MAGIC_SPLIT)]
except TypeError: # got a badly formatted/old format record
continue
inps, results = zip(*records)
inp = inps[0]
if not func(inp, results):
continue
result = max(results, key=lambda res: res.timestamp)
matched_records.append((inp, result))
return matched_records
def flush(self):
self.db.flushdb()
class DummyDatabase(RedisDatabase):
"""
A database based on python dictionary for testing.
"""
def __init__(self):
# pylint: disable=super-init-not-called
self.db = {}
def set(self, key, value):
self.db[key] = value
def get(self, key):
return self.db.get(key)
def flush(self):
self.db = {}
"""Global configuration/variable scope for autotvm"""
class AutotvmGlobalScope(object):
current = None
def __init__(self):
self._old = AutotvmGlobalScope.current
AutotvmGlobalScope.current = self
self.cuda_target_arch = None
GLOBAL_SCOPE = AutotvmGlobalScope()
# pylint: disable=invalid-name
"""Extract feature of iter vars
There are two types of feature
1) Itervar feature
This feature is extracted based on loop variables.
Different loop structures will result in different shapes of feature
2) Curve sample feature (relation feature)
This feature is extracted by sampling relation curve.
This feature is invariant of loop structure.
"""
import struct
import numpy as np
from tvm import schedule, ir_pass, build_module, get_global_func, target as _target
def ana_lower(sch, args,
binds=None,
simple_mode=True):
"""Do lower while keeping all axes in IR
i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads
"""
binds, _ = build_module.get_binds(args, binds)
sch = sch.normalize()
# Phase 0
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds, True)
stmt = ir_pass.StorageFlatten(stmt, binds, 64)
stmt = ir_pass.CanonicalSimplify(stmt)
assert simple_mode
return stmt
try:
_get_buffer_curve_sample_flatten = get_global_func(
"autotvm.feature.GetCurveSampleFeatureFlatten")
_get_itervar_feature = get_global_func("autotvm.feature.GetItervarFeature")
_get_itervar_feature_flatten = get_global_func("autotvm.feature.GetItervarFeatureFlatten")
except ValueError as e:
def raise_error(*args, **kwargs): # pylint: disable=unused-argument
raise RuntimeError("Cannot load autotvm c++ API")
_get_buffer_curve_sample_flatten = _get_itervar_feature = _get_itervar_feature_flatten = \
raise_error
def get_itervar_feature(sch, args, take_log=False):
"""get features of iter vars
Parameters
----------
sch: tvm.schedule.Schedule
args: Array of tvm.tensor.Tensor
the buffer args for lower
take_log: bool
whether take log of numerical statics
Returns
-------
features of every axis in the IR, see doc/features.md for detail
"""
stmt = ana_lower(sch, args, simple_mode=True)
feas = _get_itervar_feature(stmt, take_log)
# convert tvm node to python type
ret = []
for row in feas:
tmp = []
tmp.append([row[0][0].value, row[0][1]])
for item in row[1:]:
tmp.append([item[0].value] + [x.value for x in item[1:]])
ret.append(tmp)
return ret
def flatten_itervar_feature(fea):
"""flatten features into one-dimensional feature vectors
Parameters
----------
fea: list
return value of get_itervar_feature
Returns
-------
flatten_feature: np.ndarray
one-dimensional vector
"""
flatten = []
for axis in fea:
for pair in axis[1:]:
flatten.append(pair[1:])
return np.concatenate(flatten)
def get_itervar_feature_flatten(sch, args, take_log=True):
"""get flatten features of iter vars
this is equivalent to get_itervar_feature + flatten_itervar_feature, but much faster.
Parameters
----------
sch: tvm.schedule.Schedule
args: Array of tvm.tensor.Tensor
the buffer args for lower
take_log: bool
whether take log of numerical statics
Returns
-------
flatten_feature: np.ndarray
one-dimensional vector
"""
stmt = ana_lower(sch, args, simple_mode=True)
feas = _get_itervar_feature_flatten(stmt, take_log)
feas = struct.unpack('%df' % (len(feas)//4), feas)
return feas
def get_flatten_name(fea):
""" Get names of feature after flatten.
Parameters
----------
fea: list or str
return value of get_itervar_feature or a line of logfile
Returns
-------
feature_names: Array of str
"""
feature_name = {
"_attr_": ["length", "nest_level", "topdown", "bottomup"] +
["ann_%d" % i for i in range(20)],
"_arith_": ["add", "mul", "div"],
"buf_touch": ["stride", "mod", "count", "reuse", "T_count", "T_reuse"],
}
if isinstance(fea, str):
from .record import decode
# flatten line to feature
line = fea
inp, _ = decode(line)
target = _target.create(inp.target)
with target:
s, args = inp.template.instantiate(inp.config)
fea = get_itervar_feature(s, args)
names = []
ct = 0
for row in fea:
var_name = str(row[0][1])
for pair in row[1:]:
key = pair[0]
if key in feature_name:
name_list = feature_name[key]
else:
name_list = feature_name["buf_touch"]
for i in range(len((pair[1:]))):
names.append(".".join(["f%d" % ct, var_name, key, name_list[i]]))
ct += 1
return names
def get_buffer_curve_sample_flatten(sch, args, sample_n=30):
"""
Get flatten curve sample feature (relation feature)
Parameters
----------
sch: tvm.schedule.Schedule
args: Array of tvm.tensor.Tensor
the buffer args for lower
sample_n: int
number of sample points along one dimension
Returns
-------
flatten_feature: np.ndarray
one-dimensional vector
"""
stmt = ana_lower(sch, args, simple_mode=True)
feas = _get_buffer_curve_sample_flatten(stmt, sample_n, False)
feas = struct.unpack('%df' % (len(feas)//4), feas)
return feas
"""Distributed executor infrastructure to scale up the tuning"""
from .measure import MeasureInput, MeasureResult, MeasureErrorNo
from .measure import create_measure_batch, measure_option
from .measure_methods import request_remote
from .local_executor import LocalExecutor
from .executor import Future, Executor
""" Abstraction for asynchronous job execution """
class Executor(object):
"""
Base abstract executor interface for asynchronous job submission.
Allows submit asynchronous jobs and returns the Future object.
"""
# timeout for jobs that may hang
DEFAULT_TIMEOUT = 60
def submit(self, func, *args, **kwargs):
"""
Pass task (function, arguments) to the Executor.
Parameters
----------
func : callable
function to be run by a worker
args : list or tuple, optional
arguments passed to the function
kwargs : dict, optional
The keyword arguments
Returns
-------
future : Future
Future object wrapping the task which can be used to
collect the task's result.
"""
raise NotImplementedError()
class Future(object):
"""
Base class of the future object.
The implementations can return object of subclass of this.
This objects encapsulates the asynchronous execution of task
submitted to another thread, or another worker for execution.
Future objects store the state of tasks--can be polled for
result or a blocking call to retrieve the result can be used.
"""
def done(self):
"""
Return True if job was successfully cancelled or finished running.
"""
raise NotImplementedError()
def get(self, timeout=None):
"""
Get the result. This will block until the result is available.
Parameters
----------
timeout : int or float, optional
Maximum number of seconds to wait before it timeouts.
If not specified, it means we block until the result is available.
Returns
-------
result : Any
The result returned by the submitted function.
Raises
------
TimeoutError : if the result call timeouts.
"""
raise NotImplementedError()
class FutureError(RuntimeError):
"""Base error class of all future events"""
pass
# pylint:disable=redefined-builtin
class TimeoutError(FutureError):
"""Error raised when a task is timeout."""
pass
class ExecutionError(FutureError):
"""
Error raised when future execution crashes or failed.
"""
pass
"""Local based implementation of the executor using multiprocessing"""
import signal
from multiprocessing import Process, Queue
try:
from queue import Empty
except ImportError:
from Queue import Empty
import psutil
from . import executor
def kill_child_processes(parent_pid, sig=signal.SIGTERM):
"""kill all child processes recursively"""
try:
parent = psutil.Process(parent_pid)
except psutil.NoSuchProcess:
return
children = parent.children(recursive=True)
for process in children:
try:
process.send_signal(sig)
except psutil.NoSuchProcess:
return
def _execute_func(func, queue, args, kwargs):
"""execute function and return the result or exception to a queue"""
try:
res = func(*args, **kwargs)
except Exception as exc: # pylint: disable=broad-except
res = exc
queue.put(res)
def timeout_monitor(queue, timeout, func, args, kwargs):
"""A wrapper to support timeout of a function call"""
# start a new process for timeout (cannot use thread because we have c function)
p = Process(target=_execute_func, args=(func, queue, args, kwargs))
p.start()
p.join(timeout=timeout)
alive = p.is_alive()
kill_child_processes(p.pid)
p.terminate()
p.join()
if alive:
queue.put(executor.TimeoutError())
else:
if queue.empty():
queue.put(executor.ExecutionError("Fatal error in local executor"))
class LocalFuture(executor.Future):
"""Local wrapper for the future
Parameters
----------
process: multiprocessing.Process
process for running this task
queue: multiprocessing.Queue
queue for receiving the result of this task
"""
def __init__(self, process, queue):
self._done = False
self._process = process
self._queue = queue
def done(self):
self._done = self._done or not self._queue.empty()
return self._done
def get(self, timeout=None):
try:
res = self._queue.get(block=True, timeout=timeout)
except Empty:
raise executor.TimeoutError()
if self._process.is_alive():
kill_child_processes(self._process.pid)
self._process.terminate()
self._process.join()
self._queue.close()
self._queue.join_thread()
self._done = True
del self._queue
del self._process
return res
class LocalFutureNoFork(executor.Future):
"""Local wrapper for the future.
This is a none-fork version of LocalFuture.
Use this for the runtime that does not support fork (like cudnn)
"""
def __init__(self, result):
self._result = result
def done(self):
return True
def get(self, timeout=None):
return self._result
class LocalExecutor(executor.Executor):
"""Local executor that runs workers on the same machine with multiprocessing."""
def __init__(self, timeout=None):
self.timeout = timeout or executor.Executor.DEFAULT_TIMEOUT
def submit(self, func, *args, **kwargs):
"""
Note
----------
By default, the executor will fork a new process for a new job
But some runtime does not support fork (e.g. cuda runtime, cudnn).
In this circumstance, you should set 'fork_new_process' to False in kwargs
"""
fork_new_process = kwargs.pop('fork_new_process', True)
if not fork_new_process:
return LocalFutureNoFork(func(*args, **kwargs))
queue = Queue(1)
process = Process(target=timeout_monitor,
args=(queue, self.timeout, func, args, kwargs))
process.start()
return LocalFuture(process, queue)
# pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
"""User facing API for specifying how to measure the generated code"""
import time
from collections import namedtuple
import numpy as np
from ... import build, nd, target as _target
from ...contrib.util import tempdir
from ...rpc.tracker import Tracker
from ...rpc.server import Server
from ..util import get_const_tuple
from .local_executor import LocalExecutor
class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
"""
Stores all the necessary inputs for a measurement.
Parameters
----------
target : tvm.target.Target
The target device
task : task.Task
Task function
config : ConfigEntity
Specific configuration.
"""
class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost", "timestamp"])):
"""
Stores all the results of a measurement
Parameters
----------
costs: Array of float or Array of Exception
If no error occurs for this measurement, it is an array of measured running times.
If some error occurs during the measurement, it is an array of the exception objections.
error_no: int
Denote error type, defined by MeasureErrorNo
all_cost: float
All cost of this measure, including rpc, compilation, test runs
timestamp: float
The absolute time stamp when we finish measurement.
"""
class MeasureErrorNo(object):
"""Error type for MeasureResult"""
NO_ERROR = 0 # no error
INSTANTIATION_ERROR = 1 # error when calling template function
COMPILE_HOST = 2 # error when compiling code on host (e.g. tvm.build)
COMPILE_DEVICE = 3 # error when compiling code on device (e.g. opencl JIT on device)
RUNTIME_DEVICE = 4 # error when run program on device
WRONG_ANSWER = 5 # answer is wrong when compared to a golden output
FLEET_ERROR = 6 # error of measure infrastructure
def measure_option(mode,
number=1,
repeat=1,
timeout=60,
parallel_num=1,
pack_size=1,
check_correctness=False,
build_option=None,
replay_db=None,
save_to_replay_db=True,
rpc_device_key=None,
rpc_priority=1,
rpc_timeout=60,
rpc_tracker_addr=None,
use_ndk=False,
custom_measure_batch=None):
"""Configure how to do measurement
Parameters
----------
mode: str
'local': use the local device for measurement. In this mode,
the tuner starts a tracker and a RPC server silently for the user.
'rpc': request devices for measurement from rpc tracker. In this mode,
you should start a rpc tracker in a separate processing.
'custom': use custom measure function
'local-nofork': use local device for measure but does not use multiprocessing.
This mode is suitable for debug, but does not support timeout and parallel.
number : int, optional
Number of times to do the measurement for average
repeat : int, optional
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
each of which is the average of `number` test run.
timeout: int, optional
Timeout for a whole batch. TimeoutError will be returned as the result if a
task timeouts.
parallel_num: int, optional
The number of measurement task that can run in parallel.
Set this according to the number of cpu cores (for compilation) and
the number of devices you have (for measuring generate code).
pack_size : int, optional
Number of configs to measure in one RPC call.
Usually this can be set to 1. If your device has high cost to establish a rpc connection,
set this higher.
check_correctness: bool
Whether check correctness after measurement.
build_option: Dict, optional
Build options for tvm.build_config
replay_db : Database, optional
The database that we retrieve saved MeasureResults from
save_to_replay_db: bool, optional
Whether save measure result to database. This is useless when replay_db is None
rpc_priority: int, optional
Priority of this task, used by scheduler in tracker
rpc_device_key: str, optional
The device key of registered devices in tracker
rpc_timeout: int, optional
Timeout of rpc session
rpc_tracker_addr: Tuple(str, int), optional
The address of rpc tracker in Tuple(host, port) format.
If is set, will use this address.
If is not set, will use environment variable "TVM_TRACKER_HOST" and "TVM_TRACKER_PORT"
use_ndk: bool, option
Whether export requires ndk
custom_measure_batch: callable, optional
custom measure function
Returns
-------
options: dict
A dict to store all options
"""
return {
'mode': mode,
'number': number,
'repeat': repeat,
'timeout': timeout,
'parallel_num': parallel_num,
'pack_size': pack_size,
'check_correctness': check_correctness,
'build_option': build_option,
'replay_db': replay_db,
'save_to_replay_db': save_to_replay_db,
'rpc_device_key': rpc_device_key,
'rpc_priority': rpc_priority,
'rpc_timeout': rpc_timeout,
'rpc_tracker_addr': rpc_tracker_addr,
'use_ndk': use_ndk,
'custom_measure_batch': custom_measure_batch
}
def create_measure_batch(task, options):
"""Get a standard measure_batch function.
Parameters
----------
task: tvm.autotvm.task.Task
The tuning task
options: dict
The option for measuring generated code.
You should use the return value of :any:`autotvm.measure_option` for this argument
Returns
-------
measure_batch: callable
a callback function to measure a batch of configs
"""
from . import measure_methods
from ..database import filter_inputs
mode = options['mode']
number, repeat = options['number'], options['repeat']
timeout, parallel_num = options['timeout'], options['parallel_num']
pack_size = options['pack_size']
check_correctness = options['check_correctness']
build_option = options['build_option']
replay_db = options['replay_db']
save_to_replay_db = options['save_to_replay_db']
rpc_device_key = options['rpc_device_key']
rpc_priority, rpc_timeout = options['rpc_priority'], options['rpc_timeout']
use_ndk = options['use_ndk']
custom_measure_batch = options['custom_measure_batch']
kwargs = {}
executor = LocalExecutor(timeout=timeout)
if mode == 'local':
# start temporary rpc tracker and rpc server for the user
tracker = Tracker('localhost', port=9000, port_end=10000,
silent=True)
rpc_device_key = '$local$device$%d' % tracker.port
server = Server('localhost', port=9000, port_end=10000,
key=rpc_device_key,
use_popen=True, silent=True,
tracker_addr=(tracker.host, tracker.port))
fmeasure = measure_methods.measure_rpc
kwargs['rpc_device_key'] = rpc_device_key
kwargs['rpc_tracker_addr'] = (tracker.host, tracker.port)
kwargs['rpc_timeout'] = timeout
kwargs['tmp_dir'] = tempdir()
elif mode == 'rpc':
fmeasure = measure_methods.measure_rpc
kwargs['rpc_device_key'] = rpc_device_key
kwargs['rpc_priority'] = rpc_priority
kwargs['rpc_timeout'] = rpc_timeout
kwargs['use_ndk'] = use_ndk
kwargs['tmp_dir'] = tempdir()
assert rpc_device_key, "In rpc mode, a rpc_device_key must be provided"
elif mode == "custom":
assert callable(custom_measure_batch), "In custom mode, custom_measure_func " \
"must be a callable object"
elif mode == 'local-nofork':
fmeasure = measure_methods.measure_local
kwargs['fork_new_process'] = False
else:
raise RuntimeError("Invalid mode: " + mode)
if 'cuda' in task.target.keys and 'rpc_device_key' in kwargs: # query cuda device info
add_cuda_device_info(kwargs['rpc_device_key'], kwargs.get('rpc_tracker_addr'), kwargs)
if 'opencl' in task.target.keys and 'rpc_device_key' in kwargs:
add_opencl_device_info(kwargs['rpc_device_key'], kwargs.get('rpc_tracker_addr'), kwargs)
if check_correctness:
# use llvm to generate a reference input/output
# this option works for tuning topi, but might not work for you custom op
with _target.create("llvm"):
s, arg_bufs = task.instantiate(task.config_space.get(0))
ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype)
for x in arg_bufs]
func = build(s, arg_bufs, "llvm")
tvm_buf = [nd.array(x) for x in ref_input]
func(*tvm_buf)
ref_output = [x.asnumpy() for x in tvm_buf]
kwargs['ref_input'], kwargs['ref_outpu'] = ref_input, ref_output
def measure_batch(measure_inputs):
"""measure the time cost for a batch of configs in real machines"""
if replay_db is not None:
partial_results, measure_inputs =\
filter_inputs(replay_db, measure_inputs, retry=False)
# pack configs
input_packs = []
for i in range(0, len(measure_inputs), pack_size):
input_packs.append(measure_inputs[i:i + pack_size])
# send to measure
futures = []
for input_pack in input_packs:
future = executor.submit(
fmeasure, input_pack,
number=number,
repeat=repeat,
build_option=build_option,
**kwargs
)
futures.append(future)
# transform results
results = []
for future in futures:
result = future.get()
if isinstance(result, Exception):
if mode == 'local-nofork':
# debug usage, raise exception
raise result
tstamp = time.time()
results.extend([MeasureResult((result,), MeasureErrorNo.FLEET_ERROR,
timeout, tstamp)] * pack_size)
else:
results.extend(result)
if replay_db is not None:
if save_to_replay_db: # save result to database
for measure_input, result in zip(measure_inputs, results):
replay_db.save(measure_input, result)
result_idx = 0
for i in range(len(partial_results)):
if partial_results[i] is None:
partial_results[i] = results[result_idx]
result_idx += 1
return partial_results
return results
if mode == 'custom':
measure_batch = custom_measure_batch
measure_batch.parallel_num = parallel_num
if mode == 'local':
measure_batch.aux_objects = {"server": server, "tracker": tracker}
return measure_batch
def add_cuda_device_info(device_key, rpc_tracker_addr, kwargs):
"""Query cuda device info. This is used to set the flags for nvcc compiler
and check the validity of a generated code."""
from .measure_methods import request_remote
remote = request_remote(device_key, rpc_tracker_addr)
ctx = remote.context('cuda', 0)
max_dims = ctx.max_thread_dimensions
kwargs['check_gpu'] = {
'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
'max_threads_per_block': ctx.max_threads_per_block,
'max_thread_x': max_dims[0],
'max_thread_y': max_dims[1],
'max_thread_z': max_dims[2],
}
kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.'))
def add_opencl_device_info(device_key, rpc_tracker_addr, kwargs):
"""Query opencl device info. This is used to check the validity of a generated code."""
from .measure_methods import request_remote
remote = request_remote(device_key, rpc_tracker_addr)
ctx = remote.context('opencl', 0)
max_dims = ctx.max_thread_dimensions
kwargs['check_gpu'] = {
'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
'max_threads_per_block': ctx.max_threads_per_block,
'max_thread_x': max_dims[0],
'max_thread_y': max_dims[1],
'max_thread_z': max_dims[2],
}
# pylint: disable=consider-using-enumerate,invalid-name,too-many-function-args
"""
Functions that run on executor for measurement.
These functions are responsible for building tvm module, uploading it to
remote devices, recording the running time costs and checking the correctness of output
"""
import logging
import os
import time
from random import getrandbits
import numpy as np
from ...contrib import ndk, nvcc
from ... import rpc, ir_pass, build, build_config, nd, context, TVMError, register_func
from ..util import get_const_tuple
from ..env import AutotvmGlobalScope
from .measure import MeasureResult, MeasureErrorNo
from ..task.space import InstantiationError
class HashMismatchError(ValueError):
"""Raised when the code hash of a submitted config doesn't match that on the
measure side """
pass
def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
"""request a remote session
Parameters
----------
device_key: string
device key of registered device in tracker
tracker_addr: Tuple(string, int), optional
The address of rpc tracker in (host, port) format
priority: int, optional
priority of this request, larger is more prior
timeout: float, optional
timeout of this session (units: seconds)
Returns
------
session: RPCSession
"""
# connect to the tracker
if tracker_addr:
host = tracker_addr[0]
port = tracker_addr[1]
else:
host = os.environ['TVM_TRACKER_HOST']
port = int(os.environ['TVM_TRACKER_PORT'])
tracker = rpc.connect_tracker(host, port)
remote = tracker.request(device_key, priority=priority,
session_timeout=timeout)
return remote
def _measure_generic(fbuild, input_pack, ref_input, ref_output):
"""Generic measurement function
Parameters
----------
fbuild : function takes MeasureInput returns tuple of (time_func, ctx)
The build function used to build each input.
input_pack : list of MeasureInput
The inputs we need to evaluate
ref_input: Array of np.ndarray
Reference input for checking correctness
ref_output: Array of np.ndarray
Reference output for checking correctness
Returns
-------
res_pack : array of MeasureResult
The list of execution result of measurement.
"""
res_pack = []
for inp in input_pack:
tic = time.time()
try:
time_f, ctx, arg_bufs = fbuild(inp)
except TVMError as exc:
tstamp = time.time()
msg = str(exc)
if "Stack trace returned" in msg:
msg = msg[:msg.index("Stack trace returned")]
if "InstantiationError" in msg:
try:
msg = msg.split('\n')[-2].split(": ")[1]
except Exception: # pylint: disable=broad-except
pass
res_pack.append(MeasureResult((InstantiationError(msg),),
MeasureErrorNo.INSTANTIATION_ERROR,
tstamp - tic, tstamp))
else:
res_pack.append(MeasureResult((RuntimeError(msg),),
MeasureErrorNo.COMPILE_HOST,
tstamp - tic, tstamp))
continue
except InstantiationError as e:
tstamp = time.time()
res_pack.append(MeasureResult((e,),
MeasureErrorNo.INSTANTIATION_ERROR,
tstamp - tic, tstamp))
continue
# measure time
errno = MeasureErrorNo.NO_ERROR
try:
if ref_input:
args = [nd.array(x, ctx) for x in ref_input]
else:
args = [nd.array(np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype),
ctx) for x in arg_bufs]
costs = time_f(*args).results
if len(costs) > 2: # remove largest and smallest value to reduce variance
costs = list(costs)
costs.sort()
costs = tuple(costs[1:-1])
if ref_output:
for expected, real in zip(ref_output, args):
if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
logging.warning("Wrong Answer!")
errno = MeasureErrorNo.WRONG_ANSWER
except TVMError as exc:
msg = str(exc)
if "Stack trace returned" in msg:
msg = msg[:msg.index("Stack trace returned")]
costs = (RuntimeError(msg),)
errno = MeasureErrorNo.RUNTIME_DEVICE
tstamp = time.time()
res_pack.append(MeasureResult(costs, errno, tstamp - tic, tstamp))
return res_pack
def _build_func(inp, build_option, kwargs):
"""Build function module. Exception will be raised when error occurs"""
with inp.target:
s, args = inp.task.instantiate(inp.config)
if not inp.config.valid():
raise InstantiationError(inp.config.errors)
code_hash = getattr(s, 'code_hash', None)
if inp.config.code_hash != code_hash:
raise HashMismatchError('got {0:s}, expected {1:s}'
.format(str(inp.config.code_hash), str(code_hash)))
opts = build_option or {}
if "check_gpu" in kwargs:
values = kwargs['check_gpu']
# Add gpu verify pass to filter out invalid configs in advance.
# This can accelerate the tuning process
check_keys = ['max_shared_memory_per_block', 'max_threads_per_block',
'max_thread_x', 'max_thread_y', 'max_thread_z']
opts["add_lower_pass"] = [
(2, gpu_verify_pass(**{key: values[key] for key in check_keys}))]
if 'cuda_arch' in kwargs:
set_cuda_target_arch(kwargs['cuda_arch'])
with build_config(**opts):
func = build(s, args, target_host=inp.task.target_host)
return func, args
def measure_rpc(input_pack,
rpc_device_key,
number,
repeat=1,
build_option=None,
rpc_tracker_addr=None,
rpc_priority=1,
rpc_timeout=60,
tmp_dir=None,
**kwargs):
"""Measure the time cost on a device by rpc
Parameters
----------
input_pack : list of MeasureInput
The inputs we need to evaluate
rpc_device_key: str
The device key of registered devices in tracker
number : int
Number of times to get the running measurement
repeat : int, optional
How many times we want to repeat the measurement.
build_option: Dict
build options for tvm.build_config
rpc_tracker_addr: Tuple(string, int), optional
The address of rpc tracker in (host, port) format
If is none, will use environment variable
rpc_priority: int, optional
priority of this task, used by scheduler in tracker
rpc_timeout: int, optional
timeout of the rpc session
tmp_dir: tvm.contrib.util.TempDirectory, optional
directory to store temp file
kwargs: dict, optional
Additional key word arguments
Returns
-------
res_pack : Array of MeasureResult
The list of execution results of measurement.
"""
def _fbuild(inp):
""" Local build function."""
func, args = _build_func(inp, build_option, kwargs)
if not kwargs.get('use_ndk', False):
file_name = "tmp_func_%0x.tar" % getrandbits(64)
path = tmp_dir.relpath(file_name)
func.export_library(path)
else:
file_name = "tmp_func_%0x.so" % getrandbits(64)
path = tmp_dir.relpath(file_name)
func.export_library(path, ndk.create_shared)
remote = request_remote(rpc_device_key, rpc_tracker_addr, rpc_priority, rpc_timeout)
remote.upload(path)
func = remote.load_module(file_name)
ctx = remote.context(str(inp.target), 0)
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat)
return time_f, ctx, args
ret = _measure_generic(_fbuild, input_pack,
kwargs.get("ref_input", None), kwargs.get("ref_output", None))
return ret
def measure_local(input_pack,
number,
repeat=1,
build_option=None,
**kwargs):
"""Measure the time cost on a local machine.
Parameters
----------
input_pack : list of MeasureInput
The inputs we need to evaluate
number : int
Number of times to get the running measurement
repeat : int, optional
How many times we want to repeat the measurement.
build_option: dict, optional
Build options for tvm.build_config
kwargs: dict, optional
Additional key word arguments
Returns
-------
res_pack : Array of MeasureResult
The list of execution results of measurement.
"""
def _fbuild(inp):
""" Local build function """
func, args = _build_func(inp, build_option, kwargs)
ctx = context(str(inp.target), 0)
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat)
return time_f, ctx, args
ret = _measure_generic(_fbuild, input_pack,
kwargs.get("ref_input", None), kwargs.get("ref_output", None))
return ret
def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel
This pass will check shared memory size and number of threads per block.
"""
def verify_pass(stmt):
valid = ir_pass.VerifyGPUCode(stmt, kwargs)
if not valid:
raise InstantiationError("Skipped because of invalid gpu kernel")
return stmt
return verify_pass
@register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate ptx code for better optimization"""
ptx = nvcc.compile_cuda(code, target="ptx", arch=AutotvmGlobalScope.current.cuda_target_arch)
return ptx
def set_cuda_target_arch(arch):
"""set target architecture of nvcc compiler"""
AutotvmGlobalScope.current.cuda_target_arch = arch
# pylint: disable=superfluous-parens, redefined-outer-name, redefined-outer-name,pointless-string-statement
# pylint: disable=consider-using-enumerate,invalid-name
"""Tuning record and serialization format"""
import argparse
import base64
import logging
import multiprocessing
import pickle
import json
import time
from collections import OrderedDict
import numpy as np
from .. import target, build, lower
from . import task
from .task import DispatchContext, ConfigEntity
from .measure import MeasureInput, MeasureResult
AUTOTVM_LOG_VERSION = 0.1
try: # convert unicode to str for python2
_unicode = unicode
except NameError:
_unicode = ()
def measure_str_key(inp, include_config=True):
""" get unique str key for MeasureInput
Parameters
----------
inp: MeasureInput
input for the measure
include_config: bool, optional
whether includes config in the str key
Returns
-------
key: str
The str representation of key
"""
config_str = str(inp.config) if include_config else ""
return "".join([str(inp.target), inp.task.name, str(inp.task.args),
str(inp.task.kwargs), config_str])
def encode(inp, result, protocol='json'):
"""encode (MeasureInput, MeasureResult) pair to a string
Parameters
----------
inp: autotvm.tuner.MeasureInput
result: autotvm.tuner.MeasureResult
pair of input/result
protocol: str
log protocol, json or pickle
Returns
-------
row: str
a row in the logger file
"""
if protocol == 'json':
json_dict = {
"i": (str(inp.target),
inp.task.name, inp.task.args, inp.task.kwargs,
inp.task.workload,
inp.config.to_json_dict()),
"r": (result.costs if result.error_no == 0 else (1e9,),
result.error_no,
result.all_cost,
result.timestamp),
"v": AUTOTVM_LOG_VERSION
}
return json.dumps(json_dict)
elif protocol == 'pickle':
row = (str(inp.target),
str(base64.b64encode(pickle.dumps([inp.task.name,
inp.task.args,
inp.task.kwargs,
inp.task.workload])).decode()),
str(base64.b64encode(pickle.dumps(inp.config)).decode()),
str(base64.b64encode(pickle.dumps(tuple(result))).decode()))
return '\t'.join(row)
else:
raise RuntimeError("Invalid log protocol: " + protocol)
def decode(row, protocol='json'):
"""Decode encoded record string to python object
Parameters
----------
row: str
a row in the logger file
protocol: str
log protocol, json or pickle
Returns
-------
input: autotvm.tuner.MeasureInput
result: autotvm.tuner.MeasureResult
"""
# pylint: disable=unused-variable
if protocol == 'json':
row = json.loads(row)
tgt, task_name, task_args, task_kwargs, workload, config = row['i']
tgt = target.create(str(tgt))
def clean_json_to_python(x):
"""1. convert all list in x to tuple (hashable)
2. convert unicode to str for python2
"""
if isinstance(x, list):
return tuple([clean_json_to_python(a) for a in x])
if isinstance(x, _unicode):
return str(x)
return x
tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args))
tsk.workload = clean_json_to_python(workload)
config = ConfigEntity.from_json_dict(config)
inp = MeasureInput(tgt, tsk, config)
result = MeasureResult(*[tuple(x) if isinstance(x, list) else x for x in row["r"]])
return inp, result
elif protocol == 'pickle':
items = row.split("\t")
tgt = target.create(items[0])
task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
config = pickle.loads(base64.b64decode(items[2].encode()))
result = pickle.loads(base64.b64decode(items[3].encode()))
tsk = task.Task(task_tuple[0], task_tuple[1])
tsk.workload = task_tuple[3]
return MeasureInput(tgt, tsk, config), MeasureResult(*result)
else:
raise RuntimeError("Invalid log protocol: " + protocol)
def load_from_file(filename):
"""Generator: load records from file.
This is a generator that yields the records.
Parameters
----------
filename: str
Yields
------
input: autotvm.tuner.MeasureInput
result: autotvm.tuner.MeasureResult
"""
for row in open(filename):
yield decode(row)
class ApplyHistoryBest(DispatchContext):
"""
Apply the history best config
Parameters
----------
records : str or iterator of (MeasureInput, MeasureResult)
Collection of tuning records.
if is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair.
otherwise, it is an iterator
default: ConfigEntity, optional
default config to return when no history records
"""
def __init__(self, records, default=None):
super(ApplyHistoryBest, self).__init__()
if isinstance(records, str):
records = load_from_file(records)
counter = 0
best_map = {}
for inp, res in records:
counter += 1
if res.error_no != 0:
continue
for k in inp.target.keys:
key = (k, inp.task.workload)
if key not in best_map:
best_map[key] = (inp, res)
else:
_, other_res = best_map[key]
if np.mean(other_res.costs) > np.mean(res.costs):
best_map[key] = (inp, res)
logging.info(
"Finish load %d records, %d entries selected", counter, len(best_map))
self._best_map = best_map
self._default = default
def query(self, target, workload):
if target is None:
raise RuntimeError("Need a target context to find the history best. "
"Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
" above the dispatcher call. So does other target. ")
for k in target.keys:
key = (k, workload)
if key in self._best_map:
return self._best_map[key][0].config
if self._default:
return self._default
raise RuntimeError(
"Cannot find config for target=%s, workload=%s" % (target, workload))
def dump_best(self, out_file):
"""Dump the best records for each workload to a file
Parameters
----------
out_file: str
filename
"""
fout = open(out_file, 'a')
for val in self._best_map.values():
inp, res = val
fout.write(encode(inp, res) + '\n')
def split_workload(in_file, clean=True):
"""Split a log file into separate files, each of which contains only a single workload
This function can also delete duplicated records in log file
Parameters
----------
in_file: str
input filename
clean: bool
whether delete duplicated items
"""
tic = time.time()
lines = list(open(in_file).readlines())
logging.info("start convert...")
pool = multiprocessing.Pool()
lines = pool.map(decode, lines)
logging.info("map done %.2f", time.time() - tic)
wkl_dict = OrderedDict()
for inp, res in lines:
wkl = measure_str_key(inp, False)
if wkl not in wkl_dict:
wkl_dict[wkl] = []
wkl_dict[wkl].append([inp, res])
if clean:
for i, (k, v) in enumerate(wkl_dict.items()):
# clean duplicated items
added = set()
cleaned = []
for inp, res in v:
str_key = measure_str_key(inp)
if str_key in added:
continue
added.add(str_key)
cleaned.append([inp, res])
# write to file
logging.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned))
with open(args.i + ".%03d.wkl" % i, 'w') as fout:
for inp, res in cleaned:
fout.write(encode(inp, res) + '\n')
else:
for i, (k, v) in enumerate(wkl_dict.items()):
logging.info("Key: %s\tNum: %d", k, len(v))
with open(args.i + ".%03d.wkl" % i, 'w') as fout:
for inp, res in v:
fout.write(encode(inp, res) + '\n')
"""
Usage:
This record executable module has three modes.
* Print log file in readable format
e.g. python -m autotvm.record --mode read --i collect_conv.tsv --begin 0 --end 5 --ir --code
* Extract history best from a large log file
e.g. python -m autotvm.record --mode best --i collect.tsv
* Split a log file into separate files, each of which contains only a single wkl
e.g. python -m autotvm.record --mode split --i collect.tsv
"""
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=['read', 'best', 'split'], default='read')
parser.add_argument("--i", type=str, help="input file")
parser.add_argument("--o", type=str, default=None, help='output file')
parser.add_argument("--begin", type=int, default=0)
parser.add_argument("--end", type=int, default=5)
parser.add_argument("--ir", action='store_true')
parser.add_argument("--code", action='store_true')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.mode == 'best':
args.o = args.o or args.i + ".best"
hist_best = ApplyHistoryBest(load_from_file(args.i))
hist_best.dump_best(args.o)
elif args.mode == 'read':
for i, (inp, result) in enumerate(load_from_file(args.i)):
if args.begin <= i < args.end:
with inp.target:
s, arg_bufs = inp.task.instantiate(inp.config)
print("")
print(inp.target, inp.task, inp.config)
print(result)
if args.ir:
with inp.target:
print(lower(s, arg_bufs, simple_mode=True))
if args.code:
with inp.target:
func = build(s, arg_bufs)
print(func.imported_modules[0].get_source())
elif args.mode == 'split':
split_workload(args.i)
"""Task is a tunable composition of template functions.
Tuner takes a tunable task and optimizes the joint configuration
space of all the template functions in the task.
This module defines the task data structure, as well as a collection(zoo)
of typical tasks of interest.
"""
from .task import Task, create, register, template, get_config
from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, dispatcher
"""
Decorator functions for hashing schedule code
code hashing is used to check the consistence of schedule code and the parameters loaded from log
"""
import inspect
import zlib
from tvm import schedule
def attach_code_hash(s):
"""Decorator for attaching a code hash to a schedule
Parameters
----------
s: Schedule
tvm.schedule.Schedule to attach the hash to
"""
def decorator(func):
def wrapper(*args, **kwargs):
func(*args, **kwargs)
raw_hash = zlib.crc32(''.join(inspect.getsourcelines(func)[0]).encode())
s.code_hash = hex(raw_hash)[2:]
return wrapper
return decorator
def attach_code_hash_to_arg(arg_idx=1):
"""Decorator for attaching a code hash to a schedule
Parameters
----------
arg_idx: int
index of the argument (expected to be a Schedule) to attach the code
hash to
"""
def decorator(func):
def wrapper(*args, **kwargs):
func(*args, **kwargs)
assert isinstance(args[arg_idx], schedule.Schedule)
raw_hash = zlib.crc32(''.join(inspect.getsourcelines(func)[0]).encode())
args[arg_idx].code_hash = hex(raw_hash)[2:]
return wrapper
return decorator
"""
Template dispatcher module.
A dispatcher is a function that can contains multiple behaviors.
Its specific behavior is can be controlled by DispatchContext.
DispatchContext is used in two ways, usually via different implementation
of the DispatchContext base class.
- During search, we can use it to pass the current proposal from tuner.
- During evaluation, we can use it to set pick the best policy.
"""
from __future__ import absolute_import as _abs
from decorator import decorate
from tvm import target as _target
class DispatchContext(object):
"""
Base class of dispatch context.
DispatchContext enables the target and workload
specific dispatch mechanism for templates.
"""
current = None
def query(self, target, workload):
"""
Query the context to get the specific implementation.
Parameters
----------
target: Target
The current target
workload : Workload
The current workload.
Returns
-------
cfg : ConfigSpace
The specific configuration.
"""
raise NotImplementedError()
def __enter__(self):
self._old_ctx = DispatchContext.current
DispatchContext.current = self
return self
def __exit__(self, ptype, value, trace):
DispatchContext.current = self._old_ctx
class ApplyConfig(DispatchContext):
"""Apply a specific config entity during query.
Parameters
----------
config : ConfigSpace or ConfigEntity
The specific configuration we care about.
"""
def __init__(self, config):
super(ApplyConfig, self).__init__()
self._config = config
self.workload = None
def query(self, target, workload):
"""Override query"""
self.workload = workload
return self._config
def dispatcher(fworkload):
"""Wrap a workload dispatcher function.
Parameters
----------
fworkload : function
The workload extraction function from arguments.
Returns
-------
fdispatcher : function
A wrapped dispatcher function, which will
dispatch based on DispatchContext and
the current workload.
"""
dispatch_dict = {}
func_name = fworkload.__name__
def register(key, func=None, override=False):
"""Register template function.
Parameters
----------
key : str or List of str
The template key to identify the template
under this dispatcher.
func : function
The function to be registered.
The first argument of the function is always
cfg returned by DispatchContext,
the rest arguments are the same as the fworkload.
override : bool
Whether override existing registration.
Returns
-------
The register function if necessary.
"""
if isinstance(key, str):
key = [key]
def _do_reg(myf):
for x in key:
if x in dispatch_dict and not override:
raise ValueError(
"Key %s is already registered for %s" % (x, func_name))
dispatch_dict[x] = myf
return myf
if func:
return _do_reg(func)
return _do_reg
def dispatch_func(func, *args, **kwargs):
"""The wrapped dispatch function"""
tgt = _target.current_target()
context = DispatchContext.current
if context is None:
raise RuntimeError("DispatchContext is not initialized")
workload = func(*args, **kwargs)
cfg = context.query(tgt, workload)
return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)
fdecorate = decorate(fworkload, dispatch_func)
fdecorate.register = register
return fdecorate
# pylint: disable=too-few-public-methods,invalid-name,unused-argument,arguments-differ
# pylint: disable=consider-using-enumerate
"""
Template configuration space.
Each template function can be parametrized by a ConfigSpace.
The space is declared when we invoke the template function with ConfigSpace.
During evaluation, we pass in a ConfigEntity, which contains a specific
entity in the space. This entity contains deterministic parameters.
"""
from __future__ import absolute_import as _abs
import itertools
import functools
import math
from collections import namedtuple, OrderedDict
import numpy as np
from tvm import schedule, thread_axis
from tvm.autotvm.util import get_const_int
Axis = namedtuple('Axis', ['space', 'index'])
class InstantiationError(ValueError):
"""Actively detected error in instantiating a template with a config,
raised by cfg.raise_error
e.g. too many unrolling, too many threads in a block
"""
pass
class TransformSpace(object):
"""Base class for transform space
TransformSpace is the node in the computation graph of axes
Note
----
We can regard our schedule code as a transformation graph of axes.
Starting from raw axes in the definition of tvm.compute, we can transform these axes
by some operators. The operator includes 'split', 'reorder' and 'annotate'.
Each operator has some tunable parameters (e.g. the split factor).
Then the tuning process is just to find good parameters of these op.
So the all the combinations of the parameters of these op forms our search space.
Naming convention:
We call the set of all possible values as XXXSpace. (XXX can be Split, Reorder, Config ...)
We call a specific entity in a space as XXXEntity.
"""
def __init__(self):
self.ins = []
self.num_output = 0
self.entities = []
def __len__(self):
return len(self.entities)
def __getitem__(self, index):
"""Get an entity of the space by index
Parameters
----------
index: int
Returns
-------
transform entity
"""
return self.entities[index]
@staticmethod
def get_num_output():
"""get number of output axes after this transform
Returns
-------
n: int
number of output axes
"""
return 0
class VirtualAxis(TransformSpace):
"""Axis placeholder in template
Parameters
----------
var: int or tvm.schedule.IterVar
If is int, return a virtual axis whose length is the provided argument.
If is IterVar, return a virtual axis whose length is extracted from
the IterVar's extent domain.
name: str
"""
name_ct = 0
def __init__(self, var, name=None):
super(VirtualAxis, self).__init__()
self.num_output = 1
if name is None:
name = 'axis_%d' % VirtualAxis.name_ct
VirtualAxis.name_ct += 1
self.name = name
if isinstance(var, int):
self.length = var
elif isinstance(var, schedule.IterVar):
self.name = var.var.name
if var.dom is None:
self.length = -1
else:
self.length = get_const_int(var.dom.extent)
elif isinstance(var, VirtualAxis):
self.length = var.length
else:
raise RuntimeError("Invalid type of axis")
@staticmethod
def get_num_output(var, name=None):
return 1
def __repr__(self):
return "vaxis(%s)" % self.name
def get_factors(n):
"""return all factors of an integer
Parameters
----------
n: int
integer to factorize
Returns
-------
factors: list
List of all factors
"""
step = 2 if n % 2 else 1
ret = list(set(
functools.reduce(
list.__add__, ([i, n//i] for i in range(1, int(math.sqrt(n)) + 1, step)
if n % i == 0))))
ret.sort()
return ret
class SplitSpace(TransformSpace):
"""Split an axis for several times"""
def __init__(self, axes, policy, **kwargs):
super(SplitSpace, self).__init__()
axis = axes[0]
self.policy = policy
self.entities = []
if policy == 'all':
num_outputs = kwargs["num_outputs"]
max_factor = kwargs.get("max_factor", 1 << 31)
fil = kwargs.get("filter", lambda x: True)
length = axis.length
factors = get_factors(length)
factors = [x for x in factors if x <= max_factor]
# copy factors for every level
self.product = length
self.num_outputs = num_outputs
self.factors = [factors] * (num_outputs-1)
self._generate_space(0, [None] * (num_outputs - 1))
self.entities = list(filter(fil, self.entities))
self.num_output = num_outputs
elif policy == 'candidate':
self.product = axis.length
self.num_outputs = kwargs["num_outputs"]
for size in kwargs["candidate"]:
assert len(size) == self.num_outputs
# assert np.prod(size) == self.product
self.entities.append(SplitEntity(size))
self.num_output = self.num_outputs
else:
raise RuntimeError("Invalid policy: " + policy)
def _generate_space(self, now, tmp_stack):
"""Generate space by DFS"""
if now == self.num_outputs - 1:
if self.product % np.prod(tmp_stack) == 0:
first = int(self.product // int(np.prod(tmp_stack)))
self.entities.append(SplitEntity([first] + tmp_stack[::-1]))
else:
for factor in self.factors[now]:
tmp_stack[now] = factor
self._generate_space(now + 1, tmp_stack)
@staticmethod
def get_num_output(axes, policy, **kwargs):
return kwargs["num_outputs"]
def __repr__(self):
return ("Split(policy=%s, product=%d, num_outputs=%d) len=%d" %
(self.policy, self.product, self.num_outputs, len(self)))
class SplitEntity(object):
"""
A split operation with detailed parameters
that can apply to an axis
Parameters
----------
size: Array of int
the size of every axis after split
e.g. an axis of extent 128, we split it into 3 axes, a possible
size is [4, 4, 8] (4x4x8 = 128)
"""
def __init__(self, size):
self.size = size
def apply(self, sch, op, axis):
"""Apply split to an axis
Parameters
----------
sch: tvm.schedule.Schedule
The tvm schedule
op: tvm.tensor.Operation
The stage to be applied
axis: tvm.schedule.IterVar
axis to split
Returns
-------
axes : list of Axis
The transformed axes.
"""
ret = []
for i in range(1, len(self.size)):
ax0, ax1 = sch[op].split(axis, int(np.prod(self.size[i:])))
ret.append(ax0)
axis = ax1
return ret + [axis]
def __repr__(self):
return str(self.size)
class ReorderSpace(TransformSpace):
"""The parameter space for ordering an array of axes"""
def __init__(self, axes, policy, **kwargs):
super(ReorderSpace, self).__init__()
self.ins = axes
self.policy = policy
self.num_output = len(axes)
if policy == 'identity':
self.entities = [ReorderEntity(range(len(axes)))]
elif policy == 'all':
self.entities = [
ReorderEntity(x) for x in itertools.permutations(range(len(axes)))]
elif policy == 'interval_all':
begin, end = kwargs['interval']
sub_space = list(itertools.permutations(range(begin, end)))
prefix, suffix = tuple(range(begin)), tuple(range(end, len(axes)))
self.entities = [ReorderEntity(prefix + x + suffix) for x in sub_space]
elif policy == 'candidate':
candidate = kwargs["candidate"]
for can in candidate:
perm = [axes.index(x) for x in can]
self.entities.append(ReorderEntity(perm))
elif policy == 'interleave':
spatial, reduce = kwargs['spatial'], kwargs['reduce']
spatial = [[axes.index(x) for x in ch] for ch in spatial]
reduce = [[axes.index(x) for x in ch] for ch in reduce]
outer_merged = self._merge_chain([x[:-1] for x in spatial])
inner_merged = self._merge_chain([x[-1:] for x in spatial] + reduce)
for o in outer_merged:
for i in inner_merged:
self.entities.append(ReorderEntity(o + i))
elif policy == 'interleave_cuda':
spatial, reduce = kwargs['spatial'], kwargs['reduce']
spatial = [[axes.index(x) for x in ch] for ch in spatial]
reduce = [[axes.index(x) for x in ch] for ch in reduce]
outer_merged = self._merge_chain([x[:-1] for x in spatial])
reduce_merged = self._merge_chain(reduce)
inner_merged = [x[-1] for x in spatial]
for o in outer_merged:
for r in reduce_merged:
self.entities.append(ReorderEntity(o + r + inner_merged))
else:
raise RuntimeError("Invalid policy: " + policy)
@staticmethod
def get_num_output(axes, policy, **kwargs):
return len(axes)
def __repr__(self):
return "Reorder(policy=%s) len=%d" % (self.policy, len(self))
def _merge_chain(self, chains):
"""generate all combinations of merge some chains"""
merged = []
tmp_pt = [0] * len(chains)
tmp_stack = []
size = np.sum([len(x) for x in chains])
self._merge_dfs(chains, size, tmp_pt, tmp_stack, merged)
return merged
def _merge_dfs(self, chains, size, tmp_pt, tmp_stack, merged):
if np.sum(tmp_pt) == size:
merged.append(list(tmp_stack))
return
else:
for i in range(len(chains)):
# use i == np.argmax(....) here to take spatial order into consideration
# if we don't want to consider spatial order, we can use tmp_pt[i] == np.max(....)
if (tmp_pt[i] < len(chains[i]) and
(i == np.argmax([len(chains[x]) - tmp_pt[x] for x in range(len(chains))]))):
tmp_stack.append(chains[i][tmp_pt[i]])
tmp_pt[i] += 1
self._merge_dfs(chains, size, tmp_pt, tmp_stack, merged)
tmp_pt[i] -= 1
tmp_stack.pop()
class ReorderEntity(object):
"""A reorder operation with detailed parameters that can apply to axes
Parameters
----------
perm: Array of int
define the permutation
"""
def __init__(self, perm):
self.perm = perm
def apply(self, sch, op, axes):
"""Apply reorder to an array of axes
Parameters
----------
sch: tvm.schedule.Schedule
The tvm schedule
op: tvm.tensor.Operation
The stage to be applied
axis: tvm.schedule.IterVar
axis to split
Returns
-------
axes : list of Axis
The transformed axes.
"""
if len(axes) == len(self.perm):
new_order = [axes[i] for i in self.perm]
else:
new_order = [axes[i] for i in self.perm if i < len(axes)]
sch[op].reorder(*new_order)
return new_order
def __repr__(self):
return str(self.perm)
class AnnotateSpace(TransformSpace):
"""The parameter space for annotating an array of axes"""
def __init__(self, axes, policy, **kwargs):
super(AnnotateSpace, self).__init__()
self.ins = axes
self.policy = policy
self.num_output = len(axes)
if policy == 'bind_gpu':
self.num_axis = len(axes)
if self.num_axis >= 6:
self.entities.append(AnnotateEntity(
['fuse'] * (self.num_axis - 6) +
['blockIdx.z', 'blockIdx.y', 'blockIdx.x',
'threadIdx.z', 'threadIdx.y', 'threadIdx.x']))
elif self.num_axis >= 4:
self.entities.append(AnnotateEntity(
['fuse'] * (self.num_axis - 4) +
['blockIdx.y', 'blockIdx.x',
'threadIdx.y', 'threadIdx.x']))
elif self.num_axis >= 2:
self.entities.append(AnnotateEntity(
['fuse'] * (self.num_axis - 2) +
['blockIdx.x', 'threadIdx.x']))
else:
raise RuntimeError("Unhandled case in bind_gpu")
elif policy == 'bind_gpu_virtual':
self.num_axis = len(axes)
if self.num_axis >= 9:
self.entities.append(AnnotateEntity(
['fuse'] * (self.num_axis - 9) +
['blockIdx.z', 'blockIdx.y', 'blockIdx.x',
'vthread', 'vthread', 'vthread',
'threadIdx.z', 'threadIdx.y', 'threadIdx.x']))
elif self.num_axis >= 6:
self.entities.append(AnnotateEntity(
['fuse'] * (self.num_axis - 6) +
['blockIdx.y', 'blockIdx.x',
'vthread', 'vthread',
'threadIdx.y', 'threadIdx.x']))
elif self.num_axis >= 3:
self.entities.append(AnnotateEntity(
['fuse'] * (self.num_axis - 3) +
['blockIdx.x', 'vthread', 'threadIdx.x']))
else:
raise RuntimeError("Unhandled case in bind_gpu")
elif policy == 'locate_cache':
self.num_axis = len(axes)
num_anchor = kwargs["num_anchor"]
self.anns = list(itertools.combinations(np.arange(self.num_axis), num_anchor))
self.entities = [AnnotateEntity(x) for x in self.anns]
else: # none, vec, unroll, try_vec, try_unroll, try_vec_unroll, ...
anns = policy.replace('try', 'none').split('_')
for ann in anns:
if ann not in ['none', 'unroll', 'vec']:
raise RuntimeError("Invalid policy: " + policy)
self.num_axis = len(axes)
self.anns = [anns] * self.num_axis
self._generate_space(0, [""] * self.num_axis)
def _generate_space(self, now, tmp_stack):
"""Generate space by DFS"""
if now == self.num_axis:
# only vectorize inner most dimension
vec_ct = tmp_stack.count('vec')
if vec_ct == 0 or vec_ct == 1:
self.entities.append(AnnotateEntity(list(tmp_stack)))
else:
for ann in self.anns[now]:
tmp_stack[now] = ann
self._generate_space(now + 1, tmp_stack)
@staticmethod
def get_num_output(axes, policy, **kwargs):
return len(axes)
def __repr__(self):
return "Annotate(policy=%s) len=%d" % (self.policy, len(self))
class AnnotateEntity(object):
"""An annotation operation with detailed parameters that can apply to axes
Parameters
----------
anns: Array of string
The annotations of axes
"""
def __init__(self, anns):
self.anns = anns
def apply(self, sch, op, axes, axis_lens=None,
max_unroll=None, vec_size=None, cfg=None, source=None):
"""Apply annotation to an array of axes
Parameters
----------
sch: tvm.schedule.Schedule
The tvm schedule
op: tvm.tensor.Operation
The stage to be applied
axes: Array of tvm.schedule.IterVar
axis to split
axis_lens: Array of int, optional
the length of axes
max_unroll: int, optional
maximum unroll step
vec_size: Array of int, optional
valid vector lanes for vectorization
cfg: ConfigEntity, optional
cfg for recording error
source: Array of Array tensor, optional
source tensor for attaching cache
Returns
-------
axes : list of tvm.schedule.IterVar
The transformed axes
"""
if source is not None: # special case : attach cache_read/cache_write
for src, to in zip(source, self.anns):
for t in src:
sch[t].compute_at(sch[op], axes[to])
else: # other cases
for i, ann in enumerate(self.anns):
if ann == 'none':
pass
elif ann == 'unroll':
if max_unroll and axis_lens[i] > max_unroll:
cfg.raise_error("Too large factor for unrolling")
sch[op].unroll(axes[i])
elif ann == 'vec':
if vec_size and axis_lens[i] not in vec_size:
cfg.raise_error("Wrong size of lanes in vectorization")
sch[op].vectorize(axes[i])
elif ann == 'blockIdx.x':
sch[op].bind(axes[i], thread_axis('blockIdx.x'))
elif ann == 'blockIdx.y':
sch[op].bind(axes[i], thread_axis('blockIdx.y'))
elif ann == 'blockIdx.z':
sch[op].bind(axes[i], thread_axis('blockIdx.z'))
elif ann == 'threadIdx.x':
sch[op].bind(axes[i], thread_axis('threadIdx.x'))
elif ann == 'threadIdx.y':
sch[op].bind(axes[i], thread_axis('threadIdx.y'))
elif ann == 'threadIdx.z':
sch[op].bind(axes[i], thread_axis('threadIdx.z'))
elif ann == 'vthread':
sch[op].bind(axes[i], thread_axis("vthread"))
elif ann == 'fuse':
assert i < len(axes) - 1
axes[i+1] = sch[op].fuse(axes[i], axes[i+1])
else:
raise RuntimeError("Invalid annotation " + ann)
return axes
def __repr__(self):
return str(self.anns)
class OtherOptionSpace(TransformSpace):
"""The parameter space for general option"""
def __init__(self, axes, policy, **kwargs):
super(OtherOptionSpace, self).__init__()
candidate = kwargs["candidate"]
self.entities = [OtherOptionEntity(x) for x in candidate]
@staticmethod
def get_num_output(axes, policy, **kwargs):
return 0
def __repr__(self):
return "OtherOption(%s) len=%d" % (self.entities, len(self))
class OtherOptionEntity(object):
"""The parameter entity for general option, with a detailed value"""
def __init__(self, val):
self.val = val
def __repr__(self):
return str(self.val)
class ConfigSpace(object):
"""The configuration space of a schedule. Pass it as config in template to
collect transformation space and build transform graph of axes
"""
def __init__(self):
# private dict to provide sugar
self.space_map = OrderedDict() # name -> space
self._collect = True
self._length = None
self._entity_map = OrderedDict()
self._constraints = []
self.errors = []
self.template_key = None
self.code_hash = None
self.flop = 0
@staticmethod
def axis(var):
"""get a virtual axis (axis placeholder)
Parameters
----------
var: int or tvm.schedule.IterVar
If is int, return an axis whose length is the provided argument.
If is IterVar, return an axis whose length is extracted from the
IterVar's extent domain.
"""
return VirtualAxis(var)
reduce_axis = axis
def define_split(self, name, axis, policy='all', **kwargs):
"""Define a new tunable knob which splits an axis into a list of axes
Parameters
----------
name: str
name to index the entity of this space
axis: tvm.schedule.IterVar
axis to split
policy: str
name of policy.
If is 'all', the tuner will try all divisible factors.
If is 'candidate', try listed candidate.
kwargs: dict
extra arguments for policy
"""
axes = [axis]
return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs)
def define_reorder(self, name, axes, policy, **kwargs):
"""Define a new tunable knob which reorders a list of axes
Parameters
----------
name: str
name to index the entity of this space
axes: Array of tvm.schedule.IterVar
axes to reorder
policy: str
name of policy
If is 'identity', do an identity permutation.
If is 'all', try all permutations.
If is 'interval_all', try all permutations of an interval of axes.
If is 'candidate', try listed candidate.
If is 'interleave', interleave chains of spatial axes and chains of reduction axes.
kwargs: dict
extra arguments for policy
"""
return self._add_new_transform(ReorderSpace, name, axes, policy, **kwargs)
def define_annotate(self, name, axes, policy, **kwargs):
"""Define a new tunable knob which annotates a list of axes
Parameters
----------
name: str
name to index the entity of this space
axes: Array of tvm.schedule.IterVar
axes to annotate
policy: str
name of policy
If is 'unroll', unroll the axes.
If is 'try_unroll', try to unroll the axes.
If is 'try_unroll_vec', try to unroll or vectorize the axes.
If is 'bind_gpu', bind the first few axes to gpu threads.
If is 'locate_cache', choose n axes to attach shared/local cache.
kwargs: dict
extra arguments for policy
"""
return self._add_new_transform(AnnotateSpace, name, axes, policy, **kwargs)
def define_knob(self, name, candidate):
"""Define a tunable knob with a list of candidates
Parameters
----------
name: str
name key of that option
candidate: list
list of candidates
"""
return self._add_new_transform(OtherOptionSpace, name, [], None, candidate=candidate)
def add_flop(self, flop):
"""Add float operation statistics for this tuning task
Parameters
---------
flop: int or float
number of float operations
"""
self.flop += flop
def raise_error(self, msg):
"""register error in config
Using this to actively detect error when scheudling.
Otherwise these error will occur during runtime, which
will cost more time.
Parameters
----------
msg: str
"""
self.errors.append(msg)
def valid(self):
"""Check whether the config meets all the constraints
Note: This check should be called after instantiation of task,
because the ConfigEntity/ConfigSpace collects errors during instantiation
Returns
-------
valid: bool
whether the config meets all the constraints
"""
return not bool(self.errors)
def _add_new_transform(self, space_class, name, axes, policy, **kwargs):
"""Add a new transform space in template"""
if self._collect:
# convert schedule axis to space definition axis
axes = [x if isinstance(x, (VirtualAxis, Axis)) else self.axis(x) for x in axes]
# add subspace (knob)
space = space_class(axes, policy, **kwargs)
self.space_map[name] = space
self._entity_map[name] = space[0]
return [Axis(space, i) for i in range(space.num_output)]
return [Axis(None, i) for i in range(space_class.get_num_output(axes, policy, **kwargs))]
def __len__(self):
if self._length is None:
self._length = int(np.prod([len(x) for x in self.space_map.values()]))
return self._length
def get(self, index):
"""Get a config entity with detailed parameters from this space
Parameters
----------
index: int
index in the space
"""
entities = OrderedDict()
t = index
for name, space in self.space_map.items():
entities[name] = space[t % len(space)]
t //= len(space)
ret = ConfigEntity(index, self.code_hash, self.template_key, entities, self._constraints)
return ret
def __iter__(self):
return self._entity_map.__iter__()
def __getitem__(self, name):
"""get the transform entity(knob) of this entity by name
do not use this to get a ConfigEntity of this space (should use ConfigSpace.get instead)
Parameters
----------
name: str
name of the transform
"""
return self._entity_map[name]
def __repr__(self):
res = "ConfigSpace (len=%d, space_map=\n" % len(self)
for i, (name, space) in enumerate(self.space_map.items()):
res += " %2d %s: %s\n" % (i, name, space)
return res + ")"
_ann_to_number = {
'none': 0, 'vec': 1, 'unroll': 2,
'blockIdx.x': 3, 'blockIdx.y': 4, 'blockIdx.z': 5,
'threadIdx.x': 6, 'threadIdx.y': 7, 'threadIdx.z': 8,
'vthread': 9, 'fuse': 10
}
class ConfigEntity(ConfigSpace):
"""A configuration with detailed parameters
Parameters
----------
index: int
index of this config in space
code_hash: str
hash of schedule code
template_key : str
The specific template key
entity_map: dict
map name to transform entity
constraints : list
List of constraints
"""
def __init__(self, index, code_hash, template_key, entity_map, constraints):
super(ConfigEntity, self).__init__()
self.index = index
self.template_key = template_key
self._collect = False
self._entity_map = entity_map
self._space_map = None
self._constraints = constraints
self.code_hash = code_hash
def get_flatten_feature(self):
""" flatten entities to a numerical one-dimensional feature vector
Returns
-------
fea: np.array
one dimensional float32 array
"""
fea = []
for _, v in self._entity_map.items():
if isinstance(v, SplitEntity):
fea.extend(v.size)
elif isinstance(v, ReorderEntity):
# use a naive way: directly copy the permutation
fea.extend(v.perm)
elif isinstance(v, AnnotateEntity):
# one-hot encoding
for ann in v.anns:
tmp = [0] * len(_ann_to_number)
tmp[_ann_to_number[ann]] = 1
fea.extend(tmp)
elif isinstance(v, OtherOptionEntity):
fea.append(v.val)
return np.array(fea, dtype=np.float32)
def get_other_option(self):
"""
Returns
-------
other_option: dict
other tunable parameters (tunable parameters defined by `cfg.define_knob`)
"""
return {x: x.val for x in self._entity_map.values() if isinstance(x, OtherOptionEntity)}
def to_json_dict(self):
"""convert to a json serializable dictionary
Return
------
json_dict: dict
a json serializable dictionary
"""
ret = {}
ret['i'] = int(self.index)
ret['t'] = self.template_key
ret['c'] = self.code_hash
entity_map = []
for k, v in self._entity_map.items():
if isinstance(v, SplitEntity):
entity_map.append((k, 'sp', v.size))
elif isinstance(v, ReorderEntity):
entity_map.append((k, 're', v.perm))
elif isinstance(v, AnnotateEntity):
entity_map.append((k, 'an', v.anns))
elif isinstance(v, OtherOptionEntity):
entity_map.append((k, 'ot', v.val))
else:
raise RuntimeError("Invalid entity instance: " + v)
ret['e'] = entity_map
return ret
@staticmethod
def from_json_dict(json_dict):
"""Build a ConfigEntity from json serializable dictionary
Parameters
----------
json_dict: dict
Json serializable dictionary. This should be the return value
of :any:`to_json_dict`.
Returns
-------
config: ConfigEntity
The corresponding config object
"""
index = json_dict["i"]
code_hash = json_dict["c"]
template_key = json_dict["t"]
constraints = []
entity_map = OrderedDict()
for item in json_dict["e"]:
key, knob_type, knob_args = item
if knob_type == 'sp':
entity = SplitEntity(knob_args)
elif knob_type == 're':
entity = ReorderEntity(knob_args)
elif knob_type == 'an':
entity = AnnotateEntity(knob_args)
elif knob_type == 'ot':
entity = OtherOptionEntity(knob_args)
else:
raise RuntimeError("Invalid config knob type: " + knob_type)
entity_map[str(key)] = entity
return ConfigEntity(index, code_hash, template_key, entity_map, constraints)
def __repr__(self):
return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key,
self.code_hash, self.index)
# pylint: disable=unused-variable
"""Definition of task function.
Task can be constructed from tuple of func, args, and kwargs.
func is a state-less function, or a string that
registers the standard task.
"""
import numpy as np
from ... import tensor, expr, container, target as _target
from ..util import get_const_int, get_const_tuple, get_func_name
from .dispatcher import DispatchContext, ApplyConfig, dispatcher
from .space import ConfigSpace
def _raise_error(*args, **kwargs): # pylint: disable=unused-argument
raise RuntimeError("The function of this task is not found. Possibly the function "
"of this task is registered in another python file "
"which is not imported in this run")
class Task(object):
"""A Tunable Task
Parameters
----------
name: str
The name of the task.
args: Tuple
Positional argument of func
"""
def __init__(self, name, args):
self.name = name
self.args = args
self.kwargs = {} # currently unused
# init null config space
self.config_space = None
self.func = TASK_TABLE.get(name, _raise_error)
# auxiliary info, available after `init_space` is called
self.workload = None
self.flop = None
self.target = None
self.target_host = None
def instantiate(self, config):
"""Instantiate this task function (template) with a config.
Returns corresponding schedule.
Parameters
----------
config: template.ConfigEntity
parameter config for this template
Returns
-------
sch: tvm.schedule.Schedule
The tvm schedule
arg_bufs: Array of tvm.tensor.Tensor
The input/output buffers
"""
config.flop = 0
with ApplyConfig(config):
sch, arg_bufs = self.func(*self.args, **self.kwargs)
if not self.flop:
config.flop = config.flop or compute_flop(sch)
self.flop = config.flop
return sch, arg_bufs
def __repr__(self):
return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
self.name, self.args, self.kwargs, self.workload
)
TASK_TABLE = {
}
def register(name, func=None, override=False):
"""Register a task function.
Parameters
----------
name : str
The name to identify the task.
func : callable
The function to be registered.
override : bool
Whether override existing registration.
Returns
-------
func: callable
The registered function
"""
def _do_reg(myf):
if name in TASK_TABLE and not override:
raise ValueError(
"Key %s is already registered" % name)
TASK_TABLE[name] = myf
return myf
if func:
return _do_reg(func)
return _do_reg
def create(func_name, args, target, target_host=None, template_key=None):
"""Create a tuning task and initialize its search space
Parameters
----------
func_name : str or callable
The task function
args : List
Positional arguments
target : Target
The compilation target
target_host: Target, optional
The compilation target for host side
Returns
-------
tsk: Task
a task object
"""
if callable(func_name):
# register this function if it is not registered before
func = func_name
func_name = func.func_name if hasattr(func, 'func_name') else func.__name__
if func_name in TASK_TABLE:
assert func == TASK_TABLE[func_name], "Find name conflict in task registration. " \
"Consider to choose another name for this task"
else:
register(func_name, func=func)
func = TASK_TABLE[func_name]
ret = Task(func_name, args)
if isinstance(target, str):
target = _target.create(target)
# init config space
ret.config_space = ConfigSpace()
ret.config_space.template_key = template_key or ""
ctx = ApplyConfig(ret.config_space)
with ctx:
with target:
sch, _ = func(*args)
ret.config_space.code_hash = getattr(sch, 'code_hash', None)
ret.workload = ctx.workload
ret.flop = ret.config_space.flop or compute_flop(sch)
ret.target = target
ret.target_host = target_host
return ret
def args_to_workload(x):
"""Convert argument list to hashable workload tuple.
This function will convert list to tuple, tvm node to python value and
flatten tvm.tensor.Tensor to a tuple
Parameters
----------
x: primitive hashable types or tensor.Tensor
The original value
Returns
-------
ret: hashable
The hashable value
"""
if isinstance(x, tensor.Tensor):
return get_const_tuple(x.shape) + (x.dtype, )
elif isinstance(x, (tuple, list, container.Array)):
return tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float)):
return x
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value
elif x is None:
return None
else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
'primitive types only' % type(x))
def template(func):
"""
Decorate a function as a tunable schedule template
Parameters
----------
func: callable
A callable template function.
Its argument should be hashable values.
Its return value should be a Tuple(Schedule, Array of Tensor)
Returns
-------
func: callable
The decorated function
Examples
--------
The following code is a tunable template for a blocked matrix multiplication
.. code-block:: python
@autotvm.template
def matmul(N, L, M, dtype):
A = tvm.placeholder((N, L), name='A', dtype=dtype)
B = tvm.placeholder((L, M), name='B', dtype=dtype)
k = tvm.reduce_axis((0, L), name='k')
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
s = tvm.create_schedule(C.op)
# schedule
y, x = s[C].op.axis
k = s[C].op.reduce_axis[0]
##### define space begin #####
cfg = autotvm.get_config()
cfg.define_split("tile_y", y, num_outputs=2)
cfg.define_split("tile_x", x, num_outputs=2)
##### define space end #####
# schedule according to config
yo, yi = cfg["tile_y"].apply(s, C, y)
xo, xi = cfg["tile_x"].apply(s, C, x)
s[C].reorder(yo, xo, k, yi, xi)
return s, [A, B, C]
"""
# pylint: disable=unused-variable
fname = get_func_name(func)
@register(fname)
@dispatcher
def config_dispatcher(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
return (fname, ) + args_to_workload(args)
@config_dispatcher.register("")
def template_call(cfg, *args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
with ApplyConfig(cfg):
return func(*args, **kwargs)
config_dispatcher.func_name = fname
return config_dispatcher
def get_config():
"""Get current config object
Returns
-------
cfg: ConfigSpace or ConfigEntity
The current config
"""
return DispatchContext.current.query(None, None)
class FlopCalculationError(RuntimeError):
"""Error happens when estimating FLOP for a compute op"""
pass
def compute_flop(sch):
"""Calculate number of FLOP (floating number operations) of the compute ops in a schedule
Parameters
----------
sch: tvm.schedule.Schedule
schedule
Returns
-------
flop: int
number of FLOP in this schedule
"""
def _prod_length(axes):
"""compute product of the lengths of a list of axes"""
try:
num_iter = int(np.prod([get_const_int(axis.dom.extent) for axis in axes]))
except ValueError:
raise FlopCalculationError("The length of axis is not constant. ")
return num_iter
def _count_flop(exp):
"""compute flop for a single expression"""
if isinstance(exp, expr.Reduce):
num_iter = _prod_length(exp.axis)
combiner = exp.combiner.result
source = exp.source
if len(combiner) != 1:
raise FlopCalculationError("Found multiple output in the combiner of reduce op")
if len(source) != 1:
raise FlopCalculationError("Found multiple output in the source of reduce op")
return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0]))
elif isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)):
return 0
elif isinstance(exp, expr.Cast):
return _count_flop(exp.value)
elif isinstance(exp, expr.Var):
return 0
elif isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod,
expr.Max, expr.Min,
expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
expr.And, expr.Or, expr.Not)):
base = 1 if "float" in exp.a.dtype else 0
if isinstance(exp, expr.Not): # unary
return base + _count_flop(exp.a)
return base + _count_flop(exp.a) + _count_flop(exp.b)
elif isinstance(exp, expr.Select):
return _count_flop(exp.condition) + max(_count_flop(exp.true_value),
_count_flop(exp.false_value))
elif isinstance(exp, expr.Call):
return sum([_count_flop(x) for x in exp.args])
else:
raise FlopCalculationError("Found unsupported operator in the compute expr")
def traverse(ops):
"""accumulate flops"""
ret = 0
for op in ops:
if isinstance(op, tensor.ComputeOp):
num_element = _prod_length(op.axis)
body = op.body
if len(body) != 1:
raise FlopCalculationError("Found multiple output in the compute")
exp = body[0]
ret += num_element * _count_flop(exp)
ret += traverse([sch[t].op for t in op.input_tensors])
elif isinstance(op, tensor.PlaceholderOp):
pass
else:
raise FlopCalculationError("Only support tvm.compute currently. "
"Other ops like tvm.scan is not supported")
return ret
try:
ret = traverse(sch.outputs)
except FlopCalculationError as exc:
raise RuntimeError("FLOP estimator fails for this operator. Error msg: "
+ str(exc) + ". Please use `cfg.add_flop` to manually set "
"FLOP for this operator")
if ret == 0:
raise RuntimeError("Cannot find float number operation in this operator. "
"Please use `cfg.add_flop` to manually set "
"FLOP for this operator")
return ret
"""
A tuner takes a task as input. It proposes some promising :any:`ConfigEntity`
in the :any:`ConfigSpace` and measure them on the real hardware. Then it
proposed the next batch of :any:`ConfigEntity` according to the measure results.
This tuning loop is repeated.
"""
from . import callback
from .tuner import Tuner
from .gridsearch_tuner import GridSearchTuner, RandomTuner
from .ga_tuner import GATuner
from .xgboost_tuner import XGBTuner
# pylint: disable=consider-using-enumerate,invalid-name
"""Namespace of callback utilities of AutoTVM"""
import numpy as np
from .. import record
def log_to_file(file_out, protocol='json'):
"""Log the tuning records into file.
The rows of the log are stored in the format of autotvm.record.encode.
Parameters
----------
file_out : File or str
The file to log to.
protocol: str, optional
The log protocol. Can be 'json' or 'pickle'
Returns
-------
callback : callable
Callback function to do the logging.
"""
def _callback(_, inputs, results):
"""Callback implementation"""
if isinstance(file_out, str):
with open(file_out, "a") as f:
for inp, result in zip(inputs, results):
f.write(record.encode(inp, result, protocol) + "\n")
else:
for inp, result in zip(inputs, results):
file_out.write(record.encode(inp, result, protocol) + "\n")
return _callback
def save_tuner_state(prefix, save_every_sample=100):
"""Save the state of tuner
Parameters
----------
prefix : srt
prefix of the filename to store state
save_every_sample: int
save the state every x samples
Returns
-------
callback : function
Callback function to do the auto saving.
"""
def _callback(tuner, inputs, results):
for _, __ in zip(inputs, results):
try:
ct = len(tuner.visited)
except AttributeError:
ct = 0
if ct % save_every_sample == 0:
tuner.save_state(prefix + "_%d.state" % ct)
return _callback
def log_to_redis(host="localhost", port=6379, dbn=11):
"""Record the tuning record to a redis DB.
Parameters
----------
host: str, optional
Host address of redis db
port: int, optional
Port of redis db
dbn: int, optional
which redis db to use, default 11
"""
# import here so only depend on redis when necessary
import redis
red = redis.StrictRedis(host=host, port=port, db=dbn)
def _callback(_, inputs, results):
"""Callback implementation"""
for inp, result in zip(inputs, results):
red.set(inp, result)
return _callback
class Monitor(object):
"""A monitor to collect statistic during tuning"""
def __init__(self):
self.scores = []
self.timestamps = []
def __call__(self, tuner, inputs, results):
for inp, res in zip(inputs, results):
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)
self.scores.append(flops)
else:
self.scores.append(0)
self.timestamps.append(res.timestamp)
def reset(self):
self.scores = []
self.timestamps = []
def trial_scores(self):
"""get scores (currently is flops) of all trials"""
return np.array(self.scores)
def trial_timestamps(self):
"""get wall clock time stamp of all trials"""
return np.array(self.timestamps)
# pylint: disable=consider-using-enumerate,invalid-name,abstract-method
"""Tuner with genetic algorithm"""
import numpy as np
from .tuner import Tuner
from .model_based_tuner import knob2point, point2knob
class GATuner(Tuner):
"""Tuner with genetic algorithm.
This tuner does not have a cost model so it always run measurement on real machines.
This tuner expands the :code:`ConfigEntity` as gene.
Parameters
----------
pop_size: int
number of genes in one generation
elite_num: int
number of elite to keep
mutation_prob: float
probability of mutation of a knob in a gene
"""
def __init__(self, task, pop_size, elite_num=3, mutation_prob=0.1):
super(GATuner, self).__init__(task)
# algorithm configurations
self.pop_size = pop_size
self.elite_num = elite_num
self.mutation_prob = mutation_prob
assert elite_num <= pop_size, "The number of elites must be less than population size"
# space info
self.space = task.config_space
self.dims = [len(x) for x in self.space.space_map.values()]
self.visited = set([])
# current generation
self.genes = []
self.scores = []
self.elites = []
self.elite_scores = []
self.trial_pt = 0
# random initialization
self.pop_size = min(self.pop_size, len(self.space))
for _ in range(self.pop_size):
tmp_gene = point2knob(np.random.randint(len(self.space)), self.dims)
while knob2point(tmp_gene, self.dims) in self.visited:
tmp_gene = point2knob(np.random.randint(len(self.space)), self.dims)
self.genes.append(tmp_gene)
self.visited.add(knob2point(tmp_gene, self.dims))
def next_batch(self, batch_size):
ret = []
for _ in range(batch_size):
gene = self.genes[self.trial_pt % self.pop_size]
self.trial_pt += 1
ret.append(self.space.get(knob2point(gene, self.dims)))
return ret
def update(self, inputs, results):
for inp, res in zip(inputs, results):
if res.error_no == 0:
y = inp.task.flop / np.mean(res.costs)
self.scores.append(y)
else:
self.scores.append(0)
if len(self.scores) >= len(self.genes):
genes = self.genes + self.elites
scores = np.array(self.scores[:len(self.genes)] + self.elite_scores)
# reserve elite
self.elites, self.elite_scores = [], []
elite_indexes = np.argpartition(scores, -self.elite_num)[-self.elite_num:]
for ind in elite_indexes:
self.elites.append(genes[ind])
self.elite_scores.append(scores[ind])
# cross over
indices = np.arange(len(genes))
scores /= np.max(scores)
probs = scores / np.sum(scores)
tmp_genes = []
for _ in range(self.pop_size):
p1, p2 = np.random.choice(indices, size=2, replace=False, p=probs)
p1, p2 = genes[p1], genes[p2]
point = np.random.randint(len(self.dims))
tmp_gene = p1[:point] + p2[point:]
tmp_genes.append(tmp_gene)
# mutation
next_genes = []
for tmp_gene in tmp_genes:
for j, dim in enumerate(self.dims):
if np.random.random() < self.mutation_prob:
tmp_gene[j] = np.random.randint(dim)
if len(self.visited) < len(self.space):
while knob2point(tmp_gene, self.dims) in self.visited:
j = np.random.randint(len(self.dims))
tmp_gene[j] = np.random.randint(self.dims[j])
next_genes.append(tmp_gene)
self.visited.add(knob2point(tmp_gene, self.dims))
else:
break
self.genes = next_genes
self.trial_pt = 0
self.scores = []
def has_next(self):
return len(self.visited) - (len(self.genes) - self.trial_pt) < len(self.space)
# pylint: disable=abstract-method
"""Grid search tuner and random tuner"""
import numpy as np
from .tuner import Tuner
class GridSearchTuner(Tuner):
"""Enumerate the search space in a grid search order"""
def __init__(self, task):
super(GridSearchTuner, self).__init__(task)
self.counter = 0
def next_batch(self, batch_size):
ret = []
for _ in range(batch_size):
if self.counter >= len(self.task.config_space):
continue
index = self.counter
ret.append(self.task.config_space.get(index))
self.counter = self.counter + 1
return ret
def has_next(self):
return self.counter < len(self.task.config_space)
def __getstate__(self):
return {"counter": self.counter}
def __setstate__(self, state):
self.counter = state['counter']
class RandomTuner(Tuner):
"""Enumerate the search space in a random order"""
def __init__(self, task):
super(RandomTuner, self).__init__(task)
self.visited = set()
def next_batch(self, batch_size):
ret = []
counter = 0
while counter < batch_size:
if len(self.visited) >= len(self.task.config_space):
break
index = np.random.randint(len(self.task.config_space))
while index in self.visited:
index = np.random.randint(len(self.task.config_space))
ret.append(self.task.config_space.get(index))
self.visited.add(index)
counter += 1
return ret
def has_next(self):
return len(self.visited) < len(self.task.config_space)
def __getstate__(self):
return {"visited": self.counter}
def __setstate__(self, state):
self.counter = state['visited']
# pylint: disable=invalid-name
"""Metrics for evaluating tuning process"""
import numpy as np
from ..util import get_rank
def max_curve(trial_scores):
""" f(n) = max([s[i] fo i < n])
Parameters
----------
trial_scores: Array of float
the score of i th trial
Returns
-------
curve: Array of float
function values
"""
ret = np.empty(len(trial_scores))
keep = -1e9
for i, score in enumerate(trial_scores):
keep = max(keep, score)
ret[i] = keep
return ret
def mean_curve(trial_scores):
""" f(n) = mean([s[i] fo i < n])
Parameters
----------
trial_scores: Array of float
the score of i th trial
Returns
-------
curve: Array of float
function values
"""
ret = np.empty(len(trial_scores))
keep = 0
for i, score in enumerate(trial_scores):
keep += score
ret[i] = keep / (i+1)
return ret
def recall_curve(trial_ranks, top=None):
"""
if top is None, f(n) = sum([I(rank[i] < n) for i < n]) / n
if top is K, f(n) = sum([I(rank[i] < K) for i < n]) / K
Parameters
----------
trial_ranks: Array of int
the rank of i th trial in labels
top: int or None
top-n recall
Returns
-------
curve: Array of float
function values
"""
if not isinstance(trial_ranks, np.ndarray):
trial_ranks = np.array(trial_ranks)
ret = np.zeros(len(trial_ranks))
if top is None:
for i in range(len(trial_ranks)):
ret[i] = np.sum(trial_ranks[:i] <= i) / (i+1)
else:
for i in range(len(trial_ranks)):
ret[i] = 1.0 * np.sum(trial_ranks[:i] < top) / top
return ret
def cover_curve(trial_ranks):
"""
f(n) = max k s.t. {1,2,...,k} is a subset of {ranks[i] for i < n}
Parameters
----------
trial_ranks: Array of int
the rank of i th trial in labels
Returns
-------
curve: Array of float
function values
"""
ret = np.empty(len(trial_ranks))
keep = -1
cover = set()
for i, rank in enumerate(trial_ranks):
cover.add(rank)
while keep+1 in cover:
keep += 1
ret[i] = keep + 1
return ret / len(trial_ranks)
def average_recall(preds, labels, N):
"""evaluate average recall-n for predictions and labels"""
trials = np.argsort(preds)[::-1]
ranks = get_rank(labels[trials])
curve = recall_curve(ranks)
return np.sum(curve[:N]) / N
# pylint: disable=no-else-return,invalid-name,consider-using-enumerate,abstract-method
"""Base class for model-based tuner
This type of tuner will fit a cost model and use some optimization methods to
find optimums points of cost model in space.
"""
import gc
import numpy as np
from .tuner import Tuner
class FeatureCache(object):
"""Feature cache manager for cache sharing between different cost models"""
def __init__(self):
self.feature_cache = {}
def get(self, key):
""" Get feature cache dictionary for a key
Parameters
----------
key: str
The key of a feature type
Returns
-------
fea_cache: dict
cache dictionary
"""
if key not in self.feature_cache:
self.feature_cache[key] = {}
return self.feature_cache[key]
def size(self, key):
"""" Get the size of a feature cache dictionary
Parameters
----------
key: str
The key of a feature type
Returns
-------
n: int
"""
return len(self.feature_cache.get(key, tuple()))
def clear(self, key):
"""Clear feature cache for a key
Parameters
----------
key: str
The key of a feature type
"""
del self.feature_cache[key]
self.feature_cache[key] = {}
gc.collect()
class CostModel(object):
"""Cost model to predict the speed of a config"""
def __init__(self):
pass
def fit(self, xs, ys, plan_size):
"""Fit to training data
Parameters
----------
xs: Array of int
indexes of configs in the config space
ys: Array of float
The speed (flop, float number operations per second)
plan_size: int
The plan size of tuner
"""
raise NotImplementedError()
def fit_log(self, records, plan_size):
"""Fit training data from log.
Parameters
----------
records: Array of Tuple(MeasureInput, MeasureResult)
The tuning records
plan_size: int
The plan size of tuner
"""
raise NotImplementedError()
def predict(self, xs, output_margin=False):
"""Predict the speed of configs
Parameters
----------
xs: Array of int
The indexes of configs to predict
output_margin: bool, optional
Whether output the untransformed margin.
When a model is used as base model, it should output untransformed margin
Returns
-------
preds: Array of float
The prediction
"""
raise NotImplementedError()
def load_basemodel(self, base_model):
"""Load base model for transfer learning
Parameters
----------
base_model: CostModel
base model
"""
raise NotImplementedError()
def clone_new(self):
"""Clone a new model with the same parameters.
This function will only copy hyperparameters of the tuner, not all the trained model
This is used for deriving a base model conveniently
Returns
-------
model: CostModel
A model with the same hyperparameter (argument)
"""
raise NotImplementedError()
class ModelOptimizer(object):
"""Optimizer used to find optimal points of cost model"""
def __init__(self):
pass
def find_maximums(self, model, num, exclusive):
"""Find maximum of a cost model
Note we use cost model to predict GFLOPS, so we should find the maximum
Parameters
----------
model: CostModel
Cost model
num: int
The number of returned maximum points
exclusive: set, optional
The excluded set of this optimizer. Return results won't include any
elements in this set.
"""
raise NotImplementedError()
class ModelBasedTuner(Tuner):
"""Base class for model based tuner
This type of tuner will fit a cost model and use an optimizer to
find the maximums of the cost model as next trials
Parameters
----------
task: autotvm.task.Task
The tuning task
cost_model: CostModel
The cost model that predicts the speed of a config (IR)
model_optimizer:
The optimizer to find local optimum points of cost model in tuning search space
plan_size: int
Tuner will re-fit model per `plan_size` new measure samples
diversity_filter_ratio: int or float, optional
If is not None, the tuner will first select
top-(plan_size * diversity_filter_ratio) candidates according to the cost model
and then pick plan_size of them according to the diversity metric.
"""
def __init__(self, task, cost_model, model_optimizer, plan_size, diversity_filter_ratio=None):
super(ModelBasedTuner, self).__init__(task)
# space
self.task = task
self.target = task.target
self.plan_size = plan_size
self.space = task.config_space
self.space_len = len(task.config_space)
self.dims = [len(x) for x in self.space.space_map.values()]
self.cost_model = cost_model
self.model_optimizer = model_optimizer
self.diversity_filter_ratio = diversity_filter_ratio
if self.diversity_filter_ratio:
assert self.diversity_filter_ratio >= 1, "Diversity filter ratio " \
"must be larger than one"
# trial plan
self.trials = []
self.trial_pt = 0
self.visited = set()
# observed samples
self.xs = []
self.ys = []
self.flops_max = 0.0
self.train_ct = 0
def next_batch(self, batch_size):
ret = []
counter = 0
while counter < batch_size:
if len(self.visited) >= len(self.space):
break
while self.trial_pt < len(self.trials):
index = self.trials[self.trial_pt]
if index not in self.visited:
break
self.trial_pt += 1
if self.trial_pt >= len(self.trials): # trial list is empty, choose randomly
index = np.random.randint(len(self.space))
while index in self.visited:
index = np.random.randint(len(self.space))
ret.append(self.space.get(index))
self.visited.add(index)
counter += 1
return ret
def update(self, inputs, results):
for inp, res in zip(inputs, results):
index = inp.config.index
if res.error_no == 0:
self.xs.append(index)
flops = inp.task.flop / np.mean(res.costs)
self.flops_max = max(self.flops_max, flops)
self.ys.append(flops)
else:
self.xs.append(index)
self.ys.append(0)
# if we have enough new training samples
if len(self.xs) >= self.plan_size * (self.train_ct + 1) \
and self.flops_max > 1e-6:
self.cost_model.fit(self.xs, self.ys, self.plan_size)
if self.diversity_filter_ratio:
candidate = self.model_optimizer.find_maximums(
self.cost_model, self.plan_size * self.diversity_filter_ratio, self.visited)
scores = self.cost_model.predict(candidate)
knobs = [point2knob(x, self.dims) for x in candidate]
pick_index = submodular_pick(0 * scores, knobs, self.plan_size, knob_weight=1)
maximums = np.array(candidate)[pick_index]
else:
maximums = self.model_optimizer.find_maximums(
self.cost_model, self.plan_size, self.visited)
self.trials = maximums
self.trial_pt = 0
self.train_ct += 1
def load_history(self, data_set):
base_model = self.cost_model.clone_new()
base_model.fit_log(data_set, self.plan_size)
if not self.trials:
# no plan yet, use base model to select initial trials
maximums = self.model_optimizer.find_maximums(base_model, self.visited)
self.trials = maximums
self.trial_pt = 0
self.cost_model.load_basemodel(base_model)
def has_next(self):
return len(self.visited) < len(self.space)
def point2knob(p, dims):
"""convert point form (single integer) to knob form (vector)"""
knob = []
for dim in dims:
knob.append(p % dim)
p //= dim
return knob
def knob2point(knob, dims):
"""convert knob form (vector) to point form (single integer)"""
p = 0
for j, k in enumerate(knob):
p += int(np.prod(dims[:j])) * k
return p
def submodular_pick(scores, knobs, n_pick, knob_weight=1.0):
"""Run greedy optimization to pick points with regard to both score and diversity.
DiversityScore = knob_weight * number of unique knobs in the selected set
Obj = sum(scores[i] for i in pick) + DiversityScore
Note that this objective function is a monotone submodular function.
Parameters
----------
scores: Array of float
score of every points
knobs: Array of Array of int
feature vector (tunable knobs) of every points
n_pick: int
number of points to pick
knob_weight: float
weight of an unique knob feature
"""
n = len(scores)
assert n == len(knobs)
n_knobs = len(knobs[0])
knobs_set = [set() for _ in range(n_knobs)]
ret = []
remain = list(range(len(scores)))
for _ in range(n_pick):
max_x = -1
max_delta = -1e9
for x in remain:
tmp_delta = scores[x]
for i in range(n_knobs):
if knobs[x][i] not in knobs_set[i]:
tmp_delta += knob_weight
if tmp_delta > max_delta:
max_delta, max_x = tmp_delta, x
ret.append(max_x)
remain.remove(max_x)
for i in range(n_knobs):
knobs_set[i].add(knobs[max_x][i])
return ret
# pylint: disable=consider-using-enumerate
"""
Cost model optimizer based on simulated annealing
"""
import heapq
import logging
import time
import numpy as np
from ..util import sample_ints
from .model_based_tuner import ModelOptimizer, knob2point, point2knob
class SimulatedAnnealingOptimizer(ModelOptimizer):
"""parallel simulated annealing optimization algorithm
Parameters
----------
task: Task
The tuning task
n_iter: int
The number of iterations of simulated annealing
temp: float or Array of float
If is a single float, then use a constant temperature.
If is an Array, then perform linear cooling from temp[0] to temp[1]
early_stop: int, optional
Stop iteration if the optimal set do not change in `early_stop` rounds
verbose: int, optional
Print log every `verbose` iterations
"""
def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128,
early_stop=30, verbose=50):
super(SimulatedAnnealingOptimizer, self).__init__()
self.task = task
self.dims = [len(x) for x in self.task.config_space.space_map.values()]
self.n_iter = n_iter
self.temp = temp
self.persistent = persistent
self.parallel_size = parallel_size
self.early_stop = early_stop
self.verbose = verbose
self.points = None
def find_maximums(self, model, num, exclusive):
tic = time.time()
temp, n_iter, early_stop, verbose = self.temp, self.n_iter, self.early_stop, self.verbose
if self.persistent and self.points is not None:
points = self.points
else:
points = np.array(sample_ints(0, len(self.task.config_space), self.parallel_size))
scores = model.predict(points)
# build heap and insert initial points
heap_items = [(float('-inf'), -i) for i in range(num)]
heapq.heapify(heap_items)
in_heap = set(exclusive)
in_heap.update([-i for i in range(num)])
for s, p in zip(scores, points):
if s > heap_items[0][0] and p not in in_heap:
pop = heapq.heapreplace(heap_items, (s, p))
in_heap.remove(pop[1])
in_heap.add(p)
k = 0
k_last_modify = 0
if isinstance(temp, (tuple, list, np.ndarray)):
t = temp[0]
cool = 1.0 * (temp[0] - temp[1]) / (n_iter + 1)
else:
t = temp
cool = 0
while k < n_iter and k < k_last_modify + early_stop:
new_points = np.empty_like(points)
for i, p in enumerate(points):
new_points[i] = random_walk(p, self.dims)
new_scores = model.predict(new_points)
ac_prob = np.exp((new_scores - scores) / t)
ac_index = np.random.random(len(ac_prob)) < ac_prob
points[ac_index] = new_points[ac_index]
scores[ac_index] = new_scores[ac_index]
for s, p in zip(new_scores, new_points):
if s > heap_items[0][0] and p not in in_heap:
pop = heapq.heapreplace(heap_items, (s, p))
in_heap.remove(pop[1])
in_heap.add(p)
k_last_modify = k
k += 1
t -= cool
if verbose >= 1 and k % verbose == 0:
t_str = "%.2f" % t
logging.info("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
"elapsed: %.2f",
k, k_last_modify, heap_items[0][0],
np.max([v for v, _ in heap_items]), t_str,
time.time() - tic)
heap_items.sort(key=lambda item: -item[0])
if verbose:
logging.info("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic)
logging.info("SA Maximums: %s", heap_items)
if self.persistent:
self.points = points
return [x[1] for x in heap_items]
def random_walk(p, dims):
"""random walk as local transition
Parameters
----------
p: int
index of the ConfigEntity
dims: Array of int
sizes of each dimension
Returns
-------
new_p: int
new neighborhood index
"""
# transform to knob form
old = point2knob(p, dims)
new = list(old)
# mutate
while new == old:
from_i = np.random.randint(len(old))
to_v = np.random.randint(dims[from_i])
new[from_i] = to_v
# transform to index form
return knob2point(new, dims)
# pylint: disable=unused-argument, no-self-use, invalid-name
"""Base class of tuner"""
import logging
import numpy as np
from ..measure import MeasureInput
from ..measure import create_measure_batch
class Tuner(object):
"""Base class for tuners
Parameters
----------
task: autotvm.task.Task
Tuning Task
"""
def __init__(self, task, **kwargs):
self.param = kwargs
self.recorder = None
self.task = task
# keep the current best
self.best_config = None
self.best_flops = 0
self.best_measure_pair = None
def has_next(self):
"""Whether has next untried config in the space
Returns
-------
has_next: bool
"""
raise NotImplementedError()
def next_batch(self, batch_size):
"""get the next batch of configs to be measure on real hardware
Parameters
----------
batch_size: int
The size of the batch
Returns
-------
a batch of configs
"""
raise NotImplementedError()
def update(self, inputs, results):
"""Update parameters of the tuner according to measurement results
Parameters
----------
inputs: Array of autotvm.measure.MeasureInput
The input for measurement
results: Array of autotvm.measure.MeasureResult
result for measurement
"""
pass
def tune(self, n_trial, measure_option, verbose=1, callbacks=()):
"""Begin tuning
Parameters
----------
n_trial: int
Maximum number of configs to try (measure on real hardware)
measure_option: dict
The options for how to measure generated code.
You should use the return value ot autotvm.measure_option for this argument.
verbose: int
0: silent mode, no output
1: print every measurement result
callbacks: List of callable
A list of callback functions. The signature of callback function is
(Tuner, List of MeasureInput, List of MeasureResult)
with no return value. These callback functions will be called on
every measurement pair. See autotvm/tuner/callback.py for some examples.
"""
measure_batch = create_measure_batch(self.task, measure_option)
parallel_num = getattr(measure_batch, 'parallel_num', 1)
i = 0
while i < n_trial:
if not self.has_next():
break
configs = self.next_batch(min(parallel_num, n_trial - i))
inputs = [MeasureInput(self.task.target, self.task, config) for config in configs]
results = measure_batch(inputs)
# print info
if verbose >= 1:
for k, (inp, res) in enumerate(zip(inputs, results)):
config = inp.config
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)
else:
flops = 0
if flops > self.best_flops:
self.best_flops = flops
self.best_config = config
self.best_measure_pair = (inp, res)
logging.info("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9,
res, config)
i += len(results)
self.update(inputs, results)
for callback in callbacks:
callback(self, inputs, results)
del measure_batch
def reset(self):
"""reset the status of tuner"""
self.best_config = None
self.best_flops = 0
self.best_measure_pair = None
def load_history(self, data_set):
"""load history data for transfer learning
Parameters
----------
data_set: Array of (MeasureInput, MeasureResult) pair
Previous tuning records
"""
raise NotImplementedError()
# pylint: disable=invalid-name
"""XGBoost as cost model"""
import multiprocessing
import logging
import time
import numpy as np
try:
import xgboost as xgb
except ImportError:
xgb = None
from .. import feature
from ..util import get_rank
from .metric import max_curve, recall_curve, cover_curve
from .model_based_tuner import CostModel, FeatureCache
class XGBoostCostModel(CostModel):
"""XGBoost as cost model
Parameters
----------
task: Task
The tuning task
feature_type: str, optional
If is 'itervar', use features extracted from IterVar (loop variable).
If is 'knob', use flatten ConfigEntity directly.
If is 'curve', use sampled curve feature (relation feature).
Note on choosing feature type:
For single task tuning, 'itervar' and 'knob' is good.
'itervar' is more accurate but 'knob' is much faster.
For cross-shape tuning (e.g. many convolutions with different shapes),
'itervar' and 'curve' has better transferability,
'knob' is faster.
For cross-device or cross-operator tuning, you can use 'curve' only.
loss_type: str
If is 'reg', use regression loss to train cost model.
The cost model predicts the normalized flops.
If is 'rank', use pairwise rank loss to train cost model.
The cost model predicts relative rank score.
num_threads: int, optional
The number of threads.
verbose: int, optional
If is not none, the cost model will print training log every `verbose` iterations.
"""
def __init__(self, task, feature_type, loss_type, num_threads=None, verbose=20):
super(XGBoostCostModel, self).__init__()
if xgb is None:
raise RuntimeError("XGBoost is required for XGBoostCostModel. "
"Please install its python package first. "
"Help: (https://xgboost.readthedocs.io/en/latest/) ")
self.task = task
self.target = task.target
self.space = task.config_space
self.fea_type = feature_type
self.loss_type = loss_type
self.num_threads = num_threads
self.verbose = verbose
if loss_type == 'reg':
self.xgb_params = {
'max_depth': 3,
'gamma': 0.0001,
'min_child_weight': 1,
'subsample': 1.0,
'eta': 0.3,
'lambda': 1.00,
'alpha': 0,
'objective': 'reg:linear',
}
elif loss_type == 'rank':
self.xgb_params = {
'max_depth': 3,
'gamma': 0.0001,
'min_child_weight': 1,
'subsample': 1.0,
'eta': 0.3,
'lambda': 1.00,
'alpha': 0,
'objective': 'rank:pairwise',
}
else:
raise RuntimeError("Invalid loss type: " + loss_type)
self.xgb_params['silent'] = 1
if num_threads:
self.xgb_params['nthread'] = num_threads
self.bst = None
if feature_type == 'itervar':
self.feature_extract_func = _extract_itervar_feature_index
elif feature_type == 'knob':
self.feature_extract_func = _extract_knob_feature_index
elif feature_type == 'curve':
self.feature_extract_func = _extract_curve_feature_index
else:
raise RuntimeError("Invalid feature type " + feature_type)
self.feature_cache = FeatureCache()
self.feature_extra_ct = 0
self.pool = None
self.base_model = None
self._reset_pool()
def _reset_pool(self):
# reset processing pool for feature extraction
if self.pool:
self.pool.terminate()
self.pool.join()
del self.pool
# use global variable to pass common arguments
global _extract_space, _extract_target, _extract_task
_extract_space = self.space
_extract_target = self.target
_extract_task = self.task
self.pool = multiprocessing.Pool(self.num_threads)
def fit(self, xs, ys, plan_size):
tic = time.time()
self._reset_pool()
x_train = self._get_feature(xs)
y_train = np.array(ys)
y_train /= np.max(y_train)
valid_index = y_train > 1e-6
index = np.random.permutation(len(x_train))
dtrain = xgb.DMatrix(x_train[index], y_train[index])
if self.base_model:
dtrain.set_base_margin(self.base_model.predict(xs, output_margin=True))
self.bst = xgb.train(self.xgb_params, dtrain,
num_boost_round=8000,
callbacks=[custom_callback(
stopping_rounds=20,
metric='tr-a-recall@%d' % plan_size,
evals=[(dtrain, 'tr')],
maximize=True,
fevals=[
xgb_average_recalln_curve_score(plan_size),
],
verbose_eval=self.verbose)])
logging.info("train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
time.time() - tic, len(xs),
len(xs) - np.sum(valid_index),
self.feature_cache.size(self.fea_type))
def fit_log(self, records, plan_size):
tic = time.time()
self._reset_pool()
args = list(records)
if self.fea_type == 'itervar':
feature_extract_func = _extract_itervar_feature_log
elif self.fea_type == 'knob':
feature_extract_func = _extract_knob_feature_log
elif self.fea_type == 'curve':
feature_extract_func = _extract_curve_feature_log
else:
raise RuntimeError("Invalid feature type: " + self.fea_type)
res = self.pool.map(feature_extract_func, args)
xs, ys = zip(*res)
xs, ys = np.array(xs), np.array(ys)
x_train = xs
y_train = ys
y_train /= np.max(y_train)
index = np.random.permutation(len(x_train))
dtrain = xgb.DMatrix(x_train[index], y_train[index])
plan_size *= 2
self.bst = xgb.train(self.xgb_params, dtrain,
num_boost_round=200,
callbacks=[custom_callback(
stopping_rounds=100,
metric='tr-a-recall@%d' % plan_size,
evals=[(dtrain, 'tr')],
maximize=True,
fevals=[
xgb_average_recalln_curve_score(plan_size),
],
verbose_eval=self.verbose)])
logging.info("train: %.2f\tobs: %d", time.time() - tic, len(xs))
def predict(self, xs, output_margin=False):
feas = self._get_feature(xs)
dtest = xgb.DMatrix(feas)
if self.base_model:
dtest.set_base_margin(self.base_model.predict(xs, output_margin=True))
return self.bst.predict(dtest, output_margin=output_margin)
def load_basemodel(self, base_model):
self.base_model = base_model
def clone_new(self):
return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
self.num_threads, self.verbose)
def _get_feature(self, indexes):
"""get features for indexes, run extraction if we do not have cache for them"""
# free feature cache
if self.feature_cache.size(self.fea_type) >= 100000:
self.feature_cache.clear(self.fea_type)
fea_cache = self.feature_cache.get(self.fea_type)
indexes = np.array(indexes)
need_extract = [x for x in indexes if x not in fea_cache]
if need_extract:
feas = self.pool.map(self.feature_extract_func, need_extract)
for i, fea in zip(need_extract, feas):
fea_cache[i] = fea
ret = np.empty((len(indexes), fea_cache[indexes[0]].shape[-1]), dtype=np.float32)
for i, ii in enumerate(indexes):
ret[i, :] = fea_cache[ii]
return ret
_extract_space = None
_extract_target = None
_extract_task = None
def _extract_itervar_feature_index(index):
"""extract iteration var feature for an index in extract_space"""
config = _extract_space.get(index)
with _extract_target:
sch, args = _extract_task.instantiate(config)
fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
fea = np.concatenate((fea, list(config.get_other_option().values())))
return fea
def _extract_itervar_feature_log(arg):
"""extract iteration var feature for log items"""
inp, res = arg
config = inp.config
with inp.target:
sch, args = inp.task.instantiate(config)
fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
x = np.concatenate((fea, list(config.get_other_option().values())))
if res.error_no == 0:
y = inp.task.flop / np.mean(res.costs)
else:
y = 0
return x, y
def _extract_knob_feature_index(index):
"""extract knob feature for an index in extract_space"""
config = _extract_space.get(index)
return config.get_flatten_feature()
def _extract_knob_feature_log(arg):
"""extract knob feature for log items"""
inp, res = arg
config = inp.config
x = config.get_flatten_feature()
if res.error_no == 0:
with inp.target: # necessary, for calculating flops of this task
inp.task.instantiate(config)
y = inp.task.flop / np.mean(res.costs)
else:
y = 0
return x, y
def _extract_curve_feature_index(index):
"""extract sampled curve feature for an index in extract_space"""
config = _extract_space.get(index)
with _extract_target:
sch, args = _extract_task.instantiate(config)
fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
fea = np.concatenate((fea, list(config.get_other_option().values())))
return np.array(fea)
def _extract_curve_feature_log(arg):
"""extract sampled curve feature for log items"""
inp, res = arg
config = inp.config
with inp.target:
sch, args = inp.task.instantiate(config)
fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
x = np.concatenate((fea, list(config.get_other_option().values())))
if res.error_no == 0:
y = inp.task.flop / np.mean(res.costs)
else:
y = 0
return x, y
def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
save_file="xgb_checkpoint", save_every=None,
maximize=False, verbose_eval=True):
"""callback function for xgboost to support multiple custom evaluation functions"""
from xgboost.core import EarlyStopException
from xgboost.callback import _fmt_metric
from xgboost.training import aggcv
state = {}
metric_shortname = metric.split("-")[1]
def init(env):
"""internal function"""
bst = env.model
state['maximize_score'] = maximize
state['best_iteration'] = 0
if maximize:
state['best_score'] = float('-inf')
else:
state['best_score'] = float('inf')
if bst is not None:
if bst.attr('best_score') is not None:
state['best_score'] = float(bst.attr('best_score'))
state['best_iteration'] = int(bst.attr('best_iteration'))
state['best_msg'] = bst.attr('best_msg')
else:
bst.set_attr(best_iteration=str(state['best_iteration']))
bst.set_attr(best_score=str(state['best_score']))
else:
assert env.cvfolds is not None
def callback(env):
"""internal function"""
if not state:
init(env)
bst = env.model
i = env.iteration
cvfolds = env.cvfolds
res_dict = {}
##### evaluation #####
if cvfolds is not None:
for feval in fevals:
tmp = aggcv([f.eval(i, feval) for f in cvfolds])
for k, mean, std in tmp:
res_dict[k] = [mean, std]
else:
for feval in fevals:
bst_eval = bst.eval_set(evals, i, feval)
res = [x.split(':') for x in bst_eval.split()]
for kv in res[1:]:
res_dict[kv[0]] = [float(kv[1])]
eval_res = []
keys = list(res_dict.keys())
keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x)
for key in keys:
v = res_dict[key]
eval_res.append([key] + v)
##### print eval result #####
infos = ["XGB iter: %3d" % i]
for item in eval_res:
if 'null' in item[0]:
continue
infos.append("%s: %.6f" % (item[0], item[1]))
if not isinstance(verbose_eval, bool) and i % verbose_eval == 0:
logging.info("\t".join(infos))
if log_file:
with open(log_file, "a") as fout:
fout.write("\t".join(infos) + '\n')
##### save model #####
if save_every and i % save_every == 0:
filename = save_file + ".%05d.bst" % i
logging.info("save model to %s ...", filename)
bst.save_model(filename)
##### choose score and do early stopping #####
score = None
for item in eval_res:
if item[0] == metric:
score = item[1]
break
assert score is not None
best_score = state['best_score']
best_iteration = state['best_iteration']
maximize_score = state['maximize_score']
if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score):
msg = '[%d] %s' % (
env.iteration,
'\t'.join([_fmt_metric(x) for x in eval_res]))
state['best_msg'] = msg
state['best_score'] = score
state['best_iteration'] = env.iteration
# save the property to attributes, so they will occur in checkpoint.
if env.model is not None:
env.model.set_attr(best_score=str(state['best_score']),
best_iteration=str(state['best_iteration']),
best_msg=state['best_msg'])
elif env.iteration - best_iteration >= stopping_rounds:
best_msg = state['best_msg']
if verbose_eval and env.rank == 0:
logging.info("Stopping. Best iteration: %s ", best_msg)
raise EarlyStopException(best_iteration)
return callback
# feval wrapper for xgboost
def xgb_max_curve_score(N):
"""evaluate max curve score for xgb"""
def feval(preds, labels):
labels = labels.get_label()
trials = np.argsort(preds)[::-1]
scores = labels[trials]
curve = max_curve(scores)
return "Smax@%d" % N, curve[N] / np.max(labels)
return feval
def xgb_recalln_curve_score(N):
"""evaluate recall-n curve score for xgb"""
def feval(preds, labels):
labels = labels.get_label()
trials = np.argsort(preds)[::-1]
ranks = get_rank(labels[trials])
curve = recall_curve(ranks)
return "recall@%d" % N, curve[N]
return feval
def xgb_average_recalln_curve_score(N):
"""evaluate average recall-n curve score for xgb"""
def feval(preds, labels):
labels = labels.get_label()
trials = np.argsort(preds)[::-1]
ranks = get_rank(labels[trials])
curve = recall_curve(ranks)
return "a-recall@%d" % N, np.sum(curve[:N]) / N
return feval
def xgb_recallk_curve_score(N, topk):
"""evaluate recall-k curve score for xgb"""
def feval(preds, labels):
labels = labels.get_label()
trials = np.argsort(preds)[::-1]
ranks = get_rank(labels[trials])
curve = recall_curve(ranks, topk)
return "recall@%d" % topk, curve[N]
return feval
def xgb_cover_curve_score(N):
"""evaluate cover curve score for xgb"""
def feval(preds, labels):
labels = labels.get_label()
trials = np.argsort(preds)[::-1]
ranks = get_rank(labels[trials])
curve = cover_curve(ranks)
return "cover@%d" % N, curve[N]
return feval
def xgb_null_score(_):
"""empty score function for xgb"""
def feval(__, ___):
return "null", 0
return feval
"""Tuner that uses xgboost as cost model"""
from .model_based_tuner import ModelBasedTuner, ModelOptimizer
from .xgboost_cost_model import XGBoostCostModel
from .sa_model_optimizer import SimulatedAnnealingOptimizer
class XGBTuner(ModelBasedTuner):
"""Tuner that uses xgboost as cost model
Parameters
----------
task: Task
The tuning task
plan_size: int
The size of a plan. After `plan_size` trials, the tuner will refit a new cost model
and do planing for the next `plan_size` trials.
feature_type: str, optional
If is 'itervar', use features extracted from IterVar (loop variable).
If is 'knob', use flatten ConfigEntity directly.
If is 'curve', use sampled curve feature (relation feature).
Note on choosing feature type:
For single task tuning, 'itervar' and 'knob' is good.
'itervar' is more accurate but 'knob' is much faster.
For cross-shape tuning (e.g. many convolutions with different shapes),
'itervar' and 'curve' has better transferability,
'knob' is faster.
For cross-device or cross-operator tuning, you can use 'curve' only.
loss_type: str
If is 'reg', use regression loss to train cost model.
The cost model predicts the normalized flops.
If is 'rank', use pairwise rank loss to train cost model.
The cost model predicts relative rank score.
num_threads: int, optional
The number of threads.
optimizer: str or ModelOptimizer, optional
If is 'sa', use a default simulated annealing optimizer.
Otherwise it should be a ModelOptimizer object.
diversity_filter_ratio: int or float, optional
If is not None, the tuner will first select
top-(plan_size * diversity_filter_ratio) candidates according to the cost model
and then pick batch_size of them according to the diversity metric.
"""
def __init__(self, task, plan_size=32,
feature_type='itervar', loss_type='rank', num_threads=None,
optimizer='sa', diversity_filter_ratio=None):
cost_model = XGBoostCostModel(task,
feature_type=feature_type,
loss_type=loss_type,
num_threads=num_threads)
if optimizer == 'sa':
optimizer = SimulatedAnnealingOptimizer(task)
else:
assert isinstance(optimizer, ModelOptimizer), "Optimizer must be " \
"a supported name string" \
"or a ModelOptimizer object."
super(XGBTuner, self).__init__(task, cost_model, optimizer,
plan_size, diversity_filter_ratio)
# pylint: disable=invalid-name
"""Utilities"""
import logging
import multiprocessing
import time
import numpy as np
from .. import expr, ir_pass
def get_rank(values):
"""get rank of items
Parameters
----------
values: Array
Returns
-------
ranks: Array of int
the rank of this item in the input (the largest value ranks first)
"""
tmp = np.argsort(-values)
ranks = np.empty_like(tmp)
ranks[tmp] = np.arange(len(tmp))
return ranks
def sample_ints(low, high, m):
"""
Sample m different integer numbers from [low, high) without replacement
This function is an alternative of `np.random.choice` when (high - low) > 2 ^ 32, in
which case numpy does not work.
Parameters
----------
low: int
low point of sample range
high: int
high point of sample range
m: int
The number of sampled int
Returns
-------
ints: an array of size m
"""
vis = set()
assert m <= high - low
while len(vis) < m:
new = np.random.randint(low, high)
while new in vis:
new = np.random.randint(low, high)
vis.add(new)
return list(vis)
def pool_map(func, args, batch_size, verbose=False, pool=None):
"""A wrapper of multiprocessing.pool.Pool.map to support small-batch mapping
for large argument list. This can reduce memory usage
Parameters
----------
func: Func(arg) -> np.ndarray
mapping function
args: List
list of arguments
batch_size: int
batch size in mapping
verbose: bool, optional
whether print progress
pool: multiprocessing.Pool, optional
pool objection
Returns
-------
converted numpy array
"""
ret = None
tic = time.time()
local_pool = pool or multiprocessing.Pool()
if verbose:
logging.info("mapping begin")
for i in range(0, len(args), batch_size):
if verbose:
logging.info("mapping %d/%d elapsed %.2f", i, len(args),
time.time() - tic)
tmp = np.array(local_pool.map(func, args[i:i+batch_size]))
ret = tmp if ret is None else np.concatenate((ret, tmp))
if verbose:
logging.info("mapping done")
if not pool:
local_pool.close()
return ret
def get_func_name(func):
"""Get name of a function
Parameters
----------
func: Function
The function
Returns
-------
name: str
The name
"""
return func.func_name if hasattr(func, 'func_name') else func.__name__
def get_const_int(exp):
"""Verifies expr is integer and get the constant value.
Parameters
----------
exp : tvm.Expr or int
The input expression.
Returns
-------
out_value : int
The output.
"""
if isinstance(exp, int):
return exp
if not isinstance(exp, (expr.IntImm, expr.UIntImm)):
exp = ir_pass.Simplify(expr)
if not isinstance(exp, (expr.IntImm, expr.UIntImm)):
raise ValueError("Expect value to be constant int")
return exp.value
def get_const_tuple(in_tuple):
"""Verifies input tuple is IntImm, returns tuple of int.
Parameters
----------
in_tuple : tuple of Expr
The input.
Returns
-------
out_tuple : tuple of int
The output.
"""
return tuple(get_const_int(x) for x in in_tuple)
...@@ -229,8 +229,14 @@ class TrackerSession(object): ...@@ -229,8 +229,14 @@ class TrackerSession(object):
res += "----------------------------\n" res += "----------------------------\n"
res += "key\tfree\tpending\n" res += "key\tfree\tpending\n"
res += "----------------------------\n" res += "----------------------------\n"
for k, v in data["queue_info"].items(): queue_info = data['queue_info']
res += "%s\t%d\t%g\n" % (k, v["free"], v["pending"]) keys = list(queue_info.keys())
if keys:
keys.sort()
max_key_len = max([len(k) for k in keys])
for k in keys:
res += ("%%-%d" % max_key_len + "s\t%d\t%g\n") % \
(k, queue_info[k]["free"], queue_info[k]["pending"])
res += "----------------------------\n" res += "----------------------------\n"
return res return res
......
/*!
* Copyright (c) 2018 by Contributors
* \file feature_visitor.cc
* \brief Base class for feature extractor.
* These features are used for machine learning cost model
*/
#include "feature_visitor.h"
namespace tvm {
namespace autotvm {
// for loop
void FeatureVisitor::Visit_(const For *op) {
const auto *extent = op->extent.as<IntImm>();
int64_t loop_extent = -1;
if (extent != nullptr)
loop_extent = extent->value;
AnnotationType ann = kSerial;
switch (op->for_type) {
case ForType ::Parallel:
ann = kParallel;
break;
case ForType::Unrolled:
ann = kUnrolled;
break;
case ForType::Vectorized:
ann = kVectorized;
break;
case ForType::Serial:
ann = kSerial;
break;
}
if (EnterItervar_(op->loop_var, loop_extent, ann)) {
IRVisitor::Visit_(op);
ExitItervar_();
}
}
// parallel axis, virtual thread
void FeatureVisitor::Visit_(const AttrStmt *op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
VarExpr var = op->node.as<tvm::IterVarNode>()->var;
const auto *extent = op->value.as<IntImm>();
CHECK(extent);
std::string name = var.get()->name_hint;
AnnotationType ann = kParallel;
if (op->attr_key == attr::thread_extent) {
if (name == "blockIdx.x")
ann = kBlockX;
else if (name == "blockIdx.y")
ann = kBlockY;
else if (name == "blockIdx.z")
ann = kBlockZ;
else if (name == "threadIdx.x")
ann = kThreadX;
else if (name == "threadIdx.y")
ann = kThreadY;
else if (name == "threadIdx.z")
ann = kThreadZ;
else
LOG(FATAL) << "invalid thread itervar " + name;
} else {
ann = kVirtualThread;
}
if (EnterItervar_(var, extent->value, ann)) {
IRVisitor::Visit_(op);
ExitItervar_();
}
} else {
IRVisitor::Visit_(op);
}
}
// memory access
void FeatureVisitor::Visit_(const Load *op) {
EnterMem_(op->buffer_var, op->index);
IRVisitor::Visit_(op);
ExitMem_();
}
void FeatureVisitor::Visit_(const Store *op) {
EnterMem_(op->buffer_var, op->index);
IRVisitor::Visit_(op);
ExitMem_();
}
} // namespace autotvm
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file feature_visitor.h
* \brief Base class for feature extractor.
* These features are used for machine learning cost model
*/
#ifndef TVM_AUTOTVM_FEATURE_VISITOR_H_
#define TVM_AUTOTVM_FEATURE_VISITOR_H_
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <string>
namespace tvm {
namespace autotvm {
using namespace tvm::ir;
/*!
* \brief Type of for loop, used as one-hot encoding in features
*/
enum AnnotationType {
kBlockX, kBlockY, kBlockZ, kThreadX, kThreadY, kThreadZ,
kUnrolled, kVectorized, kParallel, kSerial, kVirtualThread,
kNum,
};
/*!
* \brief A base class for feature extractor, used for processing
* for loop and memory access in the IR
*/
class FeatureVisitor : public IRVisitor {
public:
// for loop
void Visit_(const For *op);
void Visit_(const AttrStmt *op);
// memory access
void Visit_(const Load *op);
void Visit_(const Store *op);
protected:
/*!
* \brief Enter a for loop node
* \param var The expression to be printed.
* \param length The output stream
* \param ann_type The type for the for loop
* \return skip Whether skip this node
*/
virtual bool EnterItervar_(tvm::VarExpr var, int64_t length, AnnotationType ann_type) = 0;
/*! \brief Exit a for loop subtree */
virtual void ExitItervar_() = 0;
/*!
* \brief Enter a memory access node
* \param buffer_var The buffer to access.
* \param index Index expression
*/
virtual void EnterMem_(tvm::VarExpr buffer_var, tvm::Expr index) = 0;
/*! \brief Exit a memory access node */
virtual void ExitMem_() = 0;
};
} // namespace autotvm
} // namespace tvm
#endif // TVM_AUTOTVM_FEATURE_VISITOR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file touch_extractor.cc
* \brief Extract feature of touch pattern of axes in lowered IR
*/
#include "touch_extractor.h"
#include <set>
#include <algorithm>
#include <cmath>
namespace tvm {
namespace autotvm {
int ParallelLevel(AnnotationType ann) {
switch (ann) {
case kBlockX: case kBlockY: case kBlockZ:
return 2;
case kThreadX: case kThreadY: case kThreadZ: case kParallel:
return 1;
default:
return 0;
}
}
// get touch pattern from index expression
class IndexParser: public IRVisitor {
public:
void Parse(Expr expr) {
pattern_map.clear();
this->Visit(expr);
}
void Visit_(const Variable *op) {
// TODO(lmzheng): handle more index types (multiple occurrence)
if (pattern_map.count(op) == 0) {
pattern_map[op] = TouchPattern();
pattern_map[op].stride = next_stride_;
next_stride_ = 1;
}
}
void Visit_(const Mul *op) {
if (op->a.as<Variable>()) {
if (const auto stride = op->b.as<IntImm>()) {
next_stride_ = stride->value;
}
}
IRVisitor::Visit_(op);
}
std::unordered_map<const Variable*, TouchPattern> pattern_map;
private:
int64_t next_stride_ = 1;
};
// extract iter vars and their touch pattern from ir
bool TouchExtractor::EnterItervar_(VarExpr var, int64_t length, AnnotationType ann_type) {
// do not insert duplicated occurrences of virtual thread
if (ann_type == kVirtualThread && itervar_map.count(var) != 0) {
skip_stack_size_.push_back(itervar_stack_.size());
return true;
} else {
itervar_stack_.push_back(var);
topdown_product_ *= length;
if (itervar_map.count(var) != 0) {
// find two duplicated axes
// these happens when we create tvm.thread_axis("threadIdx.x") once and
// bind it twice. Here we treat them as two axes
// so we create a snapshot for the old one and freeze it
VarExpr old = VarExpr(var.get()->name_hint);
itervar_map.insert({old, itervar_map[var]});
itervar_map.erase(var);
}
itervar_map.insert({var, ItervarFeature(var, length,
static_cast<int>(itervar_stack_.size()),
ann_type,
topdown_product_,
static_cast<int>(itervar_counter_++))});
}
return true;
}
void TouchExtractor::ExitItervar_() {
if (!skip_stack_size_.empty() && skip_stack_size_.back() == itervar_stack_.size()) {
skip_stack_size_.pop_back();
return;
}
VarExpr var = itervar_stack_.back();
// update count and reuse ratio for upper iter vars (includes self)
for (auto kv : itervar_map[var].touch_feature) {
if (kv.second.stride != 0) { // multiply count
for (auto stack_var : itervar_stack_) {
auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first);
CHECK(touch_pattern != itervar_map[stack_var].touch_feature.end());
touch_pattern->second.count *= itervar_map[var].length;
}
} else { // multiply reuse ratio
for (auto stack_var : itervar_stack_) {
auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first);
CHECK(touch_pattern != itervar_map[stack_var].touch_feature.end());
touch_pattern->second.reuse *= itervar_map[var].length;
}
}
}
itervar_stack_.pop_back();
topdown_product_ /= itervar_map[var].length;
int64_t bottomup_product = -1;
for (auto kv : itervar_map[var].touch_feature) {
bottomup_product = std::max(bottomup_product, kv.second.count * kv.second.reuse);
}
itervar_map[var].bottomup_product = bottomup_product;
// push base to upper parallel axis
int para_level = ParallelLevel(itervar_map[var].ann);
// if is the separate line of parallel level, push the base to upper parallel level
if (!itervar_stack_.empty() &&
ParallelLevel(itervar_map[itervar_stack_.back()].ann) == para_level + 1) {
for (auto kv : itervar_map[var].touch_feature) {
for (auto stack_var : itervar_stack_) {
if (ParallelLevel(itervar_map[stack_var].ann) == para_level + 1) {
auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first);
CHECK(touch_pattern != itervar_map[stack_var].touch_feature.end());
touch_pattern->second.thread_reuse = -kv.second.reuse;
touch_pattern->second.thread_count = -kv.second.count;
// NOTE: use minus as a flag to denote it is a base,
// indicating it is not the final value
}
}
}
}
for (auto kv : itervar_map[var].touch_feature) {
if (kv.second.thread_count < 0) {
itervar_map[var].touch_feature[kv.first].thread_count =
kv.second.count / (-kv.second.thread_count);
itervar_map[var].touch_feature[kv.first].thread_reuse =
kv.second.reuse / (-kv.second.thread_reuse);
}
}
}
void TouchExtractor::EnterMem_(VarExpr buffer_var, Expr index) {
std::string name = buffer_var.get()->name_hint;
TouchedBuffer buf = name + "_" + std::to_string(buffer_counter_[name]++);
// extract touch pattern from index
IndexParser parser;
parser.Parse(index);
// push up mem access info
for (auto var : itervar_stack_) {
auto x = parser.pattern_map.find(var.get());
if (x != parser.pattern_map.end()) {
itervar_map[var].touch_feature[buf] = x->second;
} else {
itervar_map[var].touch_feature[buf] = TouchPattern();
}
}
}
void TouchExtractor::ExitMem_() {
}
/*!
* \brief Get axis-based feature for all axes
* \param stmt The statement to be extracted
* \param bool Whether take log for numerical feature
* \param ret_feature The buffer where the return value is stored
*
* \note The format of return value is
* ((
* ('_itervar_', var),
* ('_attr_', length, nest_level, topdown, bottomup, one_hot_annotation),
* ('_arith_', add_ct, mul_ct, div_ct),
* ('data_vec_0', stride, mod, count, reuse, thread_count, thread_reuse),
* ('conv_0', stride, mod, count, reuse, thread_count, thread_reuse),
* ),
* (
* ('_itervar_', var2),
* ('_attr_', length, nest_level, one_hot_annotation),
* ('_arith_', add_ct, mul_ct, div_ct),
* ('kernel_vec_0', stride, mod, count, reuse, thread_count, thread_reuse),
* ('conv_1', stride, mod, count, reuse, thread_count, thread_reuse),
* ))
*
* Itervars are sorted according to their first occurrence position in IR.
* Buffers touched by an itervar are sorted by their unique names.
*
* \note If you want to flatten these features as the input of your model,
* You can use the faster one GetItervarFeatureFlatten below.
*/
void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *ret_feature) {
// extract
TouchExtractor touch_analyzer;
touch_analyzer.Analyze(stmt);
// sort according to order
std::vector<VarExpr> vars;
for (auto kv : touch_analyzer.itervar_map) {
vars.push_back(kv.first);
}
std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool {
return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order;
});
// whether take log for numerical feature
std::function<double(int64_t)> trans;
if (take_log) {
trans = [](int64_t x) {
if (x < 0)
return -std::log(-x+1) / std::log(2);
x = x + 1;
return std::log(x) / std::log(2);
};
} else {
trans = [](int64_t x) {
return x;
};
}
// serialize for front end
for (auto var : vars) {
Array<Array<Expr> > feature_row;
ItervarFeature &fea = touch_analyzer.itervar_map[var];
feature_row.push_back(Array<Expr>{std::string("_itervar_"), var});
Array<Expr> attr{std::string("_attr_"),
FloatImm::make(Float(32), trans(fea.length)),
IntImm::make(Int(32), fea.nest_level),
FloatImm::make(Float(32), trans(fea.topdown_product)),
FloatImm::make(Float(32), trans(fea.bottomup_product)),
};
// one hot annotation
for (int i = 0; i < kNum; i++) {
attr.push_back(i == fea.ann);
}
feature_row.push_back(attr);
// arithmetic
feature_row.push_back(Array<Expr>{std::string("_arith_"),
FloatImm::make(Float(32), trans(fea.add_ct)),
FloatImm::make(Float(32), trans(fea.mul_ct)),
FloatImm::make(Float(32), trans(fea.div_ct)),
});
// touch map
std::vector<TouchedBuffer> bufs;
for (auto kv : fea.touch_feature) {
bufs.push_back(kv.first);
}
std::sort(bufs.begin(), bufs.end());
for (auto k : bufs) {
TouchPattern &v = fea.touch_feature[k];
feature_row.push_back(Array<Expr>{k,
FloatImm::make(Float(32), trans(v.stride)),
FloatImm::make(Float(32), trans(v.mod)),
FloatImm::make(Float(32), trans(v.count)),
FloatImm::make(Float(32), trans(v.reuse)),
FloatImm::make(Float(32), trans(v.thread_count)),
FloatImm::make(Float(32), trans(v.thread_reuse)),
});
}
ret_feature->push_back(feature_row);
}
}
/*!
* \brief Get axis-based feature for all axes and flatten them into a one-dimensional vector.
* \param stmt The statement to be extracted
* \param bool Whether take log for numerical feature
* \param ret_feature The buffer where the return value is stored
*
* \note See GetItervarFeature for more details about the return value.
* This is an optimized version of GetItervarFeature + Flatten. This runs much faster.
*/
void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector<float> *ret_feature) {
// extract touch feature
TouchExtractor touch_analyzer;
touch_analyzer.Analyze(stmt);
// sort according to order
std::vector<VarExpr> vars;
for (auto kv : touch_analyzer.itervar_map) {
vars.push_back(kv.first);
}
std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool {
return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order;
});
// whether take log for numerical feature
std::function<float(int64_t)> trans;
if (take_log) {
trans = [](int64_t x) {
if (x < 0)
return -std::log(-x+1) / std::log(2);
x = x + 1;
return std::log(x) / std::log(2);
};
} else {
trans = [](int64_t x) {
return x;
};
}
// serialize for front end
for (auto var : vars) {
ItervarFeature &fea = touch_analyzer.itervar_map[var];
ret_feature->push_back(trans(fea.length));
ret_feature->push_back(fea.nest_level);
ret_feature->push_back(trans(fea.topdown_product));
ret_feature->push_back(trans(fea.bottomup_product));
// one hot annotation
for (int i = 0; i < kNum; i++) {
ret_feature->push_back(i == fea.ann);
}
// arithmetic
ret_feature->push_back(trans(fea.add_ct));
ret_feature->push_back(trans(fea.mul_ct));
ret_feature->push_back(trans(fea.div_ct));
// touch map
std::vector<TouchedBuffer> bufs;
for (auto kv : fea.touch_feature) {
bufs.push_back(kv.first);
}
std::sort(bufs.begin(), bufs.end());
for (auto k : bufs) {
TouchPattern &v = fea.touch_feature[k];
ret_feature->push_back(trans(v.stride));
ret_feature->push_back(trans(v.mod));
ret_feature->push_back(trans(v.count));
ret_feature->push_back(trans(v.reuse));
ret_feature->push_back(trans(v.thread_count));
ret_feature->push_back(trans(v.thread_reuse));
}
}
}
/*!
* \brief Get curve sample feature (relation feature) and flatten them into a one-dimensional vector.
* \param stmt The statement to be extracted
* \param sample_n The number of points used for sampling a curve (along one dimension)
* \param ret_feature The buffer where the return value is stored
*/
void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector<float> *ret_feature) {
// extract touch feature
TouchExtractor touch_ext;
touch_ext.Analyze(stmt);
// sort according to order
std::vector<VarExpr> vars;
for (auto kv : touch_ext.itervar_map) {
vars.push_back(kv.first);
}
std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool {
return touch_ext.itervar_map[lhs].order < touch_ext.itervar_map[rhs].order;
});
int max_depth = 0;
std::map<TouchedBuffer, std::vector<double> > reuse_curve;
std::map<TouchedBuffer, std::vector<double> > count_curve;
std::map<TouchedBuffer, std::vector<double> > topdown_curve;
std::map<TouchedBuffer, std::vector<double> > bottomup_curve;
std::set<TouchedBuffer> innermost_buffers;
std::set<std::string> added;
// find maximum depth of loop nest
for (auto var : vars) {
ItervarFeature &fea = touch_ext.itervar_map[var];
max_depth = std::max(max_depth, fea.nest_level);
}
// mark inner most buffer
for (auto iter = vars.rbegin(); iter != vars.rend(); iter++) {
auto var = *iter;
ItervarFeature &fea = touch_ext.itervar_map[var];
if (fea.nest_level == max_depth) {
for (auto kv : fea.touch_feature) {
// delete buffer no (e.g. 'A_0' -> 'A', 'A_1' -> 'A')
std::string raw_name = kv.first.substr(0, kv.first.rfind("_"));
// delete memory scope (e.g. 'A.local' -> 'A', 'A.shared' -> 'A')
size_t pos = raw_name.find(".");
if (pos < kv.first.size())
raw_name = raw_name.substr(0, pos);
// If there are multiple innermost buffers that are derived from a same raw buffer
// We only record the last occurrence (note the `iter` is in reverse order)
// e.g. `A.local`, `A.shared` are derived from `A`, if they all occurred at the inner most
// level, we will only record the last occurrence,
if (added.find(raw_name) == added.end()) {
innermost_buffers.insert(kv.first);
added.insert(raw_name);
}
}
}
}
// pad the first point (zero) for all curves
for (auto buf : innermost_buffers) {
reuse_curve[buf].push_back(0);
count_curve[buf].push_back(0);
topdown_curve[buf].push_back(0);
bottomup_curve[buf].push_back(0);
}
// extract curves
for (auto var : vars) {
ItervarFeature &fea = touch_ext.itervar_map[var];
for (auto kv : fea.touch_feature) {
if (innermost_buffers.find(kv.first) != innermost_buffers.end()) {
reuse_curve[kv.first].emplace_back(std::log(kv.second.reuse) / std::log(2));
count_curve[kv.first].emplace_back(std::log(kv.second.count) / std::log(2));
topdown_curve[kv.first].emplace_back(std::log(fea.topdown_product) / std::log(2));
bottomup_curve[kv.first].emplace_back(std::log(fea.bottomup_product) / std::log(2));
}
}
}
// sample relation in the curve
auto sample_curve = [&](const std::vector<double> &x, const std::vector<double> &y,
double weight) {
for (int i = 0; i < sample_n; i++) {
double xx = i * weight;
for (int j = static_cast<int>(x.size()) - 1; j >= 0; j--) {
if (xx > x[j] - 1e-6) {
ret_feature->emplace_back(y[j]);
ret_feature->emplace_back(xx - x[j]);
break;
}
}
}
};
// serialize to frontend
for (auto k : innermost_buffers) {
std::vector<double> &count = count_curve[k];
std::vector<double> &reuse = reuse_curve[k];
std::vector<double> &top_down = topdown_curve[k];
std::sort(count.begin(), count.end());
std::sort(reuse.begin(), reuse.end());
std::sort(top_down.begin(), top_down.end());
sample_curve(count, reuse, 1);
sample_curve(reuse, count, 1);
sample_curve(count, top_down, 1);
sample_curve(top_down, count, 1);
}
}
// register API for front end
TVM_REGISTER_API("autotvm.feature.GetItervarFeature")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Stmt stmt = args[0];
bool take_log = args[1];
Array<Array<Array<Expr > > > ret_feature;
GetItervarFeature(stmt, take_log, &ret_feature);
*ret = ret_feature;
});
TVM_REGISTER_API("autotvm.feature.GetItervarFeatureFlatten")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Stmt stmt = args[0];
bool take_log = args[1];
std::vector<float> ret_feature;
GetItervarFeatureFlatten(stmt, take_log, &ret_feature);
TVMByteArray arr;
arr.size = sizeof(float) * ret_feature.size();
arr.data = reinterpret_cast<char *>(ret_feature.data());
*ret = arr;
});
TVM_REGISTER_API("autotvm.feature.GetCurveSampleFeatureFlatten")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Stmt stmt = args[0];
bool take_log = args[1];
std::vector<float> ret_feature;
GetCurveSampleFeatureFlatten(stmt, take_log, &ret_feature);
TVMByteArray arr;
arr.size = sizeof(float) * ret_feature.size();
arr.data = reinterpret_cast<char *>(ret_feature.data());
*ret = arr;
});
} // namespace autotvm
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file touch_extractor.h
* \brief Extract feature of touch pattern of axes in lowered IR
*/
#ifndef TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
#define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h>
#include <stack>
#include <vector>
#include <map>
#include <string>
#include <deque>
#include "feature_visitor.h"
namespace tvm {
namespace autotvm {
using TouchedBuffer = std::string;
// touch pattern buf[(stride * var) % mod) + other]
struct TouchPattern {
int64_t stride{0};
int64_t mod{-1}; // -1 for +inf
int64_t count{1};
int64_t reuse{1};
int64_t thread_count{0}; // count when move thread axis into innermost
int64_t thread_reuse{0}; // reuse ratio move thread axis into innermost
};
// all the feature of an iter var
struct ItervarFeature {
ItervarFeature(VarExpr var,
int64_t extent,
int nest,
AnnotationType ann_type,
int64_t topdown,
int counter)
: length(extent), nest_level(nest), ann(ann_type), topdown_product(topdown), order(counter) {}
ItervarFeature() {}
// Axis Attributes
int64_t length;
int nest_level;
AnnotationType ann; // one-hot axis type
int64_t topdown_product; // accumulative product of axis length, in top-down order
int64_t bottomup_product; // accumulative product of axis length, in bottom-up order
// bottomup_product = reuse * count for any touched buffer
int order; // used for soring axis
// Arithmetic feature
int add_ct{0};
int mul_ct{0};
int div_ct{0};
// Memory Touch Feature
std::unordered_map<TouchedBuffer, TouchPattern> touch_feature;
};
// extract iter vars and their touch pattern from ir
class TouchExtractor : public FeatureVisitor {
public:
void Analyze(Stmt stmt) {
this->Visit(stmt);
}
// arithmetic stats
void Visit_(const Add *op) {
if (op->type.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
IRVisitor::Visit_(op);
}
void Visit_(const Sub *op) {
if (op->type.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
IRVisitor::Visit_(op);
}
void Visit_(const Mul *op) {
if (op->type.is_float())
itervar_map[itervar_stack_.back()].mul_ct++;
IRVisitor::Visit_(op);
}
void Visit_(const Div *op) {
if (op->type.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
IRVisitor::Visit_(op);
}
void Visit_(const Mod *op) {
if (op->type.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
IRVisitor::Visit_(op);
}
std::unordered_map<VarExpr, ItervarFeature, tvm::ExprHash, tvm::ExprEqual> itervar_map;
private:
bool EnterItervar_(VarExpr var, int64_t length, AnnotationType ann_type);
void ExitItervar_();
void EnterMem_(VarExpr buffer_var, Expr index);
void ExitMem_();
int64_t topdown_product_{1};
std::map<std::string, size_t> buffer_counter_;
size_t itervar_counter_{0};
std::deque<VarExpr> itervar_stack_; // use deque instead of stack for indexing
std::deque<size_t> skip_stack_size_;
using IRVisitor::Visit_;
};
} // namespace autotvm
} // namespace tvm
#endif // TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
...@@ -73,7 +73,7 @@ Target CreateTarget(const std::string& target_name, ...@@ -73,7 +73,7 @@ Target CreateTarget(const std::string& target_name,
} else { } else {
t->device_type = kDLROCM; t->device_type = kDLROCM;
} }
t->keys_array.push_back(ir::StringImm::make("rocm")); t->keys_array.push_back(ir::StringImm::make(target_name));
t->keys_array.push_back(ir::StringImm::make("gpu")); t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 256; t->max_num_threads = 256;
if (t->device_name == "intel_graphics") { if (t->device_name == "intel_graphics") {
...@@ -195,11 +195,7 @@ Target Target::create(const std::string& target_str) { ...@@ -195,11 +195,7 @@ Target Target::create(const std::string& target_str) {
options.push_back(item); options.push_back(item);
} }
if (device_name == "rasp") { return CreateTarget(target_name, options);
return target::rasp(options);
} else {
return CreateTarget(target_name, options);
}
} }
/*! \brief Entry to hold the Target context stack. */ /*! \brief Entry to hold the Target context stack. */
......
...@@ -18,13 +18,13 @@ class GPUCodeVerifier : public IRVisitor { ...@@ -18,13 +18,13 @@ class GPUCodeVerifier : public IRVisitor {
bool Verify(tvm::Stmt stmt, bool Verify(tvm::Stmt stmt,
int64_t max_local_memory_per_block, int64_t max_local_memory_per_block,
int64_t max_shared_memory_per_block, int64_t max_shared_memory_per_block,
int64_t max_thread_per_block, int64_t max_threads_per_block,
int64_t max_thread_x, int64_t max_thread_x,
int64_t max_thread_y, int64_t max_thread_y,
int64_t max_thread_z) { int64_t max_thread_z) {
max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block); max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block); max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
max_thread_per_block_ = static_cast<size_t>(max_thread_per_block); max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
max_thread_x_ = static_cast<size_t>(max_thread_x); max_thread_x_ = static_cast<size_t>(max_thread_x);
max_thread_y_ = static_cast<size_t>(max_thread_y); max_thread_y_ = static_cast<size_t>(max_thread_y);
max_thread_z_ = static_cast<size_t>(max_thread_z); max_thread_z_ = static_cast<size_t>(max_thread_z);
...@@ -52,7 +52,7 @@ class GPUCodeVerifier : public IRVisitor { ...@@ -52,7 +52,7 @@ class GPUCodeVerifier : public IRVisitor {
if (nest_level_ == 0) { if (nest_level_ == 0) {
// exit a kernel, check the validity // exit a kernel, check the validity
valid_ &= thread_per_block_ <= max_thread_per_block_; valid_ &= thread_per_block_ <= max_threads_per_block_;
valid_ &= local_memory_per_block_ <= max_local_memory_per_block_; valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_; valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
...@@ -117,7 +117,7 @@ class GPUCodeVerifier : public IRVisitor { ...@@ -117,7 +117,7 @@ class GPUCodeVerifier : public IRVisitor {
size_t max_local_memory_per_block_; size_t max_local_memory_per_block_;
size_t max_shared_memory_per_block_; size_t max_shared_memory_per_block_;
size_t max_thread_per_block_; size_t max_threads_per_block_;
size_t max_thread_x_, max_thread_y_, max_thread_z_; size_t max_thread_x_, max_thread_y_, max_thread_z_;
bool valid_{true}; bool valid_{true};
...@@ -137,26 +137,34 @@ bool VerifyGPUCode(Stmt stmt, ...@@ -137,26 +137,34 @@ bool VerifyGPUCode(Stmt stmt,
Map<std::string, Expr> constraints) { Map<std::string, Expr> constraints) {
GPUCodeVerifier verifier; GPUCodeVerifier verifier;
auto get_int = [&constraints](std::string key, int64_t def) { int64_t max_local_memory_per_block = INT64_MAX;
auto iter = constraints.find(key); int64_t max_shared_memory_per_block = INT64_MAX;
if (iter != constraints.end()) { int64_t max_threads_per_block = INT64_MAX;
return ((*iter).second).as<IntImm>()->value; int64_t max_thread_x = INT64_MAX;
} else { int64_t max_thread_y = INT64_MAX;
return def; int64_t max_thread_z = INT64_MAX;
}
}; for (auto iter : constraints) {
if (iter.first == "max_local_memory_per_block")
int64_t max_local_memory_per_block = get_int("max_local_memory_per_block", INT64_MAX); max_local_memory_per_block = (iter.second).as<IntImm>()->value;
int64_t max_shared_memory_per_block = get_int("max_shared_memory_per_block", INT64_MAX); else if (iter.first == "max_shared_memory_per_block")
int64_t max_thread_per_block = get_int("max_thread_per_block", INT64_MAX); max_shared_memory_per_block = (iter.second).as<IntImm>()->value;
int64_t max_thread_x = get_int("max_thread_x", INT64_MAX); else if (iter.first == "max_threads_per_block")
int64_t max_thread_y = get_int("max_thread_y", INT64_MAX); max_threads_per_block = (iter.second).as<IntImm>()->value;
int64_t max_thread_z = get_int("max_thread_z", INT64_MAX); else if (iter.first == "max_thread_x")
max_thread_x = (iter.second).as<IntImm>()->value;
else if (iter.first == "max_thread_y")
max_thread_y = (iter.second).as<IntImm>()->value;
else if (iter.first == "max_thread_z")
max_thread_z = (iter.second).as<IntImm>()->value;
else
LOG(FATAL) << "Invalid check item: " << iter.first;
}
return verifier.Verify(stmt, return verifier.Verify(stmt,
max_local_memory_per_block, max_local_memory_per_block,
max_shared_memory_per_block, max_shared_memory_per_block,
max_thread_per_block, max_threads_per_block,
max_thread_x, max_thread_x,
max_thread_y, max_thread_y,
max_thread_z); max_thread_z);
......
"""
Test the tuner
"""
import logging
import time
import tvm
from tvm import autotvm
from tvm.autotvm.tuner import RandomTuner
@autotvm.template
def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
"""An example template for testing"""
assert N == 1, "Only consider batch_size = 1 in this template"
data = tvm.placeholder((N, CI, H, W), name='data')
kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
rc = tvm.reduce_axis((0, CI), name='rc')
ry = tvm.reduce_axis((0, KH), name='ry')
rx = tvm.reduce_axis((0, KW), name='rx')
conv = tvm.compute(
(N, CO, H - KH + 1, W - KW + 1),
lambda nn, ff, yy, xx: tvm.sum(
data[nn, rc, yy + ry, xx + rx] * kernel[ff, rc, ry, rx],
axis=[rc, ry, rx]), tag="conv2d_nchw")
s = tvm.create_schedule([conv.op])
output = conv
OL = s.cache_write(conv, 'local')
# create cache stage
AA = s.cache_read(data, 'shared', [OL])
WW = s.cache_read(kernel, 'shared', [OL])
AL = s.cache_read(AA, 'local', [OL])
WL = s.cache_read(WW, 'local', [OL])
# tile and bind spatial axes
n, f, y, x = s[output].op.axis
cfg = autotvm.get_config()
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
kernel_scope = n # this is the scope to attach global config inside this kernel
s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
s[output].bind(by, tvm.thread_axis("blockIdx.y"))
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
s[output].bind(vf, tvm.thread_axis("vthread"))
s[output].bind(vy, tvm.thread_axis("vthread"))
s[output].bind(vx, tvm.thread_axis("vthread"))
s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
s[OL].compute_at(s[output], tx)
# tile and bind reduction axes
n, f, y, x = s[OL].op.axis
rc, ry, rx = s[OL].op.reduce_axis
cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=3)
cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=3)
rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
ryo, rym, ryi = cfg['tile_rx'].apply(s, OL, ry)
rxo, rxm, rxi = cfg['tile_ry'].apply(s, OL, rx)
s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)
s[AA].compute_at(s[OL], rxo)
s[WW].compute_at(s[OL], rxo)
s[AL].compute_at(s[OL], rxm)
s[WL].compute_at(s[OL], rxm)
# cooperative fetching
for load in [AA, WW]:
n, f, y, x = s[load].op.axis
fused = s[load].fuse(n, f, y, x)
tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
# tune unroll
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
cfg.define_knob("unroll_explicit", [0, 1])
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
return s, [data, kernel, conv]
def get_sample_task(target=tvm.target.cuda(), target_host=None):
"""return a sample task for testing"""
task = autotvm.task.create(conv2d_no_batching,
args=(1, 7, 7, 512, 512, 3, 3),
target=target, target_host=target_host)
return task, target
def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()
def measure_batch(inputs):
from tvm.autotvm import MeasureResult
results = []
for inp in inputs:
tic = time.time()
# do nothing
time.sleep(0.001)
results.append(MeasureResult([time.time() - tic], 0,
time.time() - tic, time.time()))
return results
measure_option = autotvm.measure_option(mode='custom',
custom_measure_batch=measure_batch)
logging.info("%s", task.config_space)
# new tuner and recorder
for tuner_class in [autotvm.tuner.RandomTuner, autotvm.tuner.GridSearchTuner]:
tuner = tuner_class(task)
tuner.tune(n_trial=10, measure_option=measure_option)
def test_tuning_with_measure():
def check(target, target_host):
ctx = tvm.context(target, 0)
if not ctx.exist:
logging.info("Skip test because %s is not available" % target)
return
# init task
task, target = get_sample_task(target, target_host)
logging.info("%s", task.config_space)
measure_option = autotvm.measure_option(mode='local',
timeout=4,
number=2)
tuner = RandomTuner(task)
tuner.tune(n_trial=10, measure_option=measure_option)
check("cuda", None)
check("opencl", None)
if __name__ == "__main__":
# only print log when invoked from main
logging.basicConfig(level=logging.INFO)
test_task_tuner_without_measurement()
test_tuning_with_measure()
"""Common utilities for testing autotvm"""
import time
import tvm
from tvm import autotvm
from tvm.autotvm import MeasureInput, MeasureResult
@autotvm.template
def matmul(N, L, M, dtype):
A = tvm.placeholder((N, L), name='A', dtype=dtype)
B = tvm.placeholder((L, M), name='B', dtype=dtype)
k = tvm.reduce_axis((0, L), name='k')
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
s = tvm.create_schedule(C.op)
# schedule
y, x = s[C].op.axis
k = s[C].op.reduce_axis[0]
##### define space begin #####
cfg = autotvm.get_config()
cfg.define_split("tile_y", y, num_outputs=2)
cfg.define_split("tile_x", x, num_outputs=2)
##### define space end #####
# schedule according to config
yo, yi = cfg["tile_y"].apply(s, C, y)
xo, xi = cfg["tile_x"].apply(s, C, x)
s[C].reorder(yo, xo, k, yi, xi)
return s, [A, B, C]
def get_sample_task(n=128):
"""return a sample task for testing"""
target = tvm.target.create("llvm")
task = autotvm.task.create(matmul, args=(n, n, n, 'float32'), target=target)
return task, target
def get_sample_records(n):
"""get sample records for testing"""
tsk, target = get_sample_task()
inps, ress = [], []
for i in range(n):
inps.append(MeasureInput(target, tsk, tsk.config_space.get(i)))
ress.append(MeasureResult((i+1,), 0, i, time.time()))
return list(zip(inps, ress))
"""Test database"""
import copy
import logging
import time
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm import database
from tvm.autotvm.measure.measure_methods import HashMismatchError
from tvm.autotvm.record import encode, MeasureInput, MeasureResult
from test_autotvm_common import get_sample_task, get_sample_records
def test_save_load():
logging.info("test basic db load/save ...")
records = get_sample_records(3)
inp1, res1 = records[0]
inp2, res2 = records[1]
inp3, _ = records[2]
_db = database.DummyDatabase()
_db.flush()
_db.save(inp1, res1)
_db.save(inp2, res2)
load1 = _db.load(inp1)
load2 = _db.load(inp2)
load3 = _db.load(inp3)
assert load1 == res1
assert load2 == res2
assert load3 is None
assert load1 != load2
TRIAL_LIMIT = 2
def test_db_filter():
logging.info("test db filter ...")
# Pick a GPU target because there are more likely to be failures/invalid configs
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
batch_size = 2
measure_option = autotvm.measure_option(mode='local-nofork', timeout=2)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
ct = 0
all_inputs = list()
all_results = list()
batches = list()
tuner = autotvm.tuner.RandomTuner(task)
while ct < TRIAL_LIMIT:
inputs = list()
for i in range(batch_size):
cfg = tuner.next_batch(1)[0]
inputs.append((MeasureInput(target, task, cfg)))
all_inputs.append(inputs[-1])
batches.append(inputs)
results = measure_batch(inputs)
all_results += results
ct += 1
del measure_batch
db = database.DummyDatabase()
db.flush()
# First setting, memoize one input at a time, check that each is saved and replayed
measure_option = autotvm.measure_option(mode='local-nofork', timeout=2, replay_db=db)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
for i in range(len(all_inputs)+1):
db.flush()
for j in range(i):
db.save(all_inputs[j], all_results[j])
for k in range(len(batches)):
batch = batches[k]
batch_result = measure_batch(batch)
for l in range(batch_size):
all_idx = k*batch_size + l
assert batch_result[l] is not None
if all_idx < i:
assert encode(batch[l], batch_result[l]) == encode(batch[l], all_results[all_idx]), \
"(no retry) EXPECTED MATCH, GOT MISMATCH"
else:
assert encode(batch[l], batch_result[l]) != encode(batch[l], all_results[all_idx]), \
"(no retry) EXPECTED MISMATCH, GOT MATCH"
del measure_batch
def test_db_hash():
logging.info("test db hash check ...")
inp1, res1 = get_sample_records(1)[0]
inp2 = copy.deepcopy(inp1)
inp1.config.code_hash = 'cafecafe'
inp2.config.code_hash = 'dbffdbff'
res2l = list(tuple(res1))
# set timestamp
res2l[-1] = -1
res2 = MeasureResult(*res2l)
_db = database.DummyDatabase()
_db.flush()
_db.save(inp1, res1, extend=True)
_db.save(inp2, res2, extend=True)
load1 = _db.load(inp1)
load2 = _db.load(inp2)
assert load1 != load2
assert load1.timestamp != -1
assert load2.timestamp == -1
def test_db_latest_all():
logging.info("test db load w/ multiple results ...")
inp1, res1 = get_sample_records(1)[0]
lis1 = list(tuple(res1))
lis2 = list(tuple(res1))
lis3 = list(tuple(res1))
# set timestamp
lis1[-1] = 0.0
lis2[-1] = 1.1
lis3[-1] = 9999.9999
res1 = MeasureResult(*lis1)
res2 = MeasureResult(*lis2)
res3 = MeasureResult(*lis3)
_db = database.DummyDatabase()
_db.flush()
_db.save(inp1, res1, extend=True)
load1 = _db.load(inp1)
assert load1.timestamp == 0.0
_db.save(inp1, res2, extend=True)
load2 = _db.load(inp1)
assert load2.timestamp == 1.1
_db.save(inp1, res3, extend=True)
load3 = _db.load(inp1)
assert load3.timestamp == 9999.9999
load4 = _db.load(inp1, get_all=True)
assert encode(inp1, load4[0]) == encode(inp1, res1)
assert encode(inp1, load4[1]) == encode(inp1, res2)
assert encode(inp1, load4[2]) == encode(inp1, res3)
def test_db_save_replay():
logging.info("test db save (from measure_batch) and replay ...")
_db = database.DummyDatabase()
_db.flush()
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option(mode='local-nofork',
timeout=2,
replay_db=_db, save_to_replay_db=True)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
batch_size = 2
ct = 0
all_inputs = list()
all_results = list()
batches = list()
tuner = autotvm.tuner.RandomTuner(task)
while ct < TRIAL_LIMIT:
inputs = list()
for i in range(batch_size):
cfg = tuner.next_batch(1)[0]
inputs.append((MeasureInput(target, task, cfg)))
all_inputs.append(inputs[-1])
batches.append(inputs)
results = measure_batch(inputs)
all_results += results
ct += 1
assert len(_db.db.keys()) == batch_size * TRIAL_LIMIT, \
"%d vs %d" % (len(_db.db.keys()), batch_size * TRIAL_LIMIT)
all_results_2 = measure_batch(all_inputs)
all_results_3 = measure_batch(all_inputs)
for i in range(len(all_results)):
encr1 = encode(all_inputs[i], all_results[i])
encr2 = encode(all_inputs[i], all_results_2[i])
encr3 = encode(all_inputs[i], all_results_3[i])
assert encr1 == encr2, "EXPECTED MATCH WITH SAVE REPLAY (first replay), got MISMATCH"
assert encr2 == encr3, "EXPECTED MATCH WITH SAVE REPLAY (second replay), got MISMATCH"
del measure_batch
def test_check_hashmismatch():
logging.info("test hash mismatch check")
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option(mode='local-nofork')
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
inputs = list()
cfg = task.config_space.get(np.random.randint(len(task.config_space)))
# notvalidh is not a valid CRC32 hash (not hex)
cfg.code_hash = 'notvalidh'
inputs.append((MeasureInput(target, task, cfg)))
try:
results = measure_batch(inputs)
assert False, "HashMismatchError should be raised"
except HashMismatchError:
pass
del measure_batch
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
test_save_load()
test_db_filter()
test_db_hash()
test_db_latest_all()
test_db_save_replay()
test_check_hashmismatch()
"""Test dispatcher.
The dispatcher can choose which template to use according
to the parameters of workload"""
from collections import namedtuple
from tvm.autotvm.task import dispatcher, DispatchContext
SimpleWorkload = namedtuple("SimpleWorkload", ["key"])
SimpleConfig = namedtuple("SimpleConfig", ["template_key"])
def test_dispatch():
@dispatcher
def my_dispatcher(a, b):
return SimpleWorkload(key=a + b)
@my_dispatcher.register("spatial_pack")
def _sp_pack_add(cfg, a, b):
return b + 100
@my_dispatcher.register("im2col")
def _im2col_add(cfg, a, b):
return a + 1
class SimpleDispatcher(DispatchContext):
def query(self, target, workload):
tkey = "spatial_pack" if workload.key > 2 else "im2col"
return SimpleConfig(tkey)
with SimpleDispatcher():
# im2col
assert my_dispatcher(1, 0) == 2
# spack
assert my_dispatcher(1, 100) == 200
if __name__ == "__main__":
test_dispatch()
"""Test local executor"""
import time
from tvm.autotvm.measure import LocalExecutor, executor
def slow(n):
r = 0
for i in range(0, n+1):
r += i
return r
def fast(n):
return n*(n+1)//2
def test_local_measure_async():
ex = LocalExecutor()
f1 = ex.submit(slow, 9999999)
f2 = ex.submit(fast, 9999999)
t1 = 0
t2 = 0
while True:
if t1 == 0 and f1.done():
t1 = time.time()
if t2 == 0 and f2.done():
t2 = time.time()
if t1 != 0 and t2 != 0:
break
assert t2 < t1, "Expected fast async job to finish first!"
assert f1.get() == f2.get()
def timeout_job(n):
time.sleep(n * 1.5)
def test_timeout():
timeout = 0.5
ex = LocalExecutor(timeout=timeout)
f1 = ex.submit(timeout_job, timeout)
while not f1.done():
pass
res = f1.get()
assert isinstance(res, executor.TimeoutError)
if __name__ == "__main__":
test_local_measure_async()
test_timeout()
"""Test feature extraction"""
import numpy as np
import tvm
from tvm.autotvm import feature
def test_iter_feature_gemm():
N = 128
k = tvm.reduce_axis((0, N), 'k')
A = tvm.placeholder((N, N), name='A')
B = tvm.placeholder((N, N), name='B')
C = tvm.compute(
A.shape,
lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k),
name='C')
s = tvm.create_schedule(C.op)
feas = feature.get_itervar_feature(s, [A, B, C], take_log=False)
expected = [
{
'_attr_': [128, 1, 128, 2097152, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
'A_0': [128, -1, 16384, 128, 0, 0], 'B_0': [0, -1, 16384, 128, 0, 0],
'C_0': [128, -1, 16384, 128, 0, 0], 'C_1': [128, -1, 16384, 128, 0, 0],
},
{
'_attr_': [128, 2, 16384, 16384, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
'A_0': [0, -1, 128, 128, 0, 0], 'B_0': [1, -1, 16384, 1, 0, 0],
'C_0': [1, -1, 128, 128, 0, 0], 'C_1': [1, -1, 128, 128, 0, 0],
},
{
'_attr_': [128, 3, 2097152, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
'A_0': [1, -1, 128, 1, 0, 0], 'B_0': [128, -1, 128, 1, 0, 0],
'C_1': [0, -1, 1, 128, 0, 0], 'C_2': [0, -1, 1, 128, 0, 0],
}
]
for ans, row in zip(expected, feas):
for pair in row:
if pair[0] not in ans:
continue
assert ans[pair[0]] == pair[1:], "%s: %s vs %s" % (pair[0], ans[pair[0]], pair[1:])
def test_feature_shape():
"""test the dimensions of flatten feature are the same"""
N = 1024
n_sample = 100
def get_gemm_feature(target):
k = tvm.reduce_axis((0, N), 'k')
A = tvm.placeholder((N, N), name='A')
B = tvm.placeholder((N, N), name='B')
C = tvm.compute(A.shape, lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k),
name='C')
s = tvm.create_schedule(C.op)
y, x = s[C].op.axis
axes = list(s[C].tile(y, x, 8, 8)) + [k]
perm = np.random.permutation(5)
axes = [axes[x] for x in perm]
s[C].reorder(*axes)
if "gpu" in target.keys:
pick = []
# filter out reduction axis
for i in range(len(perm)):
if perm[i] != 4:
pick.append(axes[i])
s[C].bind(pick[0], tvm.thread_axis("blockIdx.x"))
s[C].bind(pick[1], tvm.thread_axis("vthread"))
s[C].bind(pick[2], tvm.thread_axis("threadIdx.y"))
with target:
feas = feature.get_itervar_feature(s, [A, B, C])
feas = feature.flatten_itervar_feature(feas)
return feas
targets = [
tvm.target.cuda(),
tvm.target.mali(),
tvm.target.rasp(),
]
for target in targets:
dim = len(get_gemm_feature(target))
for i in range(n_sample):
assert dim == len(get_gemm_feature(target)), "dimensions of feature do not match" \
" for different configurations"
if __name__ == "__main__":
test_iter_feature_gemm()
test_feature_shape()
"""Test flop calculation"""
import tvm
import numpy as np
from tvm.autotvm.task.task import compute_flop
def test_conv():
for i in range(5):
N, H, W, CO, CI, KH, KW = [np.random.randint(10, 32) for _ in range(7)]
D = tvm.placeholder((N, CI, H, W))
K = tvm.placeholder((CO, CI, KH, KW))
KH = min(H, KH)
KW = min(W, KW)
ci = tvm.reduce_axis((0, CI))
kh = tvm.reduce_axis((0, KH))
kw = tvm.reduce_axis((0, KW))
OH = (H - KH) + 1
OW = (W - KW) + 1
C = tvm.compute((N, CO, OH, OW), lambda n, co, h, w:
tvm.sum(D[n][ci][h][w] * K[co][ci][h][w], axis=[ci, kh, kw]))
s = tvm.create_schedule([C.op])
assert compute_flop(s) == 2 * N * CO * OH * OW * CI * KH * KW
def test_pack_gemm():
for i in range(5):
N, L, M = [np.random.randint(10, 128) * 4 for _ in range(3)]
A = tvm.placeholder((N, L))
B = tvm.placeholder((M, L))
k = tvm.reduce_axis((0, L))
bn = 4
A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
tvm.sum(A_pack[i, k, ii] * B_pack[j, k, jj], axis=[k]))
C = tvm.compute((N, M), lambda i, j: C_pack[i // bn][j // bn][i % bn][j % bn])
s = tvm.create_schedule([C.op])
assert compute_flop(s) == 2 * N * L * M
def test_outer_dot():
for i in range(5):
N, M = [np.random.randint(10, 128) * 4 for _ in range(2)]
A = tvm.placeholder((N,))
B = tvm.placeholder((M,))
C = tvm.compute((N, M), lambda i, j: A[i] * B[j])
s = tvm.create_schedule([C.op])
assert compute_flop(s) == N * M
def test_move():
"""No float number operation in simple move. So the estimator should raise an error """
N = 1024
A = tvm.placeholder((N,))
C = tvm.compute((N,), lambda i: A[i])
s = tvm.create_schedule([C.op])
try:
compute_flop(s)
assert False
except RuntimeError:
pass
if __name__ == '__main__':
test_conv()
test_pack_gemm()
test_outer_dot()
test_move()
"""test the correctness of dump and load of data log"""
import time
import tvm
from tvm.contrib import util
from tvm import autotvm
from tvm.autotvm.measure import MeasureInput, MeasureResult, MeasureErrorNo
from tvm.autotvm.record import encode, decode, ApplyHistoryBest, measure_str_key
from test_autotvm_common import get_sample_task
def test_load_dump():
task, target = get_sample_task()
inp = MeasureInput(target, task, task.config_space.get(0))
result = MeasureResult((2.0, 2.23, 0.23, 0.123, 0.234, 0.123), MeasureErrorNo.NO_ERROR,
2.3, time.time())
for protocol in ['json', 'pickle']:
row = encode(inp, result, protocol=protocol)
inp_2, result_2 = decode(row, protocol=protocol)
assert measure_str_key(inp) == measure_str_key(inp_2), \
"%s vs %s" % (measure_str_key(inp), measure_str_key(inp_2))
assert result.costs == result_2.costs
assert result.error_no == result_2.error_no
assert result.timestamp == result_2.timestamp
def test_file_io():
temp = util.tempdir()
file_path = temp.relpath("temp.log")
tsk, target = get_sample_task()
inputs = [MeasureInput(target, tsk, tsk.config_space.get(i)) for i in range(0, 10)]
results = [MeasureResult((i, ), 0, 0, 0) for i in range(0, 10)]
with open(file_path, "w") as fo:
cb = autotvm.callback.log_to_file(fo)
cb(None, inputs, results)
ref = zip(inputs, results)
for x, y in zip(ref, autotvm.record.load_from_file(file_path)):
assert x[1] == y[1]
def test_apply_history_best():
tsk, target = get_sample_task()
records = [
(MeasureInput(target, tsk, tsk.config_space.get(0)), MeasureResult((0.1,), 0, 2.3, 0)),
(MeasureInput(target, tsk, tsk.config_space.get(1)), MeasureResult((0.3,), 0, 2.3, 0)),
(MeasureInput(target, tsk, tsk.config_space.get(2)), MeasureResult((0.01,), 0, 2.3, 0)),
(MeasureInput(target, tsk, tsk.config_space.get(4)), MeasureResult((0.4,), 0, 2.3, 0))
]
hist_best = ApplyHistoryBest(records)
x = hist_best.query(target, tsk.workload)
assert str(x) == str(tsk.config_space.get(2))
if __name__ == "__main__":
test_load_dump()
test_apply_history_best()
test_file_io()
"""Test space definition primitives"""
import tvm
from tvm.autotvm.task.space import ConfigSpace
def gemm_func(cfg, N):
A = tvm.placeholder((N, N), name='A')
B = tvm.placeholder((N, N), name='B')
k = tvm.reduce_axis((0, N), name='k')
C = tvm.compute((N, N), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=[k]), name='C')
s = tvm.create_schedule([C.op])
y, x = s[C].op.axis
cfg.define_split('tile_y', cfg.axis(y), num_outputs=2)
cfg.define_split('tile_x', cfg.axis(x), num_outputs=2)
return s, [A, B, C]
def test_split():
cfg = ConfigSpace()
gemm_func(cfg, 128)
assert len(cfg) == 64
assert len(cfg.space_map['tile_y']) == 8
if __name__ == '__main__':
test_split()
...@@ -31,14 +31,14 @@ def test_shared_memory(): ...@@ -31,14 +31,14 @@ def test_shared_memory():
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=4 * M - 1, max_shared_memory_per_block=4 * M - 1,
max_thread_per_block=M))]}): max_threads_per_block=M))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert not valid[0] assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=4 * M, max_shared_memory_per_block=4 * M,
max_thread_per_block=M))]}): max_threads_per_block=M))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert valid[0] assert valid[0]
...@@ -66,14 +66,14 @@ def test_local_memory(): ...@@ -66,14 +66,14 @@ def test_local_memory():
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_local_memory_per_block=4 * M - 1, max_local_memory_per_block=4 * M - 1,
max_thread_per_block=1))]}): max_threads_per_block=1))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert not valid[0] assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_local_memory_per_block=4 * M, max_local_memory_per_block=4 * M,
max_thread_per_block=1))]}): max_threads_per_block=1))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert valid[0] assert valid[0]
...@@ -101,21 +101,21 @@ def test_num_thread(): ...@@ -101,21 +101,21 @@ def test_num_thread():
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=0, max_shared_memory_per_block=0,
max_thread_per_block=N - 1))]}): max_threads_per_block=N - 1))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert not valid[0] assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=0, max_shared_memory_per_block=0,
max_thread_per_block=N))]}): max_threads_per_block=N))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert valid[0] assert valid[0]
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=0, max_shared_memory_per_block=0,
max_thread_per_block=N, max_threads_per_block=N,
max_thread_y=M-1))]}): max_thread_y=M-1))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert not valid[0] assert not valid[0]
...@@ -123,7 +123,7 @@ def test_num_thread(): ...@@ -123,7 +123,7 @@ def test_num_thread():
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=0, max_shared_memory_per_block=0,
max_thread_per_block=N, max_threads_per_block=N,
max_thread_y=M))]}): max_thread_y=M))]}):
tvm.build(s, [A, B], target) tvm.build(s, [A, B], target)
assert valid[0] assert valid[0]
...@@ -151,14 +151,14 @@ def test_multiple_kernels(): ...@@ -151,14 +151,14 @@ def test_multiple_kernels():
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=0, max_shared_memory_per_block=0,
max_thread_per_block=N - 1))]}): max_threads_per_block=N - 1))]}):
tvm.build(s, [A, C], target) tvm.build(s, [A, C], target)
assert not valid[0] assert not valid[0]
with tvm.build_config(**{"add_lower_pass": [ with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, (2, get_verify_pass(valid,
max_shared_memory_per_block=0, max_shared_memory_per_block=0,
max_thread_per_block=N))]}): max_threads_per_block=N))]}):
tvm.build(s, [A, C], target) tvm.build(s, [A, C], target)
assert valid[0] assert valid[0]
......
...@@ -8,7 +8,7 @@ make doc ...@@ -8,7 +8,7 @@ make doc
jsdoc web/tvm_runtime.js web/README.md || exit -1 jsdoc web/tvm_runtime.js web/README.md || exit -1
mv out docs/_build/html/jsdoc || exit -1 mv out docs/_build/html/jsdoc || exit -1
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
cd docs cd docs
PYTHONPATH=`pwd`/../python make html || exit -1 PYTHONPATH=`pwd`/../python make html || exit -1
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
export PYTHONPATH=python:apps/extension/python export PYTHONPATH=python:apps/extension/python
export LD_LIBRARY_PATH=build:${LD_LIBRARY_PATH} export LD_LIBRARY_PATH=build:${LD_LIBRARY_PATH}
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
# Test TVM # Test TVM
make cython || exit -1 make cython || exit -1
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
export PYTHONPATH=python:topi/python export PYTHONPATH=python:topi/python
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
TVM_FFI=ctypes python -m nose -v tests/python/unittest || exit -1 TVM_FFI=ctypes python -m nose -v tests/python/unittest || exit -1
TVM_FFI=ctypes python3 -m nose -v tests/python/unittest || exit -1 TVM_FFI=ctypes python3 -m nose -v tests/python/unittest || exit -1
......
Auto tuning
-------------
"""
How to get high performance convolution kernel on NVIDIA GPU by auto-tuning
=========================================================================
**Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_
This is an advanced tutorial for writing high performance tunable template for
NVIDIA GPU. By running auto-tuner on this template, we can outperform the
vendor provided library CuDNN in many cases.
"""
import logging
import sys
import tvm
import topi
from tvm import autotvm
######################################################################
# Step 1: Define the search space
# ---------------------------------
# There are plenty of useful schedule primitives in tvm. You can also find
# some tutorials that describe them in more details, such as
# (1). :doc:``Optimizing Conv2d on NVIDIA GPU <../optimize/opt_conv_cuda>`
# (2). `Optimizing DepthwiseConv on NVIDIA GPU <https://tvm.ai/2017/08/22/Optimize-Deep-Learning-GPU-Operators-with-TVM-A-Depthwise-Convolution-Example.html>`_
#
# However, their implementations are manually tuned for some special input
# shapes. In this section, we build a large enough space to cover
# the techniques used in these tutorials. Then we rely on the efficient auto-tuner
# to search through this space and pick some good configurations.
#
# If you are familiar with writing cuda schedule, you can find the following
# template is very general. Actually this template can be easily modified
# to tune other operators such as depthwise convolution and gemm.
# In order to fully understand this template, you should be familiar with
# the schedule primitives and auto tuning API. You can refer to the above
# tutorials and :doc:`autotvm tutorial <tune_simple_template>`
#
# It is worth noting that the search space for a conv2d operator
# can be very large (at the level of 10^9 for some input shapes)
#
@autotvm.template
def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
assert N == 1, "Only consider batch_size = 1 in this template"
data = tvm.placeholder((N, CI, H, W), name='data')
kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, 'float32')
s = tvm.create_schedule([conv.op])
# inline padding
pad_data = s[conv].op.input_tensors[0]
s[pad_data].compute_inline()
data, raw_data = pad_data, data
output = conv
OL = s.cache_write(conv, 'local')
# create cache stage
AA = s.cache_read(data, 'shared', [OL])
WW = s.cache_read(kernel, 'shared', [OL])
AL = s.cache_read(AA, 'local', [OL])
WL = s.cache_read(WW, 'local', [OL])
# tile and bind spatial axes
n, f, y, x = s[output].op.axis
cfg = autotvm.get_config()
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
kernel_scope = n # this is the scope to attach global config inside this kernel
s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
s[output].bind(by, tvm.thread_axis("blockIdx.y"))
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
s[output].bind(vf, tvm.thread_axis("vthread"))
s[output].bind(vy, tvm.thread_axis("vthread"))
s[output].bind(vx, tvm.thread_axis("vthread"))
s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
s[OL].compute_at(s[output], tx)
# tile and bind reduction axes
n, f, y, x = s[OL].op.axis
rc, ry, rx = s[OL].op.reduce_axis
cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=3)
cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=3)
rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
ryo, rym, ryi = cfg['tile_rx'].apply(s, OL, ry)
rxo, rxm, rxi = cfg['tile_ry'].apply(s, OL, rx)
s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)
s[AA].compute_at(s[OL], rxo)
s[WW].compute_at(s[OL], rxo)
s[AL].compute_at(s[OL], rxm)
s[WL].compute_at(s[OL], rxm)
# cooperative fetching
for load in [AA, WW]:
n, f, y, x = s[load].op.axis
fused = s[load].fuse(n, f, y, x)
tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
# tune unroll
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
cfg.define_knob("unroll_explicit", [0, 1])
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
return s, [raw_data, kernel, conv]
######################################################################
# Step 2: Search through the space
# ---------------------------------
# We pick the last layer on resnet as test case.
# Since our space is very large, :code:`XGBoostTuner` is most suitable
# for our case. Here we only do 20 trials for demonstration.
# In practice, making 1000 trials usually can find some good kernels
# for this template
# logging config (for printing tuning log to screen)
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
# the last layer in resnet
task = autotvm.task.create(conv2d_no_batching,
args=(1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)),
target='cuda')
print(task.config_space)
# use local gpu, measure 5 times for every config to reduce variance
# run 8 parallel threads for compilation
measure_option = autotvm.measure_option(mode='local',
number=10,
parallel_num=8,
timeout=20)
# begin tuning, log records to file `cache.tsv`
tuner = autotvm.tuner.XGBTuner(task)
tuner.tune(n_trial=20,
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('cache.tsv')])
# get best config from cache file
dispatch_context = autotvm.apply_history_best("cache.tsv")
best_config = dispatch_context.query(task.target, task.workload)
print("\nBest config:")
print(best_config)
"""
Writing tunable template and Using auto-tuner
=============================================
**Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_
This is an introduction tutorial to the auto-tuning module in tvm.
There are two steps in auto-tuning.
The first step is defining a search space.
The second step is running a search algorithm to explore through this space.
In this tutorial, you can learn how to perform these two steps in tvm.
The whole workflow is illustrated by a matrix multiplication example.
"""
import logging
import sys
import numpy as np
import tvm
# the module is called `autotvm`
from tvm import autotvm
######################################################################
# Step 1: Define the search space
# ---------------------------------
# In this section, we will rewrite a deterministic tvm schedule code to a
# tunable schedule template. You can regard the process of search space definition
# as the parametrization of our exiting schedule code.
#
# To begin with, here is how we implement a blocked matrix multiplication in tvm
# Matmul V0: Constant tiling factor
def matmul_v0(N, L, M, dtype):
A = tvm.placeholder((N, L), name='A', dtype=dtype)
B = tvm.placeholder((L, M), name='B', dtype=dtype)
k = tvm.reduce_axis((0, L), name='k')
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
s = tvm.create_schedule(C.op)
# schedule
y, x = s[C].op.axis
k = s[C].op.reduce_axis[0]
yo, yi = s[C].split(y, 8)
xo, xi = s[C].split(x, 8)
s[C].reorder(yo, xo, k, yi, xi)
return s, [A, B, C]
#####################################################################
# Parametrize the schedule
# ^^^^^^^^^^^^^^^^^^^^^^^^^
# In the previous schedule code, we use a constant "8" as tiling factor.
# However, it might not be the best one because the best tiling factor depends
# on real hardware environment and input shape.
#
# If you want the schedule code to be portable across a wider range of input shapes
# and target hardware, it is better to define a set of candidate values and
# pick the best one according to the measurement results on target hardware.
#
# In autotvm, we can define a tunable parameter, or a "knob" for such kind of value.
# Matmul V1: List candidate values
@autotvm.template # 1. use a decorator
def matmul_v1(N, L, M, dtype):
A = tvm.placeholder((N, L), name='A', dtype=dtype)
B = tvm.placeholder((L, M), name='B', dtype=dtype)
k = tvm.reduce_axis((0, L), name='k')
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
s = tvm.create_schedule(C.op)
# schedule
y, x = s[C].op.axis
k = s[C].op.reduce_axis[0]
# 2. get the config object
cfg = autotvm.get_config()
# 3. define search space
cfg.define_knob("tile_y", [1, 2, 4, 8, 16])
cfg.define_knob("tile_x", [1, 2, 4, 8, 16])
# 4. schedule according to config
yo, yi = s[C].split(y, cfg['tile_y'].val)
xo, xi = s[C].split(x, cfg['tile_x'].val)
s[C].reorder(yo, xo, k, yi, xi)
return s, [A, B, C]
###############################################################################
# Here we make four modifications to the previous schedule code and get
# a tunable "template". We can explain the modifications one by one.
#
# 1. Use a decorator to mark this function as a simple template
# 2. Get a config object:
# You can regard this :code:`cfg` as an argument of this function but
# we obtain it in a different way. With this argument, this function is no longer
# a deterministic schedule code. Instead, we can pass different configurations to
# this function and get different schedules, so this function is a "template".
#
# To make the template function more compact, we do two things in a single function.
# (1) define a search space and (2) schedule according to an entity in this space.
# To achieve this, we make :code:`cfg` be either
# a :any:`ConfigSpace` or a :any:`ConfigEntity` object.
#
# When it is a :any:`ConfigSpace`, it will collect all tunable knobs in this function and
# build the search space.
# When it is a :any:`ConfigEntity`, it will ignore all space definition API
# (namely, :code:`cfg.define_XXXXX(...)`). Instead, it stores deterministic values for
# all tunable knobs, and we schedule according to these values.
#
# During auto-tuning, we will first call this template with a :any:`ConfigSpace`
# object to build the search space. Then we call this template with different :any:`ConfigEntity`
# in the built space to get different schedules. Finally we will measure the code generated by
# different schedules and pick the best one.
#
# 3. Define two tunable knobs. The first one is :code:`tile_y` with
# 5 possible values. The second one is :code:`tile_x` with a same
# list of possible values. These two knobs are independent, so they
# span a search space with size = 5x5 = 25
# 4. Schedule according to the deterministic values in :code:`cfg`
#
#####################################################################
# Use better space definition API
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# In the previous template, we manually list all possible values for a knob.
# This is the lowest level API to define the space.
# However, we also provide another set of API to make the space definition
# easier and smarter. It is recommended to use this set of high level API.
#
# In the flowing example, we use :any:`ConfigSpace.define_split` to define a split
# knob. It will enumerate all the possible ways to split an axis and construct
# the space.
#
# We also have :any:`ConfigSpace.define_reorder` for reorder knob and
# :any:`ConfigSpace.define_annotate` for annotation like unroll, vectorization,
# thread binding.
# When the high level API cannot meet your requirement, you can always fall
# back to use low level API.
@autotvm.template
def matmul(N, L, M, dtype):
A = tvm.placeholder((N, L), name='A', dtype=dtype)
B = tvm.placeholder((L, M), name='B', dtype=dtype)
k = tvm.reduce_axis((0, L), name='k')
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
s = tvm.create_schedule(C.op)
# schedule
y, x = s[C].op.axis
k = s[C].op.reduce_axis[0]
##### define space begin #####
cfg = autotvm.get_config()
cfg.define_split("tile_y", y, num_outputs=2)
cfg.define_split("tile_x", x, num_outputs=2)
##### define space end #####
# schedule according to config
yo, yi = cfg["tile_y"].apply(s, C, y)
xo, xi = cfg["tile_x"].apply(s, C, x)
s[C].reorder(yo, xo, k, yi, xi)
return s, [A, B, C]
######################################################################
# .. note:: More Explanation on :code:`cfg.defile_split`
#
# In this template, :code:`cfg.define_split("tile_y", y, num_outputs=2)` will enumerate
# all possible combinations that can split axis y into two axes with factors of the length of y.
# For example, if the length of y is 32 and we want to split it into two axes
# using factors of 32, then there are 6 possible values for
# (length of outer axis, length of inner axis) pair, namely
# (32, 1), (16, 2), (8, 4), (4, 8), (2, 16) or (1, 32).
# They are just the 6 possible values of `tile_y`.
#
# During schedule, :code:`cfg["tile_y"]` is a :code:`SplitEntity` object.
# We stores the lengths of outer axes and inner axes in :code:`cfg['tile_y'].size`
# (a tuple with two elements).
# In this template, we apply it by using :code:`yo, yi = cfg['tile_y'].apply(s, C, y)`.
# Actually, this is equivalent to
# :code:`yo, yi = s[C].split(y, cfg["tile_y"].size[1])`
# or :code:`yo, yi = s[C].split(y, nparts=cfg['tile_y"].size[0])`
#
# The advantage of using cfg.apply API is that it makes multi-level split
# (when num_outputs >= 3) easier.
######################################################################
# Step 2: Search through the space
# ---------------------------------
# In step 1, we build the search space by extending our old schedule code
# into a template. The next step is to pick a tuner and explore in this space.
#
# Auto-tuners in tvm
# ^^^^^^^^^^^^^^^^^^
# The job for a tuner can be described by following pseudo code
#
# .. code-block:: c
#
# ct = 0
# while ct < max_number_of_trials:
# propose a batch of configs
# measure this batch of configs on real hardware and get results
# ct += batch_size
#
# When proposing the next batch of configs, the tuner can take different strategies. We
# provide four tuners with different strategies in autotvm.
#
# * :any:`RandomTuner`: Enumerate the space in a random order
# * :any:`GridSearchTuner`: Enumerate the space in a grid search order
# * :any:`GATuner`: Using genetic algorithm to search through the space
# * :any:`XGBTuner`: Uses a model based method. Train a XGBoost model to predict the speed of lowered IR and pick the next batch according to the prediction.
#
# You can choose the tuner according to the size of your space, your time budget and other factors.
# For example, if your space is very small (less than 1000), a gridsearch tuner or a
# random tuner is good enough. If your space is at the level of 10^9 (this is the space
# size of a conv2d operator on CUDA GPU), XGBoostTuner can explore more efficiently
# and find better configs.
################################################################
# Begin tuning
# ^^^^^^^^^^^^
# Here we continue our matrix multiplication example.
# First we should create a tuning task.
# We can also inspect the initialized search space.
# In this case, for a 512x512 square matrix multiplication, the space size
# is 10x10=100
N, L, M = 512, 512, 512
task = autotvm.task.create(matmul, args=(N, L, M, 'float32'), target='llvm')
print(task.config_space)
################################################################
# Then we need to define how to measure the generated code and pick a tuner.
# Since our space is small, a random tuner is just okay.
#
# We only make 10 trials in this tutorial for demonstration. In practice,
# you can do more trials according to your time budget.
# We will log the tuning results into a cache file. This file can be
# used to get the best config later.
# logging config (for printing tuning log to screen)
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
# use local cpu, measure 5 times for every config to reduce variance
measure_option = autotvm.measure_option(mode='local',
number=5)
# begin tuning, log records to file `cache.tsv`
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=10,
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('cache.tsv')])
#########################################################################
# Finally we apply history best from the cache file and check its correctness.
# We can call the function :code:`matmul` directly under the
# :any:`autotvm.apply_history_best` context. When we call this function,
# it will query the dispatch context with its argument and get the best config
# with the same argument.
# apply history best from log file
with autotvm.apply_history_best('cache.tsv'):
with tvm.target.create("llvm"):
s, arg_bufs = matmul(N, L, M, 'float32')
func = tvm.build(s, arg_bufs)
# check correctness
a_np = np.random.uniform(size=(N, L)).astype(np.float32)
b_np = np.random.uniform(size=(L, M)).astype(np.float32)
c_np = a_np.dot(b_np)
c_tvm = tvm.nd.empty(c_np.shape)
func(tvm.nd.array(a_np), tvm.nd.array(b_np), c_tvm)
np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment