Commit 12839e6d by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Decouple build and run in measurement (#1661)

parent 38203a86
......@@ -16,6 +16,11 @@ tvm.autotvm.measure
.. autofunction:: tvm.autotvm.measure.create_measure_batch
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalBuilder
.. autoclass:: tvm.autotvm.measure.measure_methods.RPCRunner
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalRunner
tvm.autotvm.tuner
~~~~~~~~~~~~~~~~~
......
......@@ -22,7 +22,8 @@ from . import env
from . import tophub
# some shortcuts
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \
LocalBuilder, LocalRunner, RPCRunner
from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, \
......
"""Distributed executor infrastructure to scale up the tuning"""
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option
from .measure_methods import request_remote, check_remote, create_measure_batch, rpc
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option, \
create_measure_batch
from .measure_methods import LocalBuilder, LocalRunner, RPCRunner, request_remote
from .executor import Executor
from .local_executor import LocalExecutor
from .executor import Future, Executor
......@@ -37,7 +37,8 @@ def _execute_func(func, queue, args, kwargs):
res = exc
queue.put(res)
def timeout_monitor(queue, timeout, func, args, kwargs):
def call_with_timeout(queue, timeout, func, args, kwargs):
"""A wrapper to support timeout of a function call"""
# start a new process for timeout (cannot use thread because we have c function)
......@@ -45,17 +46,12 @@ def timeout_monitor(queue, timeout, func, args, kwargs):
p.start()
p.join(timeout=timeout)
alive = p.is_alive()
queue.put(executor.TimeoutError())
kill_child_processes(p.pid)
p.terminate()
p.join()
if alive:
queue.put(executor.TimeoutError())
else:
if queue.empty():
queue.put(executor.ExecutionError("Fatal error in local executor"))
class LocalFuture(executor.Future):
"""Local wrapper for the future
......@@ -134,7 +130,7 @@ class LocalExecutor(executor.Executor):
return LocalFutureNoFork(func(*args, **kwargs))
queue = Queue(2)
process = Process(target=timeout_monitor,
process = Process(target=call_with_timeout,
args=(queue, self.timeout, func, args, kwargs))
process.start()
return LocalFuture(process, queue)
# pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
"""User facing API for specifying how to measure the generated code"""
import multiprocessing
from collections import namedtuple
class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
......@@ -16,6 +17,7 @@ class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
Specific configuration.
"""
class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost", "timestamp"])):
"""
Stores all the results of a measurement
......@@ -23,8 +25,8 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
Parameters
----------
costs: Array of float or Array of Exception
If no error occurs for this measurement, it is an array of measured running times.
If some error occurs during the measurement, it is an array of the exception objections.
If no error occurs during measurement, it is an array of measured running times.
If an error occurs during measurement, it is an array of the exception objections.
error_no: int
Denote error type, defined by MeasureErrorNo
all_cost: float
......@@ -37,92 +39,185 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
class MeasureErrorNo(object):
"""Error type for MeasureResult"""
NO_ERROR = 0 # no error
INSTANTIATION_ERROR = 1 # error when calling template function
INSTANTIATION_ERROR = 1 # actively detected error in instantiating a template with a config
COMPILE_HOST = 2 # error when compiling code on host (e.g. tvm.build)
COMPILE_DEVICE = 3 # error when compiling code on device (e.g. opencl JIT on device)
COMPILE_DEVICE = 3 # error when compiling code on device (e.g. OpenCL JIT on the device)
RUNTIME_DEVICE = 4 # error when run program on device
WRONG_ANSWER = 5 # answer is wrong when compared to a golden output
FLEET_ERROR = 6 # error of measure infrastructure
BUILD_TIMEOUT = 6 # timeout during compilation
RUN_TIMEOUT = 7 # timeout during run
UNKNOWN_ERROR = 8 # unknown error
def measure_option(measure_func,
number=1,
repeat=1,
timeout=60,
n_parallel=1,
do_fork=True,
build_func='default',
check_correctness=False,
replay_db=None):
"""Configure how to do measurement
class Builder(object):
"""Builder that builds programs in tuning
Parameters
----------
measure_func: str or callable
'local': use the local device for measurement. The tuner will start a tracker
and a RPC server silently for the user.
callable: It is a callable function for measurement.
See the return value of measure/measure_methods.py::rpc for example.
number : int, optional
Number of times to do the measurement for average
repeat : int, optional
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
each of which is the average of `number` test run.
timeout: int, optional
Timeout for a whole batch. TimeoutError will be returned as the result if a
task timeouts.
timeout: float, optional
The timeout of a build task
n_parallel: int, optional
The number of measurement task that can run in parallel.
Set this according to the number of cpu cores (for compilation) and
the number of devices you have (for measuring generate code).
do_fork: bool, optional
Whether use multiprocessing (based on fork) for running measure jobs in parallel.
Set this to False if you want to debug (see trackback) or using fork is not suitable.
NOTE: If this is False, parallel and timeout do not work.
build_func: str or callable, optional
'default': call default builder. This works for normal target (llvm, cuda)
'ndk': use Android NDK to create shared library. Use this for android target.
callable: customized build function for other backends (e.g. VTA).
See measure/measure_methods.py::default_build_func for example.
check_correctness: bool, optional
Whether check correctness after measurement. This will use llvm cpu target to generate
reference output.
replay_db : Database, optional
The database that we retrieve saved MeasureResult from.
The number of tasks submitted in parallel
By default it will use all cpu cores
"""
def __init__(self, timeout=10, n_parallel=None):
self.timeout = timeout
self.n_parallel = n_parallel or multiprocessing.cpu_count()
self.build_kwargs = {}
self.task = None
def set_task(self, task, build_kwargs=None):
"""
Initialize for a new tuning task
Parameters
----------
task: Task
The tuning task
build_kwargs: dict, optional
The additional kwargs for build function
"""
self.task = task
self.build_kwargs = build_kwargs
def build(self, measure_inputs):
"""Build programs
Parameters
----------
measure_inputs: List of MeasureInput
The measure input
Returns
-------
options: dict
A dict to store all options
Note
----
To support customized measure, you can pass callable `measure_func` or
`build_func` in. The `measure_func` will call `build_func` to build binary library
and handle the logic of measurement.
Signature:
* measure_func (see the return value of measure/measure_methods.py::rpc for example)
def measure_func(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output):
return measure_results
* build_func (see measure/measure_methods.py::default_build_func for example)
def build_func(inp, tmp_dir, **kwargs):
return func, args, filename
"""
return {
'measure_func': measure_func,
'number': number,
'repeat': repeat,
'timeout': timeout,
'n_parallel': n_parallel,
'do_fork': do_fork,
'build_func': build_func,
'check_correctness': check_correctness,
'replay_db': replay_db,
build_results: List of BuildResult
The build result.
"""
raise NotImplementedError()
class Runner(object):
"""Runner that runs and measures the time cost of a generated program in tuning
Parameters
----------
timeout: float, optional
The timeout of a build task
n_parallel: int, optional
The number of tasks submitted in parallel
By default it will use all cpu cores
"""
def __init__(self, timeout=5, n_parallel=None):
self.timeout = timeout
self.n_parallel = n_parallel or multiprocessing.cpu_count()
self.task = None
def set_task(self, task):
"""
Initialize for a new tuning task
Parameters
----------
task: Task
The tuning task
"""
self.task = task
def get_build_kwargs(self):
"""
Get device specific build arguments (e.g. maximum shared memory size)
Returns
----------
kwargs: dict
The additional keyword arguments
"""
raise NotImplementedError()
def run(self, measure_inputs, build_results):
"""Run amd measure built programs
Parameters
----------
measure_inputs: List of MeasureInput
The raw measure input
build_results: List of BuildResults
The build results
Returns
-------
measure_results: List of MeasureResult
The final results of measurement
"""
raise NotImplementedError()
def measure_option(builder, runner):
"""
Set options for measure. To measure a config, we will build it and run it.
So we have to set options for these two steps.
They have their own options on timeout, parallel, etc.
Parameters
----------
builder: Builder
Specify how to build programs
runner: Runner
Specify how to run programs
"""
from .measure_methods import LocalBuilder, LocalRunner
if isinstance(builder, str):
if builder == 'local':
builder = LocalBuilder()
else:
raise ValueError("Invalid builder: " + builder)
if isinstance(runner, str):
if runner == 'local':
runner = LocalRunner()
else:
raise ValueError("Invalid runner: " + runner)
opt = {
'builder': builder,
'runner': runner,
}
return opt
def create_measure_batch(task, option):
"""Get a standard measure_batch function.
Parameters
----------
task: tvm.autotvm.task.Task
The tuning task
option: dict
The option for measuring generated code.
You should use the return value of function :any:`measure_option` for this argument.
Returns
-------
measure_batch: callable
a callback function to measure a batch of configs
"""
builder = option['builder']
runner = option['runner']
attach_objects = runner.set_task(task)
# feed device related information from runner to builder
# (e.g. max shared memory for validity checking)
build_kwargs = runner.get_build_kwargs()
builder.set_task(task, build_kwargs)
def measure_batch(measure_inputs):
build_results = builder.build(measure_inputs)
results = runner.run(measure_inputs, build_results)
return results
measure_batch.n_parallel = builder.n_parallel
measure_batch.attach_objects = attach_objects
return measure_batch
# pylint: disable=consider-using-enumerate,invalid-name,too-many-function-args
# pylint: disable=invalid-name,too-many-function-args,too-many-nested-blocks
"""
Functions that run on executor for measurement.
These functions are responsible for building tvm module, uploading it to
remote devices, recording the running time costs and checking the correctness of output
These functions are responsible for building the tvm module, uploading it to
remote devices, recording the running time costs, and checking the correctness of the output.
"""
import logging
import shutil
import os
import threading
import time
from random import getrandbits
import threading
from collections import namedtuple
import tempfile
import numpy as np
from ... import ir_pass, build, build_config, nd, context, TVMError, register_func, \
target as _target, rpc as _rpc
from ...contrib import nvcc, util, ndk
from ... import ir_pass, build, build_config, nd, TVMError, register_func, \
rpc as _rpc, target as _target
from ...contrib import nvcc, ndk
from ..util import get_const_tuple
from ..env import AutotvmGlobalScope
from ..task.space import InstantiationError
from .measure import MeasureResult, MeasureErrorNo
from .measure import MeasureResult, MeasureErrorNo, Builder, Runner
from .local_executor import LocalExecutor
logger = logging.getLogger('autotvm')
class HashMismatchError(ValueError):
"""Raised when the code hash of a submitted config doesn't match that on the
measure side """
pass
def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
"""request a remote session
class BuildResult(namedtuple("BuildResult", ('filename', 'arg_info', 'error', 'time_cost'))):
"""
Stores all the necessary inputs for a measurement.
Parameters
----------
device_key: string
device key of registered device in tracker
tracker_addr: Tuple(string, int), optional
The address of rpc tracker in (host, port) format.
If is none, will use environment variable "TVM_TRACKER_HOST"
and "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this session (units: seconds)
Returns
------
session: RPCSession
filename : str
The filename of generated library
arg_info : Tuple
The shape and dtype information of tvm tensor arguments
error : Exception
The error happens during compilation.
time_cost : float
The time cost of building
"""
# connect to the tracker
if tracker_addr:
host = tracker_addr[0] or os.environ['TVM_TRACKER_HOST']
port = tracker_addr[1] or int(os.environ['TVM_TRACKER_PORT'])
else:
host = os.environ['TVM_TRACKER_HOST']
port = int(os.environ['TVM_TRACKER_PORT'])
tracker = _rpc.connect_tracker(host, port)
remote = tracker.request(device_key, priority=priority,
session_timeout=timeout)
return remote
def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10):
"""
Check the availability of a remote device
class LocalBuilder(Builder):
"""Run compilation on local machine
Parameters
----------
target: Target
The wanted compilation target
device_key: string
device key of registered device in tracker
tracker_addr: Tuple(string, int), optional
The address of rpc tracker in (host, port) format.
If is none, will use environment variable "TVM_TRACKER_HOST"
and "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this check (units: seconds).
If time is out, a RuntimeError will be raised.
timeout: float
The timeout of a compilation
n_parallel: int
The number of tasks run in parallel. "None" will use all cpu cores
build_func: callable or str
If is 'default', use default build function
If is 'ndk', use function for android ndk
If is callable, use it as custom build function
"""
def _check():
remote = request_remote(device_key, tracker_addr, priority)
remote.context(str(target))
t = threading.Thread(target=_check,)
t.start()
t.join(timeout)
return not t.is_alive()
def __init__(self, timeout=10, n_parallel=None, build_func='default'):
super(LocalBuilder, self).__init__(timeout, n_parallel)
def create_measure_batch(task, option):
"""Get a standard measure_batch function.
if isinstance(build_func, str):
if build_func == 'default':
build_func = default_build_func
elif build_func == 'ndk':
build_func = android_ndk_build_func
else:
raise ValueError("Invalid build_func" + build_func)
Parameters
----------
task: tvm.autotvm.task.Task
The tuning task
option: dict
The option for measuring generated code.
You should use the return value of function :any:`measure_option` for this argument.
self.build_func = build_func
self.tmp_dir = tempfile.mkdtemp()
self.executor = LocalExecutor(timeout=timeout)
Returns
-------
measure_batch: callable
a callback function to measure a batch of configs
"""
from ..database import filter_inputs
def build(self, measure_inputs):
results = []
measure_func = option['measure_func']
number, repeat = option['number'], option['repeat']
timeout, n_parallel, do_fork = option['timeout'], option['n_parallel'], option['do_fork']
build_func = option['build_func']
check_correctness = option['check_correctness']
replay_db = option['replay_db']
for i in range(0, len(measure_inputs), self.n_parallel):
futures = []
for inp in measure_inputs[i:i + self.n_parallel]:
ret = self.executor.submit(self.build_func,
inp,
self.tmp_dir,
**self.build_kwargs)
futures.append(ret)
executor = LocalExecutor(timeout=timeout, do_fork=do_fork)
for future in futures:
res = future.get()
if isinstance(res, Exception):
# timeout or fleet error, return MeasureResult directly
results.append(MeasureResult((res,), MeasureErrorNo.BUILD_TIMEOUT,
self.timeout, time.time()))
elif res.error is not None:
# instantiation errorD
if isinstance(res.error, InstantiationError):
results.append(MeasureResult((res.error,),
MeasureErrorNo.INSTANTIATION_ERROR,
res.time_cost, time.time()))
else:
if "InstantiationError" in str(res.error):
msg = str(res.error)
try:
msg = msg.split('\n')[-2].split(": ")[1]
except Exception: # pylint: disable=broad-except
pass
results.append(MeasureResult((InstantiationError(msg),),
MeasureErrorNo.INSTANTIATION_ERROR,
res.time_cost, time.time()))
else: # tvm error
results.append(MeasureResult((res.error,),
MeasureErrorNo.COMPILE_HOST,
res.time_cost, time.time()))
else:
# return BuildResult
results.append(res)
# convert convenient string to function object
attach_objects = None
if measure_func == 'local':
# start temporary rpc tracker and rpc server for the user
from ...rpc.tracker import Tracker
from ...rpc.server import Server
return results
tracker = Tracker('localhost', port=9000, port_end=10000, silent=True)
device_key = '$local$device$%d' % tracker.port
server = Server('localhost', port=9000, port_end=10000,
key=device_key,
use_popen=True, silent=True,
tracker_addr=(tracker.host, tracker.port))
def __del__(self):
shutil.rmtree(self.tmp_dir)
measure_func = rpc(device_key, tracker.host, tracker.port)
attach_objects = (server, tracker)
build_kwargs = {}
if build_func == 'default':
build_func = default_build_func
if build_func == 'ndk':
build_func = default_build_func
build_kwargs['use_ndk'] = True
class RPCRunner(Runner):
"""Run generated code on remove devices.
This function will ask a RPC Tracker to get device for measurement.
# check the availability of remote devices
if hasattr(measure_func, 'rpc_info'):
rpc_info = measure_func.rpc_info
if check_remote(task.target, rpc_info['key'], (rpc_info['host'], rpc_info['port'])):
Parameters
----------
timeout: float
The timeout of a compilation
n_parallel: int
The number of tasks run in parallel. "None" will use all cpu cores
key: str
The key of the device registered in the tracker
host: str
The host address of RPC Tracker
port: int
The port of RPC Tracker
number : int, optional
Number of times to do measurement for tasking average
repeat : int, optional
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
min_repeat_ms : float, optional
Minimum duration of a timer measurement in milliseconds.
When the run time of a measurement trial falls below this time, the
`number` parameter will be automatically increased.
Set this to improve the accuracy of perf measurement, e.g., when timers
are not precise enough to capture short-running tasks. This parameter is
also critical when devices need a certain minimum running time to "warm
up," such as GPUs that need time to reach a performance power state.
cooldown_interval: float, optional
The cool down interval between two measurements.
check_correctness: bool, optional
Whether check correctness after measurement. This will use llvm cpu target to
call your template and get the reference output.
This can work for TOPI templates, but may not work for your custom template.
"""
def __init__(self,
key, host, port, priority=1,
timeout=10, n_parallel=None,
number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1,
check_correctness=False):
super(RPCRunner, self).__init__(timeout, n_parallel)
self.key = key
self.host = host
self.port = port
self.priority = priority
self.timeout = timeout
self.number = number
self.repeat = repeat
self.min_repeat_ms = min_repeat_ms
self.cur_number = number
self.ref_input = None
self.ref_output = None
self.check_correctness = check_correctness
self.cooldown_interval = cooldown_interval
self.executor = LocalExecutor()
def set_task(self, task):
self.task = task
self.cur_number = self.number
if check_remote(task.target, self.key, self.host, self.port):
logger.info("Get devices for measurement successfully!")
else:
raise RuntimeError("Cannot get remote devices from the tracker. "
......@@ -155,225 +198,261 @@ def create_measure_batch(task, option):
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
"and make sure you have free devices on the queue status.")
# add device info of cuda and opencl target
if ('cuda' in task.target.keys or 'opencl' in task.target.keys) \
and hasattr(measure_func, 'rpc_info'):
rpc_info = measure_func.rpc_info
add_gpu_target_info(task.target, rpc_info["key"], (rpc_info["host"], rpc_info["port"]),
build_kwargs)
if check_correctness:
if self.check_correctness:
# use llvm cpu to generate a reference input/output
# this option works for tuning topi, but might not work for you custom op
with _target.create("llvm"):
s, arg_bufs = task.instantiate(task.config_space.get(0))
ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype)
self.ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype)
for x in arg_bufs]
func = build(s, arg_bufs, "llvm")
tvm_buf = [nd.array(x) for x in ref_input]
tvm_buf = [nd.array(x) for x in self.ref_input]
func(*tvm_buf)
ref_output = [x.asnumpy() for x in tvm_buf]
else:
ref_input = ref_output = None
self.ref_output = [x.asnumpy() for x in tvm_buf]
def get_build_kwargs(self):
kwargs = {}
if 'cuda' in self.task.target.keys or 'opencl' in self.task.target.keys:
remote = request_remote(self.key, self.host, self.port)
ctx = remote.context(str(self.task.target), 0)
max_dims = ctx.max_thread_dimensions
kwargs['check_gpu'] = {
'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
'max_threads_per_block': ctx.max_threads_per_block,
'max_thread_x': max_dims[0],
'max_thread_y': max_dims[1],
'max_thread_z': max_dims[2],
}
if 'cuda' in self.task.target.keys:
kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.'))
def measure_batch(measure_inputs):
"""measure the time cost for a batch of configs in real machines"""
if replay_db is not None:
partial_results, measure_inputs = \
filter_inputs(replay_db, measure_inputs, retry=False)
return kwargs
# launch measure jobs in parallel
pack_size = getattr(measure_func, "pack_size", 1) # measure `pack_size` inputs in one job
def run(self, measure_inputs, build_results):
results = []
remote_args = (self.key, self.host, self.port, self.priority, self.timeout)
for i in range(0, len(measure_inputs), self.n_parallel):
futures = []
for i in range(0, len(measure_inputs), pack_size):
input_pack = measure_inputs[i:i + pack_size]
ret = executor.submit(
measure_func,
input_pack,
build_func,
build_kwargs,
number,
repeat,
ref_input,
ref_output)
for measure_inp, build_res in zip(measure_inputs[i:i+self.n_parallel],
build_results[i:i+self.n_parallel]):
ret = self.executor.submit(run_through_rpc,
measure_inp,
build_res,
self.cur_number,
self.repeat,
self.cooldown_interval,
remote_args,
self.ref_input,
self.ref_output)
futures.append(ret)
# transform results
results = []
for future in futures:
result = future.get()
if isinstance(result, Exception):
tstamp = time.time()
results.extend([MeasureResult((result,), MeasureErrorNo.FLEET_ERROR,
timeout, tstamp)] * pack_size)
res = future.get()
if isinstance(res, Exception): # executor error or timeout
results.append(MeasureResult((str(res),), MeasureErrorNo.RUN_TIMEOUT,
self.timeout, time.time()))
else:
results.extend(result)
if replay_db is not None:
result_idx = 0
for i in range(len(partial_results)):
if partial_results[i] is None:
partial_results[i] = results[result_idx]
result_idx += 1
return partial_results
return results
measure_batch.n_parallel = n_parallel
# attach server and tracker object to avoid them of being garbage-collected
measure_batch.attach_objects = attach_objects
return measure_batch
results.append(res)
# If some runs were too fast, do remeasure for them
# to meet the requirement of `min_repeat_ms`
remeasure = np.zeros((len(measure_inputs),), dtype=np.bool)
pre_number = next_number = self.cur_number
min_repeat_duration = self.min_repeat_ms / 1000.0
for i, res in enumerate(results):
if res.error_no == MeasureErrorNo.NO_ERROR:
if np.mean(res.costs) * pre_number <= min_repeat_duration:
next_number = max(next_number,
int(np.ceil(min_repeat_duration / np.mean(res.costs))))
remeasure[i] = True
if pre_number != next_number:
self.cur_number = next_number
msg = "increasing number to %d" % self.cur_number
logger.info(msg)
re_measure_inputs = [x for i, x in enumerate(measure_inputs) if remeasure[i]]
re_build_results = [x for i, x in enumerate(build_results) if remeasure[i]]
re_res = self.run(re_measure_inputs, re_build_results)
ct = 0
for i, rerun in enumerate(remeasure):
if rerun:
results[i] = re_res[ct]
ct += 1
def rpc(key,
host=None,
port=None,
priority=1,
session_timeout=60,
pack_size=1):
"""
Create a standard measure_func which uses RPC Tracker for measurement.
This measure_func will request a device from the RPC Tracker and
upload the built binary library to that device for measurement.
return results
Parameters
----------
key: str
The registered key of the device in tracker. The tuner will request devices for
measurement by this key.
host: str, optional
The hostname of RPC Tracker. If not set, will use environment variable "TVM_TRACKER_HOST"
port: int, optional
The port of RPC Tracker. If not set, will use environment variable "TVM_TRACKER_PORT"
priority: int, optional
Priority of this task, used by scheduler in tracker
session_timeout: int, optional
Timeout of rpc session
pack_size: int, optional
The number of configs measure in one RPC session.
Usually this can be set to 1. If your device has high overhead to establish a
rpc connection, set this higher.
"""
def fmeasure(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output):
"""Do measurement for a list of inputs inside a same RPC session.
class LocalRunner(RPCRunner):
"""Run generated code on local devices.
Parameters
----------
input_pack: List of MeasureInput
The inputs of measurement
build_func: callable
Function for building the code. see :any:`default_build_func` for example
build_kwargs: dict
Extra arguments for build_func
timeout: float
The timeout of a compilation
number : int, optional
Number of times to do the measurement for average
Number of times to do measurement for tasking average
repeat : int, optional
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
each of which is the average of `number` test run.
ref_input: List of numpy array
Reference input for correctness check
ref_output: List of numpy array
Reference output for correctness check
min_repeat_ms : float, optional
Minimum duration of a timer measurement in milliseconds.
When the run time of a measurement trial falls below this time, the
`number` parameter will be automatically increased.
Set this to improve the accuracy of perf measurement, e.g., when timers
are not precise enough to capture short-running tasks. This parameter is
also critical when devices need a certain minimum running time to "warm
up," such as GPUs that need time to reach a performance power state.
cooldown_interval: float, optional
The cool down interval between two measurements.
check_correctness: bool, optional
Whether check correctness after measurement. This will use llvm cpu target to
call your template and get the reference output.
This can work for TOPI templates, but may not work for your custom template.
Note
----
This is a "fake" local mode. We start a silent rpc tracker and rpc server
for the user. In this way we reuse timeout/isolation mechanism in RPC infrastructure.
"""
def __init__(self,
timeout=10,
number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1,
check_correctness=False):
super(LocalRunner, self).__init__('', None, None, 0,
timeout=timeout, n_parallel=1,
number=number, repeat=repeat,
min_repeat_ms=min_repeat_ms,
cooldown_interval=cooldown_interval,
check_correctness=check_correctness)
self.tracker = None
self.server = None
def set_task(self, task):
self.task = task
Returns
-------
results: List of MeasureResult
The results for input_pack
from ...rpc.tracker import Tracker
from ...rpc.server import Server
tracker = Tracker('localhost', port=9000, port_end=10000, silent=True)
device_key = '$local$device$%d' % tracker.port
server = Server('localhost', port=9000, port_end=10000,
key=device_key,
use_popen=True, silent=True,
tracker_addr=(tracker.host, tracker.port))
self.key = device_key
self.host = tracker.host
self.port = tracker.port
super(LocalRunner, self).set_task(task)
return server, tracker
def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None):
"""Common part for building a configuration"""
target, task, config = measure_input
with target:
s, args = task.instantiate(config)
# check invalidity of template and code hash consistency
if not config.valid():
raise InstantiationError(config.errors)
opts = build_option or {}
if check_gpu: # Add verify pass to filter out invalid configs in advance.
opts["add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))]
if cuda_arch:
set_cuda_target_arch(cuda_arch)
with build_config(**opts):
func = build(s, args, target_host=task.target_host)
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
def default_build_func(measure_input, tmp_dir, **kwargs):
"""
remote_args = (key, (host, port), priority, session_timeout)
Default build func. This can work for cuda, opencl, llvm backend
Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
"""
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
res = _measure_common(input_pack, build_func, build_kwargs, number, repeat,
ref_input, ref_output,
remote_args)
return res
fmeasure.pack_size = pack_size
fmeasure.rpc_info = {"key": key, "host": host, "port": port}
return fmeasure
def android_ndk_build_func(measure_input, tmp_dir, **kwargs):
"""
Build function for android device using ndk.
Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
"""
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.so" % getrandbits(64))
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename, ndk.create_shared)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
ref_input=None, ref_output=None, remote_args=None):
"""Measure the time cost for a pack of inputs.
(Note: A pack is a list of inputs which will be measured inside a same RPC session)
def run_through_rpc(measure_input, build_result,
number, repeat, cooldown_interval,
remote_args, ref_input=None, ref_output=None):
"""Run a generated library through rpc
Parameters
----------
input_pack : list of MeasureInput
The inputs we need to evaluate
build_func : function takes MeasureInput returns tuple of (time_func, ctx, args)
The build function used to build each input.
build_kwargs: Dict
The extra keyword arguments to build_func
measure_input: MeasureInput
The raw measure input
build_result: BuildResult
The result returned from Builder. This contains the path to the generated library.
number : int, optional
Number of times to do the measurement for average
Number of times to do measurement for tasking average
repeat : int, optional
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
each of which is the average of `number` test run.
ref_input: Array of np.ndarray, optional
Reference input for checking correctness
ref_output: Array of np.ndarray, optional
Reference output for checking correctness
remote_args: Tuple, optional
The arguments to request_remote. If is not None, will use remote rpc devices.
Returns
-------
res_pack : Array of MeasureResult
The list of results of measurement.
cooldown_interval: float
The cool down interval between two measurements
remote_args: Tuple
The argument for request_remote
ref_input: List of np.ndarray
The reference input used for checking correctness
ref_output: List of np.ndarray
The reference output used for checking correctness
"""
res_pack = []
tmp_dir = util.tempdir() if remote_args else None
assert len(input_pack) == 1, "Only supports input_pack == 1 for now"
if isinstance(build_result, MeasureResult):
return build_result
for inp in input_pack:
tic = time.time()
# build function
try:
func, arg_bufs, filename = build_func(inp, tmp_dir, **build_kwargs)
except TVMError as exc:
tstamp = time.time()
msg = str(exc)
if "Stack trace returned" in msg:
msg = msg[:msg.index("Stack trace returned")]
if "InstantiationError" in msg:
try:
msg = msg.split('\n')[-2].split(": ")[1]
except Exception: # pylint: disable=broad-except
pass
res_pack.append(MeasureResult((InstantiationError(msg),),
MeasureErrorNo.INSTANTIATION_ERROR,
tstamp - tic, tstamp))
else:
res_pack.append(MeasureResult((RuntimeError(msg),),
MeasureErrorNo.COMPILE_HOST,
tstamp - tic, tstamp))
continue
except InstantiationError as e:
tstamp = time.time()
res_pack.append(MeasureResult((InstantiationError(str(e)),),
MeasureErrorNo.INSTANTIATION_ERROR,
tstamp - tic, tstamp))
continue
# measure time
errno = MeasureErrorNo.NO_ERROR
try:
# upload built module
if remote_args:
remote = request_remote(*remote_args)
remote.upload(tmp_dir.relpath(filename))
func = remote.load_module(filename)
ctx = remote.context(str(inp.target), 0)
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat)
else:
ctx = context(str(inp.target), 0)
remote.upload(build_result.filename)
func = remote.load_module(os.path.split(build_result.filename)[1])
ctx = remote.context(str(measure_input.target), 0)
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat)
......@@ -381,8 +460,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
if ref_input:
args = [nd.array(x, ctx=ctx) for x in ref_input]
else:
args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype, ctx=ctx)
for x in arg_bufs]
args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info]
costs = time_f(*args).results
if len(costs) > 2: # remove largest and smallest value to reduce variance
......@@ -402,91 +480,78 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
msg = msg[:msg.index("Stack trace returned")]
if "CUDA Source" in msg:
msg = msg[:msg.index("CUDA Source")]
costs = (RuntimeError(msg),)
costs = (RuntimeError(msg[:1024]),)
errno = MeasureErrorNo.RUNTIME_DEVICE
tstamp = time.time()
res_pack.append(MeasureResult(costs, errno, tstamp - tic, tstamp))
return res_pack
time.sleep(cooldown_interval)
return MeasureResult(costs, errno, tstamp - tic + build_result.time_cost, tstamp)
def default_build_func(inp, tmp_dir=None, **kwargs):
"""Build function module. Exception will be raised when any error occurs
def request_remote(device_key, host=None, port=None, priority=1, timeout=60):
"""Request a remote session
Parameters
----------
inp: MeasureInput
The input of this measurement
tmp_dir: tvm.contrib.util.TempDirectory, optional
The temporary directory for exporting built binary library.
If is not None (in RPC mode), the library in this directory will be uploaded to
remote devices.
kwargs: Dict, optional
Other extra arguments
device_key: string
The device key of registered device in tracker
host: host, optional
The host address of rpc tracker.
If is none, will use environment variable "TVM_TRACKER_HOST"
port: int, optional
The port of rpc tracker.
If is none, will use environment variable "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this session (units: second)
Returns
-------
func: Function
TVM built function. Typically this is the return value of tvm.build.
args: Array of Buffer or Tensor
The argument list for the function. Typically this is the second argument of tvm.build.
filename: str
The filename of the output build library
------
session: RPCSession
"""
# build function
with inp.target:
s, args = inp.task.instantiate(inp.config)
# check invalidity of template and code hash consistency
if not inp.config.valid():
raise InstantiationError(inp.config.errors)
code_hash = getattr(s, 'code_hash', None)
if inp.config.code_hash != code_hash:
raise HashMismatchError('got {0:s}, expected {1:s}'
.format(str(inp.config.code_hash), str(code_hash)))
opts = {}
if "check_gpu" in kwargs: # Add verify pass to filter out invalid configs in advance.
opts["add_lower_pass"] = [(2, gpu_verify_pass(**kwargs['check_gpu']))]
if 'cuda_arch' in kwargs:
set_cuda_target_arch(kwargs['cuda_arch'])
with build_config(**opts):
func = build(s, args, target_host=inp.task.target_host)
# export library to temp directory
if tmp_dir:
if kwargs.get('use_ndk', False): # for Android NDK
filename = "tmp_func_%0x.so" % getrandbits(64)
func.export_library(tmp_dir.relpath(filename), ndk.create_shared)
else:
filename = "tmp_func_%0x.tar" % getrandbits(64)
func.export_library(tmp_dir.relpath(filename))
else:
filename = None
# connect to the tracker
host = host or os.environ['TVM_TRACKER_HOST']
port = port or int(os.environ['TVM_TRACKER_PORT'])
return func, args, filename
tracker = _rpc.connect_tracker(host, port)
remote = tracker.request(device_key, priority=priority,
session_timeout=timeout)
return remote
def add_gpu_target_info(target, device_key, rpc_tracker_addr, kwargs):
"""Add device info for gpu target.
The info will be used to check the validity of generated code."""
remote = request_remote(device_key, rpc_tracker_addr)
ctx = remote.context(str(target), 0)
max_dims = ctx.max_thread_dimensions
kwargs['check_gpu'] = {
'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
'max_threads_per_block': ctx.max_threads_per_block,
'max_thread_x': max_dims[0],
'max_thread_y': max_dims[1],
'max_thread_z': max_dims[2],
}
def check_remote(target, device_key, host=None, port=None, priority=2, timeout=10):
"""
Check the availability of a remote device
if 'cuda' in target.keys:
kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.'))
Parameters
----------
target: Target
The wanted compilation target
device_key: string
device key of registered device in tracker
host: host, optional
The host address of rpc tracker.
If is none, will use environment variable "TVM_TRACKER_HOST"
port: int, optional
The port address of rpc tracker.
If is none, will use environment variable "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this check (units: seconds).
def set_cuda_target_arch(arch):
"""set target architecture of nvcc compiler"""
AutotvmGlobalScope.current.cuda_target_arch = arch
Returns
-------
available: bool
True if can find available device
"""
def _check():
remote = request_remote(device_key, host, port, priority)
remote.context(str(target))
t = threading.Thread(target=_check,)
t.start()
t.join(timeout)
return not t.is_alive()
@register_func
......@@ -496,6 +561,17 @@ def tvm_callback_cuda_compile(code):
return ptx
def set_cuda_target_arch(arch):
"""set target architecture of nvcc compiler
Parameters
----------
arch: str
The argument of nvcc -arch. (e.g. "sm_51", "sm_62")
"""
AutotvmGlobalScope.current.cuda_target_arch = arch
def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block.
......
......@@ -22,7 +22,7 @@ class GATuner(Tuner):
mutation_prob: float
probability of mutation of a knob in a gene
"""
def __init__(self, task, pop_size, elite_num=3, mutation_prob=0.1):
def __init__(self, task, pop_size=100, elite_num=3, mutation_prob=0.1):
super(GATuner, self).__init__(task)
# algorithm configurations
......
......@@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
new_scores = model.predict(new_points)
ac_prob = np.exp((new_scores - scores) / (t + 1e-2))
ac_prob = np.exp(np.minimum((new_scores - scores) / (t + 1e-5), 1))
ac_index = np.random.random(len(ac_prob)) < ac_prob
points[ac_index] = new_points[ac_index]
......
......@@ -103,34 +103,7 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None):
target=target, target_host=target_host)
return task, target
def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()
def custom_measure(input_pack, build_func, build_args, number, repeat,
ref_input, ref_output):
from tvm.autotvm import MeasureResult
results = []
for inp in input_pack:
tic = time.time()
# do nothing
time.sleep(0.001)
results.append(MeasureResult([time.time() - tic], 0,
time.time() - tic, time.time()))
return results
measure_option = autotvm.measure_option(custom_measure)
logging.info("%s", task.config_space)
# new tuner and recorder
for tuner_class in [autotvm.tuner.RandomTuner, autotvm.tuner.GridSearchTuner]:
tuner = tuner_class(task)
tuner.tune(n_trial=10, measure_option=measure_option)
assert tuner.best_flops > 1
def test_tuning_with_measure():
def test_tuning():
def check(target, target_host):
ctx = tvm.context(target, 0)
if not ctx.exist:
......@@ -141,12 +114,12 @@ def test_tuning_with_measure():
task, target = get_sample_task(target, target_host)
logging.info("%s", task.config_space)
measure_option = autotvm.measure_option('local',
timeout=4,
number=2)
measure_option = autotvm.measure_option(
autotvm.LocalBuilder(),
autotvm.LocalRunner())
tuner = RandomTuner(task)
tuner.tune(n_trial=10, measure_option=measure_option)
tuner.tune(n_trial=20, measure_option=measure_option)
check("cuda", None)
check("opencl", None)
......@@ -155,6 +128,4 @@ if __name__ == "__main__":
# only print log when invoked from main
logging.basicConfig(level=logging.DEBUG)
test_task_tuner_without_measurement()
test_tuning_with_measure()
test_tuning()
......@@ -32,6 +32,25 @@ def matmul(N, L, M, dtype):
return s, [A, B, C]
@autotvm.template
def bad_matmul(N, L, M, dtype):
if 'bad_device' in tvm.target.current_target().keys:
A = tvm.placeholder((N, L), name='A', dtype=dtype)
B = tvm.placeholder((L, M), name='B', dtype=dtype)
k = tvm.reduce_axis((0, L-1), name='k')
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
s = tvm.create_schedule(C.op)
# schedule
y, x = s[C].op.axis
cfg = autotvm.get_config()
cfg.define_split("tile_y", y, num_outputs=2)
cfg.define_split("tile_x", x, num_outputs=2)
return s, [A, B, C]
return matmul(N, L, M, dtype)
def get_sample_task(n=128):
"""return a sample task for testing"""
target = tvm.target.create("llvm")
......
"""Test database"""
import copy
import logging
import time
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm import database
from tvm.autotvm.measure.measure_methods import HashMismatchError
from tvm.autotvm.record import encode, MeasureInput, MeasureResult
from tvm.autotvm.record import encode, MeasureResult
from test_autotvm_common import get_sample_task, get_sample_records
from test_autotvm_common import get_sample_records
def test_save_load():
logging.info("test basic db load/save ...")
......@@ -35,66 +29,6 @@ def test_save_load():
TRIAL_LIMIT = 2
def test_db_filter():
logging.info("test db filter ...")
# Pick a GPU target because there are more likely to be failures/invalid configs
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
batch_size = 2
measure_option = autotvm.measure_option('local', do_fork=False, timeout=2)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
ct = 0
all_inputs = list()
all_results = list()
batches = list()
tuner = autotvm.tuner.RandomTuner(task)
while ct < TRIAL_LIMIT:
inputs = list()
for i in range(batch_size):
cfg = tuner.next_batch(1)[0]
inputs.append((MeasureInput(target, task, cfg)))
all_inputs.append(inputs[-1])
batches.append(inputs)
results = measure_batch(inputs)
all_results += results
ct += 1
del measure_batch
db = database.DummyDatabase()
db.flush()
# First setting, memoize one input at a time, check that each is saved and replayed
measure_option = autotvm.measure_option('local', do_fork=False, timeout=2, replay_db=db)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
for i in range(len(all_inputs)+1):
db.flush()
for j in range(i):
db.save(all_inputs[j], all_results[j])
for k in range(len(batches)):
batch = batches[k]
batch_result = measure_batch(batch)
for l in range(batch_size):
all_idx = k*batch_size + l
assert batch_result[l] is not None
if all_idx < i:
assert encode(batch[l], batch_result[l]) == encode(batch[l], all_results[all_idx]), \
"(no retry) EXPECTED MATCH, GOT MISMATCH"
else:
assert encode(batch[l], batch_result[l]) != encode(batch[l], all_results[all_idx]), \
"(no retry) EXPECTED MISMATCH, GOT MATCH"
del measure_batch
def test_db_hash():
logging.info("test db hash check ...")
inp1, res1 = get_sample_records(1)[0]
......@@ -149,89 +83,8 @@ def test_db_latest_all():
assert encode(inp1, load4[1]) == encode(inp1, res2)
assert encode(inp1, load4[2]) == encode(inp1, res3)
def test_db_save_replay():
logging.info("test db save (from measure_batch) and replay ...")
_db = database.DummyDatabase()
_db.flush()
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option('local',
do_fork=False,
timeout=2,
replay_db=_db)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
batch_size = 2
ct = 0
all_inputs = list()
all_results = list()
batches = list()
tuner = autotvm.tuner.RandomTuner(task)
while ct < TRIAL_LIMIT:
inputs = list()
for i in range(batch_size):
cfg = tuner.next_batch(1)[0]
inputs.append((MeasureInput(target, task, cfg)))
all_inputs.append(inputs[-1])
batches.append(inputs)
results = measure_batch(inputs)
all_results += results
ct += 1
callback = autotvm.callback.log_to_database(_db)
callback(None, all_inputs, all_results)
assert len(_db.db.keys()) == batch_size * TRIAL_LIMIT, \
"%d vs %d" % (len(_db.db.keys()), batch_size * TRIAL_LIMIT)
all_results_2 = measure_batch(all_inputs)
all_results_3 = measure_batch(all_inputs)
for i in range(len(all_results)):
encr1 = encode(all_inputs[i], all_results[i])
encr2 = encode(all_inputs[i], all_results_2[i])
encr3 = encode(all_inputs[i], all_results_3[i])
assert encr1 == encr2, "EXPECTED MATCH WITH SAVE REPLAY (first replay), got MISMATCH"
assert encr2 == encr3, "EXPECTED MATCH WITH SAVE REPLAY (second replay), got MISMATCH"
del measure_batch
def test_check_hashmismatch():
logging.info("test hash mismatch check")
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option('local', do_fork=False)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
inputs = list()
cfg = task.config_space.get(np.random.randint(len(task.config_space)))
# notvalidh is not a valid CRC32 hash (not hex)
cfg.code_hash = 'notvalidh'
inputs.append((MeasureInput(target, task, cfg)))
try:
results = measure_batch(inputs)
assert False, "HashMismatchError should be raised"
except HashMismatchError:
pass
del measure_batch
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
test_save_load()
test_db_filter()
test_db_hash()
test_db_latest_all()
test_db_save_replay()
test_check_hashmismatch()
"""Test builder and runner"""
import logging
import time
import numpy as np
import tvm
from tvm import autotvm
from test_autotvm_common import get_sample_task, bad_matmul
from tvm.autotvm.measure.measure import Runner, MeasureResult, MeasureErrorNo
def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()
class DummyRunner(Runner):
def __init__(self):
super(DummyRunner, self).__init__(1, 1)
def run(self, measure_inputs, build_results):
return [MeasureResult((np.random.random(),), 0, 0.2, time.time())
for _ in range(len(measure_inputs))]
def get_build_kwargs(self):
return {}
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=DummyRunner()
)
logging.info("%s", task.config_space)
for tuner_class in [autotvm.tuner.RandomTuner,
autotvm.tuner.GridSearchTuner,
autotvm.tuner.GATuner,
autotvm.tuner.XGBTuner]:
tuner = tuner_class(task)
tuner.tune(n_trial=10, measure_option=measure_option)
assert tuner.best_flops > 1
def test_check_correctness():
task, target = get_sample_task()
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(check_correctness=True)
)
def _callback_correct(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
assert res.error_no == 0
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=2, measure_option=measure_option,
callbacks=[_callback_correct])
# a bad template
n = 128
target = tvm.target.create("llvm -device=bad_device")
task = autotvm.task.create(bad_matmul, args=(n, n, n, 'float32'), target=target)
def _callback_wrong(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
assert res.error_no == MeasureErrorNo.WRONG_ANSWER
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=2, measure_option=measure_option,
callbacks=[_callback_wrong])
def test_min_repeat_ms():
task, target = get_sample_task()
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(number=1, min_repeat_ms=100)
)
def _callback(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
if res.error_no != 0:
continue
assert 1000 * np.mean(res.costs) * \
measure_option['runner'].cur_number >= 100
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=5, measure_option=measure_option,
callbacks=[_callback])
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
test_task_tuner_without_measurement()
test_check_correctness()
test_min_repeat_ms()
......@@ -137,7 +137,10 @@ if __name__ == '__main__':
print(task.config_space)
measure_option = autotvm.measure_option(
measure_func='local', number=10, n_parallel=8, timeout=20)
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
)
log_name = 'gemm_int8.log'
if DO_TUNING:
tuner = autotvm.tuner.XGBTuner(task)
......
......@@ -164,12 +164,12 @@ task = autotvm.task.create(conv2d_no_batching,
target='cuda')
print(task.config_space)
# use local gpu, measure 5 times for every config to reduce variance
# run 8 parallel threads for compilation
measure_option = autotvm.measure_option('local',
number=5,
n_parallel=8,
timeout=20)
# use local gpu, measure 10 times for every config to reduce variance
# The timeout of compiling a program is 10 seconds, the timeout for running is 4 seconds
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
)
# begin tuning, log records to file `conv2d.log`
tuner = autotvm.tuner.XGBTuner(task)
......
......@@ -65,15 +65,20 @@ def get_network(name, batch_size):
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)
if name =='resnet-18':
net, params = nnvm.testing.resnet.get_workload(num_layers=18, batch_size=batch_size)
elif name =='mobilenet':
if "resnet" in name:
n_layer = int(name.split('-')[1])
net, params = nnvm.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size)
elif "vgg" in name:
n_layer = int(name.split('-')[1])
net, params = nnvm.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size)
elif name == 'mobilenet':
net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size)
elif name =='squeezenet v1.1':
elif name == 'squeezenet_v1.1':
net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1')
elif name =='vgg-16':
net, params = nnvm.testing.vgg.get_workload(num_layers=16, batch_size=batch_size)
elif name =='custom':
elif name == 'inception_v3':
input_shape = (1, 3, 299, 299)
net, params = nnvm.testing.inception_v3.get_workload(batch_size=batch_size)
elif name == 'custom':
# an example for custom network
from nnvm.testing import utils
net = nnvm.sym.Variable('data')
......@@ -92,6 +97,7 @@ def get_network(name, batch_size):
return net, params, input_shape, output_shape
#################################################################
# Start RPC Tracker
# -----------------
......@@ -158,6 +164,8 @@ def get_network(name, batch_size):
# rk3399 2 2 0
# rpi3b 11 11 0
# ----------------------------------
#
# You can register multiple devices to the tracker to accelerate the measurement in tuning.
###########################################
# Set Tuning Options
......@@ -188,14 +196,16 @@ tuning_option = {
'tuner': 'xgb',
'n_trial': 1000,
'early_stopping': 250,
'early_stopping': 400,
'measure_option': autotvm.measure_option(
autotvm.measure.rpc(device_key, host='localhost', port=9190),
number=4,
n_parallel=1,
timeout=10,
build_func='ndk' if use_android else 'default',
builder=autotvm.LocalBuilder(
build_func='ndk' if use_android else 'default'),
runner=autotvm.RPCRunner(
device_key, host='localhost', port=9190,
number=5,
timeout=4,
),
),
}
......@@ -203,15 +213,9 @@ tuning_option = {
#
# .. note:: How to set tuning options
#
# In general, the default value provided here works well. It is the same
# value that we used to generate pre-tuned parameters.
# If you have multiple devices, you can set :code:`n_parallel` to
# the number of devices you have. (e.g. set it to 3 if you register 3 rk3399
# boards to the tracker).
# In general, the default value provided here works well.
# If you have large time budget, you can set :code:`n_trial`, :code:`early_stopping` larger,
# which makes the tuning run longer.
# If your device is very slow or a single conv2d operator in your network has large FLOPs,
# consider setting timeout larger.
#
###################################################################
......@@ -219,7 +223,7 @@ tuning_option = {
# ------------
# Now we can extract tuning tasks from the network and begin tuning.
# Here we provide a simple utility function to tune a list of tasks.
# This function is just an initial implementation which tune them in sequential order.
# This function is just an initial implementation which tunes them in sequential order.
# Later we will bring more sophisticated tuner scheduler.
# You can skip the implementation of this function for this tutorial.
......@@ -236,7 +240,9 @@ def tune_tasks(tasks,
try: # try winograd template
tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
tasks[i].target, tasks[i].target_host, 'winograd')
tasks.append(tsk)
input_channel = tsk.workload[1][1]
if input_channel >= 64:
tasks[i] = tsk
except Exception:
pass
......@@ -245,8 +251,8 @@ def tune_tasks(tasks,
if os.path.exists(tmp_log_file):
os.remove(tmp_log_file)
for i, tsk in enumerate(tasks):
prefix = "[Task %2d/%2d] " %(i+1, len(tasks))
for i, tsk in enumerate(reversed(tasks)):
prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# create tuner
if tuner == 'xgb' or tuner == 'xgb-rank':
......@@ -280,7 +286,7 @@ def tune_tasks(tasks,
########################################################################
# Finally we launch tuning jobs and evaluate the end-to-end performance.
def tune_and_evaluate():
def tune_and_evaluate(tuning_opt):
# extract workloads from nnvm graph
print("Extract tasks...")
net, params, input_shape, out_shape = get_network(network, batch_size=1)
......@@ -290,19 +296,18 @@ def tune_and_evaluate():
# run tuning tasks
print("Tuning...")
tune_tasks(tasks, **tuning_option)
tune_tasks(tasks, **tuning_opt)
# compile kernels with history best records
with autotvm.apply_history_best(log_file):
print("Compile...")
with nnvm.compiler.build_config(opt_level=2, add_pass=['AlterOpLayout']):
graph, lib, params = nnvm.compiler.build(
net, target=target,
shape={'data': input_shape}, params=params, dtype=dtype)
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
# export library
tmp = tempdir()
if tuning_option['measure_option']['build_func'] == 'ndk': # for android
if use_android:
from tvm.contrib import ndk
filename = "net.so"
lib.export_library(tmp.relpath(filename), ndk.create_shared)
......@@ -312,8 +317,7 @@ def tune_and_evaluate():
# upload module to device
print("Upload...")
remote = autotvm.measure.request_remote(device_key,
tracker_addr=('localhost', 9190),
remote = autotvm.measure.request_remote(device_key, 'localhost', 9190,
timeout=10000)
remote.upload(tmp.relpath(filename))
rlib = remote.load_module(filename)
......@@ -328,47 +332,44 @@ def tune_and_evaluate():
# evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10)
ftimer = module.module.time_evaluator("run", ctx, number=8, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run by yourself.
# tune_and_evaluate()
# tune_and_evaluate(tuning_option)
######################################################################
# Sample Output
# -------------
# The tuning needs to train xgboost models and use them for prediction.
# The tuning needs to compile many programs and extract feature from them.
# So a high performance CPU is recommended.
# It takes about 2 hours on a 32T AMD Ryzen CPU.
# One sample output is
# One sample output is listed below.
# It takes about 2 hours on a 32T AMD Ryzen Threadripper.
#
# .. code-block:: bash
#
# Extract tasks...
# Tuning...
# [Task 1/16] Current/Best: 18.85/ 19.67 GFLOPS | Progress: (353/1000) | 387.05 s Done.
# [Task 2/16] Current/Best: 16.10/ 23.50 GFLOPS | Progress: (444/1000) | 379.99 s Done.
# [Task 3/16] Current/Best: 5.49/ 13.96 GFLOPS | Progress: (610/1000) | 485.87 s Done.
# [Task 4/16] Current/Best: 10.07/ 20.48 GFLOPS | Progress: (430/1000) | 391.66 s Done.
# [Task 5/16] Current/Best: 11.50/ 15.50 GFLOPS | Progress: (374/1000) | 356.03 s Done.
# [Task 6/16] Current/Best: 10.76/ 23.77 GFLOPS | Progress: (526/1000) | 526.42 s Done.
# [Task 7/16] Current/Best: 12.71/ 22.03 GFLOPS | Progress: (341/1000) | 322.96 s Done.
# [Task 8/16] Current/Best: 8.60/ 17.91 GFLOPS | Progress: (272/1000) | 236.08 s Done.
# [Task 9/16] Current/Best: 15.37/ 23.62 GFLOPS | Progress: (275/1000) | 275.18 s Done.
# [Task 10/16] Current/Best: 6.62/ 23.01 GFLOPS | Progress: (330/1000) | 315.02 s Done.
# [Task 11/16] Current/Best: 1.85/ 21.39 GFLOPS | Progress: (281/1000) | 239.19 s Done.
# [Task 12/16] Current/Best: 15.41/ 24.02 GFLOPS | Progress: (258/1000) | 270.82 s Done.
# [Task 13/16] Current/Best: 17.96/ 25.79 GFLOPS | Progress: (380/1000) | 738.29 s Done.
# [Task 14/16] Current/Best: 14.81/ 31.17 GFLOPS | Progress: (413/1000) | 799.21 s Done.
# [Task 15/16] Current/Best: 24.39/ 40.97 GFLOPS | Progress: (355/1000) | 700.25 s Done.
# [Task 16/16] Current/Best: 9.42/ 49.90 GFLOPS | Progress: (348/1000) | 603.84 s Done.
# [Task 1/12] Current/Best: 22.37/ 52.19 GFLOPS | Progress: (544/1000) | 406.59 s Done.
# [Task 2/12] Current/Best: 6.51/ 18.77 GFLOPS | Progress: (608/1000) | 325.05 s Done.
# [Task 3/12] Current/Best: 4.67/ 24.87 GFLOPS | Progress: (480/1000) | 372.31 s Done.
# [Task 4/12] Current/Best: 11.35/ 46.83 GFLOPS | Progress: (736/1000) | 602.39 s Done.
# [Task 5/12] Current/Best: 1.01/ 19.80 GFLOPS | Progress: (448/1000) | 262.16 s Done.
# [Task 6/12] Current/Best: 2.47/ 23.76 GFLOPS | Progress: (672/1000) | 563.85 s Done.
# [Task 7/12] Current/Best: 14.57/ 33.97 GFLOPS | Progress: (544/1000) | 465.15 s Done.
# [Task 8/12] Current/Best: 1.13/ 17.65 GFLOPS | Progress: (576/1000) | 365.08 s Done.
# [Task 9/12] Current/Best: 14.45/ 22.66 GFLOPS | Progress: (928/1000) | 724.25 s Done.
# [Task 10/12] Current/Best: 3.22/ 15.36 GFLOPS | Progress: (864/1000) | 564.27 s Done.
# [Task 11/12] Current/Best: 11.03/ 32.23 GFLOPS | Progress: (736/1000) | 635.15 s Done.
# [Task 12/12] Current/Best: 8.00/ 21.65 GFLOPS | Progress: (1000/1000) | 1111.81 s Done.
# Compile...
# Upload...
# Evaluate inference time cost...
# Mean inference time (std dev): 157.29 ms (1.74 ms)
# Mean inference time (std dev): 162.59 ms (0.06 ms)
######################################################################
#
......
......@@ -271,9 +271,12 @@ print(task.config_space)
logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
# use local cpu, measure 5 times for every config to reduce variance
measure_option = autotvm.measure_option('local',
number=5)
# There are two steps for measuring a config: build and run.
# By default, we use all cpu cores to compile program. Then measure them sequentially.
# We measure 5 times and take average to reduce variance.
measure_option = autotvm.measure_option(
builder='local',
runner=autotvm.LocalRunner(number=5))
# begin tuning, log records to file `matmul.log`
tuner = autotvm.tuner.RandomTuner(task)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment