Commit 12839e6d by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Decouple build and run in measurement (#1661)

parent 38203a86
...@@ -16,6 +16,11 @@ tvm.autotvm.measure ...@@ -16,6 +16,11 @@ tvm.autotvm.measure
.. autofunction:: tvm.autotvm.measure.create_measure_batch .. autofunction:: tvm.autotvm.measure.create_measure_batch
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalBuilder
.. autoclass:: tvm.autotvm.measure.measure_methods.RPCRunner
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalRunner
tvm.autotvm.tuner tvm.autotvm.tuner
~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~
......
...@@ -22,7 +22,8 @@ from . import env ...@@ -22,7 +22,8 @@ from . import env
from . import tophub from . import tophub
# some shortcuts # some shortcuts
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \
LocalBuilder, LocalRunner, RPCRunner
from .tuner import callback from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \ from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, \ register_topi_compute, register_topi_schedule, \
......
"""Distributed executor infrastructure to scale up the tuning""" """Distributed executor infrastructure to scale up the tuning"""
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option, \
from .measure_methods import request_remote, check_remote, create_measure_batch, rpc create_measure_batch
from .measure_methods import LocalBuilder, LocalRunner, RPCRunner, request_remote
from .executor import Executor
from .local_executor import LocalExecutor from .local_executor import LocalExecutor
from .executor import Future, Executor
...@@ -37,7 +37,8 @@ def _execute_func(func, queue, args, kwargs): ...@@ -37,7 +37,8 @@ def _execute_func(func, queue, args, kwargs):
res = exc res = exc
queue.put(res) queue.put(res)
def timeout_monitor(queue, timeout, func, args, kwargs):
def call_with_timeout(queue, timeout, func, args, kwargs):
"""A wrapper to support timeout of a function call""" """A wrapper to support timeout of a function call"""
# start a new process for timeout (cannot use thread because we have c function) # start a new process for timeout (cannot use thread because we have c function)
...@@ -45,17 +46,12 @@ def timeout_monitor(queue, timeout, func, args, kwargs): ...@@ -45,17 +46,12 @@ def timeout_monitor(queue, timeout, func, args, kwargs):
p.start() p.start()
p.join(timeout=timeout) p.join(timeout=timeout)
alive = p.is_alive() queue.put(executor.TimeoutError())
kill_child_processes(p.pid) kill_child_processes(p.pid)
p.terminate() p.terminate()
p.join() p.join()
if alive:
queue.put(executor.TimeoutError())
else:
if queue.empty():
queue.put(executor.ExecutionError("Fatal error in local executor"))
class LocalFuture(executor.Future): class LocalFuture(executor.Future):
"""Local wrapper for the future """Local wrapper for the future
...@@ -134,7 +130,7 @@ class LocalExecutor(executor.Executor): ...@@ -134,7 +130,7 @@ class LocalExecutor(executor.Executor):
return LocalFutureNoFork(func(*args, **kwargs)) return LocalFutureNoFork(func(*args, **kwargs))
queue = Queue(2) queue = Queue(2)
process = Process(target=timeout_monitor, process = Process(target=call_with_timeout,
args=(queue, self.timeout, func, args, kwargs)) args=(queue, self.timeout, func, args, kwargs))
process.start() process.start()
return LocalFuture(process, queue) return LocalFuture(process, queue)
# pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name # pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
"""User facing API for specifying how to measure the generated code""" """User facing API for specifying how to measure the generated code"""
import multiprocessing
from collections import namedtuple from collections import namedtuple
class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])): class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
...@@ -16,6 +17,7 @@ class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])): ...@@ -16,6 +17,7 @@ class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
Specific configuration. Specific configuration.
""" """
class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost", "timestamp"])): class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost", "timestamp"])):
""" """
Stores all the results of a measurement Stores all the results of a measurement
...@@ -23,8 +25,8 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost" ...@@ -23,8 +25,8 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
Parameters Parameters
---------- ----------
costs: Array of float or Array of Exception costs: Array of float or Array of Exception
If no error occurs for this measurement, it is an array of measured running times. If no error occurs during measurement, it is an array of measured running times.
If some error occurs during the measurement, it is an array of the exception objections. If an error occurs during measurement, it is an array of the exception objections.
error_no: int error_no: int
Denote error type, defined by MeasureErrorNo Denote error type, defined by MeasureErrorNo
all_cost: float all_cost: float
...@@ -37,92 +39,185 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost" ...@@ -37,92 +39,185 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
class MeasureErrorNo(object): class MeasureErrorNo(object):
"""Error type for MeasureResult""" """Error type for MeasureResult"""
NO_ERROR = 0 # no error NO_ERROR = 0 # no error
INSTANTIATION_ERROR = 1 # error when calling template function INSTANTIATION_ERROR = 1 # actively detected error in instantiating a template with a config
COMPILE_HOST = 2 # error when compiling code on host (e.g. tvm.build) COMPILE_HOST = 2 # error when compiling code on host (e.g. tvm.build)
COMPILE_DEVICE = 3 # error when compiling code on device (e.g. opencl JIT on device) COMPILE_DEVICE = 3 # error when compiling code on device (e.g. OpenCL JIT on the device)
RUNTIME_DEVICE = 4 # error when run program on device RUNTIME_DEVICE = 4 # error when run program on device
WRONG_ANSWER = 5 # answer is wrong when compared to a golden output WRONG_ANSWER = 5 # answer is wrong when compared to a golden output
FLEET_ERROR = 6 # error of measure infrastructure BUILD_TIMEOUT = 6 # timeout during compilation
RUN_TIMEOUT = 7 # timeout during run
UNKNOWN_ERROR = 8 # unknown error
class Builder(object):
"""Builder that builds programs in tuning
def measure_option(measure_func, Parameters
number=1, ----------
repeat=1, timeout: float, optional
timeout=60, The timeout of a build task
n_parallel=1, n_parallel: int, optional
do_fork=True, The number of tasks submitted in parallel
build_func='default', By default it will use all cpu cores
check_correctness=False, """
replay_db=None): def __init__(self, timeout=10, n_parallel=None):
"""Configure how to do measurement self.timeout = timeout
self.n_parallel = n_parallel or multiprocessing.cpu_count()
self.build_kwargs = {}
self.task = None
def set_task(self, task, build_kwargs=None):
"""
Initialize for a new tuning task
Parameters
----------
task: Task
The tuning task
build_kwargs: dict, optional
The additional kwargs for build function
"""
self.task = task
self.build_kwargs = build_kwargs
def build(self, measure_inputs):
"""Build programs
Parameters
----------
measure_inputs: List of MeasureInput
The measure input
Returns
-------
build_results: List of BuildResult
The build result.
"""
raise NotImplementedError()
class Runner(object):
"""Runner that runs and measures the time cost of a generated program in tuning
Parameters Parameters
---------- ----------
measure_func: str or callable timeout: float, optional
'local': use the local device for measurement. The tuner will start a tracker The timeout of a build task
and a RPC server silently for the user.
callable: It is a callable function for measurement.
See the return value of measure/measure_methods.py::rpc for example.
number : int, optional
Number of times to do the measurement for average
repeat : int, optional
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
each of which is the average of `number` test run.
timeout: int, optional
Timeout for a whole batch. TimeoutError will be returned as the result if a
task timeouts.
n_parallel: int, optional n_parallel: int, optional
The number of measurement task that can run in parallel. The number of tasks submitted in parallel
Set this according to the number of cpu cores (for compilation) and By default it will use all cpu cores
the number of devices you have (for measuring generate code). """
do_fork: bool, optional def __init__(self, timeout=5, n_parallel=None):
Whether use multiprocessing (based on fork) for running measure jobs in parallel. self.timeout = timeout
Set this to False if you want to debug (see trackback) or using fork is not suitable. self.n_parallel = n_parallel or multiprocessing.cpu_count()
NOTE: If this is False, parallel and timeout do not work. self.task = None
build_func: str or callable, optional
'default': call default builder. This works for normal target (llvm, cuda) def set_task(self, task):
"""
'ndk': use Android NDK to create shared library. Use this for android target. Initialize for a new tuning task
callable: customized build function for other backends (e.g. VTA). Parameters
See measure/measure_methods.py::default_build_func for example. ----------
check_correctness: bool, optional task: Task
Whether check correctness after measurement. This will use llvm cpu target to generate The tuning task
reference output. """
replay_db : Database, optional self.task = task
The database that we retrieve saved MeasureResult from.
def get_build_kwargs(self):
"""
Get device specific build arguments (e.g. maximum shared memory size)
Returns
----------
kwargs: dict
The additional keyword arguments
"""
raise NotImplementedError()
def run(self, measure_inputs, build_results):
"""Run amd measure built programs
Parameters
----------
measure_inputs: List of MeasureInput
The raw measure input
build_results: List of BuildResults
The build results
Returns
-------
measure_results: List of MeasureResult
The final results of measurement
"""
raise NotImplementedError()
def measure_option(builder, runner):
"""
Set options for measure. To measure a config, we will build it and run it.
So we have to set options for these two steps.
They have their own options on timeout, parallel, etc.
Parameters
----------
builder: Builder
Specify how to build programs
runner: Runner
Specify how to run programs
"""
from .measure_methods import LocalBuilder, LocalRunner
if isinstance(builder, str):
if builder == 'local':
builder = LocalBuilder()
else:
raise ValueError("Invalid builder: " + builder)
if isinstance(runner, str):
if runner == 'local':
runner = LocalRunner()
else:
raise ValueError("Invalid runner: " + runner)
opt = {
'builder': builder,
'runner': runner,
}
return opt
def create_measure_batch(task, option):
"""Get a standard measure_batch function.
Parameters
----------
task: tvm.autotvm.task.Task
The tuning task
option: dict
The option for measuring generated code.
You should use the return value of function :any:`measure_option` for this argument.
Returns Returns
------- -------
options: dict measure_batch: callable
A dict to store all options a callback function to measure a batch of configs
Note
----
To support customized measure, you can pass callable `measure_func` or
`build_func` in. The `measure_func` will call `build_func` to build binary library
and handle the logic of measurement.
Signature:
* measure_func (see the return value of measure/measure_methods.py::rpc for example)
def measure_func(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output):
return measure_results
* build_func (see measure/measure_methods.py::default_build_func for example)
def build_func(inp, tmp_dir, **kwargs):
return func, args, filename
""" """
return { builder = option['builder']
'measure_func': measure_func, runner = option['runner']
'number': number,
'repeat': repeat, attach_objects = runner.set_task(task)
'timeout': timeout,
'n_parallel': n_parallel, # feed device related information from runner to builder
'do_fork': do_fork, # (e.g. max shared memory for validity checking)
'build_func': build_func, build_kwargs = runner.get_build_kwargs()
'check_correctness': check_correctness, builder.set_task(task, build_kwargs)
'replay_db': replay_db,
} def measure_batch(measure_inputs):
build_results = builder.build(measure_inputs)
results = runner.run(measure_inputs, build_results)
return results
measure_batch.n_parallel = builder.n_parallel
measure_batch.attach_objects = attach_objects
return measure_batch
# pylint: disable=consider-using-enumerate,invalid-name,too-many-function-args # pylint: disable=invalid-name,too-many-function-args,too-many-nested-blocks
""" """
Functions that run on executor for measurement. Functions that run on executor for measurement.
These functions are responsible for building tvm module, uploading it to
remote devices, recording the running time costs and checking the correctness of output These functions are responsible for building the tvm module, uploading it to
remote devices, recording the running time costs, and checking the correctness of the output.
""" """
import logging import logging
import shutil
import os import os
import threading
import time import time
from random import getrandbits from random import getrandbits
import threading from collections import namedtuple
import tempfile
import numpy as np import numpy as np
from ... import ir_pass, build, build_config, nd, context, TVMError, register_func, \ from ... import ir_pass, build, build_config, nd, TVMError, register_func, \
target as _target, rpc as _rpc rpc as _rpc, target as _target
from ...contrib import nvcc, util, ndk from ...contrib import nvcc, ndk
from ..util import get_const_tuple from ..util import get_const_tuple
from ..env import AutotvmGlobalScope from ..env import AutotvmGlobalScope
from ..task.space import InstantiationError from ..task.space import InstantiationError
from .measure import MeasureResult, MeasureErrorNo from .measure import MeasureResult, MeasureErrorNo, Builder, Runner
from .local_executor import LocalExecutor from .local_executor import LocalExecutor
logger = logging.getLogger('autotvm') logger = logging.getLogger('autotvm')
class HashMismatchError(ValueError): class BuildResult(namedtuple("BuildResult", ('filename', 'arg_info', 'error', 'time_cost'))):
"""Raised when the code hash of a submitted config doesn't match that on the """
measure side """ Stores all the necessary inputs for a measurement.
pass
Parameters
----------
filename : str
The filename of generated library
arg_info : Tuple
The shape and dtype information of tvm tensor arguments
error : Exception
The error happens during compilation.
time_cost : float
The time cost of building
"""
def request_remote(device_key, tracker_addr=None, priority=1, timeout=60): class LocalBuilder(Builder):
"""request a remote session """Run compilation on local machine
Parameters Parameters
---------- ----------
device_key: string timeout: float
device key of registered device in tracker The timeout of a compilation
tracker_addr: Tuple(string, int), optional n_parallel: int
The address of rpc tracker in (host, port) format. The number of tasks run in parallel. "None" will use all cpu cores
If is none, will use environment variable "TVM_TRACKER_HOST" build_func: callable or str
and "TVM_TRACKER_PORT" If is 'default', use default build function
priority: int, optional If is 'ndk', use function for android ndk
The priority of this request, larger is more prior If is callable, use it as custom build function
timeout: float, optional
The timeout of this session (units: seconds)
Returns
------
session: RPCSession
""" """
# connect to the tracker def __init__(self, timeout=10, n_parallel=None, build_func='default'):
if tracker_addr: super(LocalBuilder, self).__init__(timeout, n_parallel)
host = tracker_addr[0] or os.environ['TVM_TRACKER_HOST']
port = tracker_addr[1] or int(os.environ['TVM_TRACKER_PORT']) if isinstance(build_func, str):
else: if build_func == 'default':
host = os.environ['TVM_TRACKER_HOST'] build_func = default_build_func
port = int(os.environ['TVM_TRACKER_PORT']) elif build_func == 'ndk':
build_func = android_ndk_build_func
else:
raise ValueError("Invalid build_func" + build_func)
tracker = _rpc.connect_tracker(host, port) self.build_func = build_func
remote = tracker.request(device_key, priority=priority, self.tmp_dir = tempfile.mkdtemp()
session_timeout=timeout) self.executor = LocalExecutor(timeout=timeout)
return remote
def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10): def build(self, measure_inputs):
""" results = []
Check the availability of a remote device
for i in range(0, len(measure_inputs), self.n_parallel):
futures = []
for inp in measure_inputs[i:i + self.n_parallel]:
ret = self.executor.submit(self.build_func,
inp,
self.tmp_dir,
**self.build_kwargs)
futures.append(ret)
for future in futures:
res = future.get()
if isinstance(res, Exception):
# timeout or fleet error, return MeasureResult directly
results.append(MeasureResult((res,), MeasureErrorNo.BUILD_TIMEOUT,
self.timeout, time.time()))
elif res.error is not None:
# instantiation errorD
if isinstance(res.error, InstantiationError):
results.append(MeasureResult((res.error,),
MeasureErrorNo.INSTANTIATION_ERROR,
res.time_cost, time.time()))
else:
if "InstantiationError" in str(res.error):
msg = str(res.error)
try:
msg = msg.split('\n')[-2].split(": ")[1]
except Exception: # pylint: disable=broad-except
pass
results.append(MeasureResult((InstantiationError(msg),),
MeasureErrorNo.INSTANTIATION_ERROR,
res.time_cost, time.time()))
else: # tvm error
results.append(MeasureResult((res.error,),
MeasureErrorNo.COMPILE_HOST,
res.time_cost, time.time()))
else:
# return BuildResult
results.append(res)
return results
def __del__(self):
shutil.rmtree(self.tmp_dir)
class RPCRunner(Runner):
"""Run generated code on remove devices.
This function will ask a RPC Tracker to get device for measurement.
Parameters Parameters
---------- ----------
target: Target timeout: float
The wanted compilation target The timeout of a compilation
device_key: string n_parallel: int
device key of registered device in tracker The number of tasks run in parallel. "None" will use all cpu cores
tracker_addr: Tuple(string, int), optional key: str
The address of rpc tracker in (host, port) format. The key of the device registered in the tracker
If is none, will use environment variable "TVM_TRACKER_HOST" host: str
and "TVM_TRACKER_PORT" The host address of RPC Tracker
priority: int, optional port: int
The priority of this request, larger is more prior The port of RPC Tracker
timeout: float, optional number : int, optional
The timeout of this check (units: seconds). Number of times to do measurement for tasking average
If time is out, a RuntimeError will be raised. repeat : int, optional
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
min_repeat_ms : float, optional
Minimum duration of a timer measurement in milliseconds.
When the run time of a measurement trial falls below this time, the
`number` parameter will be automatically increased.
Set this to improve the accuracy of perf measurement, e.g., when timers
are not precise enough to capture short-running tasks. This parameter is
also critical when devices need a certain minimum running time to "warm
up," such as GPUs that need time to reach a performance power state.
cooldown_interval: float, optional
The cool down interval between two measurements.
check_correctness: bool, optional
Whether check correctness after measurement. This will use llvm cpu target to
call your template and get the reference output.
This can work for TOPI templates, but may not work for your custom template.
""" """
def _check(): def __init__(self,
remote = request_remote(device_key, tracker_addr, priority) key, host, port, priority=1,
remote.context(str(target)) timeout=10, n_parallel=None,
t = threading.Thread(target=_check,) number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1,
t.start() check_correctness=False):
t.join(timeout) super(RPCRunner, self).__init__(timeout, n_parallel)
return not t.is_alive()
self.key = key
self.host = host
self.port = port
self.priority = priority
self.timeout = timeout
self.number = number
self.repeat = repeat
self.min_repeat_ms = min_repeat_ms
self.cur_number = number
self.ref_input = None
self.ref_output = None
self.check_correctness = check_correctness
self.cooldown_interval = cooldown_interval
self.executor = LocalExecutor()
def set_task(self, task):
self.task = task
self.cur_number = self.number
if check_remote(task.target, self.key, self.host, self.port):
logger.info("Get devices for measurement successfully!")
else:
raise RuntimeError("Cannot get remote devices from the tracker. "
"Please check the status of tracker by "
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
"and make sure you have free devices on the queue status.")
def create_measure_batch(task, option): if self.check_correctness:
"""Get a standard measure_batch function. # use llvm cpu to generate a reference input/output
# this option works for tuning topi, but might not work for you custom op
with _target.create("llvm"):
s, arg_bufs = task.instantiate(task.config_space.get(0))
self.ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype)
for x in arg_bufs]
func = build(s, arg_bufs, "llvm")
tvm_buf = [nd.array(x) for x in self.ref_input]
func(*tvm_buf)
self.ref_output = [x.asnumpy() for x in tvm_buf]
def get_build_kwargs(self):
kwargs = {}
if 'cuda' in self.task.target.keys or 'opencl' in self.task.target.keys:
remote = request_remote(self.key, self.host, self.port)
ctx = remote.context(str(self.task.target), 0)
max_dims = ctx.max_thread_dimensions
kwargs['check_gpu'] = {
'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
'max_threads_per_block': ctx.max_threads_per_block,
'max_thread_x': max_dims[0],
'max_thread_y': max_dims[1],
'max_thread_z': max_dims[2],
}
if 'cuda' in self.task.target.keys:
kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.'))
return kwargs
def run(self, measure_inputs, build_results):
results = []
remote_args = (self.key, self.host, self.port, self.priority, self.timeout)
for i in range(0, len(measure_inputs), self.n_parallel):
futures = []
for measure_inp, build_res in zip(measure_inputs[i:i+self.n_parallel],
build_results[i:i+self.n_parallel]):
ret = self.executor.submit(run_through_rpc,
measure_inp,
build_res,
self.cur_number,
self.repeat,
self.cooldown_interval,
remote_args,
self.ref_input,
self.ref_output)
futures.append(ret)
for future in futures:
res = future.get()
if isinstance(res, Exception): # executor error or timeout
results.append(MeasureResult((str(res),), MeasureErrorNo.RUN_TIMEOUT,
self.timeout, time.time()))
else:
results.append(res)
# If some runs were too fast, do remeasure for them
# to meet the requirement of `min_repeat_ms`
remeasure = np.zeros((len(measure_inputs),), dtype=np.bool)
pre_number = next_number = self.cur_number
min_repeat_duration = self.min_repeat_ms / 1000.0
for i, res in enumerate(results):
if res.error_no == MeasureErrorNo.NO_ERROR:
if np.mean(res.costs) * pre_number <= min_repeat_duration:
next_number = max(next_number,
int(np.ceil(min_repeat_duration / np.mean(res.costs))))
remeasure[i] = True
if pre_number != next_number:
self.cur_number = next_number
msg = "increasing number to %d" % self.cur_number
logger.info(msg)
re_measure_inputs = [x for i, x in enumerate(measure_inputs) if remeasure[i]]
re_build_results = [x for i, x in enumerate(build_results) if remeasure[i]]
re_res = self.run(re_measure_inputs, re_build_results)
ct = 0
for i, rerun in enumerate(remeasure):
if rerun:
results[i] = re_res[ct]
ct += 1
return results
class LocalRunner(RPCRunner):
"""Run generated code on local devices.
Parameters Parameters
---------- ----------
task: tvm.autotvm.task.Task timeout: float
The tuning task The timeout of a compilation
option: dict number : int, optional
The option for measuring generated code. Number of times to do measurement for tasking average
You should use the return value of function :any:`measure_option` for this argument. repeat : int, optional
Number of times to repeat the measurement.
Returns In total, the generated code will be run (1 + number x repeat) times,
------- where the first one is warm up. The returned result contains `repeat` costs,
measure_batch: callable each of which is the average of `number` test run.
a callback function to measure a batch of configs min_repeat_ms : float, optional
Minimum duration of a timer measurement in milliseconds.
When the run time of a measurement trial falls below this time, the
`number` parameter will be automatically increased.
Set this to improve the accuracy of perf measurement, e.g., when timers
are not precise enough to capture short-running tasks. This parameter is
also critical when devices need a certain minimum running time to "warm
up," such as GPUs that need time to reach a performance power state.
cooldown_interval: float, optional
The cool down interval between two measurements.
check_correctness: bool, optional
Whether check correctness after measurement. This will use llvm cpu target to
call your template and get the reference output.
This can work for TOPI templates, but may not work for your custom template.
Note
----
This is a "fake" local mode. We start a silent rpc tracker and rpc server
for the user. In this way we reuse timeout/isolation mechanism in RPC infrastructure.
""" """
from ..database import filter_inputs def __init__(self,
timeout=10,
measure_func = option['measure_func'] number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1,
number, repeat = option['number'], option['repeat'] check_correctness=False):
timeout, n_parallel, do_fork = option['timeout'], option['n_parallel'], option['do_fork'] super(LocalRunner, self).__init__('', None, None, 0,
build_func = option['build_func'] timeout=timeout, n_parallel=1,
check_correctness = option['check_correctness'] number=number, repeat=repeat,
replay_db = option['replay_db'] min_repeat_ms=min_repeat_ms,
cooldown_interval=cooldown_interval,
check_correctness=check_correctness)
self.tracker = None
self.server = None
def set_task(self, task):
self.task = task
executor = LocalExecutor(timeout=timeout, do_fork=do_fork)
# convert convenient string to function object
attach_objects = None
if measure_func == 'local':
# start temporary rpc tracker and rpc server for the user
from ...rpc.tracker import Tracker from ...rpc.tracker import Tracker
from ...rpc.server import Server from ...rpc.server import Server
...@@ -133,360 +343,215 @@ def create_measure_batch(task, option): ...@@ -133,360 +343,215 @@ def create_measure_batch(task, option):
key=device_key, key=device_key,
use_popen=True, silent=True, use_popen=True, silent=True,
tracker_addr=(tracker.host, tracker.port)) tracker_addr=(tracker.host, tracker.port))
self.key = device_key
self.host = tracker.host
self.port = tracker.port
measure_func = rpc(device_key, tracker.host, tracker.port) super(LocalRunner, self).set_task(task)
attach_objects = (server, tracker) return server, tracker
build_kwargs = {}
if build_func == 'default':
build_func = default_build_func
if build_func == 'ndk':
build_func = default_build_func
build_kwargs['use_ndk'] = True
# check the availability of remote devices def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None):
if hasattr(measure_func, 'rpc_info'): """Common part for building a configuration"""
rpc_info = measure_func.rpc_info target, task, config = measure_input
if check_remote(task.target, rpc_info['key'], (rpc_info['host'], rpc_info['port'])):
logger.info("Get devices for measurement successfully!")
else:
raise RuntimeError("Cannot get remote devices from the tracker. "
"Please check the status of tracker by "
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
"and make sure you have free devices on the queue status.")
# add device info of cuda and opencl target with target:
if ('cuda' in task.target.keys or 'opencl' in task.target.keys) \ s, args = task.instantiate(config)
and hasattr(measure_func, 'rpc_info'):
rpc_info = measure_func.rpc_info # check invalidity of template and code hash consistency
add_gpu_target_info(task.target, rpc_info["key"], (rpc_info["host"], rpc_info["port"]), if not config.valid():
build_kwargs) raise InstantiationError(config.errors)
if check_correctness: opts = build_option or {}
# use llvm cpu to generate a reference input/output if check_gpu: # Add verify pass to filter out invalid configs in advance.
# this option works for tuning topi, but might not work for you custom op opts["add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))]
with _target.create("llvm"): if cuda_arch:
s, arg_bufs = task.instantiate(task.config_space.get(0)) set_cuda_target_arch(cuda_arch)
ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype)
for x in arg_bufs]
func = build(s, arg_bufs, "llvm")
tvm_buf = [nd.array(x) for x in ref_input]
func(*tvm_buf)
ref_output = [x.asnumpy() for x in tvm_buf]
else:
ref_input = ref_output = None
def measure_batch(measure_inputs):
"""measure the time cost for a batch of configs in real machines"""
if replay_db is not None:
partial_results, measure_inputs = \
filter_inputs(replay_db, measure_inputs, retry=False)
# launch measure jobs in parallel
pack_size = getattr(measure_func, "pack_size", 1) # measure `pack_size` inputs in one job
futures = []
for i in range(0, len(measure_inputs), pack_size):
input_pack = measure_inputs[i:i + pack_size]
ret = executor.submit(
measure_func,
input_pack,
build_func,
build_kwargs,
number,
repeat,
ref_input,
ref_output)
futures.append(ret)
# transform results
results = []
for future in futures:
result = future.get()
if isinstance(result, Exception):
tstamp = time.time()
results.extend([MeasureResult((result,), MeasureErrorNo.FLEET_ERROR,
timeout, tstamp)] * pack_size)
else:
results.extend(result)
if replay_db is not None:
result_idx = 0
for i in range(len(partial_results)):
if partial_results[i] is None:
partial_results[i] = results[result_idx]
result_idx += 1
return partial_results
return results
measure_batch.n_parallel = n_parallel with build_config(**opts):
# attach server and tracker object to avoid them of being garbage-collected func = build(s, args, target_host=task.target_host)
measure_batch.attach_objects = attach_objects return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
return measure_batch
def rpc(key, def default_build_func(measure_input, tmp_dir, **kwargs):
host=None,
port=None,
priority=1,
session_timeout=60,
pack_size=1):
""" """
Create a standard measure_func which uses RPC Tracker for measurement. Default build func. This can work for cuda, opencl, llvm backend
This measure_func will request a device from the RPC Tracker and
upload the built binary library to that device for measurement.
Parameters Parameters
---------- ----------
key: str measure_input: MeasureInput
The registered key of the device in tracker. The tuner will request devices for The input of measurement
measurement by this key. tmp_dir: str
host: str, optional The path of temporary directory to export generated library
The hostname of RPC Tracker. If not set, will use environment variable "TVM_TRACKER_HOST" """
port: int, optional tic = time.time()
The port of RPC Tracker. If not set, will use environment variable "TVM_TRACKER_PORT" try:
priority: int, optional filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
Priority of this task, used by scheduler in tracker func, arg_info = _build_func_common(measure_input, **kwargs)
session_timeout: int, optional func.export_library(filename)
Timeout of rpc session except Exception as e: # pylint: disable=broad-except
pack_size: int, optional return BuildResult(None, None, e, time.time() - tic)
The number of configs measure in one RPC session. return BuildResult(filename, arg_info, None, time.time() - tic)
Usually this can be set to 1. If your device has high overhead to establish a
rpc connection, set this higher.
def android_ndk_build_func(measure_input, tmp_dir, **kwargs):
"""
Build function for android device using ndk.
Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
""" """
def fmeasure(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output): tic = time.time()
"""Do measurement for a list of inputs inside a same RPC session. try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.so" % getrandbits(64))
Parameters func, arg_info = _build_func_common(measure_input, **kwargs)
---------- func.export_library(filename, ndk.create_shared)
input_pack: List of MeasureInput except Exception as e: # pylint: disable=broad-except
The inputs of measurement return BuildResult(None, None, e, time.time() - tic)
build_func: callable return BuildResult(filename, arg_info, None, time.time() - tic)
Function for building the code. see :any:`default_build_func` for example
build_kwargs: dict
Extra arguments for build_func def run_through_rpc(measure_input, build_result,
number : int, optional number, repeat, cooldown_interval,
Number of times to do the measurement for average remote_args, ref_input=None, ref_output=None):
repeat : int, optional """Run a generated library through rpc
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
each of which is the average of `number` test run.
ref_input: List of numpy array
Reference input for correctness check
ref_output: List of numpy array
Reference output for correctness check
Returns
-------
results: List of MeasureResult
The results for input_pack
"""
remote_args = (key, (host, port), priority, session_timeout)
res = _measure_common(input_pack, build_func, build_kwargs, number, repeat,
ref_input, ref_output,
remote_args)
return res
fmeasure.pack_size = pack_size
fmeasure.rpc_info = {"key": key, "host": host, "port": port}
return fmeasure
def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
ref_input=None, ref_output=None, remote_args=None):
"""Measure the time cost for a pack of inputs.
(Note: A pack is a list of inputs which will be measured inside a same RPC session)
Parameters Parameters
---------- ----------
input_pack : list of MeasureInput measure_input: MeasureInput
The inputs we need to evaluate The raw measure input
build_func : function takes MeasureInput returns tuple of (time_func, ctx, args) build_result: BuildResult
The build function used to build each input. The result returned from Builder. This contains the path to the generated library.
build_kwargs: Dict
The extra keyword arguments to build_func
number : int, optional number : int, optional
Number of times to do the measurement for average Number of times to do measurement for tasking average
repeat : int, optional repeat : int, optional
Number of times to repeat the measurement. Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times, In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs, where the first one is warm up. The returned result contains `repeat` costs,
each of which is the average of `number` test run. each of which is the average of `number` test run.
ref_input: Array of np.ndarray, optional cooldown_interval: float
Reference input for checking correctness The cool down interval between two measurements
ref_output: Array of np.ndarray, optional remote_args: Tuple
Reference output for checking correctness The argument for request_remote
remote_args: Tuple, optional ref_input: List of np.ndarray
The arguments to request_remote. If is not None, will use remote rpc devices. The reference input used for checking correctness
ref_output: List of np.ndarray
Returns The reference output used for checking correctness
-------
res_pack : Array of MeasureResult
The list of results of measurement.
""" """
res_pack = [] if isinstance(build_result, MeasureResult):
tmp_dir = util.tempdir() if remote_args else None return build_result
assert len(input_pack) == 1, "Only supports input_pack == 1 for now"
tic = time.time()
for inp in input_pack: errno = MeasureErrorNo.NO_ERROR
tic = time.time() try:
# upload built module
# build function remote = request_remote(*remote_args)
try: remote.upload(build_result.filename)
func, arg_bufs, filename = build_func(inp, tmp_dir, **build_kwargs) func = remote.load_module(os.path.split(build_result.filename)[1])
except TVMError as exc: ctx = remote.context(str(measure_input.target), 0)
tstamp = time.time() time_f = func.time_evaluator(
msg = str(exc) func.entry_name, ctx, number=number, repeat=repeat)
if "Stack trace returned" in msg:
msg = msg[:msg.index("Stack trace returned")] # set input
if "InstantiationError" in msg: if ref_input:
try: args = [nd.array(x, ctx=ctx) for x in ref_input]
msg = msg.split('\n')[-2].split(": ")[1] else:
except Exception: # pylint: disable=broad-except args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info]
pass
res_pack.append(MeasureResult((InstantiationError(msg),), costs = time_f(*args).results
MeasureErrorNo.INSTANTIATION_ERROR, if len(costs) > 2: # remove largest and smallest value to reduce variance
tstamp - tic, tstamp)) costs = list(costs)
else: costs.sort()
res_pack.append(MeasureResult((RuntimeError(msg),), costs = tuple(costs[1:-1])
MeasureErrorNo.COMPILE_HOST,
tstamp - tic, tstamp)) # check correctness of output
continue if ref_output:
except InstantiationError as e: for expected, real in zip(ref_output, args):
tstamp = time.time() if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
res_pack.append(MeasureResult((InstantiationError(str(e)),), logger.warning("Wrong Answer!")
MeasureErrorNo.INSTANTIATION_ERROR, errno = MeasureErrorNo.WRONG_ANSWER
tstamp - tic, tstamp)) except TVMError as exc:
continue msg = str(exc)
if "Stack trace returned" in msg:
# measure time msg = msg[:msg.index("Stack trace returned")]
errno = MeasureErrorNo.NO_ERROR if "CUDA Source" in msg:
try: msg = msg[:msg.index("CUDA Source")]
# upload built module costs = (RuntimeError(msg[:1024]),)
if remote_args: errno = MeasureErrorNo.RUNTIME_DEVICE
remote = request_remote(*remote_args) tstamp = time.time()
remote.upload(tmp_dir.relpath(filename)) time.sleep(cooldown_interval)
func = remote.load_module(filename) return MeasureResult(costs, errno, tstamp - tic + build_result.time_cost, tstamp)
ctx = remote.context(str(inp.target), 0)
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat) def request_remote(device_key, host=None, port=None, priority=1, timeout=60):
else: """Request a remote session
ctx = context(str(inp.target), 0)
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat)
# set input
if ref_input:
args = [nd.array(x, ctx=ctx) for x in ref_input]
else:
args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype, ctx=ctx)
for x in arg_bufs]
costs = time_f(*args).results
if len(costs) > 2: # remove largest and smallest value to reduce variance
costs = list(costs)
costs.sort()
costs = tuple(costs[1:-1])
# check correctness of output
if ref_output:
for expected, real in zip(ref_output, args):
if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
logger.warning("Wrong Answer!")
errno = MeasureErrorNo.WRONG_ANSWER
except TVMError as exc:
msg = str(exc)
if "Stack trace returned" in msg:
msg = msg[:msg.index("Stack trace returned")]
if "CUDA Source" in msg:
msg = msg[:msg.index("CUDA Source")]
costs = (RuntimeError(msg),)
errno = MeasureErrorNo.RUNTIME_DEVICE
tstamp = time.time()
res_pack.append(MeasureResult(costs, errno, tstamp - tic, tstamp))
return res_pack
def default_build_func(inp, tmp_dir=None, **kwargs):
"""Build function module. Exception will be raised when any error occurs
Parameters Parameters
---------- ----------
inp: MeasureInput device_key: string
The input of this measurement The device key of registered device in tracker
tmp_dir: tvm.contrib.util.TempDirectory, optional host: host, optional
The temporary directory for exporting built binary library. The host address of rpc tracker.
If is not None (in RPC mode), the library in this directory will be uploaded to If is none, will use environment variable "TVM_TRACKER_HOST"
remote devices. port: int, optional
kwargs: Dict, optional The port of rpc tracker.
Other extra arguments If is none, will use environment variable "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this session (units: second)
Returns Returns
------- ------
func: Function session: RPCSession
TVM built function. Typically this is the return value of tvm.build.
args: Array of Buffer or Tensor
The argument list for the function. Typically this is the second argument of tvm.build.
filename: str
The filename of the output build library
""" """
# build function # connect to the tracker
with inp.target: host = host or os.environ['TVM_TRACKER_HOST']
s, args = inp.task.instantiate(inp.config) port = port or int(os.environ['TVM_TRACKER_PORT'])
# check invalidity of template and code hash consistency tracker = _rpc.connect_tracker(host, port)
if not inp.config.valid(): remote = tracker.request(device_key, priority=priority,
raise InstantiationError(inp.config.errors) session_timeout=timeout)
code_hash = getattr(s, 'code_hash', None) return remote
if inp.config.code_hash != code_hash:
raise HashMismatchError('got {0:s}, expected {1:s}'
.format(str(inp.config.code_hash), str(code_hash)))
opts = {}
if "check_gpu" in kwargs: # Add verify pass to filter out invalid configs in advance.
opts["add_lower_pass"] = [(2, gpu_verify_pass(**kwargs['check_gpu']))]
if 'cuda_arch' in kwargs:
set_cuda_target_arch(kwargs['cuda_arch'])
with build_config(**opts):
func = build(s, args, target_host=inp.task.target_host)
# export library to temp directory def check_remote(target, device_key, host=None, port=None, priority=2, timeout=10):
if tmp_dir: """
if kwargs.get('use_ndk', False): # for Android NDK Check the availability of a remote device
filename = "tmp_func_%0x.so" % getrandbits(64)
func.export_library(tmp_dir.relpath(filename), ndk.create_shared)
else:
filename = "tmp_func_%0x.tar" % getrandbits(64)
func.export_library(tmp_dir.relpath(filename))
else:
filename = None
return func, args, filename
def add_gpu_target_info(target, device_key, rpc_tracker_addr, kwargs):
"""Add device info for gpu target.
The info will be used to check the validity of generated code."""
remote = request_remote(device_key, rpc_tracker_addr)
ctx = remote.context(str(target), 0)
max_dims = ctx.max_thread_dimensions
kwargs['check_gpu'] = {
'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
'max_threads_per_block': ctx.max_threads_per_block,
'max_thread_x': max_dims[0],
'max_thread_y': max_dims[1],
'max_thread_z': max_dims[2],
}
if 'cuda' in target.keys:
kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.'))
def set_cuda_target_arch(arch): Parameters
"""set target architecture of nvcc compiler""" ----------
AutotvmGlobalScope.current.cuda_target_arch = arch target: Target
The wanted compilation target
device_key: string
device key of registered device in tracker
host: host, optional
The host address of rpc tracker.
If is none, will use environment variable "TVM_TRACKER_HOST"
port: int, optional
The port address of rpc tracker.
If is none, will use environment variable "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this check (units: seconds).
Returns
-------
available: bool
True if can find available device
"""
def _check():
remote = request_remote(device_key, host, port, priority)
remote.context(str(target))
t = threading.Thread(target=_check,)
t.start()
t.join(timeout)
return not t.is_alive()
@register_func @register_func
...@@ -496,6 +561,17 @@ def tvm_callback_cuda_compile(code): ...@@ -496,6 +561,17 @@ def tvm_callback_cuda_compile(code):
return ptx return ptx
def set_cuda_target_arch(arch):
"""set target architecture of nvcc compiler
Parameters
----------
arch: str
The argument of nvcc -arch. (e.g. "sm_51", "sm_62")
"""
AutotvmGlobalScope.current.cuda_target_arch = arch
def gpu_verify_pass(**kwargs): def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel. """Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block. This pass will check memory usage and number of threads per block.
......
...@@ -22,7 +22,7 @@ class GATuner(Tuner): ...@@ -22,7 +22,7 @@ class GATuner(Tuner):
mutation_prob: float mutation_prob: float
probability of mutation of a knob in a gene probability of mutation of a knob in a gene
""" """
def __init__(self, task, pop_size, elite_num=3, mutation_prob=0.1): def __init__(self, task, pop_size=100, elite_num=3, mutation_prob=0.1):
super(GATuner, self).__init__(task) super(GATuner, self).__init__(task)
# algorithm configurations # algorithm configurations
......
...@@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer): ...@@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
new_scores = model.predict(new_points) new_scores = model.predict(new_points)
ac_prob = np.exp((new_scores - scores) / (t + 1e-2)) ac_prob = np.exp(np.minimum((new_scores - scores) / (t + 1e-5), 1))
ac_index = np.random.random(len(ac_prob)) < ac_prob ac_index = np.random.random(len(ac_prob)) < ac_prob
points[ac_index] = new_points[ac_index] points[ac_index] = new_points[ac_index]
......
...@@ -103,34 +103,7 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None): ...@@ -103,34 +103,7 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None):
target=target, target_host=target_host) target=target, target_host=target_host)
return task, target return task, target
def test_tuning():
def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()
def custom_measure(input_pack, build_func, build_args, number, repeat,
ref_input, ref_output):
from tvm.autotvm import MeasureResult
results = []
for inp in input_pack:
tic = time.time()
# do nothing
time.sleep(0.001)
results.append(MeasureResult([time.time() - tic], 0,
time.time() - tic, time.time()))
return results
measure_option = autotvm.measure_option(custom_measure)
logging.info("%s", task.config_space)
# new tuner and recorder
for tuner_class in [autotvm.tuner.RandomTuner, autotvm.tuner.GridSearchTuner]:
tuner = tuner_class(task)
tuner.tune(n_trial=10, measure_option=measure_option)
assert tuner.best_flops > 1
def test_tuning_with_measure():
def check(target, target_host): def check(target, target_host):
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
if not ctx.exist: if not ctx.exist:
...@@ -141,12 +114,12 @@ def test_tuning_with_measure(): ...@@ -141,12 +114,12 @@ def test_tuning_with_measure():
task, target = get_sample_task(target, target_host) task, target = get_sample_task(target, target_host)
logging.info("%s", task.config_space) logging.info("%s", task.config_space)
measure_option = autotvm.measure_option('local', measure_option = autotvm.measure_option(
timeout=4, autotvm.LocalBuilder(),
number=2) autotvm.LocalRunner())
tuner = RandomTuner(task) tuner = RandomTuner(task)
tuner.tune(n_trial=10, measure_option=measure_option) tuner.tune(n_trial=20, measure_option=measure_option)
check("cuda", None) check("cuda", None)
check("opencl", None) check("opencl", None)
...@@ -155,6 +128,4 @@ if __name__ == "__main__": ...@@ -155,6 +128,4 @@ if __name__ == "__main__":
# only print log when invoked from main # only print log when invoked from main
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
test_task_tuner_without_measurement() test_tuning()
test_tuning_with_measure()
...@@ -32,6 +32,25 @@ def matmul(N, L, M, dtype): ...@@ -32,6 +32,25 @@ def matmul(N, L, M, dtype):
return s, [A, B, C] return s, [A, B, C]
@autotvm.template
def bad_matmul(N, L, M, dtype):
if 'bad_device' in tvm.target.current_target().keys:
A = tvm.placeholder((N, L), name='A', dtype=dtype)
B = tvm.placeholder((L, M), name='B', dtype=dtype)
k = tvm.reduce_axis((0, L-1), name='k')
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
s = tvm.create_schedule(C.op)
# schedule
y, x = s[C].op.axis
cfg = autotvm.get_config()
cfg.define_split("tile_y", y, num_outputs=2)
cfg.define_split("tile_x", x, num_outputs=2)
return s, [A, B, C]
return matmul(N, L, M, dtype)
def get_sample_task(n=128): def get_sample_task(n=128):
"""return a sample task for testing""" """return a sample task for testing"""
target = tvm.target.create("llvm") target = tvm.target.create("llvm")
......
"""Test database""" """Test database"""
import copy import copy
import logging import logging
import time
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm import database from tvm.autotvm import database
from tvm.autotvm.measure.measure_methods import HashMismatchError from tvm.autotvm.record import encode, MeasureResult
from tvm.autotvm.record import encode, MeasureInput, MeasureResult
from test_autotvm_common import get_sample_task, get_sample_records from test_autotvm_common import get_sample_records
def test_save_load(): def test_save_load():
logging.info("test basic db load/save ...") logging.info("test basic db load/save ...")
...@@ -35,66 +29,6 @@ def test_save_load(): ...@@ -35,66 +29,6 @@ def test_save_load():
TRIAL_LIMIT = 2 TRIAL_LIMIT = 2
def test_db_filter():
logging.info("test db filter ...")
# Pick a GPU target because there are more likely to be failures/invalid configs
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
batch_size = 2
measure_option = autotvm.measure_option('local', do_fork=False, timeout=2)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
ct = 0
all_inputs = list()
all_results = list()
batches = list()
tuner = autotvm.tuner.RandomTuner(task)
while ct < TRIAL_LIMIT:
inputs = list()
for i in range(batch_size):
cfg = tuner.next_batch(1)[0]
inputs.append((MeasureInput(target, task, cfg)))
all_inputs.append(inputs[-1])
batches.append(inputs)
results = measure_batch(inputs)
all_results += results
ct += 1
del measure_batch
db = database.DummyDatabase()
db.flush()
# First setting, memoize one input at a time, check that each is saved and replayed
measure_option = autotvm.measure_option('local', do_fork=False, timeout=2, replay_db=db)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
for i in range(len(all_inputs)+1):
db.flush()
for j in range(i):
db.save(all_inputs[j], all_results[j])
for k in range(len(batches)):
batch = batches[k]
batch_result = measure_batch(batch)
for l in range(batch_size):
all_idx = k*batch_size + l
assert batch_result[l] is not None
if all_idx < i:
assert encode(batch[l], batch_result[l]) == encode(batch[l], all_results[all_idx]), \
"(no retry) EXPECTED MATCH, GOT MISMATCH"
else:
assert encode(batch[l], batch_result[l]) != encode(batch[l], all_results[all_idx]), \
"(no retry) EXPECTED MISMATCH, GOT MATCH"
del measure_batch
def test_db_hash(): def test_db_hash():
logging.info("test db hash check ...") logging.info("test db hash check ...")
inp1, res1 = get_sample_records(1)[0] inp1, res1 = get_sample_records(1)[0]
...@@ -149,89 +83,8 @@ def test_db_latest_all(): ...@@ -149,89 +83,8 @@ def test_db_latest_all():
assert encode(inp1, load4[1]) == encode(inp1, res2) assert encode(inp1, load4[1]) == encode(inp1, res2)
assert encode(inp1, load4[2]) == encode(inp1, res3) assert encode(inp1, load4[2]) == encode(inp1, res3)
def test_db_save_replay():
logging.info("test db save (from measure_batch) and replay ...")
_db = database.DummyDatabase()
_db.flush()
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option('local',
do_fork=False,
timeout=2,
replay_db=_db)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
batch_size = 2
ct = 0
all_inputs = list()
all_results = list()
batches = list()
tuner = autotvm.tuner.RandomTuner(task)
while ct < TRIAL_LIMIT:
inputs = list()
for i in range(batch_size):
cfg = tuner.next_batch(1)[0]
inputs.append((MeasureInput(target, task, cfg)))
all_inputs.append(inputs[-1])
batches.append(inputs)
results = measure_batch(inputs)
all_results += results
ct += 1
callback = autotvm.callback.log_to_database(_db)
callback(None, all_inputs, all_results)
assert len(_db.db.keys()) == batch_size * TRIAL_LIMIT, \
"%d vs %d" % (len(_db.db.keys()), batch_size * TRIAL_LIMIT)
all_results_2 = measure_batch(all_inputs)
all_results_3 = measure_batch(all_inputs)
for i in range(len(all_results)):
encr1 = encode(all_inputs[i], all_results[i])
encr2 = encode(all_inputs[i], all_results_2[i])
encr3 = encode(all_inputs[i], all_results_3[i])
assert encr1 == encr2, "EXPECTED MATCH WITH SAVE REPLAY (first replay), got MISMATCH"
assert encr2 == encr3, "EXPECTED MATCH WITH SAVE REPLAY (second replay), got MISMATCH"
del measure_batch
def test_check_hashmismatch():
logging.info("test hash mismatch check")
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option('local', do_fork=False)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
inputs = list()
cfg = task.config_space.get(np.random.randint(len(task.config_space)))
# notvalidh is not a valid CRC32 hash (not hex)
cfg.code_hash = 'notvalidh'
inputs.append((MeasureInput(target, task, cfg)))
try:
results = measure_batch(inputs)
assert False, "HashMismatchError should be raised"
except HashMismatchError:
pass
del measure_batch
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
test_save_load() test_save_load()
test_db_filter()
test_db_hash() test_db_hash()
test_db_latest_all() test_db_latest_all()
test_db_save_replay()
test_check_hashmismatch()
"""Test builder and runner"""
import logging
import time
import numpy as np
import tvm
from tvm import autotvm
from test_autotvm_common import get_sample_task, bad_matmul
from tvm.autotvm.measure.measure import Runner, MeasureResult, MeasureErrorNo
def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()
class DummyRunner(Runner):
def __init__(self):
super(DummyRunner, self).__init__(1, 1)
def run(self, measure_inputs, build_results):
return [MeasureResult((np.random.random(),), 0, 0.2, time.time())
for _ in range(len(measure_inputs))]
def get_build_kwargs(self):
return {}
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=DummyRunner()
)
logging.info("%s", task.config_space)
for tuner_class in [autotvm.tuner.RandomTuner,
autotvm.tuner.GridSearchTuner,
autotvm.tuner.GATuner,
autotvm.tuner.XGBTuner]:
tuner = tuner_class(task)
tuner.tune(n_trial=10, measure_option=measure_option)
assert tuner.best_flops > 1
def test_check_correctness():
task, target = get_sample_task()
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(check_correctness=True)
)
def _callback_correct(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
assert res.error_no == 0
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=2, measure_option=measure_option,
callbacks=[_callback_correct])
# a bad template
n = 128
target = tvm.target.create("llvm -device=bad_device")
task = autotvm.task.create(bad_matmul, args=(n, n, n, 'float32'), target=target)
def _callback_wrong(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
assert res.error_no == MeasureErrorNo.WRONG_ANSWER
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=2, measure_option=measure_option,
callbacks=[_callback_wrong])
def test_min_repeat_ms():
task, target = get_sample_task()
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(number=1, min_repeat_ms=100)
)
def _callback(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
if res.error_no != 0:
continue
assert 1000 * np.mean(res.costs) * \
measure_option['runner'].cur_number >= 100
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=5, measure_option=measure_option,
callbacks=[_callback])
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
test_task_tuner_without_measurement()
test_check_correctness()
test_min_repeat_ms()
...@@ -137,12 +137,15 @@ if __name__ == '__main__': ...@@ -137,12 +137,15 @@ if __name__ == '__main__':
print(task.config_space) print(task.config_space)
measure_option = autotvm.measure_option( measure_option = autotvm.measure_option(
measure_func='local', number=10, n_parallel=8, timeout=20) builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
)
log_name = 'gemm_int8.log' log_name = 'gemm_int8.log'
if DO_TUNING: if DO_TUNING:
tuner = autotvm.tuner.XGBTuner(task) tuner = autotvm.tuner.XGBTuner(task)
tuner.tune(n_trial=1000, measure_option=measure_option, tuner.tune(n_trial=1000, measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file(log_name)]) callbacks=[autotvm.callback.log_to_file(log_name)])
dispatch_context = autotvm.apply_history_best(log_name) dispatch_context = autotvm.apply_history_best(log_name)
best_config = dispatch_context.query(task.target, task.workload) best_config = dispatch_context.query(task.target, task.workload)
......
...@@ -164,12 +164,12 @@ task = autotvm.task.create(conv2d_no_batching, ...@@ -164,12 +164,12 @@ task = autotvm.task.create(conv2d_no_batching,
target='cuda') target='cuda')
print(task.config_space) print(task.config_space)
# use local gpu, measure 5 times for every config to reduce variance # use local gpu, measure 10 times for every config to reduce variance
# run 8 parallel threads for compilation # The timeout of compiling a program is 10 seconds, the timeout for running is 4 seconds
measure_option = autotvm.measure_option('local', measure_option = autotvm.measure_option(
number=5, builder=autotvm.LocalBuilder(),
n_parallel=8, runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
timeout=20) )
# begin tuning, log records to file `conv2d.log` # begin tuning, log records to file `conv2d.log`
tuner = autotvm.tuner.XGBTuner(task) tuner = autotvm.tuner.XGBTuner(task)
......
...@@ -65,15 +65,20 @@ def get_network(name, batch_size): ...@@ -65,15 +65,20 @@ def get_network(name, batch_size):
input_shape = (batch_size, 3, 224, 224) input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000) output_shape = (batch_size, 1000)
if name =='resnet-18': if "resnet" in name:
net, params = nnvm.testing.resnet.get_workload(num_layers=18, batch_size=batch_size) n_layer = int(name.split('-')[1])
elif name =='mobilenet': net, params = nnvm.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size)
elif "vgg" in name:
n_layer = int(name.split('-')[1])
net, params = nnvm.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size)
elif name == 'mobilenet':
net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size) net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size)
elif name =='squeezenet v1.1': elif name == 'squeezenet_v1.1':
net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1') net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1')
elif name =='vgg-16': elif name == 'inception_v3':
net, params = nnvm.testing.vgg.get_workload(num_layers=16, batch_size=batch_size) input_shape = (1, 3, 299, 299)
elif name =='custom': net, params = nnvm.testing.inception_v3.get_workload(batch_size=batch_size)
elif name == 'custom':
# an example for custom network # an example for custom network
from nnvm.testing import utils from nnvm.testing import utils
net = nnvm.sym.Variable('data') net = nnvm.sym.Variable('data')
...@@ -92,6 +97,7 @@ def get_network(name, batch_size): ...@@ -92,6 +97,7 @@ def get_network(name, batch_size):
return net, params, input_shape, output_shape return net, params, input_shape, output_shape
################################################################# #################################################################
# Start RPC Tracker # Start RPC Tracker
# ----------------- # -----------------
...@@ -158,6 +164,8 @@ def get_network(name, batch_size): ...@@ -158,6 +164,8 @@ def get_network(name, batch_size):
# rk3399 2 2 0 # rk3399 2 2 0
# rpi3b 11 11 0 # rpi3b 11 11 0
# ---------------------------------- # ----------------------------------
#
# You can register multiple devices to the tracker to accelerate the measurement in tuning.
########################################### ###########################################
# Set Tuning Options # Set Tuning Options
...@@ -184,34 +192,30 @@ log_file = "%s.%s.log" % (device_key, network) ...@@ -184,34 +192,30 @@ log_file = "%s.%s.log" % (device_key, network)
dtype = 'float32' dtype = 'float32'
tuning_option = { tuning_option = {
'log_filename': log_file, 'log_filename': log_file,
'tuner': 'xgb', 'tuner': 'xgb',
'n_trial': 1000, 'n_trial': 1000,
'early_stopping': 250, 'early_stopping': 400,
'measure_option': autotvm.measure_option( 'measure_option': autotvm.measure_option(
autotvm.measure.rpc(device_key, host='localhost', port=9190), builder=autotvm.LocalBuilder(
number=4, build_func='ndk' if use_android else 'default'),
n_parallel=1, runner=autotvm.RPCRunner(
timeout=10, device_key, host='localhost', port=9190,
build_func='ndk' if use_android else 'default', number=5,
), timeout=4,
),
),
} }
#################################################################### ####################################################################
# #
# .. note:: How to set tuning options # .. note:: How to set tuning options
# #
# In general, the default value provided here works well. It is the same # In general, the default value provided here works well.
# value that we used to generate pre-tuned parameters.
# If you have multiple devices, you can set :code:`n_parallel` to
# the number of devices you have. (e.g. set it to 3 if you register 3 rk3399
# boards to the tracker).
# If you have large time budget, you can set :code:`n_trial`, :code:`early_stopping` larger, # If you have large time budget, you can set :code:`n_trial`, :code:`early_stopping` larger,
# which makes the tuning run longer. # which makes the tuning run longer.
# If your device is very slow or a single conv2d operator in your network has large FLOPs,
# consider setting timeout larger.
# #
################################################################### ###################################################################
...@@ -219,7 +223,7 @@ tuning_option = { ...@@ -219,7 +223,7 @@ tuning_option = {
# ------------ # ------------
# Now we can extract tuning tasks from the network and begin tuning. # Now we can extract tuning tasks from the network and begin tuning.
# Here we provide a simple utility function to tune a list of tasks. # Here we provide a simple utility function to tune a list of tasks.
# This function is just an initial implementation which tune them in sequential order. # This function is just an initial implementation which tunes them in sequential order.
# Later we will bring more sophisticated tuner scheduler. # Later we will bring more sophisticated tuner scheduler.
# You can skip the implementation of this function for this tutorial. # You can skip the implementation of this function for this tutorial.
...@@ -236,7 +240,9 @@ def tune_tasks(tasks, ...@@ -236,7 +240,9 @@ def tune_tasks(tasks,
try: # try winograd template try: # try winograd template
tsk = autotvm.task.create(tasks[i].name, tasks[i].args, tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
tasks[i].target, tasks[i].target_host, 'winograd') tasks[i].target, tasks[i].target_host, 'winograd')
tasks.append(tsk) input_channel = tsk.workload[1][1]
if input_channel >= 64:
tasks[i] = tsk
except Exception: except Exception:
pass pass
...@@ -245,8 +251,8 @@ def tune_tasks(tasks, ...@@ -245,8 +251,8 @@ def tune_tasks(tasks,
if os.path.exists(tmp_log_file): if os.path.exists(tmp_log_file):
os.remove(tmp_log_file) os.remove(tmp_log_file)
for i, tsk in enumerate(tasks): for i, tsk in enumerate(reversed(tasks)):
prefix = "[Task %2d/%2d] " %(i+1, len(tasks)) prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# create tuner # create tuner
if tuner == 'xgb' or tuner == 'xgb-rank': if tuner == 'xgb' or tuner == 'xgb-rank':
...@@ -280,7 +286,7 @@ def tune_tasks(tasks, ...@@ -280,7 +286,7 @@ def tune_tasks(tasks,
######################################################################## ########################################################################
# Finally we launch tuning jobs and evaluate the end-to-end performance. # Finally we launch tuning jobs and evaluate the end-to-end performance.
def tune_and_evaluate(): def tune_and_evaluate(tuning_opt):
# extract workloads from nnvm graph # extract workloads from nnvm graph
print("Extract tasks...") print("Extract tasks...")
net, params, input_shape, out_shape = get_network(network, batch_size=1) net, params, input_shape, out_shape = get_network(network, batch_size=1)
...@@ -290,19 +296,18 @@ def tune_and_evaluate(): ...@@ -290,19 +296,18 @@ def tune_and_evaluate():
# run tuning tasks # run tuning tasks
print("Tuning...") print("Tuning...")
tune_tasks(tasks, **tuning_option) tune_tasks(tasks, **tuning_opt)
# compile kernels with history best records # compile kernels with history best records
with autotvm.apply_history_best(log_file): with autotvm.apply_history_best(log_file):
print("Compile...") print("Compile...")
with nnvm.compiler.build_config(opt_level=2, add_pass=['AlterOpLayout']): with nnvm.compiler.build_config(opt_level=2, add_pass=['AlterOpLayout']):
graph, lib, params = nnvm.compiler.build( graph, lib, params = nnvm.compiler.build(
net, target=target, net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
shape={'data': input_shape}, params=params, dtype=dtype)
# export library # export library
tmp = tempdir() tmp = tempdir()
if tuning_option['measure_option']['build_func'] == 'ndk': # for android if use_android:
from tvm.contrib import ndk from tvm.contrib import ndk
filename = "net.so" filename = "net.so"
lib.export_library(tmp.relpath(filename), ndk.create_shared) lib.export_library(tmp.relpath(filename), ndk.create_shared)
...@@ -312,8 +317,7 @@ def tune_and_evaluate(): ...@@ -312,8 +317,7 @@ def tune_and_evaluate():
# upload module to device # upload module to device
print("Upload...") print("Upload...")
remote = autotvm.measure.request_remote(device_key, remote = autotvm.measure.request_remote(device_key, 'localhost', 9190,
tracker_addr=('localhost', 9190),
timeout=10000) timeout=10000)
remote.upload(tmp.relpath(filename)) remote.upload(tmp.relpath(filename))
rlib = remote.load_module(filename) rlib = remote.load_module(filename)
...@@ -328,47 +332,44 @@ def tune_and_evaluate(): ...@@ -328,47 +332,44 @@ def tune_and_evaluate():
# evaluate # evaluate
print("Evaluate inference time cost...") print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10) ftimer = module.module.time_evaluator("run", ctx, number=8, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" % print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res))) (np.mean(prof_res), np.std(prof_res)))
# We do not run the tuning in our webpage server since it takes too long. # We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run by yourself. # Uncomment the following line to run by yourself.
# tune_and_evaluate()
# tune_and_evaluate(tuning_option)
###################################################################### ######################################################################
# Sample Output # Sample Output
# ------------- # -------------
# The tuning needs to train xgboost models and use them for prediction. # The tuning needs to compile many programs and extract feature from them.
# So a high performance CPU is recommended. # So a high performance CPU is recommended.
# It takes about 2 hours on a 32T AMD Ryzen CPU. # One sample output is listed below.
# One sample output is # It takes about 2 hours on a 32T AMD Ryzen Threadripper.
# #
# .. code-block:: bash # .. code-block:: bash
# #
# Extract tasks... # Extract tasks...
# Tuning... # Tuning...
# [Task 1/16] Current/Best: 18.85/ 19.67 GFLOPS | Progress: (353/1000) | 387.05 s Done. # [Task 1/12] Current/Best: 22.37/ 52.19 GFLOPS | Progress: (544/1000) | 406.59 s Done.
# [Task 2/16] Current/Best: 16.10/ 23.50 GFLOPS | Progress: (444/1000) | 379.99 s Done. # [Task 2/12] Current/Best: 6.51/ 18.77 GFLOPS | Progress: (608/1000) | 325.05 s Done.
# [Task 3/16] Current/Best: 5.49/ 13.96 GFLOPS | Progress: (610/1000) | 485.87 s Done. # [Task 3/12] Current/Best: 4.67/ 24.87 GFLOPS | Progress: (480/1000) | 372.31 s Done.
# [Task 4/16] Current/Best: 10.07/ 20.48 GFLOPS | Progress: (430/1000) | 391.66 s Done. # [Task 4/12] Current/Best: 11.35/ 46.83 GFLOPS | Progress: (736/1000) | 602.39 s Done.
# [Task 5/16] Current/Best: 11.50/ 15.50 GFLOPS | Progress: (374/1000) | 356.03 s Done. # [Task 5/12] Current/Best: 1.01/ 19.80 GFLOPS | Progress: (448/1000) | 262.16 s Done.
# [Task 6/16] Current/Best: 10.76/ 23.77 GFLOPS | Progress: (526/1000) | 526.42 s Done. # [Task 6/12] Current/Best: 2.47/ 23.76 GFLOPS | Progress: (672/1000) | 563.85 s Done.
# [Task 7/16] Current/Best: 12.71/ 22.03 GFLOPS | Progress: (341/1000) | 322.96 s Done. # [Task 7/12] Current/Best: 14.57/ 33.97 GFLOPS | Progress: (544/1000) | 465.15 s Done.
# [Task 8/16] Current/Best: 8.60/ 17.91 GFLOPS | Progress: (272/1000) | 236.08 s Done. # [Task 8/12] Current/Best: 1.13/ 17.65 GFLOPS | Progress: (576/1000) | 365.08 s Done.
# [Task 9/16] Current/Best: 15.37/ 23.62 GFLOPS | Progress: (275/1000) | 275.18 s Done. # [Task 9/12] Current/Best: 14.45/ 22.66 GFLOPS | Progress: (928/1000) | 724.25 s Done.
# [Task 10/16] Current/Best: 6.62/ 23.01 GFLOPS | Progress: (330/1000) | 315.02 s Done. # [Task 10/12] Current/Best: 3.22/ 15.36 GFLOPS | Progress: (864/1000) | 564.27 s Done.
# [Task 11/16] Current/Best: 1.85/ 21.39 GFLOPS | Progress: (281/1000) | 239.19 s Done. # [Task 11/12] Current/Best: 11.03/ 32.23 GFLOPS | Progress: (736/1000) | 635.15 s Done.
# [Task 12/16] Current/Best: 15.41/ 24.02 GFLOPS | Progress: (258/1000) | 270.82 s Done. # [Task 12/12] Current/Best: 8.00/ 21.65 GFLOPS | Progress: (1000/1000) | 1111.81 s Done.
# [Task 13/16] Current/Best: 17.96/ 25.79 GFLOPS | Progress: (380/1000) | 738.29 s Done.
# [Task 14/16] Current/Best: 14.81/ 31.17 GFLOPS | Progress: (413/1000) | 799.21 s Done.
# [Task 15/16] Current/Best: 24.39/ 40.97 GFLOPS | Progress: (355/1000) | 700.25 s Done.
# [Task 16/16] Current/Best: 9.42/ 49.90 GFLOPS | Progress: (348/1000) | 603.84 s Done.
# Compile... # Compile...
# Upload... # Upload...
# Evaluate inference time cost... # Evaluate inference time cost...
# Mean inference time (std dev): 157.29 ms (1.74 ms) # Mean inference time (std dev): 162.59 ms (0.06 ms)
###################################################################### ######################################################################
# #
......
...@@ -271,9 +271,12 @@ print(task.config_space) ...@@ -271,9 +271,12 @@ print(task.config_space)
logging.getLogger('autotvm').setLevel(logging.DEBUG) logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
# use local cpu, measure 5 times for every config to reduce variance # There are two steps for measuring a config: build and run.
measure_option = autotvm.measure_option('local', # By default, we use all cpu cores to compile program. Then measure them sequentially.
number=5) # We measure 5 times and take average to reduce variance.
measure_option = autotvm.measure_option(
builder='local',
runner=autotvm.LocalRunner(number=5))
# begin tuning, log records to file `matmul.log` # begin tuning, log records to file `matmul.log`
tuner = autotvm.tuner.RandomTuner(task) tuner = autotvm.tuner.RandomTuner(task)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment