Commit d0c406e6 by Thierry Moreau Committed by Jared Roesch

fix dense tuning (#3768)

parent ae1ba36d
...@@ -33,9 +33,9 @@ env = vta.get_env() ...@@ -33,9 +33,9 @@ env = vta.get_env()
Workload = namedtuple("DenseWorkload", Workload = namedtuple("DenseWorkload",
['batch', 'in_filter', 'out_filter']) ['batch', 'in_filter', 'out_filter'])
resnet_wkls = [ dense_wkls = [
# Workloads of resnet18 on imagenet ('lstm.dense.1', Workload(1, 256, 128)),
('resnet-18.dense', Workload(16, 512, 1024)), ('lstm.dense.4', Workload(4, 256, 128)),
] ]
@tvm.tag_scope(tag=topi.tag.ELEMWISE) @tvm.tag_scope(tag=topi.tag.ELEMWISE)
...@@ -71,7 +71,14 @@ if __name__ == '__main__': ...@@ -71,7 +71,14 @@ if __name__ == '__main__':
# Logging config (for printing tuning log to the screen) # Logging config (for printing tuning log to the screen)
logging.basicConfig() logging.basicConfig()
logging.getLogger('autotvm').setLevel(logging.DEBUG) # logging.getLogger('autotvm').setLevel(logging.DEBUG)
# Tuning log files
log_file = "%s.dense.log" % (env.TARGET)
# create tmp log file
tmp_log_file = log_file + ".tmp"
if os.path.exists(log_file):
os.remove(log_file)
# Get tracker info from env # Get tracker info from env
tracket_host = os.environ.get("TVM_TRACKER_HOST", None) tracket_host = os.environ.get("TVM_TRACKER_HOST", None)
...@@ -80,7 +87,9 @@ if __name__ == '__main__': ...@@ -80,7 +87,9 @@ if __name__ == '__main__':
print("Set your AutoTVM tracker node host and port variables to run the autotuner") print("Set your AutoTVM tracker node host and port variables to run the autotuner")
exit() exit()
for wl_name, wl in resnet_wkls: for idx, (wl_name, wl) in enumerate(dense_wkls):
prefix = "[Task %2d/%2d] " % (idx, len(dense_wkls))
# Workload parameters # Workload parameters
N = wl.batch N = wl.batch
...@@ -91,15 +100,24 @@ if __name__ == '__main__': ...@@ -91,15 +100,24 @@ if __name__ == '__main__':
target=tvm.target.vta(), target_host=env.target_host, template_key='direct') target=tvm.target.vta(), target_host=env.target_host, template_key='direct')
print(task.config_space) print(task.config_space)
# Tune
measure_option = autotvm.measure_option( measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func), builder=autotvm.LocalBuilder(),
runner=autotvm.RPCRunner(env.TARGET, tracket_host, int(tracket_port), number=4, repeat=3, timeout=10000, runner=autotvm.RPCRunner(
check_correctness=True)) env.TARGET, host=tracket_host, port=int(tracket_port),
number=5, timeout=60,
check_correctness=True))
# Run Tuner
tuner = autotvm.tuner.RandomTuner(task) tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=len(task.config_space), tuner.tune(
n_trial=len(task.config_space),
early_stopping=None,
measure_option=measure_option, measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('dense.log')]) callbacks=[
autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])
print("\nBest tuner config:") # Pick best records to a cache file
print(tuner.best_config) autotvm.record.pick_best(tmp_log_file, log_file)
os.remove(tmp_log_file)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment