Commit 703ed9b7 by Animesh Jain Committed by Yizhi Liu

[x86 schedule] Fallback schedule for Int8 depthwise. (#4733)

parent 67b97e5a
......@@ -1182,6 +1182,35 @@ def test_conv2d_int8_intrinsics():
assert "vpmulld" in asm and "vpadd" in asm
def test_depthwise_conv2d_int8():
input_dtype = 'uint8'
weight_dtype = 'int8'
output_dtype = 'int32'
data_shape = (1, 64, 56, 56)
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
kernel_shape = (64, 1, 3, 3)
weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype))
y = relay.nn.conv2d(x, weight,
kernel_size=(3, 3),
groups=64,
padding=(1, 1),
dilation=(1, 1),
out_dtype=output_dtype)
func = relay.Function([x, weight], y)
wdata = np.random.rand(*kernel_shape) * 10
parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}
targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]
llvm_version = tvm.codegen.llvm_version_major()
for target in targets:
if llvm_version >= 8:
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=parameters)
def test_bitserial_conv2d_infer_type():
# Basic shape test with ambiguous batch.
n, c, h, w = tvm.size_var("n"), 32, 224, 224
......@@ -1234,3 +1263,4 @@ if __name__ == "__main__":
test_upsampling()
test_upsampling3d()
test_conv2d_int8_intrinsics()
test_depthwise_conv2d_int8()
......@@ -28,6 +28,7 @@ from ..generic import conv2d as conv2d_generic
from ..nn.util import get_pad_tuple
from ..util import get_const_tuple
from ..nn.conv2d import conv2d_NCHWc_int8
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
from .. import nn
from . import conv2d_avx_1x1, conv2d_avx_common
......@@ -36,7 +37,12 @@ def _get_default_config_int8(cfg, data, kernel, strides, padding, out_dtype, is_
"""
Get default schedule config for the workload
"""
assert not is_depthwise, "Depthwise Int8 not supported"
if is_depthwise:
# Fallback to FP32 default config until a VNNI schedule is defined.
wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
from .depthwise_conv2d import _fallback_schedule
_fallback_schedule(cfg, wkl)
else:
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
if is_kernel_1x1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment