Commit 136061dc by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Improve tutorial and logging (#1544)

parent 33606741
"""Distributed executor infrastructure to scale up the tuning"""
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option
from .measure_methods import request_remote, create_measure_batch, use_rpc
from .measure_methods import request_remote, check_remote, create_measure_batch, use_rpc
from .local_executor import LocalExecutor
from .executor import Future, Executor
......@@ -9,6 +9,7 @@ import logging
import os
import time
from random import getrandbits
import threading
import numpy as np
......@@ -23,6 +24,7 @@ from ..task.space import InstantiationError
from .measure import MeasureResult, MeasureErrorNo
from .local_executor import LocalExecutor
logger = logging.getLogger('autotvm')
class HashMismatchError(ValueError):
"""Raised when the code hash of a submitted config doesn't match that on the
......@@ -42,9 +44,9 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
If is none, will use environment variable "TVM_TRACKER_HOST"
and "TVM_TRACKER_PORT"
priority: int, optional
priority of this request, larger is more prior
The priority of this request, larger is more prior
timeout: float, optional
timeout of this session (units: seconds)
The timeout of this session (units: seconds)
Returns
------
......@@ -63,6 +65,33 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
session_timeout=timeout)
return remote
def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10):
"""
Check the availability of a remote device
Parameters
----------
target: Target
The wanted compilation target
device_key: string
device key of registered device in tracker
tracker_addr: Tuple(string, int), optional
The address of rpc tracker in (host, port) format.
If is none, will use environment variable "TVM_TRACKER_HOST"
and "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this check (units: seconds).
If time is out, a RuntimerError will be raised.
"""
def _check():
remote = request_remote(device_key, tracker_addr, priority)
remote.context(str(target))
t = threading.Thread(target=_check,)
t.start()
t.join(timeout)
return not t.is_alive()
def create_measure_batch(task, option):
"""Get a standard measure_batch function.
......@@ -115,6 +144,17 @@ def create_measure_batch(task, option):
build_func = default_build_func
build_kwargs['use_ndk'] = True
# check the availability of remote devices
if hasattr(measure_func, 'rpc_info'):
rpc_info = measure_func.rpc_info
if check_remote(task.target, rpc_info['key'], (rpc_info['host'], rpc_info['port'])):
logger.info("Get devices for measurement successfully!")
else:
raise RuntimeError("Cannot get remote devices from the tracker. "
"Please check the status of tracker by "
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
"and make sure you have free devices on the queue status.")
# add device info of cuda and opencl target
if ('cuda' in task.target.keys or 'opencl' in task.target.keys) \
and hasattr(measure_func, 'rpc_info'):
......@@ -313,7 +353,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
continue
except InstantiationError as e:
tstamp = time.time()
res_pack.append(MeasureResult((e,),
res_pack.append(MeasureResult((InstantiationError(str(e)),),
MeasureErrorNo.INSTANTIATION_ERROR,
tstamp - tic, tstamp))
continue
......@@ -346,7 +386,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
if ref_output:
for expected, real in zip(ref_output, args):
if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
logging.warning("Wrong Answer!")
logger.warning("Wrong Answer!")
errno = MeasureErrorNo.WRONG_ANSWER
except TVMError as exc:
msg = str(exc)
......
......@@ -18,6 +18,7 @@ from .task import ConfigEntity, ApplyHistoryBest
from .measure import MeasureInput, MeasureResult
AUTOTVM_LOG_VERSION = 0.1
logger = logging.getLogger('autotvm')
try: # convert unicode to str for python2
_unicode = unicode
......@@ -181,10 +182,10 @@ def split_workload(in_file, clean=True):
tic = time.time()
lines = list(open(in_file).readlines())
logging.info("start converting...")
logger.info("start converting...")
pool = multiprocessing.Pool()
lines = pool.map(decode, lines)
logging.info("map done %.2f", time.time() - tic)
logger.info("map done %.2f", time.time() - tic)
wkl_dict = OrderedDict()
for inp, res in lines:
......@@ -206,13 +207,13 @@ def split_workload(in_file, clean=True):
cleaned.append([inp, res])
# write to file
logging.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned))
logger.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned))
with open(args.i + ".%03d.wkl" % i, 'w') as fout:
for inp, res in cleaned:
fout.write(encode(inp, res) + '\n')
else:
for i, (k, v) in enumerate(wkl_dict.items()):
logging.info("Key: %s\tNum: %d", k, len(v))
logger.info("Key: %s\tNum: %d", k, len(v))
with open(args.i + ".%03d.wkl" % i, 'w') as fout:
for inp, res in v:
fout.write(encode(inp, res) + '\n')
......@@ -238,7 +239,7 @@ def pick_best(in_file, out_file):
for v in best_context.best_by_targetkey.values():
best_set.add(measure_str_key(v[0]))
logging.info("Extract %d best records from the %s", len(best_set), in_file)
logger.info("Extract %d best records from the %s", len(best_set), in_file)
fout = open(out_file, 'w') if isinstance(out_file, str) else out_file
for inp, res in load_from_file(in_file):
......@@ -270,7 +271,7 @@ if __name__ == '__main__':
parser.add_argument("--code", action='store_true')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logger.basicConfig(level=logger.INFO)
if args.mode == 'pick':
args.o = args.o or args.i + ".best.log"
......
......@@ -10,6 +10,8 @@ of the DispatchContext base class.
- During search, we can use it to pass the current proposal from tuner.
- During evaluation, we can use it to set pick the best policy.
"""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import logging
......@@ -19,6 +21,8 @@ import numpy as np
from tvm import target as _target
logger = logging.getLogger('autotvm')
class DispatchContext(object):
"""
Base class of dispatch context.
......@@ -216,7 +220,7 @@ class ApplyHistoryBest(DispatchContext):
best_by_model[key] = (inp, res)
break
logging.debug("Finish loading %d records", counter)
logger.debug("Finish loading %d records", counter)
def query(self, target, workload):
if target is None:
......
......@@ -4,6 +4,7 @@ To get the best performance, we typically need auto-tuning for the specific devi
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
TVM will download these parameters for you when you create the target for the first time.
"""
# pylint: disable=invalid-name
import logging
import os
......@@ -16,6 +17,7 @@ from ..contrib.download import download
AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub")
logger = logging.getLogger('autotvm')
def _alias(name):
"""convert alias for some packages"""
......@@ -79,7 +81,7 @@ def download_package(backend):
os.mkdir(path)
backend = _alias(backend)
logging.info("Download pre-tuned parameters for %s", backend)
logger.info("Download pre-tuned parameters for %s", backend)
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/%s.log" % backend,
os.path.join(rootpath, backend + ".log"), True, verbose=0)
......@@ -110,7 +112,7 @@ def list_packages():
"""
path = tempdir()
filename = path.relpath("info.json")
logging.info("Download meta info for pre-tuned parameters")
logger.info("Download meta info for pre-tuned parameters")
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/info.json",
filename, True, verbose=0)
......
......@@ -2,11 +2,13 @@
"""Namespace of callback utilities of AutoTVM"""
import sys
import time
import logging
import numpy as np
from .. import record
logger = logging.getLogger('autotvm')
def log_to_file(file_out, protocol='json'):
"""Log the tuning records into file.
......@@ -90,7 +92,7 @@ def progress_bar(total, prefix=''):
prefix: str
The prefix of output message
"""
class _Context:
class _Context(object):
"""Context to store local variables"""
def __init__(self):
self.best_flops = 0
......@@ -112,11 +114,12 @@ def progress_bar(total, prefix=''):
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)
if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
ctx.cur_flops = flops
ctx.best_flops = tuner.best_flops
sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
'| %.2f s' %
sys.stdout.write('%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
'| %.2f s\r' %
(prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total,
time.time() - tic))
sys.stdout.flush()
......
# pylint: disable=consider-using-enumerate
# pylint: disable=consider-using-enumerate, invalid-name
"""
Cost model optimizer based on simulated annealing
"""
......@@ -12,6 +12,8 @@ import numpy as np
from ..util import sample_ints
from .model_based_tuner import ModelOptimizer, knob2point, point2knob
logger = logging.getLogger('autotvm')
class SimulatedAnnealingOptimizer(ModelOptimizer):
"""parallel simulated annealing optimization algorithm
......@@ -103,16 +105,16 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
if log_interval and k % log_interval == 0:
t_str = "%.2f" % t
logging.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
"elapsed: %.2f",
k, k_last_modify, heap_items[0][0],
np.max([v for v, _ in heap_items]), t_str,
time.time() - tic)
heap_items.sort(key=lambda item: -item[0])
logging.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic)
logging.debug("SA Maximums: %s", heap_items)
logger.debug("SA Maximums: %s", heap_items)
if self.persistent:
self.points = points
......
......@@ -4,11 +4,12 @@ import logging
import numpy as np
from ..measure import MeasureInput
from ..measure import create_measure_batch
from ..measure import MeasureInput, create_measure_batch
from ..env import GLOBAL_SCOPE
logger = logging.getLogger('autotvm')
class Tuner(object):
"""Base class for tuners
......@@ -86,9 +87,10 @@ class Tuner(object):
measure_batch = create_measure_batch(self.task, measure_option)
parallel_num = getattr(measure_batch, 'parallel_num', 1)
early_stopping = early_stopping or 1e9
old_level = logger.level
GLOBAL_SCOPE.in_tuning = True
i = 0
i = error_ct = 0
while i < n_trial:
if not self.has_next():
break
......@@ -103,15 +105,18 @@ class Tuner(object):
config = inp.config
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)
error_ct = 0
else:
flops = 0
error_ct += 1
if flops > self.best_flops:
self.best_flops = flops
self.best_config = config
self.best_measure_pair = (inp, res)
self.best_iter = i + k
logging.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
logger.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9,
res, config)
......@@ -123,11 +128,16 @@ class Tuner(object):
callback(self, inputs, results)
if i > self.best_iter + early_stopping:
logging.debug("Early stopped. Best iter: %d.", self.best_iter)
logger.debug("Early stopped. Best iter: %d.", self.best_iter)
break
GLOBAL_SCOPE.in_tuning = False
if error_ct > 50:
logger.warning("Too many errors happen in the tuning. Now is in debug mode")
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(old_level)
GLOBAL_SCOPE.in_tuning = False
del measure_batch
def reset(self):
......
......@@ -16,6 +16,8 @@ from ..util import get_rank
from .metric import max_curve, recall_curve, cover_curve
from .model_based_tuner import CostModel, FeatureCache
logger = logging.getLogger('autotvm')
class XGBoostCostModel(CostModel):
"""XGBoost as cost model
......@@ -163,7 +165,7 @@ class XGBoostCostModel(CostModel):
],
verbose_eval=self.log_interval)])
logging.debug("XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
logger.debug("XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
time.time() - tic, len(xs),
len(xs) - np.sum(valid_index),
self.feature_cache.size(self.fea_type))
......@@ -173,7 +175,7 @@ class XGBoostCostModel(CostModel):
self._reset_pool()
args = list(records)
logging.debug("XGB load %d entries from history log file", len(args))
logger.debug("XGB load %d entries from history log file", len(args))
if self.fea_type == 'itervar':
feature_extract_func = _extract_itervar_feature_log
......@@ -208,7 +210,7 @@ class XGBoostCostModel(CostModel):
],
verbose_eval=self.log_interval)])
logging.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs))
logger.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs))
def predict(self, xs, output_margin=False):
feas = self._get_feature(xs)
......@@ -403,7 +405,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
infos.append("%s: %.6f" % (item[0], item[1]))
if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0:
logging.debug("\t".join(infos))
logger.debug("\t".join(infos))
if log_file:
with open(log_file, "a") as fout:
fout.write("\t".join(infos) + '\n')
......@@ -435,7 +437,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
elif env.iteration - best_iteration >= stopping_rounds:
best_msg = state['best_msg']
if verbose_eval and env.rank == 0:
logging.debug("XGB stopped. Best iteration: %s ", best_msg)
logger.debug("XGB stopped. Best iteration: %s ", best_msg)
raise EarlyStopException(best_iteration)
return callback
......
......@@ -8,6 +8,7 @@ import numpy as np
from .. import expr, ir_pass
logger = logging.getLogger('autotvm')
class EmptyContext(object):
"""An empty context"""
......@@ -92,15 +93,15 @@ def pool_map(func, args, batch_size, verbose=False, pool=None):
tic = time.time()
local_pool = pool or multiprocessing.Pool()
if verbose:
logging.info("mapping begin")
logger.info("mapping begin")
for i in range(0, len(args), batch_size):
if verbose:
logging.info("mapping %d/%d elapsed %.2f", i, len(args),
logger.info("mapping %d/%d elapsed %.2f", i, len(args),
time.time() - tic)
tmp = np.array(local_pool.map(func, args[i:i+batch_size]))
ret = tmp if ret is None else np.concatenate((ret, tmp))
if verbose:
logging.info("mapping done")
logger.info("mapping done")
if not pool:
local_pool.close()
return ret
......
"""Base definitions for RPC."""
# pylint: disable=invalid-name
from __future__ import absolute_import
import socket
......@@ -23,6 +25,7 @@ RPC_CODE_DUPLICATE = RPC_MAGIC + 1
# cannot found matched key in server
RPC_CODE_MISMATCH = RPC_MAGIC + 2
logger = logging.getLogger('RPCServer')
class TrackerCode(object):
"""Enumeration code for the RPC tracker"""
......@@ -120,7 +123,7 @@ def random_key(prefix, cmap=None):
return prefix + str(random.random())
def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
def connect_with_retry(addr, timeout=60, retry_period=5):
"""Connect to a TPC address with retry
This function is only reliable to short period of server restart.
......@@ -135,9 +138,6 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
retry_period : float
Number of seconds before we retry again.
silent: bool
whether run in silent mode
"""
tstart = time.time()
while True:
......@@ -152,8 +152,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
if period > timeout:
raise RuntimeError(
"Failed to connect to server %s" % str(addr))
if not silent:
logging.info("Cannot connect to tracker%s, retry in %g secs...",
logger.warning("Cannot connect to tracker %s, retry in %g secs...",
str(addr), retry_period)
time.sleep(retry_period)
......
......@@ -23,7 +23,8 @@ try:
from tornado import ioloop
from . import tornado_util
except ImportError as error_msg:
raise ImportError("RPCProxy module requires tornado package %s" % error_msg)
raise ImportError(
"RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg)
from . import base
from .base import TrackerCode
......@@ -540,7 +541,7 @@ def websocket_proxy_server(url, key=""):
def _connect(key):
conn = yield websocket.websocket_connect(url)
on_message = create_on_message(conn)
temp = _server_env(None, None)
temp = _server_env(None)
# Start connecton
conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True)
key = "server:" + key
......
......@@ -8,6 +8,8 @@ Server is TCP based with the following protocol:
- The key is in format
- {server|client}:device-type[:random-key] [-timeout=timeout]
"""
# pylint: disable=invalid-name
from __future__ import absolute_import
import os
......@@ -30,11 +32,11 @@ from ..contrib import util
from . import base
from . base import TrackerCode
def _server_env(load_library, logger):
logger = logging.getLogger('RPCServer')
def _server_env(load_library):
"""Server environment function return temp dir"""
temp = util.tempdir()
if logger is None:
logger = logging.getLogger()
# pylint: disable=unused-variable
@register_func("tvm.rpc.server.workpath")
......@@ -59,13 +61,10 @@ def _server_env(load_library, logger):
return temp
def _serve_loop(sock, addr, load_library, silent):
def _serve_loop(sock, addr, load_library):
"""Server loop"""
logger = logging.getLogger("RPCServer")
if silent:
logger.disabled = True
sockfd = sock.fileno()
temp = _server_env(load_library, logger)
temp = _server_env(load_library)
base._ServerLoop(sockfd)
temp.remove()
logger.info("Finish serving %s", addr)
......@@ -79,12 +78,8 @@ def _parse_server_opt(opts):
ret["timeout"] = float(kv[9:])
return ret
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, silent):
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
"""Listening loop of the server master."""
logger = logging.getLogger("RPCServer")
if silent:
logger.disabled = True
def _accept_conn(listen_sock, tracker_conn, ping_period=2):
"""Accept connection from the other places.
......@@ -148,7 +143,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
if arr[0] != expect_header:
conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
conn.close()
logger.info("mismatch key from %s", addr)
logger.warning("mismatch key from %s", addr)
continue
else:
conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
......@@ -162,7 +157,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
try:
# step 1: setup tracker and report to tracker
if tracker_addr and tracker_conn is None:
tracker_conn = base.connect_with_retry(tracker_addr, silent=silent)
tracker_conn = base.connect_with_retry(tracker_addr)
tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
......@@ -182,15 +177,12 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
tracker_conn = None
continue
except RuntimeError as exc:
if silent:
return
else:
raise exc
# step 3: serving
logger.info("connection from %s", addr)
server_proc = multiprocessing.Process(target=_serve_loop,
args=(conn, addr, load_library, silent))
args=(conn, addr, load_library))
server_proc.deamon = True
server_proc.start()
# close from our side.
......@@ -202,10 +194,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
server_proc.terminate()
def _connect_proxy_loop(addr, key, load_library, silent):
logger = logging.getLogger("RPCProxy")
if silent:
logger.disabled = True
def _connect_proxy_loop(addr, key, load_library):
key = "server:" + key
retry_count = 0
max_retry = 5
......@@ -221,7 +210,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH:
logger.info("RPCProxy do not have matching client key %s", key)
logger.warning("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % str(addr))
keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
......@@ -229,7 +218,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
opts = _parse_server_opt(remote_key.split()[1:])
logger.info("connected to %s", str(addr))
process = multiprocessing.Process(
target=_serve_loop, args=(sock, addr, load_library, silent))
target=_serve_loop, args=(sock, addr, load_library))
process.deamon = True
process.start()
sock.close()
......@@ -240,7 +229,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
retry_count = 0
except (socket.error, IOError) as err:
retry_count += 1
logger.info("Error encountered %s, retry in %g sec", str(err), retry_period)
logger.warning("Error encountered %s, retry in %g sec", str(err), retry_period)
if retry_count > max_retry:
raise RuntimeError("Maximum retry error: last error: %s" % str(err))
time.sleep(retry_period)
......@@ -323,9 +312,8 @@ class Server(object):
self.custom_addr = custom_addr
self.use_popen = use_popen
self.logger = logging.getLogger("RPCServer")
if silent:
self.logger.disabled = True
logger.setLevel(logging.WARN)
if use_popen:
cmd = [sys.executable,
......@@ -360,18 +348,18 @@ class Server(object):
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
self.logger.info("bind to %s:%d", host, self.port)
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.proc = multiprocessing.Process(
target=_listen_loop, args=(
self.sock, self.port, key, tracker_addr, load_library,
self.custom_addr, silent))
self.custom_addr))
self.proc.deamon = True
self.proc.start()
else:
self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key, load_library, silent))
target=_connect_proxy_loop, args=((host, port), key, load_library))
self.proc.deamon = True
self.proc.start()
......
......@@ -23,6 +23,8 @@ List of available APIs:
- input: [TrackerCode.REQUEST, [key, user, priority]]
- return: [TrackerCode.SUCCESS, [url, port, match-key]]
"""
# pylint: disable=invalid-name
import heapq
import time
import logging
......@@ -37,12 +39,13 @@ try:
from . import tornado_util
except ImportError as error_msg:
raise ImportError(
"RPCTracker module requires tornado package %s" % error_msg)
"RPCTracker module requires tornado package %s. Try 'pip install tornado'." % error_msg)
from .._ffi.base import py_str
from . import base
from .base import RPC_TRACKER_MAGIC, TrackerCode
logger = logging.getLogger("RPCTracker")
class Scheduler(object):
"""Abstratc interface of scheduler."""
......@@ -141,11 +144,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
def _init_conn(self, message):
"""Initialie the connection"""
if len(message) != 4:
logging.info("Invalid connection from %s", self.name())
logger.warning("Invalid connection from %s", self.name())
self.close()
magic = struct.unpack('<i', message)[0]
if magic != RPC_TRACKER_MAGIC:
logging.info("Invalid magic from %s", self.name())
logger.warning("Invalid magic from %s", self.name())
self.close()
self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True)
self._init_req_nbytes = 0
......@@ -232,14 +235,14 @@ class TCPEventHandler(tornado_util.TCPHandler):
status = self._tracker.summary()
self.ret_value([TrackerCode.SUCCESS, status])
else:
logging.info("Unknown tracker code %d", code)
logger.warning("Unknown tracker code %d", code)
self.close()
def on_close(self):
self._tracker._connections.remove(self)
def on_error(self, err):
logging.info("%s: Error in RPC Tracker: %s", self.name(), err)
logger.warning("%s: Error in RPC Tracker: %s", self.name(), err)
self.close()
......@@ -335,9 +338,8 @@ class Tracker(object):
port=9190,
port_end=9199,
silent=False):
self.logger = logging.getLogger("RPCTracker")
if silent:
self.logger.disabled = True
logger.setLevel(logging.WARN)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
......@@ -354,7 +356,7 @@ class Tracker(object):
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
self.logger.info("bind to %s:%d", host, self.port)
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.proc = multiprocessing.Process(
target=_tracker_server, args=(sock, self.stop_key))
......@@ -380,7 +382,7 @@ class Tracker(object):
self._stop_tracker()
self.proc.join(1)
if self.proc.is_alive():
self.logger.info("Terminating Tracker Server...")
logger.info("Terminating Tracker Server...")
self.proc.terminate()
self.proc = None
......
......@@ -154,7 +154,8 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
# for this template
# logging config (for printing tuning log to screen)
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
# the last layer in resnet
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
......
......@@ -163,8 +163,10 @@ def get_network(name, batch_size):
# Set Tuning Options
# ------------------
# Before tuning, we should do some configurations. Here I use an RK3399 board
# in our environment as example. In your setting, you should modify the target
# and device_key accordingly.
# as example. In your setting, you should modify the target and device_key accordingly.
# set :code:`use_android` to True if you use android phone.
#### DEVICE CONFIG ####
# Replace "aarch64-linux-gnu" with the correct target of your board.
# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
......@@ -173,7 +175,10 @@ target = tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu')
# Also replace this with the device key in your tracker
device_key = 'rk3399'
# tuning option
# Set this to True if you use android phone
use_android = False
#### TUNING OPTION ####
network = 'resnet-18'
log_file = "%s.%s.log" % (device_key, network)
dtype = 'float32'
......@@ -181,17 +186,17 @@ dtype = 'float32'
tuning_option = {
'log_filename': log_file,
'tuner':'xgb',
'tuner': 'xgb',
'n_trial': 1000,
'early_stopping': 200,
'early_stopping': 250,
'measure_option': autotvm.measure_option(
autotvm.use_rpc(device_key, host='localhost', port=9190),
number=4,
parallel_num=1,
timeout=10),
'use_transfer_learning': True,
timeout=10,
build_func='ndk' if use_android else 'default',
),
}
####################################################################
......@@ -208,9 +213,6 @@ tuning_option = {
# If your device is very slow or a single conv2d operator in your network has large FLOPs,
# consider setting timeout larger.
#
# **For android phone**, add :code:`build_func='ndk'` to the argument list of
# :code:`autotvm.measure_option` to use Android NDK for creating shared library.
#
###################################################################
# Begin Tuning
......@@ -280,12 +282,14 @@ def tune_tasks(tasks,
def tune_and_evaluate():
# extract workloads from nnvm graph
print("Extract tasks...")
net, params, shape, out_shape = get_network(network, batch_size=1)
tasks = autotvm.task.extract_from_graph(net, shape=shape, dtype=dtype,
symbols=(nnvm.sym.conv2d,),
target=target)
# run tuning tasks
print("Tuning...")
tune_tasks(tasks, **tuning_option)
# compile kernels with history best records
......@@ -329,6 +333,7 @@ def tune_and_evaluate():
# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run by yourself.
# tune_and_evaluate()
######################################################################
......@@ -341,6 +346,8 @@ def tune_and_evaluate():
#
# .. code-block:: bash
#
# Extract tasks...
# Tuning...
# [Task 1/16] Current/Best: 13.15/ 20.49 GFLOPS | Progress: (297/1000) | 348.51 s Done.
# [Task 2/16] Current/Best: 16.66/ 22.64 GFLOPS | Progress: (475/1000) | 415.42 s Done.
# [Task 3/16] Current/Best: 10.33/ 14.19 GFLOPS | Progress: (306/1000) | 239.61 s Done.
......@@ -362,3 +369,23 @@ def tune_and_evaluate():
# Evaluate inference time cost...
# Mean inference time (std dev): 156.51 ms (0.89 ms)
#
######################################################################
#
# .. note:: **Meet some problems?**
#
# The auto tuning module is error prone. If you always see " 0.00/ 0.00 GFLOPS",
# then there must be something wrong.
#
# First, make sure you set the correct configuration of your device.
# Then, you can print debug information by adding these lines in the beginning
# of the script. It will print every measurement result, where you can find useful
# error messages.
#
# .. code-block:: python
#
# import logging
# logging.getLogger('autotvm').setLevel(logging.DEBUG)
#
# Finally, always feel free to ask our community for help on https://discuss.tvm.ai
......@@ -267,8 +267,9 @@ print(task.config_space)
# We will log the tuning results into a log file. This file can be
# used to get the best config later.
# logging config (for printing tuning log to screen)
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
# logging config (for printing tuning log to the screen)
logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
# use local cpu, measure 5 times for every config to reduce variance
measure_option = autotvm.measure_option('local',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment