Commit 6c7f0c4d by Thierry Moreau Committed by Jared Roesch

[VTA] Support for batched inference (#3661)

* fix in IR pass to support padding on 6-d tensors

* support for both N>1 and N==1 for padding

* batch size > 1 tuning and base config

* output formatting

* batch conv2d

* print all category results

* revert to single-batch config

* pick record best

* fix conv test

* improving reporting

* address batching bug in fast simulator

* fix
parent 9b355fc3
......@@ -524,22 +524,29 @@ def inject_dma_intrin(stmt_in):
if pad_before:
assert pad_after
ndim = len(pad_before)
if ndim <= 2 or ndim > 4:
if ndim <= 2 or ndim > 5:
raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim)
if ndim > 2:
if not util.equal_const_int(pad_before[ndim - 1], 0):
if ndim == 5:
# This case occurs when batch size N > 1
y_pad_before = pad_before[1]
x_pad_before = pad_before[2]
y_pad_after = pad_after[1]
x_pad_after = pad_after[2]
for dim in range(3, ndim):
if not util.equal_const_int(pad_before[dim], 0):
raise ValueError("Do not support pad on the innermost block")
if not util.equal_const_int(pad_after[ndim - 1], 0):
raise ValueError("Do not support pad on the innermost block")
if ndim > 3:
if not util.equal_const_int(pad_before[ndim - 2], 0):
raise ValueError("Do not support pad on the innermost block")
if not util.equal_const_int(pad_after[ndim - 2], 0):
if not util.equal_const_int(pad_after[dim], 0):
raise ValueError("Do not support pad on the innermost block")
else:
y_pad_before = pad_before[0]
x_pad_before = pad_before[1]
y_pad_after = pad_after[0]
x_pad_after = pad_after[1]
for dim in range(2, ndim):
if not util.equal_const_int(pad_before[dim], 0):
raise ValueError("Do not support pad on the innermost block")
if not util.equal_const_int(pad_after[dim], 0):
raise ValueError("Do not support pad on the innermost block")
allow_fold = False
else:
x_pad_before = 0
......
......@@ -36,18 +36,17 @@ Workload = namedtuple("Conv2DWorkload",
resnet_wkls = [
# Workloads of resnet18 on imagenet
# ('resnet-18.C1', Workload(1, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2)),
('resnet-18.C2', Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
# ('resnet-18.C3', Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1)), # this layer does not appear in ResNet
('resnet-18.C4', Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
('resnet-18.C5', Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
('resnet-18.C6', Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
('resnet-18.C7', Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
('resnet-18.C8', Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
('resnet-18.C9', Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
('resnet-18.C10', Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
('resnet-18.C11', Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
('resnet-18.C12', Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
# ('resnet-18.C1', Workload(env.BATCH, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2)),
('resnet-18.C2', Workload(env.BATCH, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
('resnet-18.C3', Workload(env.BATCH, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
('resnet-18.C4', Workload(env.BATCH, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
('resnet-18.C5', Workload(env.BATCH, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
('resnet-18.C6', Workload(env.BATCH, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
('resnet-18.C7', Workload(env.BATCH, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
('resnet-18.C8', Workload(env.BATCH, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
('resnet-18.C9', Workload(env.BATCH, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
('resnet-18.C10', Workload(env.BATCH, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
('resnet-18.C11', Workload(env.BATCH, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
]
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
......@@ -87,16 +86,25 @@ if __name__ == '__main__':
# Logging config (for printing tuning log to the screen)
logging.basicConfig()
logging.getLogger('autotvm').setLevel(logging.DEBUG)
# logging.getLogger('autotvm').setLevel(logging.DEBUG)
# Tuning log files
log_file = "%s.conv2d.log" % (env.TARGET)
# create tmp log file
tmp_log_file = log_file + ".tmp"
if os.path.exists(log_file):
os.remove(log_file)
# Get tracker info from env
tracket_host = os.environ.get("TVM_TRACKER_HOST", None)
tracket_port = os.environ.get("TVM_TRACKER_PORT", None)
if not tracket_host or not tracket_port:
tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
tracker_port = os.environ.get("TVM_TRACKER_PORT", None)
if not tracker_host or not tracker_port:
print("Set your AutoTVM tracker node host and port variables to run the autotuner")
exit()
for wl_name, wl in resnet_wkls:
for idx, (wl_name, wl) in enumerate(resnet_wkls):
prefix = "[Task %2d/%2d] " % (idx, len(resnet_wkls))
# Workload parameters
N = wl.batch
......@@ -116,15 +124,24 @@ if __name__ == '__main__':
target=tvm.target.vta(), target_host=env.target_host, template_key='direct')
print(task.config_space)
# Tune
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
runner=autotvm.RPCRunner(env.TARGET, tracket_host, int(tracket_port), number=4, repeat=3, timeout=10000,
builder=autotvm.LocalBuilder(),
runner=autotvm.RPCRunner(
env.TARGET, host=tracker_host, port=int(tracker_port),
number=5, timeout=60,
check_correctness=True))
# Run Tuner
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=len(task.config_space),
tuner.tune(
n_trial=len(task.config_space),
early_stopping=None,
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('conv2d.log')])
callbacks=[
autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])
print("\nBest tuner config:")
print(tuner.best_config)
# Pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_file)
os.remove(tmp_log_file)
......@@ -553,7 +553,7 @@ class Device {
src_index += y * op->src_factor_out + x * op->src_factor_in;
BitPacker<VTA_ACC_WIDTH> dst(acc_.BeginPtr(dst_index));
BitPacker<VTA_ACC_WIDTH> src(acc_.BeginPtr(src_index));
for (int k = 0; k < VTA_BLOCK_OUT; ++k) {
for (int k = 0; k < VTA_BATCH * VTA_BLOCK_OUT; ++k) {
if (use_imm) {
dst.SetSigned(k, func(dst.GetSigned(k), op->imm));
} else {
......
......@@ -38,21 +38,23 @@ Workload = namedtuple("Conv2DWorkload",
['batch', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
# Get batch info from env
env = vta.get_env()
# ResNet18 workloads
resnet_wkls = [
# Workloads of resnet18 on imagenet
# ('resnet-18.C1', Workload(1, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2)),
('resnet-18.C2', Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
# ('resnet-18.C3', Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1)), # this layer does not appear in ResNet
('resnet-18.C4', Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
('resnet-18.C5', Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
('resnet-18.C6', Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
('resnet-18.C7', Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
('resnet-18.C8', Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
('resnet-18.C9', Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
('resnet-18.C10', Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
('resnet-18.C11', Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
('resnet-18.C12', Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
# ('resnet-18.C1', Workload(env.BATCH, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2)),
('resnet-18.C2', Workload(env.BATCH, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
('resnet-18.C3', Workload(env.BATCH, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
('resnet-18.C4', Workload(env.BATCH, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
('resnet-18.C5', Workload(env.BATCH, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
('resnet-18.C6', Workload(env.BATCH, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
('resnet-18.C7', Workload(env.BATCH, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
('resnet-18.C8', Workload(env.BATCH, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
('resnet-18.C9', Workload(env.BATCH, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
('resnet-18.C10', Workload(env.BATCH, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
('resnet-18.C11', Workload(env.BATCH, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
]
# FIXME: we need a custom clip operator to circumvent a pattern detection limitation
......@@ -143,7 +145,7 @@ def run_conv2d(env, remote, wl, target,
wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
bias_np = bias_np.reshape(
wl.batch // env.BATCH, wl.out_filter // env.BLOCK_OUT,
wl.batch//env.BATCH, wl.out_filter//env.BLOCK_OUT,
1, 1, env.BATCH, env.BLOCK_OUT)
# Build
......@@ -201,8 +203,10 @@ def run_conv2d(env, remote, wl, target,
if data_pack:
res_orig = res_orig.transpose(
(0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width)
bias_np = bias_np.transpose(
(0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, 1, 1)
res_ref = res_ref >> 8
res_ref += bias_np.reshape(wl.out_filter, 1, 1)
res_ref += bias_np
res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1)
res_ref = res_ref.astype(env.out_dtype)
correct = np.allclose(res_orig, res_ref)
......
......@@ -355,7 +355,19 @@ def tune_and_evaluate(tuning_opt):
assert len(tasks) == 10
print("Extracted {} conv2d tasks:".format(len(tasks)))
for tsk in tasks:
print("\t{}".format(tsk))
inp = tsk.args[0][1]
wgt = tsk.args[1][1]
batch = inp[0]*inp[4]
in_filter = inp[1]*inp[5]
out_filter = wgt[0]*wgt[4]
height, width = inp[2], inp[3]
hkernel, wkernel = wgt[2], wgt[3]
hstride, wstride = tsk.args[2][0], tsk.args[2][1]
hpad, wpad = tsk.args[3][0], tsk.args[3][1]
print("({}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {})".format(
batch, height, width, in_filter, out_filter,
hkernel, wkernel, hpad, wpad, hstride, wstride
))
# We do not run the tuning in our webpage server since it takes too long.
# Comment the following line to run it by yourself.
......
......@@ -247,29 +247,31 @@ if env.TARGET in ["sim", "tsim"]:
print("\t{:<16}: {:>16}".format(k, v // (num * rep + 1)))
else:
tcost = timer()
std = np.std(tcost.results) * 1000 / env.BATCH
mean = tcost.mean * 1000 / env.BATCH
print("\nPerformed inference in %.2fms/sample (std = %.2f)" % (mean, std))
std = np.std(tcost.results) * 1000
mean = tcost.mean * 1000
print("\nPerformed inference in %.2fms (std = %.2f) for %d samples" % (mean, std, env.BATCH))
print("Average per sample inference time: %.2fms" % (mean/env.BATCH))
# Get classification results
tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0)))
top_categories = np.argsort(tvm_output.asnumpy()[0])
# Report top-5 classification results
print("\n%s prediction" % model)
print("\t#1:", synset[top_categories[-1]])
print("\t#2:", synset[top_categories[-2]])
print("\t#3:", synset[top_categories[-3]])
print("\t#4:", synset[top_categories[-4]])
print("\t#5:", synset[top_categories[-5]])
# This just checks that one of the 5 top categories
# is one variety of cat; this is by no means an accurate
# assessment of how quantization affects classification
# accuracy but is meant to catch changes to the
# quantization pass that would accuracy in the CI.
cat_detected = False
for k in top_categories[-5:]:
for b in range(env.BATCH):
top_categories = np.argsort(tvm_output.asnumpy()[b])
# Report top-5 classification results
print("\n{} prediction for sample {}".format(model, b))
print("\t#1:", synset[top_categories[-1]])
print("\t#2:", synset[top_categories[-2]])
print("\t#3:", synset[top_categories[-3]])
print("\t#4:", synset[top_categories[-4]])
print("\t#5:", synset[top_categories[-5]])
# This just checks that one of the 5 top categories
# is one variety of cat; this is by no means an accurate
# assessment of how quantization affects classification
# accuracy but is meant to catch changes to the
# quantization pass that would accuracy in the CI.
cat_detected = False
for k in top_categories[-5:]:
if "cat" in synset[k]:
cat_detected = True
assert(cat_detected)
assert(cat_detected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment