Commit d0c406e6 by Thierry Moreau Committed by Jared Roesch

fix dense tuning (#3768)

parent ae1ba36d
......@@ -33,9 +33,9 @@ env = vta.get_env()
Workload = namedtuple("DenseWorkload",
['batch', 'in_filter', 'out_filter'])
resnet_wkls = [
# Workloads of resnet18 on imagenet
('resnet-18.dense', Workload(16, 512, 1024)),
dense_wkls = [
('lstm.dense.1', Workload(1, 256, 128)),
('lstm.dense.4', Workload(4, 256, 128)),
]
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
......@@ -71,7 +71,14 @@ if __name__ == '__main__':
# Logging config (for printing tuning log to the screen)
logging.basicConfig()
logging.getLogger('autotvm').setLevel(logging.DEBUG)
# logging.getLogger('autotvm').setLevel(logging.DEBUG)
# Tuning log files
log_file = "%s.dense.log" % (env.TARGET)
# create tmp log file
tmp_log_file = log_file + ".tmp"
if os.path.exists(log_file):
os.remove(log_file)
# Get tracker info from env
tracket_host = os.environ.get("TVM_TRACKER_HOST", None)
......@@ -80,7 +87,9 @@ if __name__ == '__main__':
print("Set your AutoTVM tracker node host and port variables to run the autotuner")
exit()
for wl_name, wl in resnet_wkls:
for idx, (wl_name, wl) in enumerate(dense_wkls):
prefix = "[Task %2d/%2d] " % (idx, len(dense_wkls))
# Workload parameters
N = wl.batch
......@@ -91,15 +100,24 @@ if __name__ == '__main__':
target=tvm.target.vta(), target_host=env.target_host, template_key='direct')
print(task.config_space)
# Tune
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
runner=autotvm.RPCRunner(env.TARGET, tracket_host, int(tracket_port), number=4, repeat=3, timeout=10000,
builder=autotvm.LocalBuilder(),
runner=autotvm.RPCRunner(
env.TARGET, host=tracket_host, port=int(tracket_port),
number=5, timeout=60,
check_correctness=True))
# Run Tuner
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=len(task.config_space),
tuner.tune(
n_trial=len(task.config_space),
early_stopping=None,
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('dense.log')])
callbacks=[
autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])
print("\nBest tuner config:")
print(tuner.best_config)
# Pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_file)
os.remove(tmp_log_file)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment