Commit 48ff777a by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] API change (#1583)

parent 48fc410e
......@@ -22,7 +22,7 @@ from . import env
from . import tophub
# some shortcuts
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, use_rpc
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
ApplyHistoryBest as apply_history_best
......
"""Distributed executor infrastructure to scale up the tuning"""
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option
from .measure_methods import request_remote, check_remote, create_measure_batch, use_rpc
from .measure_methods import request_remote, check_remote, create_measure_batch, rpc
from .local_executor import LocalExecutor
from .executor import Future, Executor
......@@ -49,7 +49,7 @@ def measure_option(measure_func,
number=1,
repeat=1,
timeout=60,
parallel_num=1,
n_parallel=1,
do_fork=True,
build_func='default',
check_correctness=False,
......@@ -63,7 +63,7 @@ def measure_option(measure_func,
and a RPC server silently for the user.
callable: It is a callable function for measurement.
See the return value of measure/measure_methods.py::use_rpc for example.
See the return value of measure/measure_methods.py::rpc for example.
number : int, optional
Number of times to do the measurement for average
repeat : int, optional
......@@ -74,7 +74,7 @@ def measure_option(measure_func,
timeout: int, optional
Timeout for a whole batch. TimeoutError will be returned as the result if a
task timeouts.
parallel_num: int, optional
n_parallel: int, optional
The number of measurement task that can run in parallel.
Set this according to the number of cpu cores (for compilation) and
the number of devices you have (for measuring generate code).
......@@ -106,7 +106,7 @@ def measure_option(measure_func,
and handle the logic of measurement.
Signature:
* measure_func (see the return value of measure/measure_methods.py::use_rpc for example)
* measure_func (see the return value of measure/measure_methods.py::rpc for example)
def measure_func(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output):
return measure_results
......@@ -119,7 +119,7 @@ def measure_option(measure_func,
'number': number,
'repeat': repeat,
'timeout': timeout,
'parallel_num': parallel_num,
'n_parallel': n_parallel,
'do_fork': do_fork,
'build_func': build_func,
'check_correctness': check_correctness,
......
......@@ -13,8 +13,8 @@ import threading
import numpy as np
from ... import rpc, ir_pass, build, build_config, nd, context, TVMError, register_func, \
target as _target
from ... import ir_pass, build, build_config, nd, context, TVMError, register_func, \
target as _target, rpc as _rpc
from ...contrib import nvcc, util, ndk
from ..util import get_const_tuple
......@@ -60,7 +60,7 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
host = os.environ['TVM_TRACKER_HOST']
port = int(os.environ['TVM_TRACKER_PORT'])
tracker = rpc.connect_tracker(host, port)
tracker = _rpc.connect_tracker(host, port)
remote = tracker.request(device_key, priority=priority,
session_timeout=timeout)
return remote
......@@ -113,7 +113,7 @@ def create_measure_batch(task, option):
measure_func = option['measure_func']
number, repeat = option['number'], option['repeat']
timeout, parallel_num, do_fork = option['timeout'], option['parallel_num'], option['do_fork']
timeout, n_parallel, do_fork = option['timeout'], option['n_parallel'], option['do_fork']
build_func = option['build_func']
check_correctness = option['check_correctness']
replay_db = option['replay_db']
......@@ -134,7 +134,7 @@ def create_measure_batch(task, option):
use_popen=True, silent=True,
tracker_addr=(tracker.host, tracker.port))
measure_func = use_rpc(device_key, tracker.host, tracker.port)
measure_func = rpc(device_key, tracker.host, tracker.port)
attach_objects = (server, tracker)
build_kwargs = {}
......@@ -218,18 +218,18 @@ def create_measure_batch(task, option):
return partial_results
return results
measure_batch.parallel_num = parallel_num
measure_batch.n_parallel = n_parallel
# attach server and tracker object to avoid them of being garbage-collected
measure_batch.attach_objects = attach_objects
return measure_batch
def use_rpc(key,
host=None,
port=None,
priority=1,
session_timeout=60,
pack_size=1):
def rpc(key,
host=None,
port=None,
priority=1,
session_timeout=60,
pack_size=1):
"""
Create a standard measure_func which uses RPC Tracker for measurement.
This measure_func will request a device from the RPC Tracker and
......
......@@ -85,7 +85,7 @@ class Tuner(object):
every measurement pair. See autotvm/tuner/callback.py for some examples.
"""
measure_batch = create_measure_batch(self.task, measure_option)
parallel_num = getattr(measure_batch, 'parallel_num', 1)
n_parallel = getattr(measure_batch, 'n_parallel', 1)
early_stopping = early_stopping or 1e9
old_level = logger.level
......@@ -95,7 +95,7 @@ class Tuner(object):
if not self.has_next():
break
configs = self.next_batch(min(parallel_num, n_trial - i))
configs = self.next_batch(min(n_parallel, n_trial - i))
inputs = [MeasureInput(self.task.target, self.task, config) for config in configs]
results = measure_batch(inputs)
......
......@@ -168,7 +168,7 @@ print(task.config_space)
# run 8 parallel threads for compilation
measure_option = autotvm.measure_option('local',
number=5,
parallel_num=8,
n_parallel=8,
timeout=20)
# begin tuning, log records to file `conv2d.log`
......
......@@ -191,9 +191,9 @@ tuning_option = {
'early_stopping': 250,
'measure_option': autotvm.measure_option(
autotvm.use_rpc(device_key, host='localhost', port=9190),
autotvm.measure.rpc(device_key, host='localhost', port=9190),
number=4,
parallel_num=1,
n_parallel=1,
timeout=10,
build_func='ndk' if use_android else 'default',
),
......@@ -205,7 +205,7 @@ tuning_option = {
#
# In general, the default value provided here works well. It is the same
# value that we used to generate pre-tuned parameters.
# If you have multiple devices, you can set :code:`parallel_num` to
# If you have multiple devices, you can set :code:`n_parallel` to
# the number of devices you have. (e.g. set it to 3 if you register 3 rk3399
# boards to the tracker).
# If you have large time budget, you can set :code:`n_trial`, :code:`early_stopping` larger,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment