Commit 5c410c4c by Wuwei Lin Committed by Tianqi Chen

[TOPI][CUDA] int8 group conv2d (#2075)

parent 3ee13fc5
......@@ -108,6 +108,9 @@ def compute_conv2d(attrs, inputs, _):
groups == channels:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
elif layout == "NCHW":
out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
out_dtype=out_dtype)
elif layout == "NHWC" and \
kernel_layout == "HWOI" and \
groups == get_const_int(inputs[0].shape[3]) and \
......@@ -143,6 +146,8 @@ def schedule_conv2d(attrs, outs, target):
return topi.generic.schedule_depthwise_conv2d_nchw(outs)
elif groups == channels and layout == "NHWC" and kernel_layout == "HWOI":
return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
elif layout == "NCHW":
return topi.generic.schedule_group_conv2d_nchw(outs)
else:
raise ValueError("No compatible schedule")
......
......@@ -58,7 +58,8 @@ class TaskExtractEnv:
# NOTE: To add more symbols, you only need to change the following lists
# nnvm symbol -> topi compute
self.symbol2topi = {
nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw],
nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw],
nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
nnvm.sym.dense: [topi.nn.dense],
}
......@@ -67,6 +68,7 @@ class TaskExtractEnv:
self.topi_to_task = {
topi.nn.conv2d: "topi_nn_conv2d",
topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
topi.nn.dense: "topi_nn_dense",
}
......@@ -76,6 +78,7 @@ class TaskExtractEnv:
topi.generic.schedule_conv2d_nhwc],
topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw,
topi.generic.schedule_depthwise_conv2d_nhwc],
topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
topi.nn.dense: [topi.generic.schedule_dense],
}
......@@ -143,6 +146,15 @@ class TaskExtractEnv:
s = topi.generic.schedule_depthwise_conv2d_nchw([C])
return s, [A, W, C]
@register("topi_nn_group_conv2d_nchw")
def _topi_nn_group_conv2d_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.group_conv2d_nchw(*args, **kwargs)
s = topi.generic.schedule_group_conv2d_nchw([C])
return s, [A, W, C]
@register("topi_nn_conv2d_transpose_nchw")
def _topi_nn_conv2d_transpose_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
......
......@@ -2,10 +2,11 @@
"""CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs
from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw
from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw, group_conv2d_nchw
from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
from .group_conv2d_nchw import schedule_conv2d_nchw_cuda
from .reduction import schedule_reduce
from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
......
......@@ -173,6 +173,25 @@ def schedule_depthwise_conv2d_nhwc(outs):
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_group_conv2d_nchw(outs):
"""Schedule for conv2d_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of group_conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_bitserial_conv2d_nchw(outs):
"""Schedule for bitserial_conv2d_nchw
......
......@@ -403,3 +403,80 @@ def conv2d_winograd_without_weight_transform(input, filter, strides, padding, di
4-D with shape [batch, out_height, out_width, out_channel]
"""
raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform")
@tvm.target.generic_func
def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtype=None):
"""Group convolution operator in NCHW layout.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
Filter : tvm.Tensor
4-D with shape [num_filter, in_channel // groups, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
dilation : int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
groups : int
number of groups
out_dtype : str
The output type. This is used for mixed precision.
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
if out_dtype is None:
out_dtype = Input.dtype
assert isinstance(stride, int) or len(stride) == 2
assert isinstance(dilation, int) or len(dilation) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
batch, in_channel, in_height, in_width = get_const_tuple(Input.shape)
num_filter, _, kernel_h, kernel_w = get_const_tuple(Filter.shape)
assert in_channel % groups == 0, "input channels must divide group size"
assert num_filter % groups == 0, "output channels must divide group size"
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_h, kernel_w))
# compute the output shape
out_channel = num_filter
out_height = simplify(
(in_height - (kernel_h - 1) * dilation_h - 1 + pad_top + pad_down) // stride_h + 1)
out_width = simplify(
(in_width - (kernel_w - 1) * dilation_w - 1 + pad_left + pad_right) // stride_w + 1)
# compute graph
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
temp = pad(Input, pad_before, pad_after, name="pad_temp")
rc = tvm.reduce_axis((0, in_channel // groups), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
return tvm.compute(
(batch, out_channel, out_height, out_width),
lambda nn, ff, yy, xx: tvm.sum(
temp[nn, ff // (num_filter//groups) * (in_channel//groups) + rc,
yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w].astype(out_dtype) *
Filter[ff, rc, ry, rx].astype(out_dtype),
axis=[rc, ry, rx]), tag="conv2d_nchw")
......@@ -4,8 +4,8 @@ import numpy as np
import scipy.signal
def conv2d_nchw_python(a_np, w_np, stride, padding):
"""Convolution operator in HWCN layout.
def _conv2d_nchw_python(a_np, w_np, stride, padding):
"""Convolution operator in NCHW layout.
Parameters
----------
......@@ -66,3 +66,36 @@ def conv2d_nchw_python(a_np, w_np, stride, padding):
apad, np.rot90(np.rot90(w_np[f, c])), mode='valid')
b_np[n, f] += out[::stride_h, ::stride_w]
return b_np
def conv2d_nchw_python(a_np, w_np, stride, padding, groups=1):
"""Convolution operator in NCHW layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
w_np : numpy.ndarray
4-D with shape [num_filter, in_channel // groups, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str or a list/tuple of two ints
Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width]
groups : int
Number of groups
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
a_slices = np.array_split(a_np, groups, axis=1)
w_slices = np.array_split(w_np, groups, axis=0)
b_slices = [_conv2d_nchw_python(a_slice, w_slice, stride, padding)
for a_slice, w_slice in zip(a_slices, w_slices)]
b_np = np.concatenate(b_slices, axis=1)
return b_np
"""Common utility for topi test"""
from tvm import autotvm
from tvm.autotvm.task.space import FallbackConfigEntity
def get_all_backend():
"""return all supported target
......@@ -10,3 +14,14 @@ def get_all_backend():
"""
return ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx',
'llvm -device=arm_cpu', 'opencl -device=mali', 'aocl_sw_emu']
class NCHWcInt8Fallback(autotvm.FallbackContext):
def _query_inside(self, target, workload):
key = (target, workload)
if key in self.memory:
return self.memory[key]
cfg = FallbackConfigEntity()
cfg.template_key = 'int8'
self.memory[key] = cfg
return cfg
......@@ -9,7 +9,7 @@ import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from common import get_all_backend
from common import get_all_backend, NCHWcInt8Fallback
oc_block_factor = 4
......@@ -88,17 +88,6 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str
check_device(device)
class NCHWcInt8Fallback(autotvm.FallbackContext):
def _query_inside(self, target, workload):
key = (target, workload)
if key in self.memory:
return self.memory[key]
cfg = FallbackConfigEntity()
cfg.template_key = 'int8'
self.memory[key] = cfg
return cfg
def test_conv2d_nchw():
with NCHWcInt8Fallback():
# ResNet18 workloads where channels in / out are multiple of oc_block_factor
......
"""Example code to do group convolution."""
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm.task.space import FallbackConfigEntity
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from common import get_all_backend, NCHWcInt8Fallback
def verify_group_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups, add_bias=False, add_relu=False):
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d)" %
(batch, in_channel, in_size, num_filter,
kernel, stride, padding, dilation, groups))
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel // groups, kernel, kernel), name='W')
bias = tvm.placeholder((num_filter, 1, 1), name='bias')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_nchw")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups).astype(dtype)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version):
print("Skip because int8 intrinsics are not available")
return
print("Running on target: %s" % device)
with tvm.target.create(device):
C = topi.nn.group_conv2d_nchw(A, W, stride, padding, dilation, groups, out_dtype=dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.generic.schedule_group_conv2d_nchw([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" %\
(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" % \
(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ["llvm"]:
check_device(device)
oc_block_factor = 4
def verify_group_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups, add_bias=False, add_relu=False):
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d)" %
(batch, in_channel, in_size, num_filter,
kernel, stride, padding, dilation, groups))
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='int8')
W = tvm.placeholder((num_filter, in_channel // groups, kernel, kernel), name='W', dtype='int8')
bias = tvm.placeholder((num_filter // oc_block_factor, 1, 1, oc_block_factor), name='bias',
dtype='int8')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_NCHWc_int8")
def get_ref_data():
a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups).astype(dtype)
# convert to NCHWc
_, _, out_height, out_width = c_np.shape
c_np = c_np.reshape((batch, num_filter // oc_block_factor, oc_block_factor, \
out_height, out_width)).transpose(0, 1, 3, 4, 2)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version):
print("Skip because int8 intrinsics are not available")
return
print("Running on target: %s" % device)
with tvm.target.create(device):
C = topi.nn.group_conv2d_nchw(A, W, stride, padding, dilation, groups, out_dtype=dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.generic.schedule_group_conv2d_nchw([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" %\
(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" % \
(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ["cuda"]:
check_device(device)
def test_group_conv2d_nchw():
# ResNeXt-50 workload
verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32)
verify_group_conv2d_nchw(1, 256, 56, 256, 3, 2, 1, 1, 32)
verify_group_conv2d_nchw(1, 256, 28, 256, 3, 1, 1, 1, 32)
verify_group_conv2d_nchw(1, 512, 28, 512, 3, 2, 1, 1, 32)
verify_group_conv2d_nchw(1, 512, 14, 512, 3, 1, 1, 1, 32)
verify_group_conv2d_nchw(1, 1024, 14, 1024, 3, 2, 1, 1, 32)
verify_group_conv2d_nchw(1, 1024, 7, 1024, 3, 1, 1, 1, 32)
# bias, relu
verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True)
verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32, add_bias=True)
verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True,
add_bias=True)
# dilation
verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 2, 32)
# batch size
verify_group_conv2d_nchw(2, 128, 56, 128, 3, 1, 1, 1, 32)
verify_group_conv2d_nchw(9, 128, 56, 128, 3, 1, 1, 1, 32)
def test_group_conv2d_NCHWc_int8():
with NCHWcInt8Fallback():
# ResNeXt-50 workload
verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32)
verify_group_conv2d_NCHWc_int8(1, 256, 56, 256, 3, 2, 1, 1, 32)
verify_group_conv2d_NCHWc_int8(1, 256, 28, 256, 3, 1, 1, 1, 32)
verify_group_conv2d_NCHWc_int8(1, 512, 28, 512, 3, 2, 1, 1, 32)
verify_group_conv2d_NCHWc_int8(1, 512, 14, 512, 3, 1, 1, 1, 32)
verify_group_conv2d_NCHWc_int8(1, 1024, 14, 1024, 3, 2, 1, 1, 32)
verify_group_conv2d_NCHWc_int8(1, 1024, 7, 1024, 3, 1, 1, 1, 32)
# bias, relu
verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True)
verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_bias=True)
verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True,
add_bias=True)
# dilation
verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 2, 32)
# batch size
verify_group_conv2d_NCHWc_int8(2, 128, 56, 128, 3, 1, 1, 1, 32)
verify_group_conv2d_NCHWc_int8(9, 128, 56, 128, 3, 1, 1, 1, 32)
if __name__ == "__main__":
test_group_conv2d_nchw()
test_group_conv2d_NCHWc_int8()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment