Unverified Commit 5cc17649 by Yizhi Liu Committed by GitHub

[topi] add ARM v8.2 udot (uint8) support (#3978)

* [topi] add ARM v8.2 udot (uint8) support

* fix test case

* fix common conv2d schedule

* add back fp32_time in test

* fix lint

* fix doc, add support for int32_lanes=4, signed int

* fix lint

* add ic_bn % 4 checker in schedule
parent 85a1d3ff
......@@ -3,6 +3,7 @@
from . import conv2d
from . import depthwise_conv2d
from . import conv2d_transpose
from . import conv2d_int8
from . import bitserial_conv2d
from . import bitserial_dense
from . import injective
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D int8 schedule on ARM"""
import tvm
from tvm import autotvm
from .. import generic, tag
from ..util import get_const_tuple
from ..nn.conv2d import conv2d_NCHWc_int8
from ..generic import conv2d as conv2d_generic
from .. import nn
from ..nn.conv2d import _get_workload as _get_conv2d_workload
from .tensor_intrin import dot_int8_int8_int32
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype):
"""
Get default int8 schedule config for the workload
"""
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype)
is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
if is_kernel_1x1:
conv2d_generic.fallback_schedule_cpu_1x1_int8(
cfg, wkl, int32_lanes=2, num_int8_elements=4)
else:
conv2d_generic.fallback_schedule_cpu_common_int8(
cfg, wkl, int32_lanes=2, num_int8_elements=4)
@autotvm.register_topi_compute(conv2d_NCHWc_int8, ['arm_cpu'], 'direct')
def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides,
padding, dilation, layout, out_layout, out_dtype):
# layout and out_layout are not used here,
# we keep them for debug convenience when dumping autotvm workload
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, n_elems = get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn
# If no config was set, we can fallback to NCHW config.
if cfg.is_fallback:
_get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
tvm.placeholder((num_filter, in_channel, kh, kw), dtype=kernel.dtype),
strides, padding, out_dtype)
return nn.conv2d_NCHWc_int8_compute(data,
kernel,
strides,
padding,
dilation,
layout,
out_layout,
out_dtype)
@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc_int8, ['arm_cpu'], ['direct'])
def _schedule_conv2d_NCHWc_int8(cfg, outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_NCHWc_int8' in op.tag:
conv_out = op.output(0)
kernel = conv_out.op.input_tensors[1]
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0] \
if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
else data_vec
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
args = [s, cfg, data_vec, conv_out, outs[0]]
# int8 conv kernel is 7-dim
_, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape)
dtype = "uint" if data.dtype == "uint8" else "int"
if kh == 1 and kw == 1:
conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(
*args, int32_lanes=4, intrin=dot_int8_int8_int32(int32_lanes=4, dtype=dtype))
else:
conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(
*args, int32_lanes=4, intrin=dot_int8_int8_int32(int32_lanes=4, dtype=dtype))
scheduled_ops.append(op)
traverse(outs[0].op)
return s
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D int8 schedule on ARM"""
import tvm
def dot_int8_int8_int32(int32_lanes, dtype='uint'):
"""
Int8 dot product by every 4 elements using ARM v8.2 udot.
This function takes two arrays of int8 datatype -- data[4] and
kernel[int32_lanes][4] -- and computes a dot product of data[4] with every
4 elements of kernels, resulting in output[int32_lanes] of uint32 datatype.
The pseudo code is as follows.
.. code-block:: c
void dot_int8_int8_int32(int8 data[4], int8 kernel[16][4], int32 output[16]){
for (int i = 0; i < int32_lanes; i++){
out[i] = 0;
for (int k = 0; k < 4; k++){
out[i] += data[k] * kernel[i][k]
}
}
}
Physically, the kernel array sits in a vector register and
the data[4] is broadcasted to another vector register. This
function returns a TensorIntrin that can be used to tensorize
a schedule.
Parameters
----------
int32_lanes: int
How many int32/uint32 to produce
dtype: str, optional, {"uint", "int"}
Whether it works on unsigned int or signed int
Returns
-------
intrin : TensorIntrin
The ARM uint8 TensorIntrin that can be used in tensorizing schedule
"""
num_int8_elements = 4 # 4 int8 elements in int32
data = tvm.placeholder((num_int8_elements,), dtype='%s8' % dtype, name='data')
kernel = tvm.placeholder((int32_lanes, num_int8_elements), dtype='%s8' % dtype, name='kernel')
k = tvm.reduce_axis((0, num_int8_elements), name='k')
C = tvm.compute((int32_lanes,),
lambda i: tvm.sum(data[k].astype('%s32' % dtype) *
kernel[i, k].astype('%s32' % dtype),
axis=k), name="C")
a_buffer = tvm.decl_buffer(data.shape, dtype='%s8' % dtype, name="a_buffer",
offset_factor=1,
strides=[1])
b_buffer = tvm.decl_buffer(kernel.shape, dtype='%s8' % dtype, name="b_buffer",
offset_factor=1,
strides=[tvm.var('s'), 1])
def _intrin_func(ins, outs):
def _instr(index):
ib = tvm.ir_builder.create()
if index == 1:
ib.emit(outs[0].vstore(0, tvm.const(0, '%s32x%d' % (dtype, int32_lanes))))
return ib.get()
dtype_a = '%s8x%d' % (dtype, num_int8_elements)
dtype_b = '%s8x%d' % (dtype, int32_lanes * num_int8_elements)
dtype_c = '%s32x%d' % (dtype, int32_lanes)
a_int8 = ins[0].vload([0], dtype_a)
re_int32 = tvm.call_pure_intrin('%s32' % dtype, 'reinterpret', a_int8)
# broadcast a
vec_ai32 = re_int32.astype(dtype_c)
vec_a = tvm.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32)
vec_b = ins[1].vload([0, 0], dtype_b)
vec_c = outs[0].vload([0], dtype_c)
inst = 'udot' if dtype == 'uint' else 'sdot'
inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % (
inst, int32_lanes, int32_lanes * num_int8_elements)
vdot = tvm.call_llvm_intrin(dtype_c,
inst,
tvm.const(2, 'uint32'),
vec_c, vec_a, vec_b)
ib.emit(outs[0].vstore(0, vdot))
return ib.get()
# body, reset, update
return _instr(0), _instr(1), _instr(2)
with tvm.build_config(offset_factor=1, partition_const_loop=True):
return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=unused-argument, redefined-builtin
"""Generic convolution schedules"""
from __future__ import absolute_import as _abs
import tvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from ..util import get_const_tuple
def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements):
"""Fallback schedule for conv2d int8 on cpu.
Normally the inner most pattern takes two int8/uint8 tensors
data[num_int8_elements] and kernel[int32_lanes, num_int8_elements],
produces a dot product int32/uint32 output[int32_lanes].
Parameters
----------
int32_lanes : int
How many numbers of int32/uint32 will be produced using intrinsic.
This is related to output channel.
num_int8_elements : int
How many numbers of input int32/uint32 will be multiplied and reduced.
This is related to input channel.
"""
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
assert wkl.out_filter % int32_lanes == 0, \
"wkl.out_filter=%d, int32_lanes=%d" % (wkl.out_filter, int32_lanes)
assert wkl.in_filter % num_int8_elements == 0, \
"wkl.in_filter=%d, num_int8_elements=%d" % (wkl.in_filter, num_int8_elements)
oc_bn = int32_lanes
ic_bn = 1
for bn in range(oc_bn, 0, -4):
if wkl.in_filter % bn == 0:
ic_bn = bn
break
reg_n = 1
for n in range(31, 0, -1):
if out_width % n == 0:
reg_n = n
break
cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
cfg["unroll_kw"] = OtherOptionEntity(False)
def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements):
"""Fallback schedule for 1x1 conv2d int8 on cpu.
Normally the inner most pattern takes two int8/uint8 tensors
data[num_int8_elements] and kernel[int32_lanes, num_int8_elements],
produces a dot product int32/uint32 output[int32_lanes].
Parameters
----------
int32_lanes : int
How many numbers of int32/uint32 will be produced using intrinsic.
This is related to output channel.
num_int8_elements : int
How many numbers of input int32/uint32 will be multiplied and reduced.
This is related to input channel.
"""
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
assert wkl.out_filter % int32_lanes == 0, \
"wkl.out_filter=%d, int32_lanes=%d" % (wkl.out_filter, int32_lanes)
assert wkl.in_filter % num_int8_elements == 0, \
"wkl.in_filter=%d, num_int8_elements=%d" % (wkl.in_filter, num_int8_elements)
oc_bn = int32_lanes
ic_bn = 1
for bn in range(oc_bn, 0, -4):
if wkl.in_filter % bn == 0:
ic_bn = bn
break
for ow_factor in range(out_width, 0, -1):
if out_width % ow_factor == 0:
for oh_factor in range(out_height, 0, -1):
if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_oh"] = OtherOptionEntity(oh_factor)
cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor])
return
raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last, int32_lanes=16, intrin=None):
"""
Defines the schedule for INT8 for Intel and ARM machines
Uses the Intel/ARM intrinsics to use INT8 operations
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""
reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
_, _, _, _, ic_bn = get_const_tuple(data.shape)
_, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, _ = s[A].op.axis
parallel_axis = s[A].fuse(batch, ic_chunk, ih)
s[A].parallel(parallel_axis)
# schedule 5-D NCHW[x]c conv
C, O = conv_out, last
CC = s.cache_write(C, 'global')
batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[C].fuse(batch, oc_chunk, oh)
s[C].vectorize(oc_block)
if C == O:
s[C].parallel(parallel_axis)
s[CC].compute_at(s[C], ow_chunk)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
assert oc_bn % int32_lanes == 0
assert ic_bn % 4 == 0 # 4 (u)int8 elements in (u)int32
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
if unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw,
ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].unroll(kw)
else:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, kw, ic_f_inner,
ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
if intrin is not None:
s[CC].tensorize(oc_s_inner, intrin)
s[CC].unroll(ow_block)
s[CC].unroll(oc_f_inner)
if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s
def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last, int32_lanes=16, intrin=None):
"""
Defines the 1x1 conv schedule for INT8 for Intel and ARM machines
Uses the Intel/ARM intrinsics to use INT8 operations
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""
oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1]
_, _, _, _, ic_bn = get_const_tuple(data.shape)
_, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
# schedule data
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
parallel_axis = s[A].fuse(batch, ic_chunk, ih)
s[A].parallel(parallel_axis)
C, O = conv_out, last
CC = s.cache_write(C, 'global')
batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[C].split(ow, factor=ow_factor)
s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
s[C].vectorize(oc_block)
parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer)
s[CC].compute_at(s[C], parallel_axis)
if C == O:
s[C].parallel(parallel_axis)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
assert oc_bn % int32_lanes == 0
assert ic_bn % 4 == 0 # 4 (u)int8 elements in (u)int32
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor)
s[CC].reorder(oc_chunk, oh_outer, ow_outer, kh, kw, ic_outer, ic_f_inner, oh_inner,
ow_inner, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].fuse(oc_chunk, oh_outer)
if intrin is not None:
s[CC].tensorize(oc_s_inner, intrin)
s[CC].unroll(ow_inner)
s[CC].unroll(oh_inner)
if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s
......@@ -595,19 +595,11 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout,
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
target = tvm.target.current_target(allow_none=False)
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \
get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn
groups = ic_chunk // ic_chunk_group
# Since the weight is 7-D and the last element size is 4, we have to
# check ic_bn should be a multiple of 4.
# Similary, oc_bn has to be a multiple of 4.
assert ic_bn % 4 == 0
assert oc_bn % 16 == 0
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
......
......@@ -22,6 +22,7 @@ from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from ..nn.pad import pad
from ..nn.util import infer_pad, get_pad_tuple
from ..generic import conv2d as conv2d_generic
from ..util import get_const_tuple, simplify
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .util import get_fp32_len
......@@ -57,36 +58,6 @@ def _fallback_schedule(cfg, wkl):
raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
def _fallback_schedule_int8(cfg, wkl):
simd_width = get_fp32_len()
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
oc_bn = 16
assert wkl.out_filter % oc_bn == 0
ic_bn = 1
for bn in range(oc_bn, 0, -4):
if wkl.in_filter % bn == 0:
ic_bn = bn
break
assert wkl.in_filter % 4 == 0
for ow_factor in range(out_width, 0, -1):
if out_width % ow_factor == 0:
for oh_factor in range(out_height, 0, -1):
if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_oh"] = OtherOptionEntity(oh_factor)
cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor])
return
raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
# fetch schedule
ic_bn, oc_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
......@@ -210,71 +181,9 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
"""
Defines the schedule for INT8 for intel machines
Uses the Intel intrinsics to use INT8 operations
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""
int32_lanes = 16
oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1]
_, _, _, _, ic_bn = get_const_tuple(data.shape)
_, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
# schedule data
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
parallel_axis = s[A].fuse(batch, ic_chunk, ih)
s[A].parallel(parallel_axis)
C, O = conv_out, last
CC = s.cache_write(C, 'global')
batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[C].split(ow, factor=ow_factor)
s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
s[C].vectorize(oc_block)
parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer)
s[CC].compute_at(s[C], parallel_axis)
if C == O:
s[C].parallel(parallel_axis)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
# Skylake and future processors have 16 vector lanes
assert oc_bn % int32_lanes == 0
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor)
s[CC].reorder(oc_chunk, oh_outer, ow_outer, kh, kw, ic_outer, ic_f_inner, oh_inner,
ow_inner, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].fuse(oc_chunk, oh_outer)
pc = dot_16x1x16_int8_int8_int32()
s[CC].tensorize(oc_s_inner, pc)
s[CC].unroll(ow_inner)
s[CC].unroll(oh_inner)
if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s
return conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last,
int32_lanes=16,
intrin=dot_16x1x16_int8_int8_int32())
def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype):
......
......@@ -21,6 +21,7 @@ import tvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from ..nn.util import infer_pad
from ..generic import conv2d as conv2d_generic
from ..util import get_const_tuple
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .util import get_fp32_len
......@@ -56,7 +57,6 @@ def _fallback_schedule(cfg, wkl):
def _fallback_schedule_int8(cfg, wkl):
simd_width = get_fp32_len()
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
......@@ -207,68 +207,6 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
"""
Defines the schedule for INT8 for intel machines
Uses the Intel intrinsics to use INT8 operations
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""
int32_lanes = 16
reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
_, _, _, _, ic_bn = get_const_tuple(data.shape)
_, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, _ = s[A].op.axis
parallel_axis = s[A].fuse(batch, ic_chunk, ih)
s[A].parallel(parallel_axis)
# schedule 5-D NCHW[x]c conv
C, O = conv_out, last
CC = s.cache_write(C, 'global')
batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[C].fuse(batch, oc_chunk, oh)
s[C].vectorize(oc_block)
if C == O:
s[C].parallel(parallel_axis)
s[CC].compute_at(s[C], ow_chunk)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
# Skylake and future processors have 16 vector lanes
assert oc_bn % int32_lanes == 0
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
if unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw,
ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].unroll(kw)
else:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, kw, ic_f_inner,
ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
pc = dot_16x1x16_int8_int8_int32()
s[CC].tensorize(oc_s_inner, pc)
s[CC].unroll(ow_block)
s[CC].unroll(oc_f_inner)
if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s
return conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last,
int32_lanes=16,
intrin=dot_16x1x16_int8_int8_int32())
......@@ -24,6 +24,7 @@ from tvm.autotvm.task import get_config
from tvm.autotvm.task.topi_integration import deserialize_args
from ..nn.conv2d import _get_workload as _get_conv2d_workload
from .. import generic, tag
from ..generic import conv2d as conv2d_generic
from ..util import get_const_tuple
from ..nn.conv2d import conv2d_NCHWc_int8
from .. import nn
......@@ -38,9 +39,11 @@ def _get_default_config_int8(cfg, data, kernel, strides, padding, out_dtype, is_
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
if is_kernel_1x1:
conv2d_avx_1x1._fallback_schedule_int8(cfg, wkl)
conv2d_generic.fallback_schedule_cpu_1x1_int8(
cfg, wkl, int32_lanes=16, num_int8_elements=4)
else:
conv2d_avx_common._fallback_schedule_int8(cfg, wkl)
conv2d_generic.fallback_schedule_cpu_common_int8(
cfg, wkl, int32_lanes=16, num_int8_elements=4)
def _is_int8_hw_support(data_dtype, kernel_dtype):
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable-msg=too-many-arguments, too-many-locals, assignment-from-no-return
""" Conv Int8 functional and performance testing"""
import sys
import logging
import numpy as np
import tvm
import topi
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
LOGGER = logging.getLogger('test_conv_int8_intel')
LOGGER.disabled = False
# All the WORKLOADS from Resnet except first layer
# Workload is ['height', 'width', 'in_filter', 'out_filter',
# 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
WORKLOADS = [(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
(56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
(56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
(56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
(28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
(56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
(28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
(28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
(14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
(28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
(14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
(14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
(7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
(14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
(7, 7, 2048, 512, 1, 1, 0, 0, 1, 1)
]
TARGET_NAME = 'llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'
NUM_VEC_LANES = 16
CTX = tvm.context(TARGET_NAME, 0)
def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad,
hstride, wstride, out_dtype):
"""
Finds out the shape of all data structures
"""
data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES)
if out_dtype == 'int32' or out_dtype == 'uint32':
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
NUM_VEC_LANES//4, NUM_VEC_LANES, 4)
elif out_dtype == 'float32':
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
NUM_VEC_LANES, NUM_VEC_LANES)
out_height = (im_height + 2 * hpad - k_h) // hstride + 1
out_width = (im_width + 2 * wpad - k_w) // wstride + 1
o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES)
return (data_shape, kernel_shape, o_shape)
def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_filter,
out_filter, k_h, k_w, hpad, wpad, hstride, wstride):
"""
Runs the inference and checks the functional correctness between
compute and schedule outputs
"""
(data_shape, kernel_shape, o_shape) = get_shape(im_height, im_width, in_filter,
out_filter, k_h, k_w, hpad, wpad,
hstride, wstride, out_dtype)
# Create TVM placeholders
data = tvm.placeholder(data_shape, name='data', dtype=data_dtype)
kernel = tvm.placeholder(kernel_shape, name='kernel', dtype=kernel_dtype)
# Create the numpy arrays to be used for executing conv models
if data_dtype == 'float32':
data_array = tvm.nd.array(np.random.rand(*data_shape).astype(dtype=data_dtype), CTX)
kernel_array = tvm.nd.array(np.random.rand(*kernel_shape).astype(dtype=kernel_dtype), CTX)
else:
data_array = tvm.nd.array(np.random.randint(100, size=data_shape).astype(data_dtype))
kernel_array = tvm.nd.array(np.random.randint(100, size=kernel_shape).astype(kernel_dtype))
# c_orig will be used for declaration ouptut
# c_sch will be used for scheduled computation output
c_orig = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
c_sch = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
with tvm.target.create(TARGET_NAME):
if out_dtype == "float32":
conv = topi.nn.conv2d_NCHWc(data, kernel, stride=hstride,
padding=hpad, dilation=(1, 1),
layout='NCHWc', out_layout='NCHWc', out_dtype=out_dtype)
else:
conv = topi.nn.conv2d_NCHWc_int8(data, kernel, strides=hstride,
padding=hpad, dilation=(1, 1),
layout='NCHWc', out_layout='NCHWc', out_dtype=out_dtype)
out = topi.nn.relu(conv)
sch = tvm.create_schedule(out.op)
func = tvm.build(sch, [data, kernel, out], target=TARGET_NAME, name='out')
func(data_array, kernel_array, c_orig)
LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True))
# Generate and run the optimized schedule
if out_dtype == "float32":
sconv = topi.generic.nn.schedule_conv2d_NCHWc(outs=[out])
else:
sconv = topi.generic.nn.schedule_conv2d_NCHWc_int8(outs=[out])
func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv')
func(data_array, kernel_array, c_sch)
# Functional check
if data_dtype == 'uint8':
np.testing.assert_equal(c_orig.asnumpy(), c_sch.asnumpy())
else:
assert np.allclose(c_orig.asnumpy(), c_sch.asnumpy())
evaluator = func.time_evaluator(func.entry_name, CTX, number=1000)
LOGGER.debug(tvm.lower(sconv, [data, kernel], simple_mode=True))
return evaluator(data_array, kernel_array, c_sch).mean
if __name__ == "__main__":
LOGGER.info("Workload, Kernel_size, FP32_time, INT8_time, Speedup")
SPEEDUP_ARRAY = []
for i, wkl in enumerate(WORKLOADS):
for dtype in ["uint", "int"]:
fp32_time = run_inference('float32', 'float32', 'float32', *wkl)
int8_time = run_inference('%s8' % dtype, '%s8' % dtype, '%s32' % dtype, *wkl)
kernel_h = wkl[4]
kernel_w = wkl[5]
LOGGER.info("[%s] Workload#" % dtype + str(i) + ", " + str(kernel_h) + "x" + str(kernel_w) + ", "
+ str(fp32_time) + ", " + str(int8_time) + ", " + str(fp32_time/int8_time))
SPEEDUP_ARRAY.append(fp32_time/int8_time)
LOGGER.info("Average speedup --> %s" % str(sum(SPEEDUP_ARRAY)/float(len(SPEEDUP_ARRAY))))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment