Commit 12839e6d by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Decouple build and run in measurement (#1661)

parent 38203a86
...@@ -16,6 +16,11 @@ tvm.autotvm.measure ...@@ -16,6 +16,11 @@ tvm.autotvm.measure
.. autofunction:: tvm.autotvm.measure.create_measure_batch .. autofunction:: tvm.autotvm.measure.create_measure_batch
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalBuilder
.. autoclass:: tvm.autotvm.measure.measure_methods.RPCRunner
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalRunner
tvm.autotvm.tuner tvm.autotvm.tuner
~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~
......
...@@ -22,7 +22,8 @@ from . import env ...@@ -22,7 +22,8 @@ from . import env
from . import tophub from . import tophub
# some shortcuts # some shortcuts
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \
LocalBuilder, LocalRunner, RPCRunner
from .tuner import callback from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \ from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, \ register_topi_compute, register_topi_schedule, \
......
"""Distributed executor infrastructure to scale up the tuning""" """Distributed executor infrastructure to scale up the tuning"""
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option, \
from .measure_methods import request_remote, check_remote, create_measure_batch, rpc create_measure_batch
from .measure_methods import LocalBuilder, LocalRunner, RPCRunner, request_remote
from .executor import Executor
from .local_executor import LocalExecutor from .local_executor import LocalExecutor
from .executor import Future, Executor
...@@ -37,7 +37,8 @@ def _execute_func(func, queue, args, kwargs): ...@@ -37,7 +37,8 @@ def _execute_func(func, queue, args, kwargs):
res = exc res = exc
queue.put(res) queue.put(res)
def timeout_monitor(queue, timeout, func, args, kwargs):
def call_with_timeout(queue, timeout, func, args, kwargs):
"""A wrapper to support timeout of a function call""" """A wrapper to support timeout of a function call"""
# start a new process for timeout (cannot use thread because we have c function) # start a new process for timeout (cannot use thread because we have c function)
...@@ -45,17 +46,12 @@ def timeout_monitor(queue, timeout, func, args, kwargs): ...@@ -45,17 +46,12 @@ def timeout_monitor(queue, timeout, func, args, kwargs):
p.start() p.start()
p.join(timeout=timeout) p.join(timeout=timeout)
alive = p.is_alive() queue.put(executor.TimeoutError())
kill_child_processes(p.pid) kill_child_processes(p.pid)
p.terminate() p.terminate()
p.join() p.join()
if alive:
queue.put(executor.TimeoutError())
else:
if queue.empty():
queue.put(executor.ExecutionError("Fatal error in local executor"))
class LocalFuture(executor.Future): class LocalFuture(executor.Future):
"""Local wrapper for the future """Local wrapper for the future
...@@ -134,7 +130,7 @@ class LocalExecutor(executor.Executor): ...@@ -134,7 +130,7 @@ class LocalExecutor(executor.Executor):
return LocalFutureNoFork(func(*args, **kwargs)) return LocalFutureNoFork(func(*args, **kwargs))
queue = Queue(2) queue = Queue(2)
process = Process(target=timeout_monitor, process = Process(target=call_with_timeout,
args=(queue, self.timeout, func, args, kwargs)) args=(queue, self.timeout, func, args, kwargs))
process.start() process.start()
return LocalFuture(process, queue) return LocalFuture(process, queue)
# pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name # pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
"""User facing API for specifying how to measure the generated code""" """User facing API for specifying how to measure the generated code"""
import multiprocessing
from collections import namedtuple from collections import namedtuple
class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])): class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
...@@ -16,6 +17,7 @@ class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])): ...@@ -16,6 +17,7 @@ class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
Specific configuration. Specific configuration.
""" """
class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost", "timestamp"])): class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost", "timestamp"])):
""" """
Stores all the results of a measurement Stores all the results of a measurement
...@@ -23,8 +25,8 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost" ...@@ -23,8 +25,8 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
Parameters Parameters
---------- ----------
costs: Array of float or Array of Exception costs: Array of float or Array of Exception
If no error occurs for this measurement, it is an array of measured running times. If no error occurs during measurement, it is an array of measured running times.
If some error occurs during the measurement, it is an array of the exception objections. If an error occurs during measurement, it is an array of the exception objections.
error_no: int error_no: int
Denote error type, defined by MeasureErrorNo Denote error type, defined by MeasureErrorNo
all_cost: float all_cost: float
...@@ -37,92 +39,185 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost" ...@@ -37,92 +39,185 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
class MeasureErrorNo(object): class MeasureErrorNo(object):
"""Error type for MeasureResult""" """Error type for MeasureResult"""
NO_ERROR = 0 # no error NO_ERROR = 0 # no error
INSTANTIATION_ERROR = 1 # error when calling template function INSTANTIATION_ERROR = 1 # actively detected error in instantiating a template with a config
COMPILE_HOST = 2 # error when compiling code on host (e.g. tvm.build) COMPILE_HOST = 2 # error when compiling code on host (e.g. tvm.build)
COMPILE_DEVICE = 3 # error when compiling code on device (e.g. opencl JIT on device) COMPILE_DEVICE = 3 # error when compiling code on device (e.g. OpenCL JIT on the device)
RUNTIME_DEVICE = 4 # error when run program on device RUNTIME_DEVICE = 4 # error when run program on device
WRONG_ANSWER = 5 # answer is wrong when compared to a golden output WRONG_ANSWER = 5 # answer is wrong when compared to a golden output
FLEET_ERROR = 6 # error of measure infrastructure BUILD_TIMEOUT = 6 # timeout during compilation
RUN_TIMEOUT = 7 # timeout during run
UNKNOWN_ERROR = 8 # unknown error
def measure_option(measure_func, class Builder(object):
number=1, """Builder that builds programs in tuning
repeat=1,
timeout=60,
n_parallel=1,
do_fork=True,
build_func='default',
check_correctness=False,
replay_db=None):
"""Configure how to do measurement
Parameters Parameters
---------- ----------
measure_func: str or callable timeout: float, optional
'local': use the local device for measurement. The tuner will start a tracker The timeout of a build task
and a RPC server silently for the user.
callable: It is a callable function for measurement.
See the return value of measure/measure_methods.py::rpc for example.
number : int, optional
Number of times to do the measurement for average
repeat : int, optional
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
each of which is the average of `number` test run.
timeout: int, optional
Timeout for a whole batch. TimeoutError will be returned as the result if a
task timeouts.
n_parallel: int, optional n_parallel: int, optional
The number of measurement task that can run in parallel. The number of tasks submitted in parallel
Set this according to the number of cpu cores (for compilation) and By default it will use all cpu cores
the number of devices you have (for measuring generate code). """
do_fork: bool, optional def __init__(self, timeout=10, n_parallel=None):
Whether use multiprocessing (based on fork) for running measure jobs in parallel. self.timeout = timeout
Set this to False if you want to debug (see trackback) or using fork is not suitable. self.n_parallel = n_parallel or multiprocessing.cpu_count()
NOTE: If this is False, parallel and timeout do not work. self.build_kwargs = {}
build_func: str or callable, optional self.task = None
'default': call default builder. This works for normal target (llvm, cuda)
def set_task(self, task, build_kwargs=None):
'ndk': use Android NDK to create shared library. Use this for android target. """
Initialize for a new tuning task
callable: customized build function for other backends (e.g. VTA).
See measure/measure_methods.py::default_build_func for example. Parameters
check_correctness: bool, optional ----------
Whether check correctness after measurement. This will use llvm cpu target to generate task: Task
reference output. The tuning task
replay_db : Database, optional build_kwargs: dict, optional
The database that we retrieve saved MeasureResult from. The additional kwargs for build function
"""
self.task = task
self.build_kwargs = build_kwargs
def build(self, measure_inputs):
"""Build programs
Parameters
----------
measure_inputs: List of MeasureInput
The measure input
Returns Returns
------- -------
options: dict build_results: List of BuildResult
A dict to store all options The build result.
"""
Note raise NotImplementedError()
----
To support customized measure, you can pass callable `measure_func` or
`build_func` in. The `measure_func` will call `build_func` to build binary library class Runner(object):
and handle the logic of measurement. """Runner that runs and measures the time cost of a generated program in tuning
Signature: Parameters
* measure_func (see the return value of measure/measure_methods.py::rpc for example) ----------
def measure_func(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output): timeout: float, optional
return measure_results The timeout of a build task
n_parallel: int, optional
* build_func (see measure/measure_methods.py::default_build_func for example) The number of tasks submitted in parallel
def build_func(inp, tmp_dir, **kwargs): By default it will use all cpu cores
return func, args, filename """
""" def __init__(self, timeout=5, n_parallel=None):
return { self.timeout = timeout
'measure_func': measure_func, self.n_parallel = n_parallel or multiprocessing.cpu_count()
'number': number, self.task = None
'repeat': repeat,
'timeout': timeout, def set_task(self, task):
'n_parallel': n_parallel, """
'do_fork': do_fork, Initialize for a new tuning task
'build_func': build_func,
'check_correctness': check_correctness, Parameters
'replay_db': replay_db, ----------
task: Task
The tuning task
"""
self.task = task
def get_build_kwargs(self):
"""
Get device specific build arguments (e.g. maximum shared memory size)
Returns
----------
kwargs: dict
The additional keyword arguments
"""
raise NotImplementedError()
def run(self, measure_inputs, build_results):
"""Run amd measure built programs
Parameters
----------
measure_inputs: List of MeasureInput
The raw measure input
build_results: List of BuildResults
The build results
Returns
-------
measure_results: List of MeasureResult
The final results of measurement
"""
raise NotImplementedError()
def measure_option(builder, runner):
"""
Set options for measure. To measure a config, we will build it and run it.
So we have to set options for these two steps.
They have their own options on timeout, parallel, etc.
Parameters
----------
builder: Builder
Specify how to build programs
runner: Runner
Specify how to run programs
"""
from .measure_methods import LocalBuilder, LocalRunner
if isinstance(builder, str):
if builder == 'local':
builder = LocalBuilder()
else:
raise ValueError("Invalid builder: " + builder)
if isinstance(runner, str):
if runner == 'local':
runner = LocalRunner()
else:
raise ValueError("Invalid runner: " + runner)
opt = {
'builder': builder,
'runner': runner,
} }
return opt
def create_measure_batch(task, option):
"""Get a standard measure_batch function.
Parameters
----------
task: tvm.autotvm.task.Task
The tuning task
option: dict
The option for measuring generated code.
You should use the return value of function :any:`measure_option` for this argument.
Returns
-------
measure_batch: callable
a callback function to measure a batch of configs
"""
builder = option['builder']
runner = option['runner']
attach_objects = runner.set_task(task)
# feed device related information from runner to builder
# (e.g. max shared memory for validity checking)
build_kwargs = runner.get_build_kwargs()
builder.set_task(task, build_kwargs)
def measure_batch(measure_inputs):
build_results = builder.build(measure_inputs)
results = runner.run(measure_inputs, build_results)
return results
measure_batch.n_parallel = builder.n_parallel
measure_batch.attach_objects = attach_objects
return measure_batch
# pylint: disable=consider-using-enumerate,invalid-name,too-many-function-args # pylint: disable=invalid-name,too-many-function-args,too-many-nested-blocks
""" """
Functions that run on executor for measurement. Functions that run on executor for measurement.
These functions are responsible for building tvm module, uploading it to
remote devices, recording the running time costs and checking the correctness of output These functions are responsible for building the tvm module, uploading it to
remote devices, recording the running time costs, and checking the correctness of the output.
""" """
import logging import logging
import shutil
import os import os
import threading
import time import time
from random import getrandbits from random import getrandbits
import threading from collections import namedtuple
import tempfile
import numpy as np import numpy as np
from ... import ir_pass, build, build_config, nd, context, TVMError, register_func, \ from ... import ir_pass, build, build_config, nd, TVMError, register_func, \
target as _target, rpc as _rpc rpc as _rpc, target as _target
from ...contrib import nvcc, util, ndk from ...contrib import nvcc, ndk
from ..util import get_const_tuple from ..util import get_const_tuple
from ..env import AutotvmGlobalScope from ..env import AutotvmGlobalScope
from ..task.space import InstantiationError from ..task.space import InstantiationError
from .measure import MeasureResult, MeasureErrorNo from .measure import MeasureResult, MeasureErrorNo, Builder, Runner
from .local_executor import LocalExecutor from .local_executor import LocalExecutor
logger = logging.getLogger('autotvm') logger = logging.getLogger('autotvm')
class HashMismatchError(ValueError): class BuildResult(namedtuple("BuildResult", ('filename', 'arg_info', 'error', 'time_cost'))):
"""Raised when the code hash of a submitted config doesn't match that on the """
measure side """ Stores all the necessary inputs for a measurement.
pass
def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
"""request a remote session
Parameters Parameters
---------- ----------
device_key: string filename : str
device key of registered device in tracker The filename of generated library
tracker_addr: Tuple(string, int), optional arg_info : Tuple
The address of rpc tracker in (host, port) format. The shape and dtype information of tvm tensor arguments
If is none, will use environment variable "TVM_TRACKER_HOST" error : Exception
and "TVM_TRACKER_PORT" The error happens during compilation.
priority: int, optional time_cost : float
The priority of this request, larger is more prior The time cost of building
timeout: float, optional
The timeout of this session (units: seconds)
Returns
------
session: RPCSession
""" """
# connect to the tracker
if tracker_addr:
host = tracker_addr[0] or os.environ['TVM_TRACKER_HOST']
port = tracker_addr[1] or int(os.environ['TVM_TRACKER_PORT'])
else:
host = os.environ['TVM_TRACKER_HOST']
port = int(os.environ['TVM_TRACKER_PORT'])
tracker = _rpc.connect_tracker(host, port) class LocalBuilder(Builder):
remote = tracker.request(device_key, priority=priority, """Run compilation on local machine
session_timeout=timeout)
return remote
def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10):
"""
Check the availability of a remote device
Parameters Parameters
---------- ----------
target: Target timeout: float
The wanted compilation target The timeout of a compilation
device_key: string n_parallel: int
device key of registered device in tracker The number of tasks run in parallel. "None" will use all cpu cores
tracker_addr: Tuple(string, int), optional build_func: callable or str
The address of rpc tracker in (host, port) format. If is 'default', use default build function
If is none, will use environment variable "TVM_TRACKER_HOST" If is 'ndk', use function for android ndk
and "TVM_TRACKER_PORT" If is callable, use it as custom build function
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this check (units: seconds).
If time is out, a RuntimeError will be raised.
""" """
def _check(): def __init__(self, timeout=10, n_parallel=None, build_func='default'):
remote = request_remote(device_key, tracker_addr, priority) super(LocalBuilder, self).__init__(timeout, n_parallel)
remote.context(str(target))
t = threading.Thread(target=_check,)
t.start()
t.join(timeout)
return not t.is_alive()
def create_measure_batch(task, option): if isinstance(build_func, str):
"""Get a standard measure_batch function. if build_func == 'default':
build_func = default_build_func
elif build_func == 'ndk':
build_func = android_ndk_build_func
else:
raise ValueError("Invalid build_func" + build_func)
Parameters self.build_func = build_func
---------- self.tmp_dir = tempfile.mkdtemp()
task: tvm.autotvm.task.Task self.executor = LocalExecutor(timeout=timeout)
The tuning task
option: dict
The option for measuring generated code.
You should use the return value of function :any:`measure_option` for this argument.
Returns def build(self, measure_inputs):
------- results = []
measure_batch: callable
a callback function to measure a batch of configs
"""
from ..database import filter_inputs
measure_func = option['measure_func'] for i in range(0, len(measure_inputs), self.n_parallel):
number, repeat = option['number'], option['repeat'] futures = []
timeout, n_parallel, do_fork = option['timeout'], option['n_parallel'], option['do_fork'] for inp in measure_inputs[i:i + self.n_parallel]:
build_func = option['build_func'] ret = self.executor.submit(self.build_func,
check_correctness = option['check_correctness'] inp,
replay_db = option['replay_db'] self.tmp_dir,
**self.build_kwargs)
futures.append(ret)
executor = LocalExecutor(timeout=timeout, do_fork=do_fork) for future in futures:
res = future.get()
if isinstance(res, Exception):
# timeout or fleet error, return MeasureResult directly
results.append(MeasureResult((res,), MeasureErrorNo.BUILD_TIMEOUT,
self.timeout, time.time()))
elif res.error is not None:
# instantiation errorD
if isinstance(res.error, InstantiationError):
results.append(MeasureResult((res.error,),
MeasureErrorNo.INSTANTIATION_ERROR,
res.time_cost, time.time()))
else:
if "InstantiationError" in str(res.error):
msg = str(res.error)
try:
msg = msg.split('\n')[-2].split(": ")[1]
except Exception: # pylint: disable=broad-except
pass
results.append(MeasureResult((InstantiationError(msg),),
MeasureErrorNo.INSTANTIATION_ERROR,
res.time_cost, time.time()))
else: # tvm error
results.append(MeasureResult((res.error,),
MeasureErrorNo.COMPILE_HOST,
res.time_cost, time.time()))
else:
# return BuildResult
results.append(res)
# convert convenient string to function object return results
attach_objects = None
if measure_func == 'local':
# start temporary rpc tracker and rpc server for the user
from ...rpc.tracker import Tracker
from ...rpc.server import Server
tracker = Tracker('localhost', port=9000, port_end=10000, silent=True) def __del__(self):
device_key = '$local$device$%d' % tracker.port shutil.rmtree(self.tmp_dir)
server = Server('localhost', port=9000, port_end=10000,
key=device_key,
use_popen=True, silent=True,
tracker_addr=(tracker.host, tracker.port))
measure_func = rpc(device_key, tracker.host, tracker.port)
attach_objects = (server, tracker)
build_kwargs = {} class RPCRunner(Runner):
if build_func == 'default': """Run generated code on remove devices.
build_func = default_build_func This function will ask a RPC Tracker to get device for measurement.
if build_func == 'ndk':
build_func = default_build_func
build_kwargs['use_ndk'] = True
# check the availability of remote devices Parameters
if hasattr(measure_func, 'rpc_info'): ----------
rpc_info = measure_func.rpc_info timeout: float
if check_remote(task.target, rpc_info['key'], (rpc_info['host'], rpc_info['port'])): The timeout of a compilation
n_parallel: int
The number of tasks run in parallel. "None" will use all cpu cores
key: str
The key of the device registered in the tracker
host: str
The host address of RPC Tracker
port: int
The port of RPC Tracker
number : int, optional
Number of times to do measurement for tasking average
repeat : int, optional
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
min_repeat_ms : float, optional
Minimum duration of a timer measurement in milliseconds.
When the run time of a measurement trial falls below this time, the
`number` parameter will be automatically increased.
Set this to improve the accuracy of perf measurement, e.g., when timers
are not precise enough to capture short-running tasks. This parameter is
also critical when devices need a certain minimum running time to "warm
up," such as GPUs that need time to reach a performance power state.
cooldown_interval: float, optional
The cool down interval between two measurements.
check_correctness: bool, optional
Whether check correctness after measurement. This will use llvm cpu target to
call your template and get the reference output.
This can work for TOPI templates, but may not work for your custom template.
"""
def __init__(self,
key, host, port, priority=1,
timeout=10, n_parallel=None,
number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1,
check_correctness=False):
super(RPCRunner, self).__init__(timeout, n_parallel)
self.key = key
self.host = host
self.port = port
self.priority = priority
self.timeout = timeout
self.number = number
self.repeat = repeat
self.min_repeat_ms = min_repeat_ms
self.cur_number = number
self.ref_input = None
self.ref_output = None
self.check_correctness = check_correctness
self.cooldown_interval = cooldown_interval
self.executor = LocalExecutor()
def set_task(self, task):
self.task = task
self.cur_number = self.number
if check_remote(task.target, self.key, self.host, self.port):
logger.info("Get devices for measurement successfully!") logger.info("Get devices for measurement successfully!")
else: else:
raise RuntimeError("Cannot get remote devices from the tracker. " raise RuntimeError("Cannot get remote devices from the tracker. "
...@@ -155,225 +198,261 @@ def create_measure_batch(task, option): ...@@ -155,225 +198,261 @@ def create_measure_batch(task, option):
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' " "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
"and make sure you have free devices on the queue status.") "and make sure you have free devices on the queue status.")
# add device info of cuda and opencl target if self.check_correctness:
if ('cuda' in task.target.keys or 'opencl' in task.target.keys) \
and hasattr(measure_func, 'rpc_info'):
rpc_info = measure_func.rpc_info
add_gpu_target_info(task.target, rpc_info["key"], (rpc_info["host"], rpc_info["port"]),
build_kwargs)
if check_correctness:
# use llvm cpu to generate a reference input/output # use llvm cpu to generate a reference input/output
# this option works for tuning topi, but might not work for you custom op # this option works for tuning topi, but might not work for you custom op
with _target.create("llvm"): with _target.create("llvm"):
s, arg_bufs = task.instantiate(task.config_space.get(0)) s, arg_bufs = task.instantiate(task.config_space.get(0))
ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype) self.ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype)
for x in arg_bufs] for x in arg_bufs]
func = build(s, arg_bufs, "llvm") func = build(s, arg_bufs, "llvm")
tvm_buf = [nd.array(x) for x in ref_input] tvm_buf = [nd.array(x) for x in self.ref_input]
func(*tvm_buf) func(*tvm_buf)
ref_output = [x.asnumpy() for x in tvm_buf] self.ref_output = [x.asnumpy() for x in tvm_buf]
else:
ref_input = ref_output = None def get_build_kwargs(self):
kwargs = {}
if 'cuda' in self.task.target.keys or 'opencl' in self.task.target.keys:
remote = request_remote(self.key, self.host, self.port)
ctx = remote.context(str(self.task.target), 0)
max_dims = ctx.max_thread_dimensions
kwargs['check_gpu'] = {
'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
'max_threads_per_block': ctx.max_threads_per_block,
'max_thread_x': max_dims[0],
'max_thread_y': max_dims[1],
'max_thread_z': max_dims[2],
}
if 'cuda' in self.task.target.keys:
kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.'))
def measure_batch(measure_inputs): return kwargs
"""measure the time cost for a batch of configs in real machines"""
if replay_db is not None:
partial_results, measure_inputs = \
filter_inputs(replay_db, measure_inputs, retry=False)
# launch measure jobs in parallel def run(self, measure_inputs, build_results):
pack_size = getattr(measure_func, "pack_size", 1) # measure `pack_size` inputs in one job results = []
remote_args = (self.key, self.host, self.port, self.priority, self.timeout)
for i in range(0, len(measure_inputs), self.n_parallel):
futures = [] futures = []
for i in range(0, len(measure_inputs), pack_size): for measure_inp, build_res in zip(measure_inputs[i:i+self.n_parallel],
input_pack = measure_inputs[i:i + pack_size] build_results[i:i+self.n_parallel]):
ret = executor.submit( ret = self.executor.submit(run_through_rpc,
measure_func, measure_inp,
input_pack, build_res,
build_func, self.cur_number,
build_kwargs, self.repeat,
number, self.cooldown_interval,
repeat, remote_args,
ref_input, self.ref_input,
ref_output) self.ref_output)
futures.append(ret) futures.append(ret)
# transform results
results = []
for future in futures: for future in futures:
result = future.get() res = future.get()
if isinstance(result, Exception): if isinstance(res, Exception): # executor error or timeout
tstamp = time.time() results.append(MeasureResult((str(res),), MeasureErrorNo.RUN_TIMEOUT,
results.extend([MeasureResult((result,), MeasureErrorNo.FLEET_ERROR, self.timeout, time.time()))
timeout, tstamp)] * pack_size)
else: else:
results.extend(result) results.append(res)
if replay_db is not None: # If some runs were too fast, do remeasure for them
result_idx = 0 # to meet the requirement of `min_repeat_ms`
for i in range(len(partial_results)): remeasure = np.zeros((len(measure_inputs),), dtype=np.bool)
if partial_results[i] is None: pre_number = next_number = self.cur_number
partial_results[i] = results[result_idx] min_repeat_duration = self.min_repeat_ms / 1000.0
result_idx += 1 for i, res in enumerate(results):
return partial_results if res.error_no == MeasureErrorNo.NO_ERROR:
return results if np.mean(res.costs) * pre_number <= min_repeat_duration:
next_number = max(next_number,
measure_batch.n_parallel = n_parallel int(np.ceil(min_repeat_duration / np.mean(res.costs))))
# attach server and tracker object to avoid them of being garbage-collected remeasure[i] = True
measure_batch.attach_objects = attach_objects
return measure_batch if pre_number != next_number:
self.cur_number = next_number
msg = "increasing number to %d" % self.cur_number
logger.info(msg)
re_measure_inputs = [x for i, x in enumerate(measure_inputs) if remeasure[i]]
re_build_results = [x for i, x in enumerate(build_results) if remeasure[i]]
re_res = self.run(re_measure_inputs, re_build_results)
ct = 0
for i, rerun in enumerate(remeasure):
if rerun:
results[i] = re_res[ct]
ct += 1
def rpc(key, return results
host=None,
port=None,
priority=1,
session_timeout=60,
pack_size=1):
"""
Create a standard measure_func which uses RPC Tracker for measurement.
This measure_func will request a device from the RPC Tracker and
upload the built binary library to that device for measurement.
Parameters class LocalRunner(RPCRunner):
---------- """Run generated code on local devices.
key: str
The registered key of the device in tracker. The tuner will request devices for
measurement by this key.
host: str, optional
The hostname of RPC Tracker. If not set, will use environment variable "TVM_TRACKER_HOST"
port: int, optional
The port of RPC Tracker. If not set, will use environment variable "TVM_TRACKER_PORT"
priority: int, optional
Priority of this task, used by scheduler in tracker
session_timeout: int, optional
Timeout of rpc session
pack_size: int, optional
The number of configs measure in one RPC session.
Usually this can be set to 1. If your device has high overhead to establish a
rpc connection, set this higher.
"""
def fmeasure(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output):
"""Do measurement for a list of inputs inside a same RPC session.
Parameters Parameters
---------- ----------
input_pack: List of MeasureInput timeout: float
The inputs of measurement The timeout of a compilation
build_func: callable
Function for building the code. see :any:`default_build_func` for example
build_kwargs: dict
Extra arguments for build_func
number : int, optional number : int, optional
Number of times to do the measurement for average Number of times to do measurement for tasking average
repeat : int, optional repeat : int, optional
Number of times to repeat the measurement. Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times, In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs, where the first one is warm up. The returned result contains `repeat` costs,
each of which is the average of `number` test run. each of which is the average of `number` test run.
ref_input: List of numpy array min_repeat_ms : float, optional
Reference input for correctness check Minimum duration of a timer measurement in milliseconds.
ref_output: List of numpy array When the run time of a measurement trial falls below this time, the
Reference output for correctness check `number` parameter will be automatically increased.
Set this to improve the accuracy of perf measurement, e.g., when timers
are not precise enough to capture short-running tasks. This parameter is
also critical when devices need a certain minimum running time to "warm
up," such as GPUs that need time to reach a performance power state.
cooldown_interval: float, optional
The cool down interval between two measurements.
check_correctness: bool, optional
Whether check correctness after measurement. This will use llvm cpu target to
call your template and get the reference output.
This can work for TOPI templates, but may not work for your custom template.
Note
----
This is a "fake" local mode. We start a silent rpc tracker and rpc server
for the user. In this way we reuse timeout/isolation mechanism in RPC infrastructure.
"""
def __init__(self,
timeout=10,
number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1,
check_correctness=False):
super(LocalRunner, self).__init__('', None, None, 0,
timeout=timeout, n_parallel=1,
number=number, repeat=repeat,
min_repeat_ms=min_repeat_ms,
cooldown_interval=cooldown_interval,
check_correctness=check_correctness)
self.tracker = None
self.server = None
def set_task(self, task):
self.task = task
Returns from ...rpc.tracker import Tracker
------- from ...rpc.server import Server
results: List of MeasureResult
The results for input_pack tracker = Tracker('localhost', port=9000, port_end=10000, silent=True)
device_key = '$local$device$%d' % tracker.port
server = Server('localhost', port=9000, port_end=10000,
key=device_key,
use_popen=True, silent=True,
tracker_addr=(tracker.host, tracker.port))
self.key = device_key
self.host = tracker.host
self.port = tracker.port
super(LocalRunner, self).set_task(task)
return server, tracker
def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None):
"""Common part for building a configuration"""
target, task, config = measure_input
with target:
s, args = task.instantiate(config)
# check invalidity of template and code hash consistency
if not config.valid():
raise InstantiationError(config.errors)
opts = build_option or {}
if check_gpu: # Add verify pass to filter out invalid configs in advance.
opts["add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))]
if cuda_arch:
set_cuda_target_arch(cuda_arch)
with build_config(**opts):
func = build(s, args, target_host=task.target_host)
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
def default_build_func(measure_input, tmp_dir, **kwargs):
""" """
remote_args = (key, (host, port), priority, session_timeout) Default build func. This can work for cuda, opencl, llvm backend
Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
"""
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
res = _measure_common(input_pack, build_func, build_kwargs, number, repeat,
ref_input, ref_output,
remote_args)
return res
fmeasure.pack_size = pack_size def android_ndk_build_func(measure_input, tmp_dir, **kwargs):
fmeasure.rpc_info = {"key": key, "host": host, "port": port} """
return fmeasure Build function for android device using ndk.
Parameters
----------
measure_input: MeasureInput
The input of measurement
tmp_dir: str
The path of temporary directory to export generated library
"""
tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.so" % getrandbits(64))
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename, ndk.create_shared)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
ref_input=None, ref_output=None, remote_args=None):
"""Measure the time cost for a pack of inputs.
(Note: A pack is a list of inputs which will be measured inside a same RPC session) def run_through_rpc(measure_input, build_result,
number, repeat, cooldown_interval,
remote_args, ref_input=None, ref_output=None):
"""Run a generated library through rpc
Parameters Parameters
---------- ----------
input_pack : list of MeasureInput measure_input: MeasureInput
The inputs we need to evaluate The raw measure input
build_func : function takes MeasureInput returns tuple of (time_func, ctx, args) build_result: BuildResult
The build function used to build each input. The result returned from Builder. This contains the path to the generated library.
build_kwargs: Dict
The extra keyword arguments to build_func
number : int, optional number : int, optional
Number of times to do the measurement for average Number of times to do measurement for tasking average
repeat : int, optional repeat : int, optional
Number of times to repeat the measurement. Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times, In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs, where the first one is warm up. The returned result contains `repeat` costs,
each of which is the average of `number` test run. each of which is the average of `number` test run.
ref_input: Array of np.ndarray, optional cooldown_interval: float
Reference input for checking correctness The cool down interval between two measurements
ref_output: Array of np.ndarray, optional remote_args: Tuple
Reference output for checking correctness The argument for request_remote
remote_args: Tuple, optional ref_input: List of np.ndarray
The arguments to request_remote. If is not None, will use remote rpc devices. The reference input used for checking correctness
ref_output: List of np.ndarray
Returns The reference output used for checking correctness
-------
res_pack : Array of MeasureResult
The list of results of measurement.
""" """
res_pack = [] if isinstance(build_result, MeasureResult):
tmp_dir = util.tempdir() if remote_args else None return build_result
assert len(input_pack) == 1, "Only supports input_pack == 1 for now"
for inp in input_pack:
tic = time.time() tic = time.time()
# build function
try:
func, arg_bufs, filename = build_func(inp, tmp_dir, **build_kwargs)
except TVMError as exc:
tstamp = time.time()
msg = str(exc)
if "Stack trace returned" in msg:
msg = msg[:msg.index("Stack trace returned")]
if "InstantiationError" in msg:
try:
msg = msg.split('\n')[-2].split(": ")[1]
except Exception: # pylint: disable=broad-except
pass
res_pack.append(MeasureResult((InstantiationError(msg),),
MeasureErrorNo.INSTANTIATION_ERROR,
tstamp - tic, tstamp))
else:
res_pack.append(MeasureResult((RuntimeError(msg),),
MeasureErrorNo.COMPILE_HOST,
tstamp - tic, tstamp))
continue
except InstantiationError as e:
tstamp = time.time()
res_pack.append(MeasureResult((InstantiationError(str(e)),),
MeasureErrorNo.INSTANTIATION_ERROR,
tstamp - tic, tstamp))
continue
# measure time
errno = MeasureErrorNo.NO_ERROR errno = MeasureErrorNo.NO_ERROR
try: try:
# upload built module # upload built module
if remote_args:
remote = request_remote(*remote_args) remote = request_remote(*remote_args)
remote.upload(tmp_dir.relpath(filename)) remote.upload(build_result.filename)
func = remote.load_module(filename) func = remote.load_module(os.path.split(build_result.filename)[1])
ctx = remote.context(str(inp.target), 0) ctx = remote.context(str(measure_input.target), 0)
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat)
else:
ctx = context(str(inp.target), 0)
time_f = func.time_evaluator( time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat) func.entry_name, ctx, number=number, repeat=repeat)
...@@ -381,8 +460,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat, ...@@ -381,8 +460,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
if ref_input: if ref_input:
args = [nd.array(x, ctx=ctx) for x in ref_input] args = [nd.array(x, ctx=ctx) for x in ref_input]
else: else:
args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype, ctx=ctx) args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info]
for x in arg_bufs]
costs = time_f(*args).results costs = time_f(*args).results
if len(costs) > 2: # remove largest and smallest value to reduce variance if len(costs) > 2: # remove largest and smallest value to reduce variance
...@@ -402,91 +480,78 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat, ...@@ -402,91 +480,78 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
msg = msg[:msg.index("Stack trace returned")] msg = msg[:msg.index("Stack trace returned")]
if "CUDA Source" in msg: if "CUDA Source" in msg:
msg = msg[:msg.index("CUDA Source")] msg = msg[:msg.index("CUDA Source")]
costs = (RuntimeError(msg),) costs = (RuntimeError(msg[:1024]),)
errno = MeasureErrorNo.RUNTIME_DEVICE errno = MeasureErrorNo.RUNTIME_DEVICE
tstamp = time.time() tstamp = time.time()
res_pack.append(MeasureResult(costs, errno, tstamp - tic, tstamp)) time.sleep(cooldown_interval)
return res_pack return MeasureResult(costs, errno, tstamp - tic + build_result.time_cost, tstamp)
def default_build_func(inp, tmp_dir=None, **kwargs): def request_remote(device_key, host=None, port=None, priority=1, timeout=60):
"""Build function module. Exception will be raised when any error occurs """Request a remote session
Parameters Parameters
---------- ----------
inp: MeasureInput device_key: string
The input of this measurement The device key of registered device in tracker
tmp_dir: tvm.contrib.util.TempDirectory, optional host: host, optional
The temporary directory for exporting built binary library. The host address of rpc tracker.
If is not None (in RPC mode), the library in this directory will be uploaded to If is none, will use environment variable "TVM_TRACKER_HOST"
remote devices. port: int, optional
kwargs: Dict, optional The port of rpc tracker.
Other extra arguments If is none, will use environment variable "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this session (units: second)
Returns Returns
------- ------
func: Function session: RPCSession
TVM built function. Typically this is the return value of tvm.build.
args: Array of Buffer or Tensor
The argument list for the function. Typically this is the second argument of tvm.build.
filename: str
The filename of the output build library
""" """
# build function # connect to the tracker
with inp.target: host = host or os.environ['TVM_TRACKER_HOST']
s, args = inp.task.instantiate(inp.config) port = port or int(os.environ['TVM_TRACKER_PORT'])
# check invalidity of template and code hash consistency
if not inp.config.valid():
raise InstantiationError(inp.config.errors)
code_hash = getattr(s, 'code_hash', None)
if inp.config.code_hash != code_hash:
raise HashMismatchError('got {0:s}, expected {1:s}'
.format(str(inp.config.code_hash), str(code_hash)))
opts = {}
if "check_gpu" in kwargs: # Add verify pass to filter out invalid configs in advance.
opts["add_lower_pass"] = [(2, gpu_verify_pass(**kwargs['check_gpu']))]
if 'cuda_arch' in kwargs:
set_cuda_target_arch(kwargs['cuda_arch'])
with build_config(**opts):
func = build(s, args, target_host=inp.task.target_host)
# export library to temp directory
if tmp_dir:
if kwargs.get('use_ndk', False): # for Android NDK
filename = "tmp_func_%0x.so" % getrandbits(64)
func.export_library(tmp_dir.relpath(filename), ndk.create_shared)
else:
filename = "tmp_func_%0x.tar" % getrandbits(64)
func.export_library(tmp_dir.relpath(filename))
else:
filename = None
return func, args, filename tracker = _rpc.connect_tracker(host, port)
remote = tracker.request(device_key, priority=priority,
session_timeout=timeout)
return remote
def add_gpu_target_info(target, device_key, rpc_tracker_addr, kwargs): def check_remote(target, device_key, host=None, port=None, priority=2, timeout=10):
"""Add device info for gpu target. """
The info will be used to check the validity of generated code.""" Check the availability of a remote device
remote = request_remote(device_key, rpc_tracker_addr)
ctx = remote.context(str(target), 0)
max_dims = ctx.max_thread_dimensions
kwargs['check_gpu'] = {
'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
'max_threads_per_block': ctx.max_threads_per_block,
'max_thread_x': max_dims[0],
'max_thread_y': max_dims[1],
'max_thread_z': max_dims[2],
}
if 'cuda' in target.keys: Parameters
kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.')) ----------
target: Target
The wanted compilation target
device_key: string
device key of registered device in tracker
host: host, optional
The host address of rpc tracker.
If is none, will use environment variable "TVM_TRACKER_HOST"
port: int, optional
The port address of rpc tracker.
If is none, will use environment variable "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this check (units: seconds).
def set_cuda_target_arch(arch): Returns
"""set target architecture of nvcc compiler""" -------
AutotvmGlobalScope.current.cuda_target_arch = arch available: bool
True if can find available device
"""
def _check():
remote = request_remote(device_key, host, port, priority)
remote.context(str(target))
t = threading.Thread(target=_check,)
t.start()
t.join(timeout)
return not t.is_alive()
@register_func @register_func
...@@ -496,6 +561,17 @@ def tvm_callback_cuda_compile(code): ...@@ -496,6 +561,17 @@ def tvm_callback_cuda_compile(code):
return ptx return ptx
def set_cuda_target_arch(arch):
"""set target architecture of nvcc compiler
Parameters
----------
arch: str
The argument of nvcc -arch. (e.g. "sm_51", "sm_62")
"""
AutotvmGlobalScope.current.cuda_target_arch = arch
def gpu_verify_pass(**kwargs): def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel. """Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block. This pass will check memory usage and number of threads per block.
......
...@@ -22,7 +22,7 @@ class GATuner(Tuner): ...@@ -22,7 +22,7 @@ class GATuner(Tuner):
mutation_prob: float mutation_prob: float
probability of mutation of a knob in a gene probability of mutation of a knob in a gene
""" """
def __init__(self, task, pop_size, elite_num=3, mutation_prob=0.1): def __init__(self, task, pop_size=100, elite_num=3, mutation_prob=0.1):
super(GATuner, self).__init__(task) super(GATuner, self).__init__(task)
# algorithm configurations # algorithm configurations
......
...@@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer): ...@@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
new_scores = model.predict(new_points) new_scores = model.predict(new_points)
ac_prob = np.exp((new_scores - scores) / (t + 1e-2)) ac_prob = np.exp(np.minimum((new_scores - scores) / (t + 1e-5), 1))
ac_index = np.random.random(len(ac_prob)) < ac_prob ac_index = np.random.random(len(ac_prob)) < ac_prob
points[ac_index] = new_points[ac_index] points[ac_index] = new_points[ac_index]
......
...@@ -103,34 +103,7 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None): ...@@ -103,34 +103,7 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None):
target=target, target_host=target_host) target=target, target_host=target_host)
return task, target return task, target
def test_tuning():
def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()
def custom_measure(input_pack, build_func, build_args, number, repeat,
ref_input, ref_output):
from tvm.autotvm import MeasureResult
results = []
for inp in input_pack:
tic = time.time()
# do nothing
time.sleep(0.001)
results.append(MeasureResult([time.time() - tic], 0,
time.time() - tic, time.time()))
return results
measure_option = autotvm.measure_option(custom_measure)
logging.info("%s", task.config_space)
# new tuner and recorder
for tuner_class in [autotvm.tuner.RandomTuner, autotvm.tuner.GridSearchTuner]:
tuner = tuner_class(task)
tuner.tune(n_trial=10, measure_option=measure_option)
assert tuner.best_flops > 1
def test_tuning_with_measure():
def check(target, target_host): def check(target, target_host):
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
if not ctx.exist: if not ctx.exist:
...@@ -141,12 +114,12 @@ def test_tuning_with_measure(): ...@@ -141,12 +114,12 @@ def test_tuning_with_measure():
task, target = get_sample_task(target, target_host) task, target = get_sample_task(target, target_host)
logging.info("%s", task.config_space) logging.info("%s", task.config_space)
measure_option = autotvm.measure_option('local', measure_option = autotvm.measure_option(
timeout=4, autotvm.LocalBuilder(),
number=2) autotvm.LocalRunner())
tuner = RandomTuner(task) tuner = RandomTuner(task)
tuner.tune(n_trial=10, measure_option=measure_option) tuner.tune(n_trial=20, measure_option=measure_option)
check("cuda", None) check("cuda", None)
check("opencl", None) check("opencl", None)
...@@ -155,6 +128,4 @@ if __name__ == "__main__": ...@@ -155,6 +128,4 @@ if __name__ == "__main__":
# only print log when invoked from main # only print log when invoked from main
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
test_task_tuner_without_measurement() test_tuning()
test_tuning_with_measure()
...@@ -32,6 +32,25 @@ def matmul(N, L, M, dtype): ...@@ -32,6 +32,25 @@ def matmul(N, L, M, dtype):
return s, [A, B, C] return s, [A, B, C]
@autotvm.template
def bad_matmul(N, L, M, dtype):
if 'bad_device' in tvm.target.current_target().keys:
A = tvm.placeholder((N, L), name='A', dtype=dtype)
B = tvm.placeholder((L, M), name='B', dtype=dtype)
k = tvm.reduce_axis((0, L-1), name='k')
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
s = tvm.create_schedule(C.op)
# schedule
y, x = s[C].op.axis
cfg = autotvm.get_config()
cfg.define_split("tile_y", y, num_outputs=2)
cfg.define_split("tile_x", x, num_outputs=2)
return s, [A, B, C]
return matmul(N, L, M, dtype)
def get_sample_task(n=128): def get_sample_task(n=128):
"""return a sample task for testing""" """return a sample task for testing"""
target = tvm.target.create("llvm") target = tvm.target.create("llvm")
......
"""Test database""" """Test database"""
import copy import copy
import logging import logging
import time
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm import database from tvm.autotvm import database
from tvm.autotvm.measure.measure_methods import HashMismatchError from tvm.autotvm.record import encode, MeasureResult
from tvm.autotvm.record import encode, MeasureInput, MeasureResult
from test_autotvm_common import get_sample_task, get_sample_records from test_autotvm_common import get_sample_records
def test_save_load(): def test_save_load():
logging.info("test basic db load/save ...") logging.info("test basic db load/save ...")
...@@ -35,66 +29,6 @@ def test_save_load(): ...@@ -35,66 +29,6 @@ def test_save_load():
TRIAL_LIMIT = 2 TRIAL_LIMIT = 2
def test_db_filter():
logging.info("test db filter ...")
# Pick a GPU target because there are more likely to be failures/invalid configs
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
batch_size = 2
measure_option = autotvm.measure_option('local', do_fork=False, timeout=2)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
ct = 0
all_inputs = list()
all_results = list()
batches = list()
tuner = autotvm.tuner.RandomTuner(task)
while ct < TRIAL_LIMIT:
inputs = list()
for i in range(batch_size):
cfg = tuner.next_batch(1)[0]
inputs.append((MeasureInput(target, task, cfg)))
all_inputs.append(inputs[-1])
batches.append(inputs)
results = measure_batch(inputs)
all_results += results
ct += 1
del measure_batch
db = database.DummyDatabase()
db.flush()
# First setting, memoize one input at a time, check that each is saved and replayed
measure_option = autotvm.measure_option('local', do_fork=False, timeout=2, replay_db=db)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
for i in range(len(all_inputs)+1):
db.flush()
for j in range(i):
db.save(all_inputs[j], all_results[j])
for k in range(len(batches)):
batch = batches[k]
batch_result = measure_batch(batch)
for l in range(batch_size):
all_idx = k*batch_size + l
assert batch_result[l] is not None
if all_idx < i:
assert encode(batch[l], batch_result[l]) == encode(batch[l], all_results[all_idx]), \
"(no retry) EXPECTED MATCH, GOT MISMATCH"
else:
assert encode(batch[l], batch_result[l]) != encode(batch[l], all_results[all_idx]), \
"(no retry) EXPECTED MISMATCH, GOT MATCH"
del measure_batch
def test_db_hash(): def test_db_hash():
logging.info("test db hash check ...") logging.info("test db hash check ...")
inp1, res1 = get_sample_records(1)[0] inp1, res1 = get_sample_records(1)[0]
...@@ -149,89 +83,8 @@ def test_db_latest_all(): ...@@ -149,89 +83,8 @@ def test_db_latest_all():
assert encode(inp1, load4[1]) == encode(inp1, res2) assert encode(inp1, load4[1]) == encode(inp1, res2)
assert encode(inp1, load4[2]) == encode(inp1, res3) assert encode(inp1, load4[2]) == encode(inp1, res3)
def test_db_save_replay():
logging.info("test db save (from measure_batch) and replay ...")
_db = database.DummyDatabase()
_db.flush()
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option('local',
do_fork=False,
timeout=2,
replay_db=_db)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
batch_size = 2
ct = 0
all_inputs = list()
all_results = list()
batches = list()
tuner = autotvm.tuner.RandomTuner(task)
while ct < TRIAL_LIMIT:
inputs = list()
for i in range(batch_size):
cfg = tuner.next_batch(1)[0]
inputs.append((MeasureInput(target, task, cfg)))
all_inputs.append(inputs[-1])
batches.append(inputs)
results = measure_batch(inputs)
all_results += results
ct += 1
callback = autotvm.callback.log_to_database(_db)
callback(None, all_inputs, all_results)
assert len(_db.db.keys()) == batch_size * TRIAL_LIMIT, \
"%d vs %d" % (len(_db.db.keys()), batch_size * TRIAL_LIMIT)
all_results_2 = measure_batch(all_inputs)
all_results_3 = measure_batch(all_inputs)
for i in range(len(all_results)):
encr1 = encode(all_inputs[i], all_results[i])
encr2 = encode(all_inputs[i], all_results_2[i])
encr3 = encode(all_inputs[i], all_results_3[i])
assert encr1 == encr2, "EXPECTED MATCH WITH SAVE REPLAY (first replay), got MISMATCH"
assert encr2 == encr3, "EXPECTED MATCH WITH SAVE REPLAY (second replay), got MISMATCH"
del measure_batch
def test_check_hashmismatch():
logging.info("test hash mismatch check")
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option('local', do_fork=False)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
inputs = list()
cfg = task.config_space.get(np.random.randint(len(task.config_space)))
# notvalidh is not a valid CRC32 hash (not hex)
cfg.code_hash = 'notvalidh'
inputs.append((MeasureInput(target, task, cfg)))
try:
results = measure_batch(inputs)
assert False, "HashMismatchError should be raised"
except HashMismatchError:
pass
del measure_batch
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
test_save_load() test_save_load()
test_db_filter()
test_db_hash() test_db_hash()
test_db_latest_all() test_db_latest_all()
test_db_save_replay()
test_check_hashmismatch()
"""Test builder and runner"""
import logging
import time
import numpy as np
import tvm
from tvm import autotvm
from test_autotvm_common import get_sample_task, bad_matmul
from tvm.autotvm.measure.measure import Runner, MeasureResult, MeasureErrorNo
def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()
class DummyRunner(Runner):
def __init__(self):
super(DummyRunner, self).__init__(1, 1)
def run(self, measure_inputs, build_results):
return [MeasureResult((np.random.random(),), 0, 0.2, time.time())
for _ in range(len(measure_inputs))]
def get_build_kwargs(self):
return {}
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=DummyRunner()
)
logging.info("%s", task.config_space)
for tuner_class in [autotvm.tuner.RandomTuner,
autotvm.tuner.GridSearchTuner,
autotvm.tuner.GATuner,
autotvm.tuner.XGBTuner]:
tuner = tuner_class(task)
tuner.tune(n_trial=10, measure_option=measure_option)
assert tuner.best_flops > 1
def test_check_correctness():
task, target = get_sample_task()
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(check_correctness=True)
)
def _callback_correct(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
assert res.error_no == 0
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=2, measure_option=measure_option,
callbacks=[_callback_correct])
# a bad template
n = 128
target = tvm.target.create("llvm -device=bad_device")
task = autotvm.task.create(bad_matmul, args=(n, n, n, 'float32'), target=target)
def _callback_wrong(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
assert res.error_no == MeasureErrorNo.WRONG_ANSWER
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=2, measure_option=measure_option,
callbacks=[_callback_wrong])
def test_min_repeat_ms():
task, target = get_sample_task()
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(number=1, min_repeat_ms=100)
)
def _callback(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
if res.error_no != 0:
continue
assert 1000 * np.mean(res.costs) * \
measure_option['runner'].cur_number >= 100
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=5, measure_option=measure_option,
callbacks=[_callback])
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
test_task_tuner_without_measurement()
test_check_correctness()
test_min_repeat_ms()
...@@ -137,7 +137,10 @@ if __name__ == '__main__': ...@@ -137,7 +137,10 @@ if __name__ == '__main__':
print(task.config_space) print(task.config_space)
measure_option = autotvm.measure_option( measure_option = autotvm.measure_option(
measure_func='local', number=10, n_parallel=8, timeout=20) builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
)
log_name = 'gemm_int8.log' log_name = 'gemm_int8.log'
if DO_TUNING: if DO_TUNING:
tuner = autotvm.tuner.XGBTuner(task) tuner = autotvm.tuner.XGBTuner(task)
......
...@@ -164,12 +164,12 @@ task = autotvm.task.create(conv2d_no_batching, ...@@ -164,12 +164,12 @@ task = autotvm.task.create(conv2d_no_batching,
target='cuda') target='cuda')
print(task.config_space) print(task.config_space)
# use local gpu, measure 5 times for every config to reduce variance # use local gpu, measure 10 times for every config to reduce variance
# run 8 parallel threads for compilation # The timeout of compiling a program is 10 seconds, the timeout for running is 4 seconds
measure_option = autotvm.measure_option('local', measure_option = autotvm.measure_option(
number=5, builder=autotvm.LocalBuilder(),
n_parallel=8, runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
timeout=20) )
# begin tuning, log records to file `conv2d.log` # begin tuning, log records to file `conv2d.log`
tuner = autotvm.tuner.XGBTuner(task) tuner = autotvm.tuner.XGBTuner(task)
......
...@@ -65,15 +65,20 @@ def get_network(name, batch_size): ...@@ -65,15 +65,20 @@ def get_network(name, batch_size):
input_shape = (batch_size, 3, 224, 224) input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000) output_shape = (batch_size, 1000)
if name =='resnet-18': if "resnet" in name:
net, params = nnvm.testing.resnet.get_workload(num_layers=18, batch_size=batch_size) n_layer = int(name.split('-')[1])
elif name =='mobilenet': net, params = nnvm.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size)
elif "vgg" in name:
n_layer = int(name.split('-')[1])
net, params = nnvm.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size)
elif name == 'mobilenet':
net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size) net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size)
elif name =='squeezenet v1.1': elif name == 'squeezenet_v1.1':
net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1') net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1')
elif name =='vgg-16': elif name == 'inception_v3':
net, params = nnvm.testing.vgg.get_workload(num_layers=16, batch_size=batch_size) input_shape = (1, 3, 299, 299)
elif name =='custom': net, params = nnvm.testing.inception_v3.get_workload(batch_size=batch_size)
elif name == 'custom':
# an example for custom network # an example for custom network
from nnvm.testing import utils from nnvm.testing import utils
net = nnvm.sym.Variable('data') net = nnvm.sym.Variable('data')
...@@ -92,6 +97,7 @@ def get_network(name, batch_size): ...@@ -92,6 +97,7 @@ def get_network(name, batch_size):
return net, params, input_shape, output_shape return net, params, input_shape, output_shape
################################################################# #################################################################
# Start RPC Tracker # Start RPC Tracker
# ----------------- # -----------------
...@@ -158,6 +164,8 @@ def get_network(name, batch_size): ...@@ -158,6 +164,8 @@ def get_network(name, batch_size):
# rk3399 2 2 0 # rk3399 2 2 0
# rpi3b 11 11 0 # rpi3b 11 11 0
# ---------------------------------- # ----------------------------------
#
# You can register multiple devices to the tracker to accelerate the measurement in tuning.
########################################### ###########################################
# Set Tuning Options # Set Tuning Options
...@@ -188,14 +196,16 @@ tuning_option = { ...@@ -188,14 +196,16 @@ tuning_option = {
'tuner': 'xgb', 'tuner': 'xgb',
'n_trial': 1000, 'n_trial': 1000,
'early_stopping': 250, 'early_stopping': 400,
'measure_option': autotvm.measure_option( 'measure_option': autotvm.measure_option(
autotvm.measure.rpc(device_key, host='localhost', port=9190), builder=autotvm.LocalBuilder(
number=4, build_func='ndk' if use_android else 'default'),
n_parallel=1, runner=autotvm.RPCRunner(
timeout=10, device_key, host='localhost', port=9190,
build_func='ndk' if use_android else 'default', number=5,
timeout=4,
),
), ),
} }
...@@ -203,15 +213,9 @@ tuning_option = { ...@@ -203,15 +213,9 @@ tuning_option = {
# #
# .. note:: How to set tuning options # .. note:: How to set tuning options
# #
# In general, the default value provided here works well. It is the same # In general, the default value provided here works well.
# value that we used to generate pre-tuned parameters.
# If you have multiple devices, you can set :code:`n_parallel` to
# the number of devices you have. (e.g. set it to 3 if you register 3 rk3399
# boards to the tracker).
# If you have large time budget, you can set :code:`n_trial`, :code:`early_stopping` larger, # If you have large time budget, you can set :code:`n_trial`, :code:`early_stopping` larger,
# which makes the tuning run longer. # which makes the tuning run longer.
# If your device is very slow or a single conv2d operator in your network has large FLOPs,
# consider setting timeout larger.
# #
################################################################### ###################################################################
...@@ -219,7 +223,7 @@ tuning_option = { ...@@ -219,7 +223,7 @@ tuning_option = {
# ------------ # ------------
# Now we can extract tuning tasks from the network and begin tuning. # Now we can extract tuning tasks from the network and begin tuning.
# Here we provide a simple utility function to tune a list of tasks. # Here we provide a simple utility function to tune a list of tasks.
# This function is just an initial implementation which tune them in sequential order. # This function is just an initial implementation which tunes them in sequential order.
# Later we will bring more sophisticated tuner scheduler. # Later we will bring more sophisticated tuner scheduler.
# You can skip the implementation of this function for this tutorial. # You can skip the implementation of this function for this tutorial.
...@@ -236,7 +240,9 @@ def tune_tasks(tasks, ...@@ -236,7 +240,9 @@ def tune_tasks(tasks,
try: # try winograd template try: # try winograd template
tsk = autotvm.task.create(tasks[i].name, tasks[i].args, tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
tasks[i].target, tasks[i].target_host, 'winograd') tasks[i].target, tasks[i].target_host, 'winograd')
tasks.append(tsk) input_channel = tsk.workload[1][1]
if input_channel >= 64:
tasks[i] = tsk
except Exception: except Exception:
pass pass
...@@ -245,8 +251,8 @@ def tune_tasks(tasks, ...@@ -245,8 +251,8 @@ def tune_tasks(tasks,
if os.path.exists(tmp_log_file): if os.path.exists(tmp_log_file):
os.remove(tmp_log_file) os.remove(tmp_log_file)
for i, tsk in enumerate(tasks): for i, tsk in enumerate(reversed(tasks)):
prefix = "[Task %2d/%2d] " %(i+1, len(tasks)) prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# create tuner # create tuner
if tuner == 'xgb' or tuner == 'xgb-rank': if tuner == 'xgb' or tuner == 'xgb-rank':
...@@ -280,7 +286,7 @@ def tune_tasks(tasks, ...@@ -280,7 +286,7 @@ def tune_tasks(tasks,
######################################################################## ########################################################################
# Finally we launch tuning jobs and evaluate the end-to-end performance. # Finally we launch tuning jobs and evaluate the end-to-end performance.
def tune_and_evaluate(): def tune_and_evaluate(tuning_opt):
# extract workloads from nnvm graph # extract workloads from nnvm graph
print("Extract tasks...") print("Extract tasks...")
net, params, input_shape, out_shape = get_network(network, batch_size=1) net, params, input_shape, out_shape = get_network(network, batch_size=1)
...@@ -290,19 +296,18 @@ def tune_and_evaluate(): ...@@ -290,19 +296,18 @@ def tune_and_evaluate():
# run tuning tasks # run tuning tasks
print("Tuning...") print("Tuning...")
tune_tasks(tasks, **tuning_option) tune_tasks(tasks, **tuning_opt)
# compile kernels with history best records # compile kernels with history best records
with autotvm.apply_history_best(log_file): with autotvm.apply_history_best(log_file):
print("Compile...") print("Compile...")
with nnvm.compiler.build_config(opt_level=2, add_pass=['AlterOpLayout']): with nnvm.compiler.build_config(opt_level=2, add_pass=['AlterOpLayout']):
graph, lib, params = nnvm.compiler.build( graph, lib, params = nnvm.compiler.build(
net, target=target, net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
shape={'data': input_shape}, params=params, dtype=dtype)
# export library # export library
tmp = tempdir() tmp = tempdir()
if tuning_option['measure_option']['build_func'] == 'ndk': # for android if use_android:
from tvm.contrib import ndk from tvm.contrib import ndk
filename = "net.so" filename = "net.so"
lib.export_library(tmp.relpath(filename), ndk.create_shared) lib.export_library(tmp.relpath(filename), ndk.create_shared)
...@@ -312,8 +317,7 @@ def tune_and_evaluate(): ...@@ -312,8 +317,7 @@ def tune_and_evaluate():
# upload module to device # upload module to device
print("Upload...") print("Upload...")
remote = autotvm.measure.request_remote(device_key, remote = autotvm.measure.request_remote(device_key, 'localhost', 9190,
tracker_addr=('localhost', 9190),
timeout=10000) timeout=10000)
remote.upload(tmp.relpath(filename)) remote.upload(tmp.relpath(filename))
rlib = remote.load_module(filename) rlib = remote.load_module(filename)
...@@ -328,47 +332,44 @@ def tune_and_evaluate(): ...@@ -328,47 +332,44 @@ def tune_and_evaluate():
# evaluate # evaluate
print("Evaluate inference time cost...") print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10) ftimer = module.module.time_evaluator("run", ctx, number=8, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" % print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res))) (np.mean(prof_res), np.std(prof_res)))
# We do not run the tuning in our webpage server since it takes too long. # We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run by yourself. # Uncomment the following line to run by yourself.
# tune_and_evaluate()
# tune_and_evaluate(tuning_option)
###################################################################### ######################################################################
# Sample Output # Sample Output
# ------------- # -------------
# The tuning needs to train xgboost models and use them for prediction. # The tuning needs to compile many programs and extract feature from them.
# So a high performance CPU is recommended. # So a high performance CPU is recommended.
# It takes about 2 hours on a 32T AMD Ryzen CPU. # One sample output is listed below.
# One sample output is # It takes about 2 hours on a 32T AMD Ryzen Threadripper.
# #
# .. code-block:: bash # .. code-block:: bash
# #
# Extract tasks... # Extract tasks...
# Tuning... # Tuning...
# [Task 1/16] Current/Best: 18.85/ 19.67 GFLOPS | Progress: (353/1000) | 387.05 s Done. # [Task 1/12] Current/Best: 22.37/ 52.19 GFLOPS | Progress: (544/1000) | 406.59 s Done.
# [Task 2/16] Current/Best: 16.10/ 23.50 GFLOPS | Progress: (444/1000) | 379.99 s Done. # [Task 2/12] Current/Best: 6.51/ 18.77 GFLOPS | Progress: (608/1000) | 325.05 s Done.
# [Task 3/16] Current/Best: 5.49/ 13.96 GFLOPS | Progress: (610/1000) | 485.87 s Done. # [Task 3/12] Current/Best: 4.67/ 24.87 GFLOPS | Progress: (480/1000) | 372.31 s Done.
# [Task 4/16] Current/Best: 10.07/ 20.48 GFLOPS | Progress: (430/1000) | 391.66 s Done. # [Task 4/12] Current/Best: 11.35/ 46.83 GFLOPS | Progress: (736/1000) | 602.39 s Done.
# [Task 5/16] Current/Best: 11.50/ 15.50 GFLOPS | Progress: (374/1000) | 356.03 s Done. # [Task 5/12] Current/Best: 1.01/ 19.80 GFLOPS | Progress: (448/1000) | 262.16 s Done.
# [Task 6/16] Current/Best: 10.76/ 23.77 GFLOPS | Progress: (526/1000) | 526.42 s Done. # [Task 6/12] Current/Best: 2.47/ 23.76 GFLOPS | Progress: (672/1000) | 563.85 s Done.
# [Task 7/16] Current/Best: 12.71/ 22.03 GFLOPS | Progress: (341/1000) | 322.96 s Done. # [Task 7/12] Current/Best: 14.57/ 33.97 GFLOPS | Progress: (544/1000) | 465.15 s Done.
# [Task 8/16] Current/Best: 8.60/ 17.91 GFLOPS | Progress: (272/1000) | 236.08 s Done. # [Task 8/12] Current/Best: 1.13/ 17.65 GFLOPS | Progress: (576/1000) | 365.08 s Done.
# [Task 9/16] Current/Best: 15.37/ 23.62 GFLOPS | Progress: (275/1000) | 275.18 s Done. # [Task 9/12] Current/Best: 14.45/ 22.66 GFLOPS | Progress: (928/1000) | 724.25 s Done.
# [Task 10/16] Current/Best: 6.62/ 23.01 GFLOPS | Progress: (330/1000) | 315.02 s Done. # [Task 10/12] Current/Best: 3.22/ 15.36 GFLOPS | Progress: (864/1000) | 564.27 s Done.
# [Task 11/16] Current/Best: 1.85/ 21.39 GFLOPS | Progress: (281/1000) | 239.19 s Done. # [Task 11/12] Current/Best: 11.03/ 32.23 GFLOPS | Progress: (736/1000) | 635.15 s Done.
# [Task 12/16] Current/Best: 15.41/ 24.02 GFLOPS | Progress: (258/1000) | 270.82 s Done. # [Task 12/12] Current/Best: 8.00/ 21.65 GFLOPS | Progress: (1000/1000) | 1111.81 s Done.
# [Task 13/16] Current/Best: 17.96/ 25.79 GFLOPS | Progress: (380/1000) | 738.29 s Done.
# [Task 14/16] Current/Best: 14.81/ 31.17 GFLOPS | Progress: (413/1000) | 799.21 s Done.
# [Task 15/16] Current/Best: 24.39/ 40.97 GFLOPS | Progress: (355/1000) | 700.25 s Done.
# [Task 16/16] Current/Best: 9.42/ 49.90 GFLOPS | Progress: (348/1000) | 603.84 s Done.
# Compile... # Compile...
# Upload... # Upload...
# Evaluate inference time cost... # Evaluate inference time cost...
# Mean inference time (std dev): 157.29 ms (1.74 ms) # Mean inference time (std dev): 162.59 ms (0.06 ms)
###################################################################### ######################################################################
# #
......
...@@ -271,9 +271,12 @@ print(task.config_space) ...@@ -271,9 +271,12 @@ print(task.config_space)
logging.getLogger('autotvm').setLevel(logging.DEBUG) logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
# use local cpu, measure 5 times for every config to reduce variance # There are two steps for measuring a config: build and run.
measure_option = autotvm.measure_option('local', # By default, we use all cpu cores to compile program. Then measure them sequentially.
number=5) # We measure 5 times and take average to reduce variance.
measure_option = autotvm.measure_option(
builder='local',
runner=autotvm.LocalRunner(number=5))
# begin tuning, log records to file `matmul.log` # begin tuning, log records to file `matmul.log`
tuner = autotvm.tuner.RandomTuner(task) tuner = autotvm.tuner.RandomTuner(task)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment