Commit 06f91dd2 by Wuwei Lin Committed by Tianqi Chen

[TOPI] Add conv2d int8 template (#1735)

parent 46363d0a
......@@ -90,10 +90,12 @@ def compute_conv2d(attrs, inputs, _):
kernel_layout = attrs["kernel_layout"]
out_dtype = attrs["out_dtype"]
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert layout == "NCHW" or layout == "NHWC"
assert layout in ["NCHW", "NHWC", "NCHW4c"]
(dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
elif layout == "NCHW4c" and (dilation_h > 1 or dilation_w > 1):
raise ValueError("not support dilate now")
elif dilation == (1, 1):
kernel = inputs[1]
elif layout == "NCHW":
......@@ -101,7 +103,12 @@ def compute_conv2d(attrs, inputs, _):
else: #layout == NHWC
kernel = topi.nn.dilate(inputs[1], [1, dilation_h, dilation_w, 1])
if groups == 1:
if groups == 1 and layout == 'NCHW4c' and inputs[0].dtype == 'int8':
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc_int8_prepacked(inputs[0], kernel, strides, padding,
layout, out_dtype=out_dtype)
# pylint: enable=assignment-from-no-return
elif groups == 1:
out = topi.nn.conv2d(
inputs[0], kernel, strides, padding, layout, out_dtype=out_dtype)
elif layout == "NCHW" and \
......@@ -120,7 +127,7 @@ def compute_conv2d(attrs, inputs, _):
if attrs.get_bool("use_bias"):
bias = inputs[2]
expand_axis = 1 if layout == "NCHW" else 0
expand_axis = 1 if layout in ["NCHW", "NCHW4c"] else 0
bias = topi.expand_dims(bias, axis=expand_axis, num_newaxis=2)
out = topi.add(out, bias)
return out
......@@ -136,6 +143,8 @@ def schedule_conv2d(attrs, outs, target):
with tvm.target.create(target):
if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs)
elif groups == 1 and layout == "NCHW4c":
return topi.generic.schedule_conv2d_NCHWc_int8_prepacked(outs)
elif groups == 1 and layout == "NHWC":
return topi.generic.schedule_conv2d_nhwc(outs)
elif groups == channels and layout == "NCHW":
......
......@@ -344,7 +344,6 @@ NNVM_REGISTER_OP(_contrib_conv2d_NCHWc)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_support_level(2);
NNVM_REGISTER_OP(_contrib_conv2d_winograd_weight_transform)
.describe(R"code(Weight transformation of winograd fast convolution algorithm.
Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
......
......@@ -9,9 +9,10 @@ from ..util import get_const_int, get_const_tuple, traverse_inline
from .conv2d_direct import schedule_direct_cuda
from .conv2d_winograd import winograd_cuda, schedule_winograd_cuda
from .conv2d_int8 import conv2d_NCHWc_int8, schedule_conv2d_NCHWc_int8
@autotvm.register_topi_compute(nn.conv2d, ['cuda', 'gpu'], ['direct', 'winograd'])
@autotvm.register_topi_compute(nn.conv2d, ['cuda', 'gpu'], ['direct', 'winograd', 'int8'])
def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='float32'):
"""Conv2D operator for cuda backend.
......@@ -21,10 +22,13 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='f
The config for this template
data : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
4-D with shape [batch, in_channel, in_height, in_width] or
5-D with shape [batch, ic_chunk, in_height, in_width, ic_block]
kernel : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width]
4-D with shape [num_filter, in_channel, filter_height, filter_width] or
6-D with shape [num_filter_chunk, in_channel_chunk, filter_height,
filter_width, num_filter_block, in_channel_block]
strides : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
......@@ -98,6 +102,9 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='f
if cfg.template_key == 'winograd':
return winograd_cuda(cfg, data, kernel, strides, padding, layout, out_dtype,
pre_computed=False)
if cfg.template_key == 'int8':
return conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, layout, out_dtype,
pre_computed=False)
if layout == 'NCHW':
return nn.conv2d_nchw(data, kernel, strides, padding, out_dtype)
......@@ -108,7 +115,7 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='f
@autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, ["cuda", "gpu"],
["direct", 'winograd'])
["direct", 'winograd', "int8"])
def schedule_conv2d_nchw_cuda(cfg, outs):
"""TOPI schedule callback of conv2d for cuda gpu
......@@ -138,6 +145,8 @@ def schedule_conv2d_nchw_cuda(cfg, outs):
schedule_direct_cuda(cfg, s, op.output(0))
if op.tag == 'conv2d_nchw_winograd':
schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=False)
if op.tag == "conv2d_NCHWc_int8":
schedule_conv2d_NCHWc_int8(cfg, s, op.output(0), pre_computed=False)
traverse_inline(s, outs[0].op, _callback)
return s
......@@ -2,6 +2,7 @@
"""The templates for cuda conv2d operators"""
import tvm
from tvm import autotvm
from ..util import get_const_tuple
def schedule_direct_cuda(cfg, s, conv):
"""schedule optimized for batch size = 1"""
......@@ -94,3 +95,7 @@ def schedule_direct_cuda(cfg, s, conv):
# unroll
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
N, CO, OH, OW = get_const_tuple(output.shape)
_, KH, KW, CI = get_const_tuple(kernel.shape)
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW)
......@@ -375,6 +375,13 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
if cfg.template_key == 'direct':
return None
if cfg.template_key == 'int8':
assert 'cuda' in tvm.target.current_target().keys
new_attrs['layout'] = 'NCHW4c'
new_attrs['out_layout'] = 'NCHW4c'
new_attrs['kernel_layout'] = 'OIHW4o4i'
return sym.conv2d(*copy_inputs, **new_attrs)
# pre-compute weight transformation in winograd
tile_size = _infer_tile_size(tinfos[0], tinfos[1])
......
......@@ -140,6 +140,24 @@ def schedule_conv2d_winograd_without_weight_transform(outs):
@tvm.target.generic_func
def schedule_conv2d_NCHWc_int8_prepacked(outs):
"""Schedule for conv2d NCHWc int8 with prepacked data and kernel
Parameters
----------
outs: Array of Tensor
The computation graph description of this operator
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_conv2d_transpose_nchw(outs):
"""Schedule for conv2d_transpose_nchw
......
......@@ -423,3 +423,37 @@ def conv2d_winograd_without_weight_transform(input, filter, strides, padding,
4-D with shape [batch, out_height, out_width, out_channel]
"""
raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform")
@tvm.target.generic_func
def conv2d_NCHWc_int8_prepacked(data, kernel, stride, padding, layout, out_dtype):
"""Convolution operator in NCHW[x]c layout for int8. Data and kernel should be packed in
advance.
Parameters
----------
data : tvm.Tensor
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
kernel : tvm.Tensor
6-D with shape [num_filter_chunk, in_channel_chunk, filter_height,
filter_width, num_filter_block, in_channel_block]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding: int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
layout : str
layout of data
out_dtype: str
The output type. This is used for mixed precision.
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
raise ValueError("missing register for topi.nn.conv2d_NCHWc_int8_prepacked")
"""Example code to do convolution."""
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm.task.space import FallbackConfigEntity
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from common import get_all_backend
oc_block_factor = 4
def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='int8')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W', dtype='int8')
bias = tvm.placeholder((num_filter // oc_block_factor, 1, 1, oc_block_factor), name='bias',
dtype='int8')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
def get_ref_data():
a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
# convert to NCHWc
_, _, out_height, out_width = c_np.shape
c_np = c_np.reshape((batch, num_filter // oc_block_factor, oc_block_factor, \
out_height, out_width)).transpose(0, 1, 3, 4, 2)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version):
print("Skip because int8 intrinsics are not available")
return
print("Running on target: %s" % device)
with tvm.target.create(device):
dW = topi.nn.dilate(W, (1, 1, dilation, dilation))
C = topi.nn.conv2d(A, dW, (stride, stride), (padding, padding),
layout='NCHW', out_dtype=dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.generic.schedule_conv2d_nchw([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ["cuda"]:
check_device(device)
class NCHWcInt8Fallback(autotvm.FallbackContext):
def _query_inside(self, target, workload):
key = (target, workload)
if key in self.memory:
return self.memory[key]
cfg = FallbackConfigEntity()
cfg.template_key = 'int8'
self.memory[key] = cfg
return cfg
def test_conv2d_nchw():
with NCHWcInt8Fallback():
# ResNet18 workloads where channels in / out are multiple of oc_block_factor
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 64, 56, 128, 3, 2, 1)
verify_conv2d_NCHWc_int8(1, 64, 56, 128, 1, 2, 0)
verify_conv2d_NCHWc_int8(1, 128, 28, 128, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 128, 28, 256, 3, 2, 1)
verify_conv2d_NCHWc_int8(1, 128, 28, 256, 1, 2, 0)
verify_conv2d_NCHWc_int8(1, 256, 14, 256, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 256, 14, 512, 3, 2, 1)
verify_conv2d_NCHWc_int8(1, 256, 14, 512, 1, 2, 0)
verify_conv2d_NCHWc_int8(1, 512, 7, 512, 3, 1, 1)
# bias, relu
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, add_relu=True)
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, add_bias=True)
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
# batch size
verify_conv2d_NCHWc_int8(4, 64, 56, 64, 3, 1, 1)
verify_conv2d_NCHWc_int8(9, 64, 56, 64, 3, 1, 1)
# weird workloads
verify_conv2d_NCHWc_int8(4, 4, 4, 4, 4, 4, 4)
# inception v3 workloads where channels in / out are multiple of oc_block_factor
verify_conv2d_NCHWc_int8(1, 32, 149, 32, 3, 1, 0)
verify_conv2d_NCHWc_int8(1, 32, 147, 64, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 64, 73, 80, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 80, 73, 192, 3, 1, 0)
verify_conv2d_NCHWc_int8(1, 192, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 192, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 48, 35, 64, 5, 1, 2)
verify_conv2d_NCHWc_int8(1, 64, 35, 96, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 96, 35, 96, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 192, 35, 32, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 256, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 256, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 288, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 288, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 288, 35, 384, 3, 2, 0)
verify_conv2d_NCHWc_int8(1, 96, 35, 96, 3, 2, 0)
verify_conv2d_NCHWc_int8(1, 768, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 768, 17, 128, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 128, 17, 128, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 128, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc_int8(1, 128, 17, 128, 7, 1, 3)
verify_conv2d_NCHWc_int8(1, 128, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 768, 17, 160, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 160, 17, 160, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 160, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc_int8(1, 160, 17, 160, 7, 1, 3)
verify_conv2d_NCHWc_int8(1, 160, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 192, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 192, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc_int8(1, 192, 17, 320, 3, 2, 0)
verify_conv2d_NCHWc_int8(1, 192, 17, 192, 3, 2, 0)
verify_conv2d_NCHWc_int8(1, 1280, 8, 320, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 1280, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 384, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 384, 8, 384, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 1280, 8, 448, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 448, 8, 384, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 1280, 8, 192, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 2048, 8, 320, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 2048, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 2048, 8, 448, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 2048, 8, 192, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 1024, 19, 84, 3, 1, 1)
if __name__ == "__main__":
test_conv2d_nchw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment