Commit 12839e6d by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Decouple build and run in measurement (#1661)

parent 38203a86
......@@ -16,6 +16,11 @@ tvm.autotvm.measure
.. autofunction:: tvm.autotvm.measure.create_measure_batch
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalBuilder
.. autoclass:: tvm.autotvm.measure.measure_methods.RPCRunner
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalRunner
tvm.autotvm.tuner
~~~~~~~~~~~~~~~~~
......
......@@ -22,7 +22,8 @@ from . import env
from . import tophub
# some shortcuts
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \
LocalBuilder, LocalRunner, RPCRunner
from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, \
......
"""Distributed executor infrastructure to scale up the tuning"""
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option
from .measure_methods import request_remote, check_remote, create_measure_batch, rpc
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option, \
create_measure_batch
from .measure_methods import LocalBuilder, LocalRunner, RPCRunner, request_remote
from .executor import Executor
from .local_executor import LocalExecutor
from .executor import Future, Executor
......@@ -37,7 +37,8 @@ def _execute_func(func, queue, args, kwargs):
res = exc
queue.put(res)
def timeout_monitor(queue, timeout, func, args, kwargs):
def call_with_timeout(queue, timeout, func, args, kwargs):
"""A wrapper to support timeout of a function call"""
# start a new process for timeout (cannot use thread because we have c function)
......@@ -45,17 +46,12 @@ def timeout_monitor(queue, timeout, func, args, kwargs):
p.start()
p.join(timeout=timeout)
alive = p.is_alive()
queue.put(executor.TimeoutError())
kill_child_processes(p.pid)
p.terminate()
p.join()
if alive:
queue.put(executor.TimeoutError())
else:
if queue.empty():
queue.put(executor.ExecutionError("Fatal error in local executor"))
class LocalFuture(executor.Future):
"""Local wrapper for the future
......@@ -134,7 +130,7 @@ class LocalExecutor(executor.Executor):
return LocalFutureNoFork(func(*args, **kwargs))
queue = Queue(2)
process = Process(target=timeout_monitor,
process = Process(target=call_with_timeout,
args=(queue, self.timeout, func, args, kwargs))
process.start()
return LocalFuture(process, queue)
......@@ -22,7 +22,7 @@ class GATuner(Tuner):
mutation_prob: float
probability of mutation of a knob in a gene
"""
def __init__(self, task, pop_size, elite_num=3, mutation_prob=0.1):
def __init__(self, task, pop_size=100, elite_num=3, mutation_prob=0.1):
super(GATuner, self).__init__(task)
# algorithm configurations
......
......@@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
new_scores = model.predict(new_points)
ac_prob = np.exp((new_scores - scores) / (t + 1e-2))
ac_prob = np.exp(np.minimum((new_scores - scores) / (t + 1e-5), 1))
ac_index = np.random.random(len(ac_prob)) < ac_prob
points[ac_index] = new_points[ac_index]
......
......@@ -103,34 +103,7 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None):
target=target, target_host=target_host)
return task, target
def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()
def custom_measure(input_pack, build_func, build_args, number, repeat,
ref_input, ref_output):
from tvm.autotvm import MeasureResult
results = []
for inp in input_pack:
tic = time.time()
# do nothing
time.sleep(0.001)
results.append(MeasureResult([time.time() - tic], 0,
time.time() - tic, time.time()))
return results
measure_option = autotvm.measure_option(custom_measure)
logging.info("%s", task.config_space)
# new tuner and recorder
for tuner_class in [autotvm.tuner.RandomTuner, autotvm.tuner.GridSearchTuner]:
tuner = tuner_class(task)
tuner.tune(n_trial=10, measure_option=measure_option)
assert tuner.best_flops > 1
def test_tuning_with_measure():
def test_tuning():
def check(target, target_host):
ctx = tvm.context(target, 0)
if not ctx.exist:
......@@ -141,12 +114,12 @@ def test_tuning_with_measure():
task, target = get_sample_task(target, target_host)
logging.info("%s", task.config_space)
measure_option = autotvm.measure_option('local',
timeout=4,
number=2)
measure_option = autotvm.measure_option(
autotvm.LocalBuilder(),
autotvm.LocalRunner())
tuner = RandomTuner(task)
tuner.tune(n_trial=10, measure_option=measure_option)
tuner.tune(n_trial=20, measure_option=measure_option)
check("cuda", None)
check("opencl", None)
......@@ -155,6 +128,4 @@ if __name__ == "__main__":
# only print log when invoked from main
logging.basicConfig(level=logging.DEBUG)
test_task_tuner_without_measurement()
test_tuning_with_measure()
test_tuning()
......@@ -32,6 +32,25 @@ def matmul(N, L, M, dtype):
return s, [A, B, C]
@autotvm.template
def bad_matmul(N, L, M, dtype):
if 'bad_device' in tvm.target.current_target().keys:
A = tvm.placeholder((N, L), name='A', dtype=dtype)
B = tvm.placeholder((L, M), name='B', dtype=dtype)
k = tvm.reduce_axis((0, L-1), name='k')
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
s = tvm.create_schedule(C.op)
# schedule
y, x = s[C].op.axis
cfg = autotvm.get_config()
cfg.define_split("tile_y", y, num_outputs=2)
cfg.define_split("tile_x", x, num_outputs=2)
return s, [A, B, C]
return matmul(N, L, M, dtype)
def get_sample_task(n=128):
"""return a sample task for testing"""
target = tvm.target.create("llvm")
......
"""Test database"""
import copy
import logging
import time
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm import database
from tvm.autotvm.measure.measure_methods import HashMismatchError
from tvm.autotvm.record import encode, MeasureInput, MeasureResult
from tvm.autotvm.record import encode, MeasureResult
from test_autotvm_common import get_sample_task, get_sample_records
from test_autotvm_common import get_sample_records
def test_save_load():
logging.info("test basic db load/save ...")
......@@ -35,66 +29,6 @@ def test_save_load():
TRIAL_LIMIT = 2
def test_db_filter():
logging.info("test db filter ...")
# Pick a GPU target because there are more likely to be failures/invalid configs
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
batch_size = 2
measure_option = autotvm.measure_option('local', do_fork=False, timeout=2)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
ct = 0
all_inputs = list()
all_results = list()
batches = list()
tuner = autotvm.tuner.RandomTuner(task)
while ct < TRIAL_LIMIT:
inputs = list()
for i in range(batch_size):
cfg = tuner.next_batch(1)[0]
inputs.append((MeasureInput(target, task, cfg)))
all_inputs.append(inputs[-1])
batches.append(inputs)
results = measure_batch(inputs)
all_results += results
ct += 1
del measure_batch
db = database.DummyDatabase()
db.flush()
# First setting, memoize one input at a time, check that each is saved and replayed
measure_option = autotvm.measure_option('local', do_fork=False, timeout=2, replay_db=db)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
for i in range(len(all_inputs)+1):
db.flush()
for j in range(i):
db.save(all_inputs[j], all_results[j])
for k in range(len(batches)):
batch = batches[k]
batch_result = measure_batch(batch)
for l in range(batch_size):
all_idx = k*batch_size + l
assert batch_result[l] is not None
if all_idx < i:
assert encode(batch[l], batch_result[l]) == encode(batch[l], all_results[all_idx]), \
"(no retry) EXPECTED MATCH, GOT MISMATCH"
else:
assert encode(batch[l], batch_result[l]) != encode(batch[l], all_results[all_idx]), \
"(no retry) EXPECTED MISMATCH, GOT MATCH"
del measure_batch
def test_db_hash():
logging.info("test db hash check ...")
inp1, res1 = get_sample_records(1)[0]
......@@ -149,89 +83,8 @@ def test_db_latest_all():
assert encode(inp1, load4[1]) == encode(inp1, res2)
assert encode(inp1, load4[2]) == encode(inp1, res3)
def test_db_save_replay():
logging.info("test db save (from measure_batch) and replay ...")
_db = database.DummyDatabase()
_db.flush()
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option('local',
do_fork=False,
timeout=2,
replay_db=_db)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
batch_size = 2
ct = 0
all_inputs = list()
all_results = list()
batches = list()
tuner = autotvm.tuner.RandomTuner(task)
while ct < TRIAL_LIMIT:
inputs = list()
for i in range(batch_size):
cfg = tuner.next_batch(1)[0]
inputs.append((MeasureInput(target, task, cfg)))
all_inputs.append(inputs[-1])
batches.append(inputs)
results = measure_batch(inputs)
all_results += results
ct += 1
callback = autotvm.callback.log_to_database(_db)
callback(None, all_inputs, all_results)
assert len(_db.db.keys()) == batch_size * TRIAL_LIMIT, \
"%d vs %d" % (len(_db.db.keys()), batch_size * TRIAL_LIMIT)
all_results_2 = measure_batch(all_inputs)
all_results_3 = measure_batch(all_inputs)
for i in range(len(all_results)):
encr1 = encode(all_inputs[i], all_results[i])
encr2 = encode(all_inputs[i], all_results_2[i])
encr3 = encode(all_inputs[i], all_results_3[i])
assert encr1 == encr2, "EXPECTED MATCH WITH SAVE REPLAY (first replay), got MISMATCH"
assert encr2 == encr3, "EXPECTED MATCH WITH SAVE REPLAY (second replay), got MISMATCH"
del measure_batch
def test_check_hashmismatch():
logging.info("test hash mismatch check")
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option('local', do_fork=False)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
inputs = list()
cfg = task.config_space.get(np.random.randint(len(task.config_space)))
# notvalidh is not a valid CRC32 hash (not hex)
cfg.code_hash = 'notvalidh'
inputs.append((MeasureInput(target, task, cfg)))
try:
results = measure_batch(inputs)
assert False, "HashMismatchError should be raised"
except HashMismatchError:
pass
del measure_batch
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
test_save_load()
test_db_filter()
test_db_hash()
test_db_latest_all()
test_db_save_replay()
test_check_hashmismatch()
"""Test builder and runner"""
import logging
import time
import numpy as np
import tvm
from tvm import autotvm
from test_autotvm_common import get_sample_task, bad_matmul
from tvm.autotvm.measure.measure import Runner, MeasureResult, MeasureErrorNo
def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()
class DummyRunner(Runner):
def __init__(self):
super(DummyRunner, self).__init__(1, 1)
def run(self, measure_inputs, build_results):
return [MeasureResult((np.random.random(),), 0, 0.2, time.time())
for _ in range(len(measure_inputs))]
def get_build_kwargs(self):
return {}
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=DummyRunner()
)
logging.info("%s", task.config_space)
for tuner_class in [autotvm.tuner.RandomTuner,
autotvm.tuner.GridSearchTuner,
autotvm.tuner.GATuner,
autotvm.tuner.XGBTuner]:
tuner = tuner_class(task)
tuner.tune(n_trial=10, measure_option=measure_option)
assert tuner.best_flops > 1
def test_check_correctness():
task, target = get_sample_task()
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(check_correctness=True)
)
def _callback_correct(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
assert res.error_no == 0
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=2, measure_option=measure_option,
callbacks=[_callback_correct])
# a bad template
n = 128
target = tvm.target.create("llvm -device=bad_device")
task = autotvm.task.create(bad_matmul, args=(n, n, n, 'float32'), target=target)
def _callback_wrong(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
assert res.error_no == MeasureErrorNo.WRONG_ANSWER
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=2, measure_option=measure_option,
callbacks=[_callback_wrong])
def test_min_repeat_ms():
task, target = get_sample_task()
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(number=1, min_repeat_ms=100)
)
def _callback(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
if res.error_no != 0:
continue
assert 1000 * np.mean(res.costs) * \
measure_option['runner'].cur_number >= 100
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=5, measure_option=measure_option,
callbacks=[_callback])
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
test_task_tuner_without_measurement()
test_check_correctness()
test_min_repeat_ms()
......@@ -137,12 +137,15 @@ if __name__ == '__main__':
print(task.config_space)
measure_option = autotvm.measure_option(
measure_func='local', number=10, n_parallel=8, timeout=20)
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
)
log_name = 'gemm_int8.log'
if DO_TUNING:
tuner = autotvm.tuner.XGBTuner(task)
tuner.tune(n_trial=1000, measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file(log_name)])
callbacks=[autotvm.callback.log_to_file(log_name)])
dispatch_context = autotvm.apply_history_best(log_name)
best_config = dispatch_context.query(task.target, task.workload)
......
......@@ -164,12 +164,12 @@ task = autotvm.task.create(conv2d_no_batching,
target='cuda')
print(task.config_space)
# use local gpu, measure 5 times for every config to reduce variance
# run 8 parallel threads for compilation
measure_option = autotvm.measure_option('local',
number=5,
n_parallel=8,
timeout=20)
# use local gpu, measure 10 times for every config to reduce variance
# The timeout of compiling a program is 10 seconds, the timeout for running is 4 seconds
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
)
# begin tuning, log records to file `conv2d.log`
tuner = autotvm.tuner.XGBTuner(task)
......
......@@ -271,9 +271,12 @@ print(task.config_space)
logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
# use local cpu, measure 5 times for every config to reduce variance
measure_option = autotvm.measure_option('local',
number=5)
# There are two steps for measuring a config: build and run.
# By default, we use all cpu cores to compile program. Then measure them sequentially.
# We measure 5 times and take average to reduce variance.
measure_option = autotvm.measure_option(
builder='local',
runner=autotvm.LocalRunner(number=5))
# begin tuning, log records to file `matmul.log`
tuner = autotvm.tuner.RandomTuner(task)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment