Unverified Commit 681df4fc by Animesh Jain Committed by GitHub

[Strategy] Support for Int8 schedules - CUDA/x86 (#5031)

* [CUDA] Op strategy changes for Int8 schedules.

* Applying Haichen's suggestions.

* Make 4D output work for task extraction.

* Make x86 work.

* Fix lint.

* Lint fixes.

* Tests, comments, out channel a multiple of 4.

* Topi test.

Co-authored-by: Ubuntu <ubuntu@ip-172-31-38-96.us-west-2.compute.internal>
parent 923b4a26
......@@ -1373,8 +1373,8 @@ def _qnn_conv(inputs, attrs, subgraphs, params):
# 3) Clip/cast to change the out dtype.
_res = relay.clip(_res,
a_min=float(tvm.api.min_value(out_dtype).value),
a_max=float(tvm.api.max_value(out_dtype).value))
a_min=float(tvm.tir.op.min_value(out_dtype).value),
a_max=float(tvm.tir.op.max_value(out_dtype).value))
_res = relay.cast(_res, out_dtype)
return _res
......@@ -1647,8 +1647,8 @@ def _qnn_fully_connected(inputs, attrs, subgraphs, params):
_op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale)
rounded_bias = _op.round(multiplied_bias)
clipped_bias = _op.clip(rounded_bias,
a_min=tvm.api.min_value('int32').value,
a_max=tvm.api.max_value('int32').value)
a_min=tvm.tir.op.min_value('int32').value,
a_max=tvm.tir.op.max_value('int32').value)
requantized_bias = _op.cast(clipped_bias, 'int32')
res = _op.nn.bias_add(res, requantized_bias, axis=-1)
enable_float_output = attrs.get_bool('enable_float_output', False)
......
......@@ -85,8 +85,14 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
if groups == 1:
if layout == "NCHW":
# TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
if data.dtype in ('int8', 'uint8') and kernel.dtype in ('int8', 'uint8'):
assert data.dtype == kernel.dtype
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_int8),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_int8),
name="conv2d_nchw_int8.cuda")
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
......
......@@ -264,3 +264,17 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
if is_fast_int8_on_intel():
return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense)
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
#####################
# CUDA legalizations.
#####################
@qnn_conv2d_legalize.register('cuda')
def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
# CUDA prefers the dtypes to be same.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
@qnn_dense_legalize.register('cuda')
def _qnn_dense_legalize_cuda(attrs, inputs, types):
# CUDA prefers the dtypes to be same.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
......@@ -177,6 +177,13 @@ def test_qnn_legalize_qnn_conv2d():
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
###########################################
# Check transformations for CUDA platforms.
###########################################
with tvm.target.create('cuda'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" in legalized_mod.astext()
def test_qnn_legalize_qnn_dense():
def _get_mod(data_dtype, kernel_dtype):
......@@ -257,6 +264,13 @@ def test_qnn_legalize_qnn_dense():
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
###########################################
# Check transformations for CUDA platforms.
###########################################
with tvm.target.create('cuda'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" in legalized_mod.astext()
if __name__ == "__main__":
test_qnn_legalize()
......
......@@ -26,6 +26,7 @@ from tvm import autotvm
from .. import nn
from ..util import get_const_tuple
from .conv2d_winograd import _infer_tile_size
from ..nn import conv2d_legalize
logger = logging.getLogger('topi')
......@@ -135,3 +136,82 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
return relay.nn.conv2d(*inputs, **new_attrs)
return None
@conv2d_legalize.register("cuda")
def _conv2d_legalize(attrs, inputs, arg_types):
"""Legalizes Conv2D op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
# Dilation not supported yet. Return None if dilation is not (1, 1)
dilation = attrs.get_int_tuple("dilation")
if not (dilation[0] == 1 and dilation[1] == 1):
return None
# No legalization for depthwise convolutions yet.
groups = attrs.get_int("groups")
if groups != 1:
return None
# Collect the input tensors.
data_tensor, kernel_tensor = arg_types[0], arg_types[1]
data_dtype = data_tensor.dtype
# Collect the output tensor.
output_tensor = arg_types[2]
# Collect the input exprs.
data, kernel = inputs
# Get the conv attrs
new_attrs = {k: attrs[k] for k in attrs.keys()}
# Get data layout. Return None if not NCHW
data_layout = attrs['data_layout']
kernel_layout = attrs['kernel_layout']
# Pad input and output channels to use int8 schedule.
if data_dtype in ['int8', 'uint8']:
if data_layout == 'NCHW' and kernel_layout == "OIHW":
oc_modified = False
in_channel = data_tensor.shape[1].value
out_channel = kernel_tensor.shape[0].value
# Pad input channel
if in_channel % 4 != 0:
new_in_channel = ((in_channel + 4) // 4) * 4
diff = new_in_channel - in_channel
pad_width = ((0, 0), (0, diff), (0, 0), (0, 0))
data = relay.nn.pad(data, pad_width=pad_width)
kernel = relay.nn.pad(kernel, pad_width=pad_width)
# Pad output channel
new_out_channel = out_channel
if out_channel % 4 != 0:
new_out_channel = ((out_channel + 4) // 4) * 4
diff = new_out_channel - out_channel
kernel = relay.nn.pad(kernel, pad_width=((0, diff), (0, 0), (0, 0), (0, 0)))
oc_modified = True
if oc_modified:
new_attrs['channels'] = new_out_channel
out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)
return out
return None
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=no-value-for-parameter
"""Int8 conv2d in NCHWc layout"""
import tvm
from tvm import te
......@@ -23,10 +24,23 @@ from tvm import autotvm
from .injective import schedule_injective_from_existing
from .tensor_intrin import dp4a
from ..nn.pad import pad
from ..nn.conv2d import unpack_NCHWc_to_nchw
from ..nn.util import get_pad_tuple
from ..util import get_const_tuple, traverse_inline
def conv2d_nchw_int8(data, kernel, strides, padding, dilation, out_dtype='int32'):
"""Compute conv2d internally using conv2d_nchwc layout for int8 dtype"""
assert data.dtype in ('int8', 'uint8')
assert kernel.dtype in ('int8', 'uint8')
assert data.dtype == kernel.dtype
packed_out = conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, "NCHW", out_dtype)
return unpack_NCHWc_to_nchw(packed_out, out_dtype)
def schedule_conv2d_nchw_int8(outs):
"""Create schedule for tensors"""
return schedule_conv2d_NCHWc_int8(outs)
@autotvm.register_topi_compute("conv2d_NCHWc_int8.cuda")
def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_dtype):
"""Convolution operator in NCHW[x]c layout for int8.
......@@ -205,7 +219,13 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output):
output = s.outputs[0].output(0)
# tile and bind spatial axes
if len(s[output].op.axis) == 5:
n, f, y, x, c = s[output].op.axis
else:
# For task extraction of auto-tuning, the expected output is 4D. Since auto-tuning tasks
# are created from scratch, therefore the real auto-tuning will still happen on 5D output.
n, f, y, x = s[output].op.axis
cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
......
......@@ -144,7 +144,8 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
s[data_vec].parallel(parallel_axis)
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
# conv2d_nchwc_int8 has 7D kernel
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis
s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
oc_bn = cfg["tile_oc"].size[-1]
if oc_bn > 1:
......@@ -189,6 +190,8 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
s[CC].unroll(oc_f_inner)
if C != O:
out_ndim = len(s[O].op.axis)
if out_ndim == 5:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
......@@ -196,6 +199,17 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
elif out_ndim == 4:
batch, oc, oh, ow = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
else:
raise ValueError("Unsupported output ndim: %s" % out_ndim)
return s
......@@ -234,7 +248,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
s[data_vec].parallel(parallel_axis)
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
# Conv2d int8 schedule has 7D kernel
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis
s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
oc_bn = cfg["tile_oc"].size[-1]
if oc_bn > 1:
......@@ -277,6 +292,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
s[CC].unroll(oh_inner)
if C != O:
out_ndim = len(s[O].op.axis)
if out_ndim == 5:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
......@@ -286,5 +303,18 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
elif out_ndim == 4:
batch, oc, oh, ow = s[O].op.axis
oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
else:
raise ValueError("Unsupported output ndim: %s" % out_ndim)
return s
......@@ -108,6 +108,76 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str
check_device(device)
def verify_conv2d_nchw_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
padding_sum = pad_top + pad_left + pad_bottom + pad_right
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
in_height = in_width = in_size
A = te.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='int8')
W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W', dtype='int8')
bias = te.placeholder((num_filter, 1, 1), name='bias', dtype='int8')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
def get_ref_data():
a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version):
print("Skip because int8 intrinsics are not available")
return
print("Running on target: %s" % device)
with tvm.target.create(device):
C = topi.cuda.conv2d_nchw_int8(A, W, (stride, stride), padding, (dilation, dilation),
dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.cuda.schedule_conv2d_nchw_int8([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ["cuda"]:
check_device(device)
def test_conv2d_nchw():
with Int8Fallback():
# ResNet18 workloads where channels in / out are multiple of oc_block_factor
......@@ -204,6 +274,17 @@ def test_conv2d_nchw():
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True)
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True)
# Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just
# performing basic testing - one test for all different scenarios - batch, dilation etc..
verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1, add_relu=True)
verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1, dilation=2)
verify_conv2d_nchw_int8(9, 64, 56, 64, 3, 1, 1)
verify_conv2d_nchw_int8(4, 4, 4, 4, 4, 4, 4)
verify_conv2d_nchw_int8(1, 32, 149, 32, 3, 1, 0)
verify_conv2d_nchw_int8(7, 32, 149, 32, 3, 1, 0)
verify_conv2d_nchw_int8(1, 32, 35, 64, 7, 2, (0, 0, 1, 1))
if __name__ == "__main__":
test_conv2d_nchw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment