Commit 48ff777a by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] API change (#1583)

parent 48fc410e
...@@ -22,7 +22,7 @@ from . import env ...@@ -22,7 +22,7 @@ from . import env
from . import tophub from . import tophub
# some shortcuts # some shortcuts
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, use_rpc from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
from .tuner import callback from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \ from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
ApplyHistoryBest as apply_history_best ApplyHistoryBest as apply_history_best
......
"""Distributed executor infrastructure to scale up the tuning""" """Distributed executor infrastructure to scale up the tuning"""
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option
from .measure_methods import request_remote, check_remote, create_measure_batch, use_rpc from .measure_methods import request_remote, check_remote, create_measure_batch, rpc
from .local_executor import LocalExecutor from .local_executor import LocalExecutor
from .executor import Future, Executor from .executor import Future, Executor
...@@ -49,7 +49,7 @@ def measure_option(measure_func, ...@@ -49,7 +49,7 @@ def measure_option(measure_func,
number=1, number=1,
repeat=1, repeat=1,
timeout=60, timeout=60,
parallel_num=1, n_parallel=1,
do_fork=True, do_fork=True,
build_func='default', build_func='default',
check_correctness=False, check_correctness=False,
...@@ -63,7 +63,7 @@ def measure_option(measure_func, ...@@ -63,7 +63,7 @@ def measure_option(measure_func,
and a RPC server silently for the user. and a RPC server silently for the user.
callable: It is a callable function for measurement. callable: It is a callable function for measurement.
See the return value of measure/measure_methods.py::use_rpc for example. See the return value of measure/measure_methods.py::rpc for example.
number : int, optional number : int, optional
Number of times to do the measurement for average Number of times to do the measurement for average
repeat : int, optional repeat : int, optional
...@@ -74,7 +74,7 @@ def measure_option(measure_func, ...@@ -74,7 +74,7 @@ def measure_option(measure_func,
timeout: int, optional timeout: int, optional
Timeout for a whole batch. TimeoutError will be returned as the result if a Timeout for a whole batch. TimeoutError will be returned as the result if a
task timeouts. task timeouts.
parallel_num: int, optional n_parallel: int, optional
The number of measurement task that can run in parallel. The number of measurement task that can run in parallel.
Set this according to the number of cpu cores (for compilation) and Set this according to the number of cpu cores (for compilation) and
the number of devices you have (for measuring generate code). the number of devices you have (for measuring generate code).
...@@ -106,7 +106,7 @@ def measure_option(measure_func, ...@@ -106,7 +106,7 @@ def measure_option(measure_func,
and handle the logic of measurement. and handle the logic of measurement.
Signature: Signature:
* measure_func (see the return value of measure/measure_methods.py::use_rpc for example) * measure_func (see the return value of measure/measure_methods.py::rpc for example)
def measure_func(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output): def measure_func(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output):
return measure_results return measure_results
...@@ -119,7 +119,7 @@ def measure_option(measure_func, ...@@ -119,7 +119,7 @@ def measure_option(measure_func,
'number': number, 'number': number,
'repeat': repeat, 'repeat': repeat,
'timeout': timeout, 'timeout': timeout,
'parallel_num': parallel_num, 'n_parallel': n_parallel,
'do_fork': do_fork, 'do_fork': do_fork,
'build_func': build_func, 'build_func': build_func,
'check_correctness': check_correctness, 'check_correctness': check_correctness,
......
...@@ -13,8 +13,8 @@ import threading ...@@ -13,8 +13,8 @@ import threading
import numpy as np import numpy as np
from ... import rpc, ir_pass, build, build_config, nd, context, TVMError, register_func, \ from ... import ir_pass, build, build_config, nd, context, TVMError, register_func, \
target as _target target as _target, rpc as _rpc
from ...contrib import nvcc, util, ndk from ...contrib import nvcc, util, ndk
from ..util import get_const_tuple from ..util import get_const_tuple
...@@ -60,7 +60,7 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60): ...@@ -60,7 +60,7 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
host = os.environ['TVM_TRACKER_HOST'] host = os.environ['TVM_TRACKER_HOST']
port = int(os.environ['TVM_TRACKER_PORT']) port = int(os.environ['TVM_TRACKER_PORT'])
tracker = rpc.connect_tracker(host, port) tracker = _rpc.connect_tracker(host, port)
remote = tracker.request(device_key, priority=priority, remote = tracker.request(device_key, priority=priority,
session_timeout=timeout) session_timeout=timeout)
return remote return remote
...@@ -113,7 +113,7 @@ def create_measure_batch(task, option): ...@@ -113,7 +113,7 @@ def create_measure_batch(task, option):
measure_func = option['measure_func'] measure_func = option['measure_func']
number, repeat = option['number'], option['repeat'] number, repeat = option['number'], option['repeat']
timeout, parallel_num, do_fork = option['timeout'], option['parallel_num'], option['do_fork'] timeout, n_parallel, do_fork = option['timeout'], option['n_parallel'], option['do_fork']
build_func = option['build_func'] build_func = option['build_func']
check_correctness = option['check_correctness'] check_correctness = option['check_correctness']
replay_db = option['replay_db'] replay_db = option['replay_db']
...@@ -134,7 +134,7 @@ def create_measure_batch(task, option): ...@@ -134,7 +134,7 @@ def create_measure_batch(task, option):
use_popen=True, silent=True, use_popen=True, silent=True,
tracker_addr=(tracker.host, tracker.port)) tracker_addr=(tracker.host, tracker.port))
measure_func = use_rpc(device_key, tracker.host, tracker.port) measure_func = rpc(device_key, tracker.host, tracker.port)
attach_objects = (server, tracker) attach_objects = (server, tracker)
build_kwargs = {} build_kwargs = {}
...@@ -218,13 +218,13 @@ def create_measure_batch(task, option): ...@@ -218,13 +218,13 @@ def create_measure_batch(task, option):
return partial_results return partial_results
return results return results
measure_batch.parallel_num = parallel_num measure_batch.n_parallel = n_parallel
# attach server and tracker object to avoid them of being garbage-collected # attach server and tracker object to avoid them of being garbage-collected
measure_batch.attach_objects = attach_objects measure_batch.attach_objects = attach_objects
return measure_batch return measure_batch
def use_rpc(key, def rpc(key,
host=None, host=None,
port=None, port=None,
priority=1, priority=1,
......
...@@ -85,7 +85,7 @@ class Tuner(object): ...@@ -85,7 +85,7 @@ class Tuner(object):
every measurement pair. See autotvm/tuner/callback.py for some examples. every measurement pair. See autotvm/tuner/callback.py for some examples.
""" """
measure_batch = create_measure_batch(self.task, measure_option) measure_batch = create_measure_batch(self.task, measure_option)
parallel_num = getattr(measure_batch, 'parallel_num', 1) n_parallel = getattr(measure_batch, 'n_parallel', 1)
early_stopping = early_stopping or 1e9 early_stopping = early_stopping or 1e9
old_level = logger.level old_level = logger.level
...@@ -95,7 +95,7 @@ class Tuner(object): ...@@ -95,7 +95,7 @@ class Tuner(object):
if not self.has_next(): if not self.has_next():
break break
configs = self.next_batch(min(parallel_num, n_trial - i)) configs = self.next_batch(min(n_parallel, n_trial - i))
inputs = [MeasureInput(self.task.target, self.task, config) for config in configs] inputs = [MeasureInput(self.task.target, self.task, config) for config in configs]
results = measure_batch(inputs) results = measure_batch(inputs)
......
...@@ -168,7 +168,7 @@ print(task.config_space) ...@@ -168,7 +168,7 @@ print(task.config_space)
# run 8 parallel threads for compilation # run 8 parallel threads for compilation
measure_option = autotvm.measure_option('local', measure_option = autotvm.measure_option('local',
number=5, number=5,
parallel_num=8, n_parallel=8,
timeout=20) timeout=20)
# begin tuning, log records to file `conv2d.log` # begin tuning, log records to file `conv2d.log`
......
...@@ -191,9 +191,9 @@ tuning_option = { ...@@ -191,9 +191,9 @@ tuning_option = {
'early_stopping': 250, 'early_stopping': 250,
'measure_option': autotvm.measure_option( 'measure_option': autotvm.measure_option(
autotvm.use_rpc(device_key, host='localhost', port=9190), autotvm.measure.rpc(device_key, host='localhost', port=9190),
number=4, number=4,
parallel_num=1, n_parallel=1,
timeout=10, timeout=10,
build_func='ndk' if use_android else 'default', build_func='ndk' if use_android else 'default',
), ),
...@@ -205,7 +205,7 @@ tuning_option = { ...@@ -205,7 +205,7 @@ tuning_option = {
# #
# In general, the default value provided here works well. It is the same # In general, the default value provided here works well. It is the same
# value that we used to generate pre-tuned parameters. # value that we used to generate pre-tuned parameters.
# If you have multiple devices, you can set :code:`parallel_num` to # If you have multiple devices, you can set :code:`n_parallel` to
# the number of devices you have. (e.g. set it to 3 if you register 3 rk3399 # the number of devices you have. (e.g. set it to 3 if you register 3 rk3399
# boards to the tracker). # boards to the tracker).
# If you have large time budget, you can set :code:`n_trial`, :code:`early_stopping` larger, # If you have large time budget, you can set :code:`n_trial`, :code:`early_stopping` larger,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment