Commit 06f91dd2 by Wuwei Lin Committed by Tianqi Chen

[TOPI] Add conv2d int8 template (#1735)

parent 46363d0a
...@@ -90,10 +90,12 @@ def compute_conv2d(attrs, inputs, _): ...@@ -90,10 +90,12 @@ def compute_conv2d(attrs, inputs, _):
kernel_layout = attrs["kernel_layout"] kernel_layout = attrs["kernel_layout"]
out_dtype = attrs["out_dtype"] out_dtype = attrs["out_dtype"]
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert layout == "NCHW" or layout == "NHWC" assert layout in ["NCHW", "NHWC", "NCHW4c"]
(dilation_h, dilation_w) = dilation (dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1: if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value") raise ValueError("dilation should be positive value")
elif layout == "NCHW4c" and (dilation_h > 1 or dilation_w > 1):
raise ValueError("not support dilate now")
elif dilation == (1, 1): elif dilation == (1, 1):
kernel = inputs[1] kernel = inputs[1]
elif layout == "NCHW": elif layout == "NCHW":
...@@ -101,7 +103,12 @@ def compute_conv2d(attrs, inputs, _): ...@@ -101,7 +103,12 @@ def compute_conv2d(attrs, inputs, _):
else: #layout == NHWC else: #layout == NHWC
kernel = topi.nn.dilate(inputs[1], [1, dilation_h, dilation_w, 1]) kernel = topi.nn.dilate(inputs[1], [1, dilation_h, dilation_w, 1])
if groups == 1: if groups == 1 and layout == 'NCHW4c' and inputs[0].dtype == 'int8':
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc_int8_prepacked(inputs[0], kernel, strides, padding,
layout, out_dtype=out_dtype)
# pylint: enable=assignment-from-no-return
elif groups == 1:
out = topi.nn.conv2d( out = topi.nn.conv2d(
inputs[0], kernel, strides, padding, layout, out_dtype=out_dtype) inputs[0], kernel, strides, padding, layout, out_dtype=out_dtype)
elif layout == "NCHW" and \ elif layout == "NCHW" and \
...@@ -120,7 +127,7 @@ def compute_conv2d(attrs, inputs, _): ...@@ -120,7 +127,7 @@ def compute_conv2d(attrs, inputs, _):
if attrs.get_bool("use_bias"): if attrs.get_bool("use_bias"):
bias = inputs[2] bias = inputs[2]
expand_axis = 1 if layout == "NCHW" else 0 expand_axis = 1 if layout in ["NCHW", "NCHW4c"] else 0
bias = topi.expand_dims(bias, axis=expand_axis, num_newaxis=2) bias = topi.expand_dims(bias, axis=expand_axis, num_newaxis=2)
out = topi.add(out, bias) out = topi.add(out, bias)
return out return out
...@@ -136,6 +143,8 @@ def schedule_conv2d(attrs, outs, target): ...@@ -136,6 +143,8 @@ def schedule_conv2d(attrs, outs, target):
with tvm.target.create(target): with tvm.target.create(target):
if groups == 1 and layout == "NCHW": if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs) return topi.generic.schedule_conv2d_nchw(outs)
elif groups == 1 and layout == "NCHW4c":
return topi.generic.schedule_conv2d_NCHWc_int8_prepacked(outs)
elif groups == 1 and layout == "NHWC": elif groups == 1 and layout == "NHWC":
return topi.generic.schedule_conv2d_nhwc(outs) return topi.generic.schedule_conv2d_nhwc(outs)
elif groups == channels and layout == "NCHW": elif groups == channels and layout == "NCHW":
......
...@@ -344,7 +344,6 @@ NNVM_REGISTER_OP(_contrib_conv2d_NCHWc) ...@@ -344,7 +344,6 @@ NNVM_REGISTER_OP(_contrib_conv2d_NCHWc)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>) .set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_support_level(2); .set_support_level(2);
NNVM_REGISTER_OP(_contrib_conv2d_winograd_weight_transform) NNVM_REGISTER_OP(_contrib_conv2d_winograd_weight_transform)
.describe(R"code(Weight transformation of winograd fast convolution algorithm. .describe(R"code(Weight transformation of winograd fast convolution algorithm.
Separate this into another nnvm symbol in order to enable Precompute Pass to compute the Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
......
...@@ -9,9 +9,10 @@ from ..util import get_const_int, get_const_tuple, traverse_inline ...@@ -9,9 +9,10 @@ from ..util import get_const_int, get_const_tuple, traverse_inline
from .conv2d_direct import schedule_direct_cuda from .conv2d_direct import schedule_direct_cuda
from .conv2d_winograd import winograd_cuda, schedule_winograd_cuda from .conv2d_winograd import winograd_cuda, schedule_winograd_cuda
from .conv2d_int8 import conv2d_NCHWc_int8, schedule_conv2d_NCHWc_int8
@autotvm.register_topi_compute(nn.conv2d, ['cuda', 'gpu'], ['direct', 'winograd']) @autotvm.register_topi_compute(nn.conv2d, ['cuda', 'gpu'], ['direct', 'winograd', 'int8'])
def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='float32'): def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='float32'):
"""Conv2D operator for cuda backend. """Conv2D operator for cuda backend.
...@@ -21,10 +22,13 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='f ...@@ -21,10 +22,13 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='f
The config for this template The config for this template
data : tvm.Tensor data : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width] 4-D with shape [batch, in_channel, in_height, in_width] or
5-D with shape [batch, ic_chunk, in_height, in_width, ic_block]
kernel : tvm.Tensor kernel : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width] 4-D with shape [num_filter, in_channel, filter_height, filter_width] or
6-D with shape [num_filter_chunk, in_channel_chunk, filter_height,
filter_width, num_filter_block, in_channel_block]
strides : int or a list/tuple of two ints strides : int or a list/tuple of two ints
stride size, or [stride_height, stride_width] stride size, or [stride_height, stride_width]
...@@ -98,6 +102,9 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='f ...@@ -98,6 +102,9 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='f
if cfg.template_key == 'winograd': if cfg.template_key == 'winograd':
return winograd_cuda(cfg, data, kernel, strides, padding, layout, out_dtype, return winograd_cuda(cfg, data, kernel, strides, padding, layout, out_dtype,
pre_computed=False) pre_computed=False)
if cfg.template_key == 'int8':
return conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, layout, out_dtype,
pre_computed=False)
if layout == 'NCHW': if layout == 'NCHW':
return nn.conv2d_nchw(data, kernel, strides, padding, out_dtype) return nn.conv2d_nchw(data, kernel, strides, padding, out_dtype)
...@@ -108,7 +115,7 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='f ...@@ -108,7 +115,7 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='f
@autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, ["cuda", "gpu"], @autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, ["cuda", "gpu"],
["direct", 'winograd']) ["direct", 'winograd', "int8"])
def schedule_conv2d_nchw_cuda(cfg, outs): def schedule_conv2d_nchw_cuda(cfg, outs):
"""TOPI schedule callback of conv2d for cuda gpu """TOPI schedule callback of conv2d for cuda gpu
...@@ -138,6 +145,8 @@ def schedule_conv2d_nchw_cuda(cfg, outs): ...@@ -138,6 +145,8 @@ def schedule_conv2d_nchw_cuda(cfg, outs):
schedule_direct_cuda(cfg, s, op.output(0)) schedule_direct_cuda(cfg, s, op.output(0))
if op.tag == 'conv2d_nchw_winograd': if op.tag == 'conv2d_nchw_winograd':
schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=False) schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=False)
if op.tag == "conv2d_NCHWc_int8":
schedule_conv2d_NCHWc_int8(cfg, s, op.output(0), pre_computed=False)
traverse_inline(s, outs[0].op, _callback) traverse_inline(s, outs[0].op, _callback)
return s return s
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""The templates for cuda conv2d operators""" """The templates for cuda conv2d operators"""
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from ..util import get_const_tuple
def schedule_direct_cuda(cfg, s, conv): def schedule_direct_cuda(cfg, s, conv):
"""schedule optimized for batch size = 1""" """schedule optimized for batch size = 1"""
...@@ -94,3 +95,7 @@ def schedule_direct_cuda(cfg, s, conv): ...@@ -94,3 +95,7 @@ def schedule_direct_cuda(cfg, s, conv):
# unroll # unroll
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
N, CO, OH, OW = get_const_tuple(output.shape)
_, KH, KW, CI = get_const_tuple(kernel.shape)
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW)
...@@ -375,6 +375,13 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -375,6 +375,13 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
if cfg.template_key == 'direct': if cfg.template_key == 'direct':
return None return None
if cfg.template_key == 'int8':
assert 'cuda' in tvm.target.current_target().keys
new_attrs['layout'] = 'NCHW4c'
new_attrs['out_layout'] = 'NCHW4c'
new_attrs['kernel_layout'] = 'OIHW4o4i'
return sym.conv2d(*copy_inputs, **new_attrs)
# pre-compute weight transformation in winograd # pre-compute weight transformation in winograd
tile_size = _infer_tile_size(tinfos[0], tinfos[1]) tile_size = _infer_tile_size(tinfos[0], tinfos[1])
......
...@@ -140,6 +140,24 @@ def schedule_conv2d_winograd_without_weight_transform(outs): ...@@ -140,6 +140,24 @@ def schedule_conv2d_winograd_without_weight_transform(outs):
@tvm.target.generic_func @tvm.target.generic_func
def schedule_conv2d_NCHWc_int8_prepacked(outs):
"""Schedule for conv2d NCHWc int8 with prepacked data and kernel
Parameters
----------
outs: Array of Tensor
The computation graph description of this operator
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_conv2d_transpose_nchw(outs): def schedule_conv2d_transpose_nchw(outs):
"""Schedule for conv2d_transpose_nchw """Schedule for conv2d_transpose_nchw
......
...@@ -423,3 +423,37 @@ def conv2d_winograd_without_weight_transform(input, filter, strides, padding, ...@@ -423,3 +423,37 @@ def conv2d_winograd_without_weight_transform(input, filter, strides, padding,
4-D with shape [batch, out_height, out_width, out_channel] 4-D with shape [batch, out_height, out_width, out_channel]
""" """
raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform") raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform")
@tvm.target.generic_func
def conv2d_NCHWc_int8_prepacked(data, kernel, stride, padding, layout, out_dtype):
"""Convolution operator in NCHW[x]c layout for int8. Data and kernel should be packed in
advance.
Parameters
----------
data : tvm.Tensor
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
kernel : tvm.Tensor
6-D with shape [num_filter_chunk, in_channel_chunk, filter_height,
filter_width, num_filter_block, in_channel_block]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding: int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
layout : str
layout of data
out_dtype: str
The output type. This is used for mixed precision.
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
raise ValueError("missing register for topi.nn.conv2d_NCHWc_int8_prepacked")
"""Example code to do convolution."""
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm.task.space import FallbackConfigEntity
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from common import get_all_backend
oc_block_factor = 4
def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='int8')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W', dtype='int8')
bias = tvm.placeholder((num_filter // oc_block_factor, 1, 1, oc_block_factor), name='bias',
dtype='int8')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
def get_ref_data():
a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
# convert to NCHWc
_, _, out_height, out_width = c_np.shape
c_np = c_np.reshape((batch, num_filter // oc_block_factor, oc_block_factor, \
out_height, out_width)).transpose(0, 1, 3, 4, 2)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version):
print("Skip because int8 intrinsics are not available")
return
print("Running on target: %s" % device)
with tvm.target.create(device):
dW = topi.nn.dilate(W, (1, 1, dilation, dilation))
C = topi.nn.conv2d(A, dW, (stride, stride), (padding, padding),
layout='NCHW', out_dtype=dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.generic.schedule_conv2d_nchw([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ["cuda"]:
check_device(device)
class NCHWcInt8Fallback(autotvm.FallbackContext):
def _query_inside(self, target, workload):
key = (target, workload)
if key in self.memory:
return self.memory[key]
cfg = FallbackConfigEntity()
cfg.template_key = 'int8'
self.memory[key] = cfg
return cfg
def test_conv2d_nchw():
with NCHWcInt8Fallback():
# ResNet18 workloads where channels in / out are multiple of oc_block_factor
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 64, 56, 128, 3, 2, 1)
verify_conv2d_NCHWc_int8(1, 64, 56, 128, 1, 2, 0)
verify_conv2d_NCHWc_int8(1, 128, 28, 128, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 128, 28, 256, 3, 2, 1)
verify_conv2d_NCHWc_int8(1, 128, 28, 256, 1, 2, 0)
verify_conv2d_NCHWc_int8(1, 256, 14, 256, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 256, 14, 512, 3, 2, 1)
verify_conv2d_NCHWc_int8(1, 256, 14, 512, 1, 2, 0)
verify_conv2d_NCHWc_int8(1, 512, 7, 512, 3, 1, 1)
# bias, relu
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, add_relu=True)
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, add_bias=True)
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
# batch size
verify_conv2d_NCHWc_int8(4, 64, 56, 64, 3, 1, 1)
verify_conv2d_NCHWc_int8(9, 64, 56, 64, 3, 1, 1)
# weird workloads
verify_conv2d_NCHWc_int8(4, 4, 4, 4, 4, 4, 4)
# inception v3 workloads where channels in / out are multiple of oc_block_factor
verify_conv2d_NCHWc_int8(1, 32, 149, 32, 3, 1, 0)
verify_conv2d_NCHWc_int8(1, 32, 147, 64, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 64, 73, 80, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 80, 73, 192, 3, 1, 0)
verify_conv2d_NCHWc_int8(1, 192, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 192, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 48, 35, 64, 5, 1, 2)
verify_conv2d_NCHWc_int8(1, 64, 35, 96, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 96, 35, 96, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 192, 35, 32, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 256, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 256, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 288, 35, 64, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 288, 35, 48, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 288, 35, 384, 3, 2, 0)
verify_conv2d_NCHWc_int8(1, 96, 35, 96, 3, 2, 0)
verify_conv2d_NCHWc_int8(1, 768, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 768, 17, 128, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 128, 17, 128, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 128, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc_int8(1, 128, 17, 128, 7, 1, 3)
verify_conv2d_NCHWc_int8(1, 128, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 768, 17, 160, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 160, 17, 160, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 160, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc_int8(1, 160, 17, 160, 7, 1, 3)
verify_conv2d_NCHWc_int8(1, 160, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 192, 17, 192, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 192, 17, 192, 7, 1, 3)
verify_conv2d_NCHWc_int8(1, 192, 17, 320, 3, 2, 0)
verify_conv2d_NCHWc_int8(1, 192, 17, 192, 3, 2, 0)
verify_conv2d_NCHWc_int8(1, 1280, 8, 320, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 1280, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 384, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 384, 8, 384, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 1280, 8, 448, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 448, 8, 384, 3, 1, 1)
verify_conv2d_NCHWc_int8(1, 1280, 8, 192, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 2048, 8, 320, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 2048, 8, 384, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 2048, 8, 448, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 2048, 8, 192, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 1024, 19, 84, 3, 1, 1)
if __name__ == "__main__":
test_conv2d_nchw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment