Commit 72ad9a38 by Animesh Jain Committed by Yizhi Liu

INT8 conv operator implementation with NCHWc data layout for Intel machines (#1680)

* Int8 implementation for convolution operator on Intel Skylake

* Int8 implementation for convolution operator on Intel Skylake

* PR changes

* PR changes

* PR changes

* Fixing an error

* Fixing an error

* Minor typos fix

* Minor typos fix

* Removing the broadcast16 CPP code. Using astype feature instead

* Replacing constant by variable name num_elements_intel

* Name fixes and tensorize update rule updated

* Fixing the bug about checking skylake

* Replacing bitcast with reinterpret

* Isolating INT8 and FP32 schedules to ease out future AutoTVM PR merge

* Putting check_skylake function in the x86 directory

* Added documentation and organizing files to better locations

* Tensor intrin renaming. Avoid code duplication for intrin by kernel reshaping
parent bde53033
......@@ -79,12 +79,27 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
assert data.dtype == kernel.dtype, \
assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
"Do not support inputs with different data types now. ' \
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
def _get_workload_int8(data, kernel, stride, padding, out_dtype):
""" Get the workload structure. """
_, CI, IH, IW = [x.value for x in data.shape]
CO, _, KH, KW = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
"Do not support inputs with different data types now. ' \
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
@tvm.target.generic_func
def _get_alter_layout_schedule(wkl):
# pylint: disable=unreachable
......@@ -118,6 +133,17 @@ def _get_schedule_NCHWc(wkl, layout, out_layout):
return wkl
@tvm.target.generic_func
def _get_schedule_NCHWc_int8(wkl, layout, out_layout):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
"""Convolution operator in NCHW layout.
......
# pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument
"""Checks different x86 targets for target specific schedules"""
def check_skylake(target):
"""
Checks if the target is skylake
"""
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
return True
return False
......@@ -3,11 +3,14 @@
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
import topi
from ..util import get_const_tuple
from ..nn.conv2d import _get_schedule, _get_workload
from ..nn.util import infer_pad, infer_stride
from ..nn.pad import pad
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake
AVXConv1x1Fwd = namedtuple('AVXConv1x1Fwd', ['ic_bn', 'oc_bn', 'oh_factor', 'ow_factor'])
......@@ -229,3 +232,117 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last):
s[O].parallel(parallel_axis)
return s
def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel):
""" Declaration for int8 conv"""
out_dtype = wkl.out_dtype
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size = data.shape[0]
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
# Intel performs dot product of 2 "4" Int8 values
n_elems = 4
assert sch.ic_bn%n_elems == 0
ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
# Reshaping kernel as the last 2 dimensions are 1x1 (k_h x k_w)
k_shape = kernel.shape
kernel = topi.reshape(kernel, (k_shape[0], k_shape[1], k_shape[2], k_shape[3],
k_shape[4] * k_shape[5] * k_shape[6]))
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR, ow*WSTR,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
kernel[oc_chunk, ic_outer, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8',
tag="conv2d_NCHWc_int8")
return conv
def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
"""
Defines the schedule for INT8 for intel machines
Uses the Intel intrinsics to use INT8 operations
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""
target = tvm.target.current_target(allow_none=False)
int32_lanes = -1
if check_skylake(target):
int32_lanes = 16
else:
return s
assert int32_lanes != -1
# schedule data
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
parallel_axis = s[A].fuse(ic_chunk, ih)
s[A].parallel(parallel_axis)
C, O = conv_out, last
CC = s.cache_write(C, 'global')
batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[C].split(ow, factor=sch.ow_factor)
s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
s[C].vectorize(oc_block)
parallel_axis = s[C].fuse(oc_chunk, oh_outer)
s[CC].compute_at(s[C], parallel_axis)
if C == O:
s[C].parallel(parallel_axis)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
# Skylake and future processors have 16 vector lanes
assert sch.oc_bn % int32_lanes == 0
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor)
s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_outer, ic_f_inner, oh_inner,
ow_inner, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].fuse(oc_chunk, oh_outer)
pc = dot_16x1x16_int8_int8_int32()
s[CC].tensorize(oc_s_inner, pc)
s[CC].unroll(ow_inner)
s[CC].unroll(oh_inner)
if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s
......@@ -8,6 +8,8 @@ from ..util import get_const_tuple
from ..nn.conv2d import _get_schedule, _get_workload
from ..nn.util import infer_pad, infer_stride
from ..nn.pad import pad
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake
AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw'])
......@@ -252,3 +254,124 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last):
s[O].parallel(parallel_axis)
return s
def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel):
"""
This function sets up the compute for INT8 conv 2d
Inputs are in INT8 datatype
Output is in INT32 datatype
"""
out_dtype = wkl.out_dtype
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size = data.shape[0]
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
# pack data
DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
# convolution
oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
kh = tvm.reduce_axis((0, wkl.hkernel), name='kh')
kw = tvm.reduce_axis((0, wkl.wkernel), name='kw')
# Intel performs dot product of 2 "4" Int8 values
# Current implementation requires ic_bn to be a multiple of 4
n_elems = 4
assert sch.ic_bn%n_elems == 0
ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8',
tag="conv2d_NCHWc_int8")
return conv
def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
"""
Defines the schedule for INT8 for intel machines
Uses the Intel intrinsics to use INT8 operations
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""
# Currently INT8 operations are supported for only Skylake
# In future the _intrin_reduce4int8 will be updated for VNNI instructions
# In case of unsupported target, the schedule will go to the original
# compute
target = tvm.target.current_target(allow_none=False)
int32_lanes = -1
if check_skylake(target):
int32_lanes = 16
else:
return s
assert int32_lanes != -1
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, _ = s[A].op.axis
parallel_axis = s[A].fuse(ic_chunk, ih)
s[A].parallel(parallel_axis)
# schedule 5-D NCHW[x]c conv
C, O = conv_out, last
CC = s.cache_write(C, 'global')
_, oc_chunk, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n)
s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[C].fuse(oc_chunk, oh)
s[C].vectorize(oc_block)
if C == O:
s[C].parallel(parallel_axis)
s[CC].compute_at(s[C], ow_chunk)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n)
# Skylake and future processors have 16 vector lanes
assert sch.oc_bn % int32_lanes == 0
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
if sch.unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw,
ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].unroll(kw)
else:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, kw, ic_f_inner,
ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
pc = dot_16x1x16_int8_int8_int32()
s[CC].tensorize(oc_s_inner, pc)
s[CC].unroll(ow_block)
s[CC].unroll(oc_f_inner)
if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s
"""Core kernel of dot product of 4 Int8 operations"""
#pylint: disable=invalid-name
import tvm
def dot_16x1x16_int8_int8_int32():
"""
Int8 dot product by every 4 elements using AVX2 Skylake instructions.
This function takes two arrays of int8 datatype -- data[4] and
kernel[16][4] -- and computes a dot product of data[4] with every
4 elements of kernels, resulting in output[16] of int32 datatype.
The pseudo code is as follows.
.. code-block:: c
void dot_16x1x16_int8_int8_int32(int8 data[4], int8 kernel[16][4],
int32 output[16]){
for (int i = 0; i < 16; i++){
out[i] = 0;
for (int k = 0; k < 4; k++){
out[i] += data[k] * kernel[i][k]
}
}
}
Physically, the kernel array sits in an AVX512 vector register and
the data[4] is broadcasted to another AVX512 vector register. This
function returns a TensorIntrin that can be used to tensorize
a schedule.
Returns
-------
intrin : TensorIntrin
The Skylake int8 TensorIntrin that can be used in tensorizing schedule
"""
int32_lanes = 16 # 16 int32 lanes in AVX512
num_int8_elements = 4 # 4 int8 elements in int32
data = tvm.placeholder((num_int8_elements,), dtype='uint8', name='data')
kernel = tvm.placeholder((int32_lanes, num_int8_elements), dtype='int8', name='kernel')
k = tvm.reduce_axis((0, num_int8_elements), name='k')
C = tvm.compute((int32_lanes,),
lambda i: tvm.sum(data[k].astype('int32') *
kernel[i, k].astype('int32'),
axis=k),
name="C")
a_buffer = tvm.decl_buffer(data.shape, dtype='uint8', name="a_buffer",
offset_factor=1,
strides=[1])
b_buffer = tvm.decl_buffer(kernel.shape, dtype='int8', name="b_buffer",
offset_factor=1,
strides=[tvm.var('ldw'), 1])
def _intrin_func(ins, outs):
def _instr(index):
ib = tvm.ir_builder.create()
if index == 1:
ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16')))
return ib.get()
a_int8 = ins[0].vload([0], "uint8x4")
re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8)
vec_ai32 = re_int32.astype('int32x16')
vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32)
vec_b = ins[1].vload([0, 0], "int8x64")
vec_one = tvm.const(1, "int16x32")
pair_reduction = tvm.call_llvm_intrin('int16x32',
'llvm.x86.avx512.pmaddubs.w.512',
tvm.const(0, 'uint32'),
vec_a, vec_b)
quad_reduction = tvm.call_llvm_intrin('int32x16',
'llvm.x86.avx512.pmaddw.d.512',
tvm.const(0, 'uint32'),
pair_reduction, vec_one)
if index == 0:
ib.emit(outs[0].vstore(0, quad_reduction))
else:
ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], 'int32x16')))
return ib.get()
# body, reset, update
return _instr(0), _instr(1), _instr(2)
with tvm.build_config(offset_factor=1, partition_const_loop=True):
return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
#pylint: disable-msg=too-many-arguments, too-many-locals, assignment-from-no-return
""" Conv Int8 functional and performance testing"""
import sys
import logging
import numpy as np
import tvm
import topi
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
LOGGER = logging.getLogger('test_conv_int8_intel')
LOGGER.disabled = False
# All the WORKLOADS from Resnet except first layer
# Workload is ['height', 'width', 'in_filter', 'out_filter',
# 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
WORKLOADS = [(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
(56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
(56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
(56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
(28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
(56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
(28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
(28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
(14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
(28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
(14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
(14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
(7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
(14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
(7, 7, 2048, 512, 1, 1, 0, 0, 1, 1)
]
TARGET_NAME = 'llvm -mcpu=skylake-avx512'
NUM_VEC_LANES = 16
CTX = tvm.context(TARGET_NAME, 0)
def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad,
hstride, wstride, out_dtype):
"""
Finds out the shape of all data structures
"""
## Find shapes
data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES)
if out_dtype == 'int32':
if k_h != 1:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
NUM_VEC_LANES//4, NUM_VEC_LANES, 4)
else:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES//4,
NUM_VEC_LANES, 4, k_h, k_w)
elif out_dtype == 'float32':
if k_h != 1:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
NUM_VEC_LANES, NUM_VEC_LANES)
else:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES,
NUM_VEC_LANES, k_h, k_w)
out_height = (im_height + 2 * hpad - k_h) // hstride + 1
out_width = (im_width + 2 * wpad - k_w) // wstride + 1
o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES)
return (data_shape, kernel_shape, o_shape)
def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_filter,
out_filter, k_h, k_w, hpad, wpad, hstride, wstride):
"""
Runs the inference and checks the functional correctness between
compute and schedule outputs
"""
(data_shape, kernel_shape, o_shape) = get_shape(im_height, im_width, in_filter,
out_filter, k_h, k_w, hpad, wpad,
hstride, wstride, out_dtype)
# Create TVM placeholders
data = tvm.placeholder(data_shape, name='data', dtype=data_dtype)
kernel = tvm.placeholder(kernel_shape, name='kernel', dtype=kernel_dtype)
# Create the numpy arrays to be used for executing conv models
if data_dtype == 'float32':
data_array = tvm.nd.array(np.random.rand(*data_shape).astype(dtype=data_dtype), CTX)
kernel_array = tvm.nd.array(np.random.rand(*kernel_shape).astype(dtype=kernel_dtype), CTX)
else:
data_array = tvm.nd.array(np.random.randint(100, size=data_shape).astype(data_dtype))
kernel_array = tvm.nd.array(np.random.randint(100, size=kernel_shape).astype(kernel_dtype))
# c_orig will be used for declaration ouptut
# c_sch will be used for scheduled computation output
c_orig = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
c_sch = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
with tvm.target.create(TARGET_NAME):
conv = topi.nn.conv2d_NCHWc(data, kernel, num_filter=out_filter,
kernel_size=(k_h, k_w), stride=hstride,
padding=hpad, layout='NCHWc',
out_layout='NCHWc', out_dtype=out_dtype)
out = topi.nn.relu(conv)
sch = tvm.create_schedule(out.op)
func = tvm.build(sch, [data, kernel, out], target=TARGET_NAME, name='out')
func(data_array, kernel_array, c_orig)
LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True))
# Generate and run the optimized schedule
sconv = topi.generic.nn.schedule_conv2d_NCHWc(num_filter=out_filter,
kernel_size=(k_h, k_w),
strides=hstride,
padding=hpad,
layout='NCHWc',
out_layout='NCHWc',
outs=[out])
func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv')
func(data_array, kernel_array, c_sch)
# Functional check
if data_dtype == 'uint8':
np.testing.assert_equal(c_orig.asnumpy(), c_sch.asnumpy())
else:
assert np.allclose(c_orig.asnumpy(), c_sch.asnumpy())
evaluator = func.time_evaluator(func.entry_name, CTX, number=1000)
LOGGER.debug(tvm.lower(sconv, [data, kernel], simple_mode=True))
return evaluator(data_array, kernel_array, c_sch).mean
if __name__ == "__main__":
LOGGER.info("Workload, Kernel_size, FP32_time, INT8_time, Speedup")
SPEEDUP_ARRAY = []
for i, wkl in enumerate(WORKLOADS):
fp32_time = run_inference('float32', 'float32', 'float32', *wkl)
int8_time = run_inference('uint8', 'int8', 'int32', *wkl)
kernel_h = wkl[4]
kernel_w = wkl[5]
LOGGER.info("Workload#" + str(i) + ", " + str(kernel_h) + "x" + str(kernel_w) + ", "
+ str(fp32_time) + ", " + str(int8_time) + ", " + str(fp32_time/int8_time))
SPEEDUP_ARRAY.append(fp32_time/int8_time)
LOGGER.info("Average speedup --> %s" % str(sum(SPEEDUP_ARRAY)/float(len(SPEEDUP_ARRAY))))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment