Commit 12839e6d by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] Decouple build and run in measurement (#1661)

parent 38203a86
...@@ -16,6 +16,11 @@ tvm.autotvm.measure ...@@ -16,6 +16,11 @@ tvm.autotvm.measure
.. autofunction:: tvm.autotvm.measure.create_measure_batch .. autofunction:: tvm.autotvm.measure.create_measure_batch
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalBuilder
.. autoclass:: tvm.autotvm.measure.measure_methods.RPCRunner
.. autoclass:: tvm.autotvm.measure.measure_methods.LocalRunner
tvm.autotvm.tuner tvm.autotvm.tuner
~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~
......
...@@ -22,7 +22,8 @@ from . import env ...@@ -22,7 +22,8 @@ from . import env
from . import tophub from . import tophub
# some shortcuts # some shortcuts
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \
LocalBuilder, LocalRunner, RPCRunner
from .tuner import callback from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \ from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, \ register_topi_compute, register_topi_schedule, \
......
"""Distributed executor infrastructure to scale up the tuning""" """Distributed executor infrastructure to scale up the tuning"""
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option, \
from .measure_methods import request_remote, check_remote, create_measure_batch, rpc create_measure_batch
from .measure_methods import LocalBuilder, LocalRunner, RPCRunner, request_remote
from .executor import Executor
from .local_executor import LocalExecutor from .local_executor import LocalExecutor
from .executor import Future, Executor
...@@ -37,7 +37,8 @@ def _execute_func(func, queue, args, kwargs): ...@@ -37,7 +37,8 @@ def _execute_func(func, queue, args, kwargs):
res = exc res = exc
queue.put(res) queue.put(res)
def timeout_monitor(queue, timeout, func, args, kwargs):
def call_with_timeout(queue, timeout, func, args, kwargs):
"""A wrapper to support timeout of a function call""" """A wrapper to support timeout of a function call"""
# start a new process for timeout (cannot use thread because we have c function) # start a new process for timeout (cannot use thread because we have c function)
...@@ -45,17 +46,12 @@ def timeout_monitor(queue, timeout, func, args, kwargs): ...@@ -45,17 +46,12 @@ def timeout_monitor(queue, timeout, func, args, kwargs):
p.start() p.start()
p.join(timeout=timeout) p.join(timeout=timeout)
alive = p.is_alive() queue.put(executor.TimeoutError())
kill_child_processes(p.pid) kill_child_processes(p.pid)
p.terminate() p.terminate()
p.join() p.join()
if alive:
queue.put(executor.TimeoutError())
else:
if queue.empty():
queue.put(executor.ExecutionError("Fatal error in local executor"))
class LocalFuture(executor.Future): class LocalFuture(executor.Future):
"""Local wrapper for the future """Local wrapper for the future
...@@ -134,7 +130,7 @@ class LocalExecutor(executor.Executor): ...@@ -134,7 +130,7 @@ class LocalExecutor(executor.Executor):
return LocalFutureNoFork(func(*args, **kwargs)) return LocalFutureNoFork(func(*args, **kwargs))
queue = Queue(2) queue = Queue(2)
process = Process(target=timeout_monitor, process = Process(target=call_with_timeout,
args=(queue, self.timeout, func, args, kwargs)) args=(queue, self.timeout, func, args, kwargs))
process.start() process.start()
return LocalFuture(process, queue) return LocalFuture(process, queue)
...@@ -22,7 +22,7 @@ class GATuner(Tuner): ...@@ -22,7 +22,7 @@ class GATuner(Tuner):
mutation_prob: float mutation_prob: float
probability of mutation of a knob in a gene probability of mutation of a knob in a gene
""" """
def __init__(self, task, pop_size, elite_num=3, mutation_prob=0.1): def __init__(self, task, pop_size=100, elite_num=3, mutation_prob=0.1):
super(GATuner, self).__init__(task) super(GATuner, self).__init__(task)
# algorithm configurations # algorithm configurations
......
...@@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer): ...@@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
new_scores = model.predict(new_points) new_scores = model.predict(new_points)
ac_prob = np.exp((new_scores - scores) / (t + 1e-2)) ac_prob = np.exp(np.minimum((new_scores - scores) / (t + 1e-5), 1))
ac_index = np.random.random(len(ac_prob)) < ac_prob ac_index = np.random.random(len(ac_prob)) < ac_prob
points[ac_index] = new_points[ac_index] points[ac_index] = new_points[ac_index]
......
...@@ -103,34 +103,7 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None): ...@@ -103,34 +103,7 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None):
target=target, target_host=target_host) target=target, target_host=target_host)
return task, target return task, target
def test_tuning():
def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()
def custom_measure(input_pack, build_func, build_args, number, repeat,
ref_input, ref_output):
from tvm.autotvm import MeasureResult
results = []
for inp in input_pack:
tic = time.time()
# do nothing
time.sleep(0.001)
results.append(MeasureResult([time.time() - tic], 0,
time.time() - tic, time.time()))
return results
measure_option = autotvm.measure_option(custom_measure)
logging.info("%s", task.config_space)
# new tuner and recorder
for tuner_class in [autotvm.tuner.RandomTuner, autotvm.tuner.GridSearchTuner]:
tuner = tuner_class(task)
tuner.tune(n_trial=10, measure_option=measure_option)
assert tuner.best_flops > 1
def test_tuning_with_measure():
def check(target, target_host): def check(target, target_host):
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
if not ctx.exist: if not ctx.exist:
...@@ -141,12 +114,12 @@ def test_tuning_with_measure(): ...@@ -141,12 +114,12 @@ def test_tuning_with_measure():
task, target = get_sample_task(target, target_host) task, target = get_sample_task(target, target_host)
logging.info("%s", task.config_space) logging.info("%s", task.config_space)
measure_option = autotvm.measure_option('local', measure_option = autotvm.measure_option(
timeout=4, autotvm.LocalBuilder(),
number=2) autotvm.LocalRunner())
tuner = RandomTuner(task) tuner = RandomTuner(task)
tuner.tune(n_trial=10, measure_option=measure_option) tuner.tune(n_trial=20, measure_option=measure_option)
check("cuda", None) check("cuda", None)
check("opencl", None) check("opencl", None)
...@@ -155,6 +128,4 @@ if __name__ == "__main__": ...@@ -155,6 +128,4 @@ if __name__ == "__main__":
# only print log when invoked from main # only print log when invoked from main
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
test_task_tuner_without_measurement() test_tuning()
test_tuning_with_measure()
...@@ -32,6 +32,25 @@ def matmul(N, L, M, dtype): ...@@ -32,6 +32,25 @@ def matmul(N, L, M, dtype):
return s, [A, B, C] return s, [A, B, C]
@autotvm.template
def bad_matmul(N, L, M, dtype):
if 'bad_device' in tvm.target.current_target().keys:
A = tvm.placeholder((N, L), name='A', dtype=dtype)
B = tvm.placeholder((L, M), name='B', dtype=dtype)
k = tvm.reduce_axis((0, L-1), name='k')
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
s = tvm.create_schedule(C.op)
# schedule
y, x = s[C].op.axis
cfg = autotvm.get_config()
cfg.define_split("tile_y", y, num_outputs=2)
cfg.define_split("tile_x", x, num_outputs=2)
return s, [A, B, C]
return matmul(N, L, M, dtype)
def get_sample_task(n=128): def get_sample_task(n=128):
"""return a sample task for testing""" """return a sample task for testing"""
target = tvm.target.create("llvm") target = tvm.target.create("llvm")
......
"""Test database""" """Test database"""
import copy import copy
import logging import logging
import time
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm import database from tvm.autotvm import database
from tvm.autotvm.measure.measure_methods import HashMismatchError from tvm.autotvm.record import encode, MeasureResult
from tvm.autotvm.record import encode, MeasureInput, MeasureResult
from test_autotvm_common import get_sample_task, get_sample_records from test_autotvm_common import get_sample_records
def test_save_load(): def test_save_load():
logging.info("test basic db load/save ...") logging.info("test basic db load/save ...")
...@@ -35,66 +29,6 @@ def test_save_load(): ...@@ -35,66 +29,6 @@ def test_save_load():
TRIAL_LIMIT = 2 TRIAL_LIMIT = 2
def test_db_filter():
logging.info("test db filter ...")
# Pick a GPU target because there are more likely to be failures/invalid configs
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
batch_size = 2
measure_option = autotvm.measure_option('local', do_fork=False, timeout=2)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
ct = 0
all_inputs = list()
all_results = list()
batches = list()
tuner = autotvm.tuner.RandomTuner(task)
while ct < TRIAL_LIMIT:
inputs = list()
for i in range(batch_size):
cfg = tuner.next_batch(1)[0]
inputs.append((MeasureInput(target, task, cfg)))
all_inputs.append(inputs[-1])
batches.append(inputs)
results = measure_batch(inputs)
all_results += results
ct += 1
del measure_batch
db = database.DummyDatabase()
db.flush()
# First setting, memoize one input at a time, check that each is saved and replayed
measure_option = autotvm.measure_option('local', do_fork=False, timeout=2, replay_db=db)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
for i in range(len(all_inputs)+1):
db.flush()
for j in range(i):
db.save(all_inputs[j], all_results[j])
for k in range(len(batches)):
batch = batches[k]
batch_result = measure_batch(batch)
for l in range(batch_size):
all_idx = k*batch_size + l
assert batch_result[l] is not None
if all_idx < i:
assert encode(batch[l], batch_result[l]) == encode(batch[l], all_results[all_idx]), \
"(no retry) EXPECTED MATCH, GOT MISMATCH"
else:
assert encode(batch[l], batch_result[l]) != encode(batch[l], all_results[all_idx]), \
"(no retry) EXPECTED MISMATCH, GOT MATCH"
del measure_batch
def test_db_hash(): def test_db_hash():
logging.info("test db hash check ...") logging.info("test db hash check ...")
inp1, res1 = get_sample_records(1)[0] inp1, res1 = get_sample_records(1)[0]
...@@ -149,89 +83,8 @@ def test_db_latest_all(): ...@@ -149,89 +83,8 @@ def test_db_latest_all():
assert encode(inp1, load4[1]) == encode(inp1, res2) assert encode(inp1, load4[1]) == encode(inp1, res2)
assert encode(inp1, load4[2]) == encode(inp1, res3) assert encode(inp1, load4[2]) == encode(inp1, res3)
def test_db_save_replay():
logging.info("test db save (from measure_batch) and replay ...")
_db = database.DummyDatabase()
_db.flush()
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option('local',
do_fork=False,
timeout=2,
replay_db=_db)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
batch_size = 2
ct = 0
all_inputs = list()
all_results = list()
batches = list()
tuner = autotvm.tuner.RandomTuner(task)
while ct < TRIAL_LIMIT:
inputs = list()
for i in range(batch_size):
cfg = tuner.next_batch(1)[0]
inputs.append((MeasureInput(target, task, cfg)))
all_inputs.append(inputs[-1])
batches.append(inputs)
results = measure_batch(inputs)
all_results += results
ct += 1
callback = autotvm.callback.log_to_database(_db)
callback(None, all_inputs, all_results)
assert len(_db.db.keys()) == batch_size * TRIAL_LIMIT, \
"%d vs %d" % (len(_db.db.keys()), batch_size * TRIAL_LIMIT)
all_results_2 = measure_batch(all_inputs)
all_results_3 = measure_batch(all_inputs)
for i in range(len(all_results)):
encr1 = encode(all_inputs[i], all_results[i])
encr2 = encode(all_inputs[i], all_results_2[i])
encr3 = encode(all_inputs[i], all_results_3[i])
assert encr1 == encr2, "EXPECTED MATCH WITH SAVE REPLAY (first replay), got MISMATCH"
assert encr2 == encr3, "EXPECTED MATCH WITH SAVE REPLAY (second replay), got MISMATCH"
del measure_batch
def test_check_hashmismatch():
logging.info("test hash mismatch check")
task, target = get_sample_task()
ctx = tvm.context(str(target))
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option('local', do_fork=False)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
inputs = list()
cfg = task.config_space.get(np.random.randint(len(task.config_space)))
# notvalidh is not a valid CRC32 hash (not hex)
cfg.code_hash = 'notvalidh'
inputs.append((MeasureInput(target, task, cfg)))
try:
results = measure_batch(inputs)
assert False, "HashMismatchError should be raised"
except HashMismatchError:
pass
del measure_batch
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
test_save_load() test_save_load()
test_db_filter()
test_db_hash() test_db_hash()
test_db_latest_all() test_db_latest_all()
test_db_save_replay()
test_check_hashmismatch()
"""Test builder and runner"""
import logging
import time
import numpy as np
import tvm
from tvm import autotvm
from test_autotvm_common import get_sample_task, bad_matmul
from tvm.autotvm.measure.measure import Runner, MeasureResult, MeasureErrorNo
def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()
class DummyRunner(Runner):
def __init__(self):
super(DummyRunner, self).__init__(1, 1)
def run(self, measure_inputs, build_results):
return [MeasureResult((np.random.random(),), 0, 0.2, time.time())
for _ in range(len(measure_inputs))]
def get_build_kwargs(self):
return {}
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=DummyRunner()
)
logging.info("%s", task.config_space)
for tuner_class in [autotvm.tuner.RandomTuner,
autotvm.tuner.GridSearchTuner,
autotvm.tuner.GATuner,
autotvm.tuner.XGBTuner]:
tuner = tuner_class(task)
tuner.tune(n_trial=10, measure_option=measure_option)
assert tuner.best_flops > 1
def test_check_correctness():
task, target = get_sample_task()
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(check_correctness=True)
)
def _callback_correct(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
assert res.error_no == 0
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=2, measure_option=measure_option,
callbacks=[_callback_correct])
# a bad template
n = 128
target = tvm.target.create("llvm -device=bad_device")
task = autotvm.task.create(bad_matmul, args=(n, n, n, 'float32'), target=target)
def _callback_wrong(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
assert res.error_no == MeasureErrorNo.WRONG_ANSWER
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=2, measure_option=measure_option,
callbacks=[_callback_wrong])
def test_min_repeat_ms():
task, target = get_sample_task()
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(number=1, min_repeat_ms=100)
)
def _callback(tuner, measure_inputs, measure_results):
for inp, res in zip(measure_inputs, measure_results):
if res.error_no != 0:
continue
assert 1000 * np.mean(res.costs) * \
measure_option['runner'].cur_number >= 100
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=5, measure_option=measure_option,
callbacks=[_callback])
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
test_task_tuner_without_measurement()
test_check_correctness()
test_min_repeat_ms()
...@@ -137,12 +137,15 @@ if __name__ == '__main__': ...@@ -137,12 +137,15 @@ if __name__ == '__main__':
print(task.config_space) print(task.config_space)
measure_option = autotvm.measure_option( measure_option = autotvm.measure_option(
measure_func='local', number=10, n_parallel=8, timeout=20) builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
)
log_name = 'gemm_int8.log' log_name = 'gemm_int8.log'
if DO_TUNING: if DO_TUNING:
tuner = autotvm.tuner.XGBTuner(task) tuner = autotvm.tuner.XGBTuner(task)
tuner.tune(n_trial=1000, measure_option=measure_option, tuner.tune(n_trial=1000, measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file(log_name)]) callbacks=[autotvm.callback.log_to_file(log_name)])
dispatch_context = autotvm.apply_history_best(log_name) dispatch_context = autotvm.apply_history_best(log_name)
best_config = dispatch_context.query(task.target, task.workload) best_config = dispatch_context.query(task.target, task.workload)
......
...@@ -164,12 +164,12 @@ task = autotvm.task.create(conv2d_no_batching, ...@@ -164,12 +164,12 @@ task = autotvm.task.create(conv2d_no_batching,
target='cuda') target='cuda')
print(task.config_space) print(task.config_space)
# use local gpu, measure 5 times for every config to reduce variance # use local gpu, measure 10 times for every config to reduce variance
# run 8 parallel threads for compilation # The timeout of compiling a program is 10 seconds, the timeout for running is 4 seconds
measure_option = autotvm.measure_option('local', measure_option = autotvm.measure_option(
number=5, builder=autotvm.LocalBuilder(),
n_parallel=8, runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
timeout=20) )
# begin tuning, log records to file `conv2d.log` # begin tuning, log records to file `conv2d.log`
tuner = autotvm.tuner.XGBTuner(task) tuner = autotvm.tuner.XGBTuner(task)
......
...@@ -271,9 +271,12 @@ print(task.config_space) ...@@ -271,9 +271,12 @@ print(task.config_space)
logging.getLogger('autotvm').setLevel(logging.DEBUG) logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
# use local cpu, measure 5 times for every config to reduce variance # There are two steps for measuring a config: build and run.
measure_option = autotvm.measure_option('local', # By default, we use all cpu cores to compile program. Then measure them sequentially.
number=5) # We measure 5 times and take average to reduce variance.
measure_option = autotvm.measure_option(
builder='local',
runner=autotvm.LocalRunner(number=5))
# begin tuning, log records to file `matmul.log` # begin tuning, log records to file `matmul.log`
tuner = autotvm.tuner.RandomTuner(task) tuner = autotvm.tuner.RandomTuner(task)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment