Commit 136061dc by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Improve tutorial and logging (#1544)

parent 33606741
"""Distributed executor infrastructure to scale up the tuning""" """Distributed executor infrastructure to scale up the tuning"""
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option
from .measure_methods import request_remote, create_measure_batch, use_rpc from .measure_methods import request_remote, check_remote, create_measure_batch, use_rpc
from .local_executor import LocalExecutor from .local_executor import LocalExecutor
from .executor import Future, Executor from .executor import Future, Executor
...@@ -9,6 +9,7 @@ import logging ...@@ -9,6 +9,7 @@ import logging
import os import os
import time import time
from random import getrandbits from random import getrandbits
import threading
import numpy as np import numpy as np
...@@ -23,6 +24,7 @@ from ..task.space import InstantiationError ...@@ -23,6 +24,7 @@ from ..task.space import InstantiationError
from .measure import MeasureResult, MeasureErrorNo from .measure import MeasureResult, MeasureErrorNo
from .local_executor import LocalExecutor from .local_executor import LocalExecutor
logger = logging.getLogger('autotvm')
class HashMismatchError(ValueError): class HashMismatchError(ValueError):
"""Raised when the code hash of a submitted config doesn't match that on the """Raised when the code hash of a submitted config doesn't match that on the
...@@ -42,9 +44,9 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60): ...@@ -42,9 +44,9 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
If is none, will use environment variable "TVM_TRACKER_HOST" If is none, will use environment variable "TVM_TRACKER_HOST"
and "TVM_TRACKER_PORT" and "TVM_TRACKER_PORT"
priority: int, optional priority: int, optional
priority of this request, larger is more prior The priority of this request, larger is more prior
timeout: float, optional timeout: float, optional
timeout of this session (units: seconds) The timeout of this session (units: seconds)
Returns Returns
------ ------
...@@ -63,6 +65,33 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60): ...@@ -63,6 +65,33 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
session_timeout=timeout) session_timeout=timeout)
return remote return remote
def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10):
"""
Check the availability of a remote device
Parameters
----------
target: Target
The wanted compilation target
device_key: string
device key of registered device in tracker
tracker_addr: Tuple(string, int), optional
The address of rpc tracker in (host, port) format.
If is none, will use environment variable "TVM_TRACKER_HOST"
and "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this check (units: seconds).
If time is out, a RuntimerError will be raised.
"""
def _check():
remote = request_remote(device_key, tracker_addr, priority)
remote.context(str(target))
t = threading.Thread(target=_check,)
t.start()
t.join(timeout)
return not t.is_alive()
def create_measure_batch(task, option): def create_measure_batch(task, option):
"""Get a standard measure_batch function. """Get a standard measure_batch function.
...@@ -115,6 +144,17 @@ def create_measure_batch(task, option): ...@@ -115,6 +144,17 @@ def create_measure_batch(task, option):
build_func = default_build_func build_func = default_build_func
build_kwargs['use_ndk'] = True build_kwargs['use_ndk'] = True
# check the availability of remote devices
if hasattr(measure_func, 'rpc_info'):
rpc_info = measure_func.rpc_info
if check_remote(task.target, rpc_info['key'], (rpc_info['host'], rpc_info['port'])):
logger.info("Get devices for measurement successfully!")
else:
raise RuntimeError("Cannot get remote devices from the tracker. "
"Please check the status of tracker by "
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
"and make sure you have free devices on the queue status.")
# add device info of cuda and opencl target # add device info of cuda and opencl target
if ('cuda' in task.target.keys or 'opencl' in task.target.keys) \ if ('cuda' in task.target.keys or 'opencl' in task.target.keys) \
and hasattr(measure_func, 'rpc_info'): and hasattr(measure_func, 'rpc_info'):
...@@ -313,7 +353,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat, ...@@ -313,7 +353,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
continue continue
except InstantiationError as e: except InstantiationError as e:
tstamp = time.time() tstamp = time.time()
res_pack.append(MeasureResult((e,), res_pack.append(MeasureResult((InstantiationError(str(e)),),
MeasureErrorNo.INSTANTIATION_ERROR, MeasureErrorNo.INSTANTIATION_ERROR,
tstamp - tic, tstamp)) tstamp - tic, tstamp))
continue continue
...@@ -346,7 +386,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat, ...@@ -346,7 +386,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
if ref_output: if ref_output:
for expected, real in zip(ref_output, args): for expected, real in zip(ref_output, args):
if not np.allclose(expected, real.asnumpy(), rtol=1e-4): if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
logging.warning("Wrong Answer!") logger.warning("Wrong Answer!")
errno = MeasureErrorNo.WRONG_ANSWER errno = MeasureErrorNo.WRONG_ANSWER
except TVMError as exc: except TVMError as exc:
msg = str(exc) msg = str(exc)
......
...@@ -18,6 +18,7 @@ from .task import ConfigEntity, ApplyHistoryBest ...@@ -18,6 +18,7 @@ from .task import ConfigEntity, ApplyHistoryBest
from .measure import MeasureInput, MeasureResult from .measure import MeasureInput, MeasureResult
AUTOTVM_LOG_VERSION = 0.1 AUTOTVM_LOG_VERSION = 0.1
logger = logging.getLogger('autotvm')
try: # convert unicode to str for python2 try: # convert unicode to str for python2
_unicode = unicode _unicode = unicode
...@@ -181,10 +182,10 @@ def split_workload(in_file, clean=True): ...@@ -181,10 +182,10 @@ def split_workload(in_file, clean=True):
tic = time.time() tic = time.time()
lines = list(open(in_file).readlines()) lines = list(open(in_file).readlines())
logging.info("start converting...") logger.info("start converting...")
pool = multiprocessing.Pool() pool = multiprocessing.Pool()
lines = pool.map(decode, lines) lines = pool.map(decode, lines)
logging.info("map done %.2f", time.time() - tic) logger.info("map done %.2f", time.time() - tic)
wkl_dict = OrderedDict() wkl_dict = OrderedDict()
for inp, res in lines: for inp, res in lines:
...@@ -206,13 +207,13 @@ def split_workload(in_file, clean=True): ...@@ -206,13 +207,13 @@ def split_workload(in_file, clean=True):
cleaned.append([inp, res]) cleaned.append([inp, res])
# write to file # write to file
logging.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned)) logger.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned))
with open(args.i + ".%03d.wkl" % i, 'w') as fout: with open(args.i + ".%03d.wkl" % i, 'w') as fout:
for inp, res in cleaned: for inp, res in cleaned:
fout.write(encode(inp, res) + '\n') fout.write(encode(inp, res) + '\n')
else: else:
for i, (k, v) in enumerate(wkl_dict.items()): for i, (k, v) in enumerate(wkl_dict.items()):
logging.info("Key: %s\tNum: %d", k, len(v)) logger.info("Key: %s\tNum: %d", k, len(v))
with open(args.i + ".%03d.wkl" % i, 'w') as fout: with open(args.i + ".%03d.wkl" % i, 'w') as fout:
for inp, res in v: for inp, res in v:
fout.write(encode(inp, res) + '\n') fout.write(encode(inp, res) + '\n')
...@@ -238,7 +239,7 @@ def pick_best(in_file, out_file): ...@@ -238,7 +239,7 @@ def pick_best(in_file, out_file):
for v in best_context.best_by_targetkey.values(): for v in best_context.best_by_targetkey.values():
best_set.add(measure_str_key(v[0])) best_set.add(measure_str_key(v[0]))
logging.info("Extract %d best records from the %s", len(best_set), in_file) logger.info("Extract %d best records from the %s", len(best_set), in_file)
fout = open(out_file, 'w') if isinstance(out_file, str) else out_file fout = open(out_file, 'w') if isinstance(out_file, str) else out_file
for inp, res in load_from_file(in_file): for inp, res in load_from_file(in_file):
...@@ -270,7 +271,7 @@ if __name__ == '__main__': ...@@ -270,7 +271,7 @@ if __name__ == '__main__':
parser.add_argument("--code", action='store_true') parser.add_argument("--code", action='store_true')
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO) logger.basicConfig(level=logger.INFO)
if args.mode == 'pick': if args.mode == 'pick':
args.o = args.o or args.i + ".best.log" args.o = args.o or args.i + ".best.log"
......
...@@ -10,6 +10,8 @@ of the DispatchContext base class. ...@@ -10,6 +10,8 @@ of the DispatchContext base class.
- During search, we can use it to pass the current proposal from tuner. - During search, we can use it to pass the current proposal from tuner.
- During evaluation, we can use it to set pick the best policy. - During evaluation, we can use it to set pick the best policy.
""" """
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import logging import logging
...@@ -19,6 +21,8 @@ import numpy as np ...@@ -19,6 +21,8 @@ import numpy as np
from tvm import target as _target from tvm import target as _target
logger = logging.getLogger('autotvm')
class DispatchContext(object): class DispatchContext(object):
""" """
Base class of dispatch context. Base class of dispatch context.
...@@ -216,7 +220,7 @@ class ApplyHistoryBest(DispatchContext): ...@@ -216,7 +220,7 @@ class ApplyHistoryBest(DispatchContext):
best_by_model[key] = (inp, res) best_by_model[key] = (inp, res)
break break
logging.debug("Finish loading %d records", counter) logger.debug("Finish loading %d records", counter)
def query(self, target, workload): def query(self, target, workload):
if target is None: if target is None:
......
...@@ -4,6 +4,7 @@ To get the best performance, we typically need auto-tuning for the specific devi ...@@ -4,6 +4,7 @@ To get the best performance, we typically need auto-tuning for the specific devi
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets. TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
TVM will download these parameters for you when you create the target for the first time. TVM will download these parameters for you when you create the target for the first time.
""" """
# pylint: disable=invalid-name
import logging import logging
import os import os
...@@ -16,6 +17,7 @@ from ..contrib.download import download ...@@ -16,6 +17,7 @@ from ..contrib.download import download
AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub") AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub")
logger = logging.getLogger('autotvm')
def _alias(name): def _alias(name):
"""convert alias for some packages""" """convert alias for some packages"""
...@@ -79,7 +81,7 @@ def download_package(backend): ...@@ -79,7 +81,7 @@ def download_package(backend):
os.mkdir(path) os.mkdir(path)
backend = _alias(backend) backend = _alias(backend)
logging.info("Download pre-tuned parameters for %s", backend) logger.info("Download pre-tuned parameters for %s", backend)
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/%s.log" % backend, download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/%s.log" % backend,
os.path.join(rootpath, backend + ".log"), True, verbose=0) os.path.join(rootpath, backend + ".log"), True, verbose=0)
...@@ -110,7 +112,7 @@ def list_packages(): ...@@ -110,7 +112,7 @@ def list_packages():
""" """
path = tempdir() path = tempdir()
filename = path.relpath("info.json") filename = path.relpath("info.json")
logging.info("Download meta info for pre-tuned parameters") logger.info("Download meta info for pre-tuned parameters")
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/info.json", download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/info.json",
filename, True, verbose=0) filename, True, verbose=0)
......
...@@ -2,11 +2,13 @@ ...@@ -2,11 +2,13 @@
"""Namespace of callback utilities of AutoTVM""" """Namespace of callback utilities of AutoTVM"""
import sys import sys
import time import time
import logging
import numpy as np import numpy as np
from .. import record from .. import record
logger = logging.getLogger('autotvm')
def log_to_file(file_out, protocol='json'): def log_to_file(file_out, protocol='json'):
"""Log the tuning records into file. """Log the tuning records into file.
...@@ -90,7 +92,7 @@ def progress_bar(total, prefix=''): ...@@ -90,7 +92,7 @@ def progress_bar(total, prefix=''):
prefix: str prefix: str
The prefix of output message The prefix of output message
""" """
class _Context: class _Context(object):
"""Context to store local variables""" """Context to store local variables"""
def __init__(self): def __init__(self):
self.best_flops = 0 self.best_flops = 0
...@@ -112,13 +114,14 @@ def progress_bar(total, prefix=''): ...@@ -112,13 +114,14 @@ def progress_bar(total, prefix=''):
if res.error_no == 0: if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs) flops = inp.task.flop / np.mean(res.costs)
ctx.cur_flops = flops if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
ctx.best_flops = tuner.best_flops ctx.cur_flops = flops
ctx.best_flops = tuner.best_flops
sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) ' sys.stdout.write('%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
'| %.2f s' % '| %.2f s\r' %
(prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total, (prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total,
time.time() - tic)) time.time() - tic))
sys.stdout.flush() sys.stdout.flush()
return _callback return _callback
# pylint: disable=consider-using-enumerate # pylint: disable=consider-using-enumerate, invalid-name
""" """
Cost model optimizer based on simulated annealing Cost model optimizer based on simulated annealing
""" """
...@@ -12,6 +12,8 @@ import numpy as np ...@@ -12,6 +12,8 @@ import numpy as np
from ..util import sample_ints from ..util import sample_ints
from .model_based_tuner import ModelOptimizer, knob2point, point2knob from .model_based_tuner import ModelOptimizer, knob2point, point2knob
logger = logging.getLogger('autotvm')
class SimulatedAnnealingOptimizer(ModelOptimizer): class SimulatedAnnealingOptimizer(ModelOptimizer):
"""parallel simulated annealing optimization algorithm """parallel simulated annealing optimization algorithm
...@@ -103,16 +105,16 @@ class SimulatedAnnealingOptimizer(ModelOptimizer): ...@@ -103,16 +105,16 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
if log_interval and k % log_interval == 0: if log_interval and k % log_interval == 0:
t_str = "%.2f" % t t_str = "%.2f" % t
logging.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t" logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
"elapsed: %.2f", "elapsed: %.2f",
k, k_last_modify, heap_items[0][0], k, k_last_modify, heap_items[0][0],
np.max([v for v, _ in heap_items]), t_str, np.max([v for v, _ in heap_items]), t_str,
time.time() - tic) time.time() - tic)
heap_items.sort(key=lambda item: -item[0]) heap_items.sort(key=lambda item: -item[0])
logging.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f", logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic) k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic)
logging.debug("SA Maximums: %s", heap_items) logger.debug("SA Maximums: %s", heap_items)
if self.persistent: if self.persistent:
self.points = points self.points = points
......
...@@ -4,11 +4,12 @@ import logging ...@@ -4,11 +4,12 @@ import logging
import numpy as np import numpy as np
from ..measure import MeasureInput from ..measure import MeasureInput, create_measure_batch
from ..measure import create_measure_batch
from ..env import GLOBAL_SCOPE from ..env import GLOBAL_SCOPE
logger = logging.getLogger('autotvm')
class Tuner(object): class Tuner(object):
"""Base class for tuners """Base class for tuners
...@@ -86,9 +87,10 @@ class Tuner(object): ...@@ -86,9 +87,10 @@ class Tuner(object):
measure_batch = create_measure_batch(self.task, measure_option) measure_batch = create_measure_batch(self.task, measure_option)
parallel_num = getattr(measure_batch, 'parallel_num', 1) parallel_num = getattr(measure_batch, 'parallel_num', 1)
early_stopping = early_stopping or 1e9 early_stopping = early_stopping or 1e9
old_level = logger.level
GLOBAL_SCOPE.in_tuning = True GLOBAL_SCOPE.in_tuning = True
i = 0 i = error_ct = 0
while i < n_trial: while i < n_trial:
if not self.has_next(): if not self.has_next():
break break
...@@ -103,17 +105,20 @@ class Tuner(object): ...@@ -103,17 +105,20 @@ class Tuner(object):
config = inp.config config = inp.config
if res.error_no == 0: if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs) flops = inp.task.flop / np.mean(res.costs)
error_ct = 0
else: else:
flops = 0 flops = 0
error_ct += 1
if flops > self.best_flops: if flops > self.best_flops:
self.best_flops = flops self.best_flops = flops
self.best_config = config self.best_config = config
self.best_measure_pair = (inp, res) self.best_measure_pair = (inp, res)
self.best_iter = i + k self.best_iter = i + k
logging.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s", logger.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9, i + k + 1, flops / 1e9, self.best_flops / 1e9,
res, config) res, config)
i += len(results) i += len(results)
...@@ -123,11 +128,16 @@ class Tuner(object): ...@@ -123,11 +128,16 @@ class Tuner(object):
callback(self, inputs, results) callback(self, inputs, results)
if i > self.best_iter + early_stopping: if i > self.best_iter + early_stopping:
logging.debug("Early stopped. Best iter: %d.", self.best_iter) logger.debug("Early stopped. Best iter: %d.", self.best_iter)
break break
GLOBAL_SCOPE.in_tuning = False if error_ct > 50:
logger.warning("Too many errors happen in the tuning. Now is in debug mode")
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(old_level)
GLOBAL_SCOPE.in_tuning = False
del measure_batch del measure_batch
def reset(self): def reset(self):
......
...@@ -16,6 +16,8 @@ from ..util import get_rank ...@@ -16,6 +16,8 @@ from ..util import get_rank
from .metric import max_curve, recall_curve, cover_curve from .metric import max_curve, recall_curve, cover_curve
from .model_based_tuner import CostModel, FeatureCache from .model_based_tuner import CostModel, FeatureCache
logger = logging.getLogger('autotvm')
class XGBoostCostModel(CostModel): class XGBoostCostModel(CostModel):
"""XGBoost as cost model """XGBoost as cost model
...@@ -163,17 +165,17 @@ class XGBoostCostModel(CostModel): ...@@ -163,17 +165,17 @@ class XGBoostCostModel(CostModel):
], ],
verbose_eval=self.log_interval)]) verbose_eval=self.log_interval)])
logging.debug("XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d", logger.debug("XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
time.time() - tic, len(xs), time.time() - tic, len(xs),
len(xs) - np.sum(valid_index), len(xs) - np.sum(valid_index),
self.feature_cache.size(self.fea_type)) self.feature_cache.size(self.fea_type))
def fit_log(self, records, plan_size): def fit_log(self, records, plan_size):
tic = time.time() tic = time.time()
self._reset_pool() self._reset_pool()
args = list(records) args = list(records)
logging.debug("XGB load %d entries from history log file", len(args)) logger.debug("XGB load %d entries from history log file", len(args))
if self.fea_type == 'itervar': if self.fea_type == 'itervar':
feature_extract_func = _extract_itervar_feature_log feature_extract_func = _extract_itervar_feature_log
...@@ -208,7 +210,7 @@ class XGBoostCostModel(CostModel): ...@@ -208,7 +210,7 @@ class XGBoostCostModel(CostModel):
], ],
verbose_eval=self.log_interval)]) verbose_eval=self.log_interval)])
logging.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs)) logger.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs))
def predict(self, xs, output_margin=False): def predict(self, xs, output_margin=False):
feas = self._get_feature(xs) feas = self._get_feature(xs)
...@@ -403,7 +405,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None, ...@@ -403,7 +405,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
infos.append("%s: %.6f" % (item[0], item[1])) infos.append("%s: %.6f" % (item[0], item[1]))
if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0: if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0:
logging.debug("\t".join(infos)) logger.debug("\t".join(infos))
if log_file: if log_file:
with open(log_file, "a") as fout: with open(log_file, "a") as fout:
fout.write("\t".join(infos) + '\n') fout.write("\t".join(infos) + '\n')
...@@ -435,7 +437,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None, ...@@ -435,7 +437,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
elif env.iteration - best_iteration >= stopping_rounds: elif env.iteration - best_iteration >= stopping_rounds:
best_msg = state['best_msg'] best_msg = state['best_msg']
if verbose_eval and env.rank == 0: if verbose_eval and env.rank == 0:
logging.debug("XGB stopped. Best iteration: %s ", best_msg) logger.debug("XGB stopped. Best iteration: %s ", best_msg)
raise EarlyStopException(best_iteration) raise EarlyStopException(best_iteration)
return callback return callback
......
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
from .. import expr, ir_pass from .. import expr, ir_pass
logger = logging.getLogger('autotvm')
class EmptyContext(object): class EmptyContext(object):
"""An empty context""" """An empty context"""
...@@ -92,15 +93,15 @@ def pool_map(func, args, batch_size, verbose=False, pool=None): ...@@ -92,15 +93,15 @@ def pool_map(func, args, batch_size, verbose=False, pool=None):
tic = time.time() tic = time.time()
local_pool = pool or multiprocessing.Pool() local_pool = pool or multiprocessing.Pool()
if verbose: if verbose:
logging.info("mapping begin") logger.info("mapping begin")
for i in range(0, len(args), batch_size): for i in range(0, len(args), batch_size):
if verbose: if verbose:
logging.info("mapping %d/%d elapsed %.2f", i, len(args), logger.info("mapping %d/%d elapsed %.2f", i, len(args),
time.time() - tic) time.time() - tic)
tmp = np.array(local_pool.map(func, args[i:i+batch_size])) tmp = np.array(local_pool.map(func, args[i:i+batch_size]))
ret = tmp if ret is None else np.concatenate((ret, tmp)) ret = tmp if ret is None else np.concatenate((ret, tmp))
if verbose: if verbose:
logging.info("mapping done") logger.info("mapping done")
if not pool: if not pool:
local_pool.close() local_pool.close()
return ret return ret
......
"""Base definitions for RPC.""" """Base definitions for RPC."""
# pylint: disable=invalid-name
from __future__ import absolute_import from __future__ import absolute_import
import socket import socket
...@@ -23,6 +25,7 @@ RPC_CODE_DUPLICATE = RPC_MAGIC + 1 ...@@ -23,6 +25,7 @@ RPC_CODE_DUPLICATE = RPC_MAGIC + 1
# cannot found matched key in server # cannot found matched key in server
RPC_CODE_MISMATCH = RPC_MAGIC + 2 RPC_CODE_MISMATCH = RPC_MAGIC + 2
logger = logging.getLogger('RPCServer')
class TrackerCode(object): class TrackerCode(object):
"""Enumeration code for the RPC tracker""" """Enumeration code for the RPC tracker"""
...@@ -120,7 +123,7 @@ def random_key(prefix, cmap=None): ...@@ -120,7 +123,7 @@ def random_key(prefix, cmap=None):
return prefix + str(random.random()) return prefix + str(random.random())
def connect_with_retry(addr, timeout=60, retry_period=5, silent=False): def connect_with_retry(addr, timeout=60, retry_period=5):
"""Connect to a TPC address with retry """Connect to a TPC address with retry
This function is only reliable to short period of server restart. This function is only reliable to short period of server restart.
...@@ -135,9 +138,6 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False): ...@@ -135,9 +138,6 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
retry_period : float retry_period : float
Number of seconds before we retry again. Number of seconds before we retry again.
silent: bool
whether run in silent mode
""" """
tstart = time.time() tstart = time.time()
while True: while True:
...@@ -152,9 +152,8 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False): ...@@ -152,9 +152,8 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
if period > timeout: if period > timeout:
raise RuntimeError( raise RuntimeError(
"Failed to connect to server %s" % str(addr)) "Failed to connect to server %s" % str(addr))
if not silent: logger.warning("Cannot connect to tracker %s, retry in %g secs...",
logging.info("Cannot connect to tracker%s, retry in %g secs...", str(addr), retry_period)
str(addr), retry_period)
time.sleep(retry_period) time.sleep(retry_period)
......
...@@ -23,7 +23,8 @@ try: ...@@ -23,7 +23,8 @@ try:
from tornado import ioloop from tornado import ioloop
from . import tornado_util from . import tornado_util
except ImportError as error_msg: except ImportError as error_msg:
raise ImportError("RPCProxy module requires tornado package %s" % error_msg) raise ImportError(
"RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg)
from . import base from . import base
from .base import TrackerCode from .base import TrackerCode
...@@ -540,7 +541,7 @@ def websocket_proxy_server(url, key=""): ...@@ -540,7 +541,7 @@ def websocket_proxy_server(url, key=""):
def _connect(key): def _connect(key):
conn = yield websocket.websocket_connect(url) conn = yield websocket.websocket_connect(url)
on_message = create_on_message(conn) on_message = create_on_message(conn)
temp = _server_env(None, None) temp = _server_env(None)
# Start connecton # Start connecton
conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True) conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True)
key = "server:" + key key = "server:" + key
......
...@@ -8,6 +8,8 @@ Server is TCP based with the following protocol: ...@@ -8,6 +8,8 @@ Server is TCP based with the following protocol:
- The key is in format - The key is in format
- {server|client}:device-type[:random-key] [-timeout=timeout] - {server|client}:device-type[:random-key] [-timeout=timeout]
""" """
# pylint: disable=invalid-name
from __future__ import absolute_import from __future__ import absolute_import
import os import os
...@@ -30,11 +32,11 @@ from ..contrib import util ...@@ -30,11 +32,11 @@ from ..contrib import util
from . import base from . import base
from . base import TrackerCode from . base import TrackerCode
def _server_env(load_library, logger): logger = logging.getLogger('RPCServer')
def _server_env(load_library):
"""Server environment function return temp dir""" """Server environment function return temp dir"""
temp = util.tempdir() temp = util.tempdir()
if logger is None:
logger = logging.getLogger()
# pylint: disable=unused-variable # pylint: disable=unused-variable
@register_func("tvm.rpc.server.workpath") @register_func("tvm.rpc.server.workpath")
...@@ -59,13 +61,10 @@ def _server_env(load_library, logger): ...@@ -59,13 +61,10 @@ def _server_env(load_library, logger):
return temp return temp
def _serve_loop(sock, addr, load_library, silent): def _serve_loop(sock, addr, load_library):
"""Server loop""" """Server loop"""
logger = logging.getLogger("RPCServer")
if silent:
logger.disabled = True
sockfd = sock.fileno() sockfd = sock.fileno()
temp = _server_env(load_library, logger) temp = _server_env(load_library)
base._ServerLoop(sockfd) base._ServerLoop(sockfd)
temp.remove() temp.remove()
logger.info("Finish serving %s", addr) logger.info("Finish serving %s", addr)
...@@ -79,12 +78,8 @@ def _parse_server_opt(opts): ...@@ -79,12 +78,8 @@ def _parse_server_opt(opts):
ret["timeout"] = float(kv[9:]) ret["timeout"] = float(kv[9:])
return ret return ret
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, silent): def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
"""Listening loop of the server master.""" """Listening loop of the server master."""
logger = logging.getLogger("RPCServer")
if silent:
logger.disabled = True
def _accept_conn(listen_sock, tracker_conn, ping_period=2): def _accept_conn(listen_sock, tracker_conn, ping_period=2):
"""Accept connection from the other places. """Accept connection from the other places.
...@@ -148,7 +143,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s ...@@ -148,7 +143,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
if arr[0] != expect_header: if arr[0] != expect_header:
conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH)) conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
conn.close() conn.close()
logger.info("mismatch key from %s", addr) logger.warning("mismatch key from %s", addr)
continue continue
else: else:
conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS)) conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
...@@ -162,7 +157,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s ...@@ -162,7 +157,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
try: try:
# step 1: setup tracker and report to tracker # step 1: setup tracker and report to tracker
if tracker_addr and tracker_conn is None: if tracker_addr and tracker_conn is None:
tracker_conn = base.connect_with_retry(tracker_addr, silent=silent) tracker_conn = base.connect_with_retry(tracker_addr)
tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC)) tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0] magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC: if magic != base.RPC_TRACKER_MAGIC:
...@@ -182,15 +177,12 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s ...@@ -182,15 +177,12 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
tracker_conn = None tracker_conn = None
continue continue
except RuntimeError as exc: except RuntimeError as exc:
if silent: raise exc
return
else:
raise exc
# step 3: serving # step 3: serving
logger.info("connection from %s", addr) logger.info("connection from %s", addr)
server_proc = multiprocessing.Process(target=_serve_loop, server_proc = multiprocessing.Process(target=_serve_loop,
args=(conn, addr, load_library, silent)) args=(conn, addr, load_library))
server_proc.deamon = True server_proc.deamon = True
server_proc.start() server_proc.start()
# close from our side. # close from our side.
...@@ -202,10 +194,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s ...@@ -202,10 +194,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
server_proc.terminate() server_proc.terminate()
def _connect_proxy_loop(addr, key, load_library, silent): def _connect_proxy_loop(addr, key, load_library):
logger = logging.getLogger("RPCProxy")
if silent:
logger.disabled = True
key = "server:" + key key = "server:" + key
retry_count = 0 retry_count = 0
max_retry = 5 max_retry = 5
...@@ -221,7 +210,7 @@ def _connect_proxy_loop(addr, key, load_library, silent): ...@@ -221,7 +210,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
if magic == base.RPC_CODE_DUPLICATE: if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key) raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH: elif magic == base.RPC_CODE_MISMATCH:
logger.info("RPCProxy do not have matching client key %s", key) logger.warning("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS: elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % str(addr)) raise RuntimeError("%s is not RPC Proxy" % str(addr))
keylen = struct.unpack("<i", base.recvall(sock, 4))[0] keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
...@@ -229,7 +218,7 @@ def _connect_proxy_loop(addr, key, load_library, silent): ...@@ -229,7 +218,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
opts = _parse_server_opt(remote_key.split()[1:]) opts = _parse_server_opt(remote_key.split()[1:])
logger.info("connected to %s", str(addr)) logger.info("connected to %s", str(addr))
process = multiprocessing.Process( process = multiprocessing.Process(
target=_serve_loop, args=(sock, addr, load_library, silent)) target=_serve_loop, args=(sock, addr, load_library))
process.deamon = True process.deamon = True
process.start() process.start()
sock.close() sock.close()
...@@ -240,7 +229,7 @@ def _connect_proxy_loop(addr, key, load_library, silent): ...@@ -240,7 +229,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
retry_count = 0 retry_count = 0
except (socket.error, IOError) as err: except (socket.error, IOError) as err:
retry_count += 1 retry_count += 1
logger.info("Error encountered %s, retry in %g sec", str(err), retry_period) logger.warning("Error encountered %s, retry in %g sec", str(err), retry_period)
if retry_count > max_retry: if retry_count > max_retry:
raise RuntimeError("Maximum retry error: last error: %s" % str(err)) raise RuntimeError("Maximum retry error: last error: %s" % str(err))
time.sleep(retry_period) time.sleep(retry_period)
...@@ -323,9 +312,8 @@ class Server(object): ...@@ -323,9 +312,8 @@ class Server(object):
self.custom_addr = custom_addr self.custom_addr = custom_addr
self.use_popen = use_popen self.use_popen = use_popen
self.logger = logging.getLogger("RPCServer")
if silent: if silent:
self.logger.disabled = True logger.setLevel(logging.WARN)
if use_popen: if use_popen:
cmd = [sys.executable, cmd = [sys.executable,
...@@ -360,18 +348,18 @@ class Server(object): ...@@ -360,18 +348,18 @@ class Server(object):
raise sock_err raise sock_err
if not self.port: if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
self.logger.info("bind to %s:%d", host, self.port) logger.info("bind to %s:%d", host, self.port)
sock.listen(1) sock.listen(1)
self.sock = sock self.sock = sock
self.proc = multiprocessing.Process( self.proc = multiprocessing.Process(
target=_listen_loop, args=( target=_listen_loop, args=(
self.sock, self.port, key, tracker_addr, load_library, self.sock, self.port, key, tracker_addr, load_library,
self.custom_addr, silent)) self.custom_addr))
self.proc.deamon = True self.proc.deamon = True
self.proc.start() self.proc.start()
else: else:
self.proc = multiprocessing.Process( self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key, load_library, silent)) target=_connect_proxy_loop, args=((host, port), key, load_library))
self.proc.deamon = True self.proc.deamon = True
self.proc.start() self.proc.start()
......
...@@ -23,6 +23,8 @@ List of available APIs: ...@@ -23,6 +23,8 @@ List of available APIs:
- input: [TrackerCode.REQUEST, [key, user, priority]] - input: [TrackerCode.REQUEST, [key, user, priority]]
- return: [TrackerCode.SUCCESS, [url, port, match-key]] - return: [TrackerCode.SUCCESS, [url, port, match-key]]
""" """
# pylint: disable=invalid-name
import heapq import heapq
import time import time
import logging import logging
...@@ -37,12 +39,13 @@ try: ...@@ -37,12 +39,13 @@ try:
from . import tornado_util from . import tornado_util
except ImportError as error_msg: except ImportError as error_msg:
raise ImportError( raise ImportError(
"RPCTracker module requires tornado package %s" % error_msg) "RPCTracker module requires tornado package %s. Try 'pip install tornado'." % error_msg)
from .._ffi.base import py_str from .._ffi.base import py_str
from . import base from . import base
from .base import RPC_TRACKER_MAGIC, TrackerCode from .base import RPC_TRACKER_MAGIC, TrackerCode
logger = logging.getLogger("RPCTracker")
class Scheduler(object): class Scheduler(object):
"""Abstratc interface of scheduler.""" """Abstratc interface of scheduler."""
...@@ -141,11 +144,11 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -141,11 +144,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
def _init_conn(self, message): def _init_conn(self, message):
"""Initialie the connection""" """Initialie the connection"""
if len(message) != 4: if len(message) != 4:
logging.info("Invalid connection from %s", self.name()) logger.warning("Invalid connection from %s", self.name())
self.close() self.close()
magic = struct.unpack('<i', message)[0] magic = struct.unpack('<i', message)[0]
if magic != RPC_TRACKER_MAGIC: if magic != RPC_TRACKER_MAGIC:
logging.info("Invalid magic from %s", self.name()) logger.warning("Invalid magic from %s", self.name())
self.close() self.close()
self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True) self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True)
self._init_req_nbytes = 0 self._init_req_nbytes = 0
...@@ -232,14 +235,14 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -232,14 +235,14 @@ class TCPEventHandler(tornado_util.TCPHandler):
status = self._tracker.summary() status = self._tracker.summary()
self.ret_value([TrackerCode.SUCCESS, status]) self.ret_value([TrackerCode.SUCCESS, status])
else: else:
logging.info("Unknown tracker code %d", code) logger.warning("Unknown tracker code %d", code)
self.close() self.close()
def on_close(self): def on_close(self):
self._tracker._connections.remove(self) self._tracker._connections.remove(self)
def on_error(self, err): def on_error(self, err):
logging.info("%s: Error in RPC Tracker: %s", self.name(), err) logger.warning("%s: Error in RPC Tracker: %s", self.name(), err)
self.close() self.close()
...@@ -335,9 +338,8 @@ class Tracker(object): ...@@ -335,9 +338,8 @@ class Tracker(object):
port=9190, port=9190,
port_end=9199, port_end=9199,
silent=False): silent=False):
self.logger = logging.getLogger("RPCTracker")
if silent: if silent:
self.logger.disabled = True logger.setLevel(logging.WARN)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None self.port = None
...@@ -354,7 +356,7 @@ class Tracker(object): ...@@ -354,7 +356,7 @@ class Tracker(object):
raise sock_err raise sock_err
if not self.port: if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
self.logger.info("bind to %s:%d", host, self.port) logger.info("bind to %s:%d", host, self.port)
sock.listen(1) sock.listen(1)
self.proc = multiprocessing.Process( self.proc = multiprocessing.Process(
target=_tracker_server, args=(sock, self.stop_key)) target=_tracker_server, args=(sock, self.stop_key))
...@@ -380,7 +382,7 @@ class Tracker(object): ...@@ -380,7 +382,7 @@ class Tracker(object):
self._stop_tracker() self._stop_tracker()
self.proc.join(1) self.proc.join(1)
if self.proc.is_alive(): if self.proc.is_alive():
self.logger.info("Terminating Tracker Server...") logger.info("Terminating Tracker Server...")
self.proc.terminate() self.proc.terminate()
self.proc = None self.proc = None
......
...@@ -154,7 +154,8 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding): ...@@ -154,7 +154,8 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
# for this template # for this template
# logging config (for printing tuning log to screen) # logging config (for printing tuning log to screen)
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
# the last layer in resnet # the last layer in resnet
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1) N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
......
...@@ -163,8 +163,10 @@ def get_network(name, batch_size): ...@@ -163,8 +163,10 @@ def get_network(name, batch_size):
# Set Tuning Options # Set Tuning Options
# ------------------ # ------------------
# Before tuning, we should do some configurations. Here I use an RK3399 board # Before tuning, we should do some configurations. Here I use an RK3399 board
# in our environment as example. In your setting, you should modify the target # as example. In your setting, you should modify the target and device_key accordingly.
# and device_key accordingly. # set :code:`use_android` to True if you use android phone.
#### DEVICE CONFIG ####
# Replace "aarch64-linux-gnu" with the correct target of your board. # Replace "aarch64-linux-gnu" with the correct target of your board.
# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device. # This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
...@@ -173,7 +175,10 @@ target = tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu') ...@@ -173,7 +175,10 @@ target = tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu')
# Also replace this with the device key in your tracker # Also replace this with the device key in your tracker
device_key = 'rk3399' device_key = 'rk3399'
# tuning option # Set this to True if you use android phone
use_android = False
#### TUNING OPTION ####
network = 'resnet-18' network = 'resnet-18'
log_file = "%s.%s.log" % (device_key, network) log_file = "%s.%s.log" % (device_key, network)
dtype = 'float32' dtype = 'float32'
...@@ -181,17 +186,17 @@ dtype = 'float32' ...@@ -181,17 +186,17 @@ dtype = 'float32'
tuning_option = { tuning_option = {
'log_filename': log_file, 'log_filename': log_file,
'tuner':'xgb', 'tuner': 'xgb',
'n_trial': 1000, 'n_trial': 1000,
'early_stopping': 200, 'early_stopping': 250,
'measure_option': autotvm.measure_option( 'measure_option': autotvm.measure_option(
autotvm.use_rpc(device_key, host='localhost', port=9190), autotvm.use_rpc(device_key, host='localhost', port=9190),
number=4, number=4,
parallel_num=1, parallel_num=1,
timeout=10), timeout=10,
build_func='ndk' if use_android else 'default',
'use_transfer_learning': True, ),
} }
#################################################################### ####################################################################
...@@ -208,9 +213,6 @@ tuning_option = { ...@@ -208,9 +213,6 @@ tuning_option = {
# If your device is very slow or a single conv2d operator in your network has large FLOPs, # If your device is very slow or a single conv2d operator in your network has large FLOPs,
# consider setting timeout larger. # consider setting timeout larger.
# #
# **For android phone**, add :code:`build_func='ndk'` to the argument list of
# :code:`autotvm.measure_option` to use Android NDK for creating shared library.
#
################################################################### ###################################################################
# Begin Tuning # Begin Tuning
...@@ -280,12 +282,14 @@ def tune_tasks(tasks, ...@@ -280,12 +282,14 @@ def tune_tasks(tasks,
def tune_and_evaluate(): def tune_and_evaluate():
# extract workloads from nnvm graph # extract workloads from nnvm graph
print("Extract tasks...")
net, params, shape, out_shape = get_network(network, batch_size=1) net, params, shape, out_shape = get_network(network, batch_size=1)
tasks = autotvm.task.extract_from_graph(net, shape=shape, dtype=dtype, tasks = autotvm.task.extract_from_graph(net, shape=shape, dtype=dtype,
symbols=(nnvm.sym.conv2d,), symbols=(nnvm.sym.conv2d,),
target=target) target=target)
# run tuning tasks # run tuning tasks
print("Tuning...")
tune_tasks(tasks, **tuning_option) tune_tasks(tasks, **tuning_option)
# compile kernels with history best records # compile kernels with history best records
...@@ -325,10 +329,11 @@ def tune_and_evaluate(): ...@@ -325,10 +329,11 @@ def tune_and_evaluate():
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10) ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" % print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res))) (np.mean(prof_res), np.std(prof_res)))
# We do not run the tuning in our webpage server since it takes too long. # We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run by yourself. # Uncomment the following line to run by yourself.
# tune_and_evaluate() # tune_and_evaluate()
###################################################################### ######################################################################
...@@ -341,6 +346,8 @@ def tune_and_evaluate(): ...@@ -341,6 +346,8 @@ def tune_and_evaluate():
# #
# .. code-block:: bash # .. code-block:: bash
# #
# Extract tasks...
# Tuning...
# [Task 1/16] Current/Best: 13.15/ 20.49 GFLOPS | Progress: (297/1000) | 348.51 s Done. # [Task 1/16] Current/Best: 13.15/ 20.49 GFLOPS | Progress: (297/1000) | 348.51 s Done.
# [Task 2/16] Current/Best: 16.66/ 22.64 GFLOPS | Progress: (475/1000) | 415.42 s Done. # [Task 2/16] Current/Best: 16.66/ 22.64 GFLOPS | Progress: (475/1000) | 415.42 s Done.
# [Task 3/16] Current/Best: 10.33/ 14.19 GFLOPS | Progress: (306/1000) | 239.61 s Done. # [Task 3/16] Current/Best: 10.33/ 14.19 GFLOPS | Progress: (306/1000) | 239.61 s Done.
...@@ -362,3 +369,23 @@ def tune_and_evaluate(): ...@@ -362,3 +369,23 @@ def tune_and_evaluate():
# Evaluate inference time cost... # Evaluate inference time cost...
# Mean inference time (std dev): 156.51 ms (0.89 ms) # Mean inference time (std dev): 156.51 ms (0.89 ms)
# #
######################################################################
#
# .. note:: **Meet some problems?**
#
# The auto tuning module is error prone. If you always see " 0.00/ 0.00 GFLOPS",
# then there must be something wrong.
#
# First, make sure you set the correct configuration of your device.
# Then, you can print debug information by adding these lines in the beginning
# of the script. It will print every measurement result, where you can find useful
# error messages.
#
# .. code-block:: python
#
# import logging
# logging.getLogger('autotvm').setLevel(logging.DEBUG)
#
# Finally, always feel free to ask our community for help on https://discuss.tvm.ai
...@@ -267,8 +267,9 @@ print(task.config_space) ...@@ -267,8 +267,9 @@ print(task.config_space)
# We will log the tuning results into a log file. This file can be # We will log the tuning results into a log file. This file can be
# used to get the best config later. # used to get the best config later.
# logging config (for printing tuning log to screen) # logging config (for printing tuning log to the screen)
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
# use local cpu, measure 5 times for every config to reduce variance # use local cpu, measure 5 times for every config to reduce variance
measure_option = autotvm.measure_option('local', measure_option = autotvm.measure_option('local',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment