Commit 9a620c69 by Lianmin Zheng Committed by Tianqi Chen

[TOPI] Update TopHub and benchmark (#1796)

parent 72cab4ee
...@@ -58,8 +58,10 @@ def evaluate_network(network, target, target_host, number): ...@@ -58,8 +58,10 @@ def evaluate_network(network, target, target_host, number):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--network", type=str, choices= parser.add_argument("--network", type=str, choices=
['resnet-18', 'resnet-34', 'vgg-16', ['resnet-18', 'resnet-34', 'resnet-50',
'mobilenet', 'mobilenet_v2', 'squeezenet v1.0', 'squeezenet v1.1']) 'vgg-16', 'vgg-19', 'densenet-121', 'inception_v3',
'mobilenet', 'mobilenet_v2', 'squeezenet_v1.0', 'squeezenet_v1.1'],
help='The name of neural network')
parser.add_argument("--model", type=str, choices= parser.add_argument("--model", type=str, choices=
['rk3399', 'mate10', 'mate10pro', 'p20', 'p20pro', ['rk3399', 'mate10', 'mate10pro', 'p20', 'p20pro',
'pixel2', 'rasp3b', 'pynq'], default='rk3399', 'pixel2', 'rasp3b', 'pynq'], default='rk3399',
...@@ -68,7 +70,7 @@ if __name__ == "__main__": ...@@ -68,7 +70,7 @@ if __name__ == "__main__":
parser.add_argument("--host", type=str, default='localhost') parser.add_argument("--host", type=str, default='localhost')
parser.add_argument("--port", type=int, default=9190) parser.add_argument("--port", type=int, default=9190)
parser.add_argument("--rpc-key", type=str, required=True) parser.add_argument("--rpc-key", type=str, required=True)
parser.add_argument("--number", type=int, default=6) parser.add_argument("--number", type=int, default=3)
args = parser.parse_args() args = parser.parse_args()
dtype = 'float32' dtype = 'float32'
......
...@@ -17,8 +17,10 @@ from util import get_network ...@@ -17,8 +17,10 @@ from util import get_network
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--network", type=str, choices= parser.add_argument("--network", type=str, choices=
['resnet-18', 'resnet-34', 'resnet-50', 'vgg-16', 'vgg-19', ['resnet-18', 'resnet-34', 'resnet-50',
'inception_v3', 'mobilenet', 'mobilenet_v2', 'densenet-121']) 'vgg-16', 'vgg-19', 'densenet-121', 'inception_v3',
'mobilenet', 'mobilenet_v2', 'squeezenet_v1.0', 'squeezenet_v1.1'],
help='The name of neural network')
parser.add_argument("--model", type=str, parser.add_argument("--model", type=str,
choices=['1080ti', 'titanx', 'gfx900'], default='1080ti', choices=['1080ti', 'titanx', 'gfx900'], default='1080ti',
help="The model of the test device. If your device is not listed in " help="The model of the test device. If your device is not listed in "
......
...@@ -58,8 +58,10 @@ def evaluate_network(network, target, target_host, number): ...@@ -58,8 +58,10 @@ def evaluate_network(network, target, target_host, number):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--network", type=str, choices= parser.add_argument("--network", type=str, choices=
['resnet-18', 'resnet-34', 'vgg-16', ['resnet-18', 'resnet-34', 'resnet-50',
'mobilenet', 'mobilenet_v2', 'squeezenet v1.1']) 'vgg-16', 'vgg-19', 'densenet-121', 'inception_v3',
'mobilenet', 'mobilenet_v2', 'squeezenet_v1.0', 'squeezenet_v1.1'],
help='The name of neural network')
parser.add_argument("--model", type=str, choices= parser.add_argument("--model", type=str, choices=
['rk3399'], default='rk3399', ['rk3399'], default='rk3399',
help="The model of the test device. If your device is not listed in " help="The model of the test device. If your device is not listed in "
...@@ -67,7 +69,7 @@ if __name__ == "__main__": ...@@ -67,7 +69,7 @@ if __name__ == "__main__":
parser.add_argument("--host", type=str, default='localhost') parser.add_argument("--host", type=str, default='localhost')
parser.add_argument("--port", type=int, default=9190) parser.add_argument("--port", type=int, default=9190)
parser.add_argument("--rpc-key", type=str, required=True) parser.add_argument("--rpc-key", type=str, required=True)
parser.add_argument("--number", type=int, default=10) parser.add_argument("--number", type=int, default=30)
args = parser.parse_args() args = parser.parse_args()
dtype = 'float32' dtype = 'float32'
......
...@@ -20,12 +20,12 @@ AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub ...@@ -20,12 +20,12 @@ AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub
# the version of each package # the version of each package
PACKAGE_VERSION = { PACKAGE_VERSION = {
'arm_cpu': "v0.01", 'arm_cpu': "v0.03",
'cuda': "v0.02", 'cuda': "v0.02",
'rocm': "v0.01", 'rocm': "v0.01",
'opencl': "v0.01", 'opencl': "v0.01",
'mali': "v0.01", 'mali': "v0.02",
'vta': "v0.01", 'vta': "v0.01",
} }
...@@ -38,7 +38,7 @@ def _alias(name): ...@@ -38,7 +38,7 @@ def _alias(name):
'vtacpu': 'vta', 'vtacpu': 'vta',
'metal': 'opencl', 'metal': 'opencl',
'nvptx': 'cuda' 'nvptx': 'cuda',
} }
return table.get(name, name) return table.get(name, name)
...@@ -61,11 +61,12 @@ def context(target, extra_files=None): ...@@ -61,11 +61,12 @@ def context(target, extra_files=None):
if isinstance(target, str): if isinstance(target, str):
target = _target.create(target) target = _target.create(target)
possible_names = [str(target).split()[0]] possible_names = []
for opt in target.options: for opt in target.options:
if opt.startswith("-device"): if opt.startswith("-device"):
device = _alias(opt[8:]) device = _alias(opt[8:])
possible_names.append(device) possible_names.append(device)
possible_names.append(target.target_name)
all_packages = list(PACKAGE_VERSION.keys()) all_packages = list(PACKAGE_VERSION.keys())
for name in possible_names: for name in possible_names:
...@@ -75,6 +76,7 @@ def context(target, extra_files=None): ...@@ -75,6 +76,7 @@ def context(target, extra_files=None):
filename = "%s_%s.log" % (name, PACKAGE_VERSION[name]) filename = "%s_%s.log" % (name, PACKAGE_VERSION[name])
best_context.load(os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, filename)) best_context.load(os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, filename))
break # only load one file to avoid some fallback template mismatch problem
if extra_files: if extra_files:
for filename in extra_files: for filename in extra_files:
......
...@@ -506,8 +506,8 @@ def schedule_conv2d_winograd_without_weight_transform_(cfg, outs): ...@@ -506,8 +506,8 @@ def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
##### REGISTER ALTER OP LAYOUT ##### ##### REGISTER ALTER OP LAYOUT #####
@conv2d_alter_layout.register(["arm_cpu", "mali"]) @conv2d_alter_layout.register(["arm_cpu"])
def _alter_conv2d_layout(attrs, inputs, tinfos): def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
"""Alter op layout for pre-computing kernel transformation""" """Alter op layout for pre-computing kernel transformation"""
import nnvm.symbol as sym import nnvm.symbol as sym
copy_inputs = [s for s in inputs] copy_inputs = [s for s in inputs]
......
...@@ -9,11 +9,11 @@ from tvm.autotvm.task.space import get_factors ...@@ -9,11 +9,11 @@ from tvm.autotvm.task.space import get_factors
from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform
from ..util import traverse_inline, get_const_int, get_const_tuple, const_matrix from ..util import traverse_inline, get_const_int, get_const_tuple, const_matrix
from ..nn import conv2d, conv2d_winograd_without_weight_transform, \ from ..nn import conv2d, conv2d_winograd_without_weight_transform, \
get_pad_tuple, pad get_pad_tuple, pad, conv2d_alter_layout
# reuse some compute declarations from ARM CPU # reuse some compute declarations from ARM CPU
from ..arm_cpu.conv2d import _conv_arg_to_workload, _decl_spatial_pack,\ from ..arm_cpu.conv2d import _conv_arg_to_workload, _decl_spatial_pack,\
_winograd_conv_arg_to_workload _winograd_conv_arg_to_workload, _alter_conv2d_layout_arm
@conv2d.register('mali') @conv2d.register('mali')
...@@ -410,6 +410,12 @@ def _schedule_winograd(cfg, s, op): ...@@ -410,6 +410,12 @@ def _schedule_winograd(cfg, s, op):
s[Y].compute_at(s[output], tt) s[Y].compute_at(s[output], tt)
@conv2d_alter_layout.register(["mali"])
def _alter_conv2d_layout(attrs, inputs, tinfos):
try:
return _alter_conv2d_layout_arm(attrs, inputs, tinfos)
except KeyError: # to filter out fallback opencl templates
return None
##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM ##### ##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
@conv2d_winograd_without_weight_transform.register(['mali']) @conv2d_winograd_without_weight_transform.register(['mali'])
......
...@@ -69,16 +69,11 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -69,16 +69,11 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in get_all_backend(): for device in get_all_backend():
check_device(device) with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device)
def test_conv2d_nchw(): def test_conv2d_nchw():
# load tophub
ctx = autotvm.apply_history_best([])
for device in get_all_backend():
context = autotvm.tophub.context(device)
context.__enter__()
# ResNet18 workloads # ResNet18 workloads
verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3)
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
......
...@@ -102,7 +102,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -102,7 +102,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
for device in get_all_backend(): for device in get_all_backend():
check_device(device) with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device)
def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding, dilation=1): def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding, dilation=1):
...@@ -201,16 +202,11 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -201,16 +202,11 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
for device in get_all_backend(): for device in get_all_backend():
check_device(device) with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device)
def test_depthwise_conv2d(): def test_depthwise_conv2d():
# load tophub
ctx = autotvm.apply_history_best([])
for device in get_all_backend():
context = autotvm.tophub.context(device)
context.__enter__()
# mobilenet workloads # mobilenet workloads
depthwise_conv2d_with_workload_nchw(1, 32, 112, 1, 3, 1, "SAME") depthwise_conv2d_with_workload_nchw(1, 32, 112, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_nchw(1, 64, 112, 1, 3, 2, "SAME") depthwise_conv2d_with_workload_nchw(1, 64, 112, 1, 3, 2, "SAME")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment