Commit d39ac773 by ziheng Committed by Tianqi Chen

[TOPI] Enhance Conv2D for More Data Type (#922)

parent 7fd7db0f
...@@ -45,13 +45,37 @@ _WORKLOADS = [ ...@@ -45,13 +45,37 @@ _WORKLOADS = [
Workload('float32', 'float32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1), Workload('float32', 'float32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1), Workload('float32', 'float32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1), Workload('float32', 'float32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
# workloads of resnet18 on imagenet (int16->int32 version)
Workload('int16', 'int32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload('int16', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('int16', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('int16', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('int16', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('int16', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('int16', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('int16', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('int16', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('int16', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('int16', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('int16', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
# workloads of mobile net on imagenet (int16->int32 version)
Workload('int16', 'int32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
Workload('int16', 'int32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
Workload('int16', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
Workload('int16', 'int32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
Workload('int16', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
Workload('int16', 'int32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
Workload('int16', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
Workload('int16', 'int32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
Workload('int16', 'int32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
Workload('int16', 'int32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
] ]
# platform specific schedule # platform specific schedule
_CONV_SCHEDULE = {} _CONV_SCHEDULE = {}
@tvm.target.generic_func @tvm.target.generic_func
def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype=None):
"""Conv2D operator. """Conv2D operator.
Parameters Parameters
...@@ -97,7 +121,9 @@ def _get_workload(data, kernel, stride, padding, out_dtype): ...@@ -97,7 +121,9 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
HSTR, WSTR = stride HSTR, WSTR = stride
else: else:
HSTR, WSTR = stride, stride HSTR, WSTR = stride, stride
assert data.dtype == kernel.dtype, "Do not support inputs with different data types now." assert data.dtype == kernel.dtype, \
"Do not support inputs with different data types now. ' \
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
...@@ -111,8 +137,11 @@ def _get_schedule(wkl): ...@@ -111,8 +137,11 @@ def _get_schedule(wkl):
# This return has no use, merely to supress pylint warning # This return has no use, merely to supress pylint warning
return wkl return wkl
def _spatial_pack(data, kernel, stride, padding, out_dtype):
def _spatial_pack(data, kernel, stride, padding, out_dtype=None):
""" Compute convolution with pack on spatial axes. """ """ Compute convolution with pack on spatial axes. """
if out_dtype is None:
out_dtype = data.dtype
assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
wkl = _get_workload(data, kernel, stride, padding, out_dtype) wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl) sch = _get_schedule(wkl)
...@@ -172,8 +201,10 @@ def _spatial_pack(data, kernel, stride, padding, out_dtype): ...@@ -172,8 +201,10 @@ def _spatial_pack(data, kernel, stride, padding, out_dtype):
return output return output
def _im2col_pack(data, kernel, stride, padding, out_dtype): def _im2col_pack(data, kernel, stride, padding, out_dtype=None):
""" Compute convolution with im2col pack layout. """ """ Compute convolution with im2col pack layout. """
if out_dtype is None:
out_dtype = data.dtype
assert data.shape[0].value == 1, "im2col pack convolution only support batch size=1" assert data.shape[0].value == 1, "im2col pack convolution only support batch size=1"
wkl = _get_workload(data, kernel, stride, padding, out_dtype) wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl) sch = _get_schedule(wkl)
...@@ -238,7 +269,7 @@ def _im2col_pack(data, kernel, stride, padding, out_dtype): ...@@ -238,7 +269,7 @@ def _im2col_pack(data, kernel, stride, padding, out_dtype):
return output return output
def conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'): def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
"""Convolution operator in NCHW layout. """Convolution operator in NCHW layout.
Parameters Parameters
...@@ -260,6 +291,8 @@ def conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'): ...@@ -260,6 +291,8 @@ def conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'):
Output : tvm.Tensor Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width] 4-D with shape [batch, out_channel, out_height, out_width]
""" """
if out_dtype is None:
out_dtype = Input.dtype
assert isinstance(stride, int) or len(stride) == 2 assert isinstance(stride, int) or len(stride) == 2
batch, in_channel, in_height, in_width = Input.shape batch, in_channel, in_height, in_width = Input.shape
num_filter, channel, kernel_h, kernel_w = Filter.shape num_filter, channel, kernel_h, kernel_w = Filter.shape
...@@ -289,7 +322,7 @@ def conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'): ...@@ -289,7 +322,7 @@ def conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'):
axis=[rc, ry, rx]), tag="conv2d_nchw") axis=[rc, ry, rx]), tag="conv2d_nchw")
def conv2d_hwcn(Input, Filter, stride, padding, out_dtype='float32'): def conv2d_hwcn(Input, Filter, stride, padding, out_dtype=None):
"""Convolution operator in HWCN layout. """Convolution operator in HWCN layout.
Parameters Parameters
...@@ -311,6 +344,8 @@ def conv2d_hwcn(Input, Filter, stride, padding, out_dtype='float32'): ...@@ -311,6 +344,8 @@ def conv2d_hwcn(Input, Filter, stride, padding, out_dtype='float32'):
output : tvm.Tensor output : tvm.Tensor
4-D with shape [out_height, out_width, out_channel, batch] 4-D with shape [out_height, out_width, out_channel, batch]
""" """
if out_dtype is None:
out_dtype = Input.dtype
assert isinstance(stride, int) or len(stride) == 2 assert isinstance(stride, int) or len(stride) == 2
in_height, in_width, in_channel, batch = Input.shape in_height, in_width, in_channel, batch = Input.shape
kernel_h, kernel_w, channel, num_filter = Filter.shape kernel_h, kernel_w, channel, num_filter = Filter.shape
......
...@@ -37,6 +37,32 @@ _SCHEDULES = [ ...@@ -37,6 +37,32 @@ _SCHEDULES = [
SpatialPack(2, 2, 8, 1, 8, False), SpatialPack(2, 2, 8, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False), Im2ColPack(7, 4, 1, 16, False),
Im2ColPack(7, 4, 1, 4, True), Im2ColPack(7, 4, 1, 4, True),
# int8 imagenet
SpatialPack(2, 2, 4, 19, 8, False),
SpatialPack(2, 2, 8, 1, 4, True),
SpatialPack(2, 2, 8, 7, 4, False),
SpatialPack(2, 4, 4, 7, 16, False),
SpatialPack(1, 7, 4, 14, 4, True),
SpatialPack(2, 2, 8, 5, 1, False),
SpatialPack(1, 2, 16, 3, 8, True),
SpatialPack(1, 7, 4, 1, 16, True),
SpatialPack(2, 2, 8, 2, 16, True),
SpatialPack(1, 1, 8, 4, 4, True),
SpatialPack(1, 1, 4, 1, 8, False),
SpatialPack(1, 1, 8, 1, 16, True),
# int8 mobilenet
SpatialPack(2, 2, 8, 8, 1, True),
SpatialPack(1, 7, 4, 16, 4, True),
SpatialPack(1, 4, 8, 1, 1, True),
SpatialPack(1, 4, 8, 1, 1, True),
SpatialPack(1, 4, 8, 4, 8, True),
SpatialPack(1, 4, 8, 7, 1, True),
SpatialPack(1, 2, 8, 2, 32, True),
SpatialPack(1, 2, 16, 2, 16, True),
SpatialPack(1, 1, 32, 1, 16, False),
SpatialPack(1, 1, 16, 1, 32, True),
] ]
@_get_schedule.register("rasp") @_get_schedule.register("rasp")
...@@ -50,6 +76,8 @@ def _schedule_conv2d(wkl): ...@@ -50,6 +76,8 @@ def _schedule_conv2d(wkl):
@conv2d.register("rasp") @conv2d.register("rasp")
def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype): def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype):
if out_dtype is None:
out_dtype = data.dtype
assert layout == 'NCHW', "only support NCHW convolution on rasp" assert layout == 'NCHW', "only support NCHW convolution on rasp"
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp" assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
wkl = _get_workload(data, kernel, stride, padding, out_dtype) wkl = _get_workload(data, kernel, stride, padding, out_dtype)
......
...@@ -24,6 +24,15 @@ _WORKLOADS = [ ...@@ -24,6 +24,15 @@ _WORKLOADS = [
_Workload('float32', 'float32', 14, 14, 512, 1, 3, 3, 1, 1, 1, 1), _Workload('float32', 'float32', 14, 14, 512, 1, 3, 3, 1, 1, 1, 1),
_Workload('float32', 'float32', 14, 14, 512, 1, 3, 3, 1, 1, 2, 2), _Workload('float32', 'float32', 14, 14, 512, 1, 3, 3, 1, 1, 2, 2),
_Workload('float32', 'float32', 7, 7, 1024, 1, 3, 3, 1, 1, 1, 1), _Workload('float32', 'float32', 7, 7, 1024, 1, 3, 3, 1, 1, 1, 1),
_Workload('int16', 'int32', 112, 112, 32, 1, 3, 3, 1, 1, 1, 1),
_Workload('int16', 'int32', 112, 112, 64, 1, 3, 3, 1, 1, 2, 2),
_Workload('int16', 'int32', 56, 56, 128, 1, 3, 3, 1, 1, 1, 1),
_Workload('int16', 'int32', 56, 56, 128, 1, 3, 3, 1, 1, 2, 2),
_Workload('int16', 'int32', 28, 28, 256, 1, 3, 3, 1, 1, 1, 1),
_Workload('int16', 'int32', 28, 28, 256, 1, 3, 3, 1, 1, 2, 2),
_Workload('int16', 'int32', 14, 14, 512, 1, 3, 3, 1, 1, 1, 1),
_Workload('int16', 'int32', 14, 14, 512, 1, 3, 3, 1, 1, 2, 2),
_Workload('int16', 'int32', 7, 7, 1024, 1, 3, 3, 1, 1, 1, 1),
] ]
_SCHEDULES = [ _SCHEDULES = [
...@@ -36,6 +45,15 @@ _SCHEDULES = [ ...@@ -36,6 +45,15 @@ _SCHEDULES = [
_Schedule(1, 1, 8, 8, True), _Schedule(1, 1, 8, 8, True),
_Schedule(1, 1, 4, 1, False), _Schedule(1, 1, 4, 1, False),
_Schedule(1, 1, 4, 4, False), _Schedule(1, 1, 4, 4, False),
_Schedule(2, 4, 4, 2, False),
_Schedule(2, 7, 4, 1, True),
_Schedule(2, 4, 4, 4, False),
_Schedule(2, 2, 4, 4, False),
_Schedule(2, 2, 8, 4, False),
_Schedule(2, 2, 4, 4, True),
_Schedule(2, 2, 8, 4, False),
_Schedule(1, 2, 8, 4, True),
_Schedule(1, 1, 4, 8, True),
] ]
def _get_workload(data, kernel, stride, padding, out_dtype): def _get_workload(data, kernel, stride, padding, out_dtype):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment