Commit 36d3a41e by Meghan Cowan Committed by Tianqi Chen

[TOPI] Bitserial low-precision convolution (#1332)

parent e806cd15
......@@ -154,6 +154,31 @@ def call_extern(dtype, func_name, *args):
dtype, func_name, convert(args), _Call.Extern, None, 0)
def call_llvm_intrin(dtype, name, *args):
"""Build expression by calling an llvm intrinsic function
Parameters
----------
dtype : str
The data type of the result.
name : str
The name of the llvm intrinsic function.
args : list
Poistional arguments.
Returns
-------
call : Expr
The call expression.
"""
import tvm
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
def exp(x):
"""Take exponetial of input x.
......
......@@ -282,6 +282,15 @@ class LLVMModuleNode final : public runtime::ModuleNode {
std::shared_ptr<llvm::LLVMContext> ctx_;
};
unsigned LookupLLVMIntrinsic(const std::string& name) {
return llvm::Function::lookupIntrinsicID(name);
}
TVM_REGISTER_API("codegen.llvm_lookup_intrinsic_id")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
});
TVM_REGISTER_API("codegen.build_llvm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
......
......@@ -17,6 +17,16 @@ def test_llvm_intrin():
func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
fcode = tvm.build(func, None, "llvm")
def test_llvm_lookup_intrin():
ib = tvm.ir_builder.create()
m = tvm.var("m")
A = ib.pointer("uint8x8", name="A")
x = tvm.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.const(1, 'uint32'), A)
ib.emit(x)
body = ib.get()
func = tvm.ir_pass.MakeAPI(body, "ctpop", [A], 1, True)
fcode = tvm.build(func, None, "llvm")
def test_llvm_add_pipeline():
nn = 1024
n = tvm.convert(nn)
......@@ -324,3 +334,4 @@ if __name__ == "__main__":
test_llvm_flip_pipeline()
test_llvm_madd_pipeline()
test_llvm_temp_space()
test_llvm_lookup_intrin()
......@@ -143,6 +143,41 @@ def schedule_depthwise_conv2d_nhwc(outs):
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_bitserial_conv2d_nchw(outs):
"""Schedule for bitserial_conv2d_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial_conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_bitserial_conv2d_nhwc(outs):
"""Schedule for bitserial_conv2d_nhwc
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial_conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.override_native_generic_func("schedule_reduce")
def schedule_reduce(outs):
......
......@@ -16,4 +16,5 @@ from .conv2d_transpose import *
from .bnn import *
from .upsampling import *
from .local_response_norm import *
from .bitserial_conv2d import *
from .l2_normalize import *
# pylint: disable=invalid-name, unused-variable, too-many-locals, too-many-arguments, unused-argument
"""Bitserial Conv2D operators"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import numpy as np
import tvm
from topi.transform import concatenate
from .pad import pad
from .util import get_pad_tuple
from ..util import get_const_tuple, get_const_int
# workload description of conv2d
Workload = namedtuple('Workload',
['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
SpatialPackNCHW = namedtuple('SpatialPack',
['vh', 'vw', 'vc', 'ba', 'bc'])
SpatialPackNHWC = namedtuple('SpatialPack',
['vh', 'vw', 'vc', 'ba', 'bc'])
_WORKLOADS = [
# workloads of resnet18 on imagenet
# input_size, input_size, ic, oc, kh, kw, pad, pad, stride, stride
Workload('uint32', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('uint32', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('uint32', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('uint32', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('uint32', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('uint32', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('uint32', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('uint32', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('uint32', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('uint32', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('uint32', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
# workload of alexnet on cifar10
Workload('int32', 'int32', 27, 27, 96, 192, 5, 5, 2, 2, 1, 1),
Workload('int32', 'int32', 13, 13, 192, 384, 3, 3, 1, 1, 1, 1),
Workload('int32', 'int32', 13, 13, 384, 384, 3, 3, 1, 1, 1, 1),
Workload('int32', 'int32', 13, 13, 384, 256, 3, 3, 1, 1, 1, 1),
]
@tvm.target.generic_func
def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits,
layout='NCHW', pack_dtype='uint32', out_dtype='int32', dorefa=True):
"""Bitserial Conv2D operator.
Parameters
----------
input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width] or
[batch, in_height, in_width, in_channel]
filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width] or
[filter_height, filter_width, in_channel, num_filter]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
layout : str
layout of data
activation_bits: int
number of bits used for activations/input elements
weight_bits: int
number of bits used for weight elements
out_dtype: str
return type of convolution
pack_dtype: str
bit packing type
dorefa: bool
preform the bitserial dot-product using 2 popcounts (required for DoReFa-Net)
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width] or
[batch, out_height, out_width, out_channel]
"""
# search platform specific declaration first
# default declaration
if layout == 'NCHW':
return spatial_pack_nchw(data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa)
elif layout == 'NHWC':
return spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa)
raise ValueError("not support this layout {} yet".format(layout))
def _get_workload(data, kernel, stride, padding, out_dtype, layout):
""" Get the workload structure. """
assert layout == "NCHW" or layout == "NHWC", \
"Only support layouts NCHW and NHWC"
if layout == "NCHW":
_, CI, IH, IW = [x.value for x in data.shape]
CO, _, KH, KW = [x.value for x in kernel.shape]
else: # NHWC
IH, IW = data.shape[1].value, data.shape[2].value
KH, KW, CI, CO = [x for x in get_const_tuple(kernel.shape)]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
@tvm.target.generic_func
def _get_schedule(wkl, layout):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
pack_dtype, out_dtype, dorefa=False):
""" Compute convolution with pack on spatial axes. """
assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
data_q = bitpack(data, in_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype)
kernel_q = bitpack(kernel, weight_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype)
IB, _, CI, H, W = data_q.shape
KB, CO, _, KH, KW = kernel_q.shape
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
HCAT, WCAT = KH-1, KW-1
wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NCHW")
sch = _get_schedule(wkl, "NCHW")
VH = sch.vh
VW = sch.vw
VC = sch.vc
TH = H + 2*HPAD
TW = W + 2*WPAD
OH = (H + 2*HPAD - KH) // HSTR + 1
OW = (W + 2*WPAD - KW) // WSTR + 1
dshape = (IB, 1, CI, H, W)
dpshape = (IB, 1, CI, TH, TW)
dvshape = (1, TH//(VH*HSTR), TW//(VW*WSTR), CI, VH*HSTR+HCAT, VW*WSTR+WCAT, IB)
kshape = (KB, CO, CI, KH, KW)
kvshape = (CO//VC, CI, KH, KW, KB, VC)
ovshape = (1, CO//VC, OH//VH, OW//VW, VH, VW, VC)
oshape = (1, CO, OH, OW)
DOPAD = (HPAD != 0 and WPAD != 0)
if DOPAD:
data_pad = pad(data_q, (0, 0, 0, HPAD, WPAD), name="data_pad")
else:
data_pad = data_q
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw, b: \
data_pad[b][n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec')
kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, b, vc: \
kernel_q[b][co*VC+vc][ci][dh][dw], name='kernel_vec')
ci = tvm.reduce_axis((0, CI), name='ci')
dh = tvm.reduce_axis((0, KH), name='dh')
dw = tvm.reduce_axis((0, KW), name='dw')
b1 = tvm.reduce_axis((0, IB), name='ib')
b2 = tvm.reduce_axis((0, KB), name='kb')
def _conv(n, co, h, w, vh, vw, vc):
b1b2 = (b1+b2).astype(out_dtype)
if dorefa:
return tvm.sum((tvm.popcount(
data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) &
kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype)) -
tvm.popcount(
data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype)
& ~kernel_vec[co, ci, dh, dw, b2, vc]).astype(out_dtype)) << b1b2,
axis=[ci, dh, dw, b1, b2])
return tvm.sum((tvm.popcount(
data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1] &
kernel_vec[co, ci, dh, dw, b2, vc])).astype(out_dtype) << b1b2,
axis=[ci, dh, dw, b1, b2])
conv = tvm.compute(ovshape, _conv, name='conv_out')
return tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
name='conv_vec', tag='spatial_bitserial_conv_nchw')
def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
pack_dtype, out_dtype, dorefa=False):
""" Compute convolution with pack on spatial axes. """
assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
data_q = bitpack(data, in_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype)
kernel_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype)
_, H, W, CI, IB = data_q.shape
KH, KW, _, CO, KB = kernel_q.shape
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
HCAT, WCAT = KH-1, KW-1
wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NHWC")
sch = _get_schedule(wkl, "NHWC")
VH = sch.vh
VW = sch.vw
VC = sch.vc
PAD_H = H + 2*HPAD
PAD_W = W + 2*WPAD
OH = (H + 2*HPAD - KH) // HSTR + 1
OW = (W + 2*WPAD - KW) // WSTR + 1
dvshape = (1, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, CI, IB)
kvshape = (CO, KH, KW, CI, VC, KB)
ovshape = (1, OH, OW, CO, VH, VW, VC)
oshape = (1, OH, OW, CO)
if (HPAD != 0 and WPAD != 0):
data_pad = pad(data_q, (0, HPAD, WPAD, 0, 0), name="data_pad")
else:
data_pad = data_q
data_vec = tvm.compute(dvshape, lambda n, h, w, vh, vw, ci, b: \
data_pad[n][h*VH*HSTR+vh][w*VW*WSTR+vw][ci][b], name='data_vec')
kernel_vec = tvm.compute(kvshape, lambda co, dh, dw, ci, vc, b: \
kernel_q[dh][dw][ci][co*VC+vc][b], name='kernel_vec')
ci = tvm.reduce_axis((0, CI), name='ci')
dh = tvm.reduce_axis((0, KH), name='dh')
dw = tvm.reduce_axis((0, KW), name='dw')
b1 = tvm.reduce_axis((0, IB), name='ib')
b2 = tvm.reduce_axis((0, KB), name='kb')
def _conv(n, h, w, co, vh, vw, vc):
b1b2 = (b1+b2).astype(out_dtype)
if dorefa:
return tvm.sum(
(tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1].astype(out_dtype) &
kernel_vec[co, dh, dw, ci, vc, b2].astype(out_dtype)) -
tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1].astype(out_dtype) &
~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2,
axis=[dh, dw, ci, b1, b2])
return tvm.sum(tvm.popcount(
data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] &
kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) << b1b2,
axis=[dh, dw, ci, b1, b2])
conv = tvm.compute(ovshape, _conv, name='conv')
return tvm.compute(oshape, lambda n, h, w, co:
conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC],
name='output_unpack', tag='spatial_bitserial_conv_nhwc')
def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
"""Packs data into format necessary for bitserial computation
pack_axis : int
index of the axis to pack in data
bit_axis : int
index of axis to place bit axis in resulting packed data"""
ishape = data.shape
n = len(ishape)
if pack_type == 'uint8':
data_width = 8
elif pack_type == 'uint16':
data_width = 16
elif pack_type == 'uint32':
data_width = 32
elif pack_type == 'uint64':
data_width = 64
# Data must be in multiples of the data_width
assert get_const_int(ishape[pack_axis]) % data_width == 0, "Not a multiple of word size"
shape_vec = list(ishape)
shape_vec[pack_axis] = (shape_vec[pack_axis] // data_width)
shape_vec.insert(bit_axis, 1)
bitserial_oshape = tuple(shape_vec)
masks = np.array([0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80])
# pack axis shifts if bit axis comes before
if bit_axis <= pack_axis:
pack_axis += 1
def _bitpack(*indices):
packed_data = [tvm.const(0, pack_type)] * bits
for k in range(data_width):
# Translate indices for packed data back to original
idx = [0] * n
j = 0
for i in range(n+1):
if i == bit_axis:
continue
elif i == pack_axis:
idx[j] = indices[i] * data_width + k
else:
idx[j] = indices[i]
j += 1
element = data(*idx)
for b in range(bits):
extracted_bit = ((element & tvm.const(masks[b])) >> b).astype(pack_type)
packed_data[b] = (packed_data[b] | extracted_bit)
if k < data_width - 1:
packed_data[b] = packed_data[b] << 1
if k == data_width - 1:
return tuple(packed_data)
return tuple(packed_data)
output_tuple = tvm.compute(bitserial_oshape, _bitpack, name=name, tag='bitpack')
if bits > 1:
return concatenate(output_tuple, axis=bit_axis)
return output_tuple
_SCH_TO_DECL_FUNC_QUANT = {
SpatialPackNCHW: spatial_pack_nchw,
SpatialPackNHWC: spatial_pack_nhwc,
}
......@@ -4,3 +4,4 @@ from __future__ import absolute_import as _abs
from .conv2d import schedule_conv2d_nchw
from .depthwise_conv2d import schedule_depthwise_conv2d_nchw
from .bitserial_conv2d import schedule_bitserial_conv2d_nhwc
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Bitserial conv2d schedule on raspberry pi"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from .. import tag
from ..nn.pad import pad
from ..nn.bitserial_conv2d import bitserial_conv2d, _get_schedule, _get_workload, bitpack
from ..nn.bitserial_conv2d import SpatialPackNCHW, _WORKLOADS, spatial_pack_nchw
from ..nn.util import get_pad_tuple
from ..util import get_const_int
from .. import generic
RaspSpatialPack = namedtuple('SpatialPack',
['vh', 'vw', 'vc', 'ba', 'bc', 'split_ci', 'kfactor'])
_QUANTIZED_SCHEDULES_NHWC = [
RaspSpatialPack(2, 2, 8, 1, 1, False, 8),
RaspSpatialPack(1, 4, 8, 4, 1, False, 8),
RaspSpatialPack(1, 4, 8, 1, 16, False, 8),
RaspSpatialPack(1, 4, 8, 4, 8, False, 8),
RaspSpatialPack(1, 7, 8, 3, 8, False, 16),
RaspSpatialPack(1, 2, 8, 1, 8, False, 16),
RaspSpatialPack(2, 1, 8, 1, 4, False, 16),
RaspSpatialPack(1, 7, 8, 1, 1, True, 16),
RaspSpatialPack(1, 1, 8, 1, 16, True, 16),
RaspSpatialPack(1, 1, 8, 1, 8, True, 16),
RaspSpatialPack(1, 1, 8, 1, 16, True, 16),
]
_QUANTIZED_SCHEDULES_NCHW = [
# resnet
SpatialPackNCHW(2, 2, 8, 1, 1),
SpatialPackNCHW(1, 4, 8, 4, 1),
SpatialPackNCHW(1, 4, 8, 1, 16),
SpatialPackNCHW(1, 4, 8, 4, 8),
SpatialPackNCHW(1, 7, 8, 3, 8),
SpatialPackNCHW(1, 2, 8, 1, 8),
SpatialPackNCHW(2, 1, 8, 1, 4),
SpatialPackNCHW(1, 7, 8, 1, 1),
SpatialPackNCHW(1, 1, 8, 1, 16),
SpatialPackNCHW(1, 1, 8, 1, 8),
SpatialPackNCHW(1, 1, 8, 1, 16),
]
@_get_schedule.register("rasp")
def _get_schedule_bitserial_conv2d(wkl, layout):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
idx = _WORKLOADS.index(wkl)
if layout == "NCHW":
sch = _QUANTIZED_SCHEDULES_NCHW[idx]
elif layout == "NHWC":
sch = _QUANTIZED_SCHEDULES_NHWC[idx]
return sch
@bitserial_conv2d.register("rasp")
def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits,
layout='NCHW', pack_dtype=None, out_dtype=None, dorefa=False):
if out_dtype is None:
out_dtype = data.dtype
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
assert layout == "NCHW" or layout == "NHWC", "only support layouts NCHW and NHWC"
if dorefa:
assert layout == "NCHW", "Cannot support dorea with NHWC layout yet"
wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout)
sch = _get_schedule(wkl, layout)
if layout == "NCHW":
return spatial_pack_nchw(data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa)
return _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits,
weight_bits, out_dtype)
def _kernel_vec_spatial_pack_nhwc(kernel, kernel_bits, VC):
kernel_q = bitpack(kernel, kernel_bits, pack_axis=2, bit_axis=2, pack_type='uint8')
KH, KW, KB, CI, CO = kernel_q.shape
kvshape = (CO//VC, KH, KW, KB, VC, CI)
return tvm.compute(kvshape, lambda co, dh, dw, b, vc, ci: \
kernel_q[dh][dw][b][ci][co*VC+vc], name='kernel_vec')
def _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bits, out_dtype):
""" Compute convolution with pack on spatial axes. """
assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NHWC")
sch = _get_schedule(wkl, "NHWC")
VH = sch.vh
VW = sch.vw
VC = sch.vc
data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8')
kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC)
N, H, W, IB, CI = data_q.shape
OCO, KH, KW, KB, VC, _ = kernel_vec.shape
CO = OCO * VC
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
HCAT, WCAT = KH-1, KW-1
PAD_H = H + 2*HPAD
PAD_W = W + 2*WPAD
OH = (H + 2*HPAD - KH) // HSTR + 1
OW = (W + 2*WPAD - KW) // WSTR + 1
dvshape = (N, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, IB, CI)
ovshape = (1, OH // VH, OW // VW, CO // VC, VH, VW, VC)
oshape = (1, OH, OW, CO)
if (HPAD != 0 and WPAD != 0):
data_pad = pad(data_q, (0, HPAD, WPAD, 0, 0), name="data_pad")
else:
data_pad = data_q
data_vec = tvm.compute(dvshape, lambda n, h, w, vh, vw, b, ci: \
data_pad[n][h*VH*HSTR+vh][w*VW*WSTR+vw][b][ci], name='data_vec')
ci = tvm.reduce_axis((0, CI), name='ci')
dh = tvm.reduce_axis((0, KH), name='dh')
dw = tvm.reduce_axis((0, KW), name='dw')
ib = tvm.reduce_axis((0, IB), name='ib')
kb = tvm.reduce_axis((0, KB), name='kb')
def _conv(n, h, w, co, vh, vw, vc):
return tvm.sum((tvm.popcount(
kernel_vec[co, dh, dw, kb, vc, ci].astype('uint16') &
data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('uint16'))
<< (kb + ib).astype('uint16')), axis=[dh, dw, kb, ib, ci])
conv = tvm.compute(ovshape, _conv, name='conv')
return tvm.compute(oshape, lambda n, h, w, co:
conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC].astype(out_dtype),
name='output_vec', tag='spatial_bitserial_conv_nhwc')
def _intrin_popcount(m, k_i, w_b, x_b):
dtype = 'uint8'
w = tvm.placeholder((w_b, m, k_i), dtype=dtype, name='w')
x = tvm.placeholder((x_b, k_i,), dtype=dtype, name='x')
k = tvm.reduce_axis((0, k_i), name='k')
bw = tvm.reduce_axis((0, w_b), name='bw')
bx = tvm.reduce_axis((0, x_b), name='bx')
z = tvm.compute((m,), lambda i:
tvm.sum(tvm.popcount(w[bw, i, k].astype('uint16') &
x[bx, k].astype('uint16'))
<< (bw+bx).astype('uint16'), axis=[bw, bx, k]), name='z')
Wb = tvm.decl_buffer(w.shape, w.dtype,
name="W",
offset_factor=k_i,
strides=[tvm.var('ldw'), tvm.var('ldw'), 1])
Xb = tvm.decl_buffer(x.shape, x.dtype,
name="X",
offset_factor=k_i,
strides=[tvm.var('ldw'), 1])
def _intrin_func(ins, outs):
ww, xx = ins
zz = outs[0]
vpadd = "llvm.arm.neon.vpadd.v8u8"
vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16"
args_1 = tvm.const(1, 'uint32')
args_2 = tvm.const(2, 'uint32')
def _instr(index):
irb = tvm.ir_builder.create()
if index == 1:
irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8')))
return irb.get()
cnts8 = [None] * 8
cnts4 = [None] * 4
cnts2 = [None] * 2
for bw in range(w_b):
for bx in range(x_b):
if k_i == 16:
for i in range(m):
ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload([bx, 0], 'uint8x16')
cnts = tvm.popcount(ands)
upper_half = tvm.call_pure_intrin('uint8x8', 'vectorhigh', cnts)
lower_half = tvm.call_pure_intrin('uint8x8', 'vectorlow', cnts)
cnts8[i] = upper_half + lower_half
for i in range(m//2):
cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
args_1, cnts8[i*2], cnts8[i*2+1])
for i in range(m//4):
cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
args_1, cnts4[i*2], cnts4[i*2+1])
cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.const(bw+bx, dtype)
out = tvm.call_llvm_intrin('uint16x8', vpadalu,
args_2, zz.vload(0, 'uint16x8'), shifted_cnts)
else: # ki == 8
for i in range(m):
ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload([bx, 0], 'uint8x8')
cnts8[i] = tvm.popcount(ands)
for i in range(m//2):
cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
args_1, cnts8[i*2], cnts8[i*2+1])
for i in range(m//4):
cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
args_1, cnts4[i*2], cnts4[i*2+1])
cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.const(bw+bx, dtype)
out = tvm.call_llvm_intrin('uint16x8', vpadalu,
args_2, zz.vload(0, 'uint16x8'), shifted_cnts)
irb.emit(zz.vstore(0, out))
return irb.get()
# body, reset, update
return _instr(0), _instr(1), _instr(2)
with tvm.build_config(offset_factor=1, partition_const_loop=True):
return tvm.decl_tensor_intrin(z.op, _intrin_func, binds={w: Wb, x:Xb})
# ARM specific schedule that using custom microkernel
def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, last):
# no stride and padding info here
_, H, W, IB, CI = data_q.shape
KH, KW, KB, _, CO = kernel_q.shape
KB = get_const_int(KB)
IB = get_const_int(IB)
if data_pad is None:
padding = (0, 0)
_, in_h, in_w, _, _ = data_q.shape
kern_h, kern_w, _, _ = kernel.shape
_, out_h, out_w, _ = output.shape
hstride = (in_h - kern_h) // (out_h - 1)
wstride = (in_w - kern_w) // (out_w - 1)
stride = get_const_int(hstride), get_const_int(wstride)
else:
_, in_h, in_w, _, _ = data_q.shape
_, pad_h, pad_w, _, _ = data_pad.shape
hpad = (pad_h - in_h) // 2
wpad = (pad_w - in_w) // 2
padding = get_const_int(hpad), get_const_int(wpad)
_, in_h, in_w, _, _ = data_pad.shape
kern_h, kern_w, _, _ = kernel.shape
_, out_h, out_w, _ = output.shape
hstride = (in_h - kern_h) // (out_h - 1)
wstride = (in_w - kern_w) // (out_w - 1)
stride = get_const_int(hstride), get_const_int(wstride)
wkl = _get_workload(data, kernel, stride, padding, output.dtype, "NHWC")
sch = _get_schedule(wkl, "NHWC")
VH = sch.vh
VW = sch.vw
VC = sch.vc
ba = sch.ba
bc = sch.bc
##### Schedule data packing
if data_pad is not None:
s[data_pad].compute_inline()
_, h, _, _, _, _, _ = s[data_vec].op.axis
if ba == 1:
oaxis = h
paxis = h
else:
oh, ih = s[data_vec].split(h, ba)
oaxis = oh
paxis = ih
s[data_vec].parallel(paxis)
s[data_vec].pragma(oaxis, "parallel_launch_point")
s[data_vec].pragma(paxis, "parallel_stride_pattern")
s[data_vec].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule kernel packing
co, _, _, _, _, _ = s[kernel_vec].op.axis
if bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[kernel_vec].split(co, bc)
oaxis = oco
paxis = ico
s[kernel_vec].parallel(paxis)
s[kernel_vec].pragma(oaxis, "parallel_launch_point")
s[kernel_vec].pragma(paxis, "parallel_stride_pattern")
s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule Convolution
n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis
dh, dw, kb, ib, ci = s[conv_out].op.reduce_axis
kfactor = sch.kfactor
if sch.split_ci:
oci, ici = s[conv_out].split(ci, kfactor)
s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, oci, kb, ib, vc, ici)
else:
s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, kb, ib, vc, ci)
pc = _intrin_popcount(8, kfactor, KB, IB)
s[conv_out].tensorize(kb, pc)
n, h, w, co = s[last].op.axis
co, vc = s[last].split(co, VC)
oh, ow, vh, vw = s[last].tile(h, w, VH, VW)
s[last].reorder(n, oh, ow, co, vc, vh, vw)
s[last].vectorize(vw)
if last != output:
s[last].compute_inline()
s[conv_out].compute_at(s[last], ow)
if co == 1:
oaxis = oh
paxis = oh
else:
oho, iho = s[last].split(oh, bc)
oaxis = oho
paxis = iho
s[last].parallel(paxis)
s = s.normalize()
return s
@generic.schedule_bitserial_conv2d_nhwc.register(["rasp"])
def schedule_bitserial_conv2d_nhwc(outs):
"""Raspverry pi schedule for bitserial conv2d"""
s = tvm.create_schedule([x.op for x in outs])
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
if 'spatial_bitserial_conv_nhwc' in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[0]
kernel_q = kernel_vec.op.input_tensors[0]
kernel = kernel_q.op.input_tensors[0]
if "QuantizeInput" in kernel.op.name:
# Need to go up 1 further, from the combine in bitpack
kernel = kernel.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[1]
data_q = data_vec.op.input_tensors[0]
data = data_q.op.input_tensors[0]
data_pad = None
if isinstance(data_q.op, tvm.tensor.ComputeOp) and "pad" in data_q.op.tag:
data_pad = data_q
data_q = data
data = data_q.op.input_tensors[0]
if "QuantizeInput" in data.op.name:
# Need to go up 1 further, from the combine in bitpack
data = data.op.input_tensors[0]
_schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec, conv_out, output, outs[0])
traverse(outs[0].op)
return s
......@@ -8,3 +8,4 @@ from .binary_dense import schedule_binary_dense
from .nn import *
from .injective import *
from .pooling import schedule_pool, schedule_global_pool
from .bitserial_conv2d import schedule_bitserial_conv2d
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Bitserial conv2d schedule on x86"""
import tvm
from topi.util import get_const_int
from .. import generic, tag
from ..nn.bitserial_conv2d import bitserial_conv2d, _get_schedule, _get_workload
from ..nn.bitserial_conv2d import SpatialPackNCHW, SpatialPackNHWC
from ..nn.bitserial_conv2d import _WORKLOADS, _SCH_TO_DECL_FUNC_QUANT
_QUANTIZED_SCHEDULES_NCHW = [
# resnet
SpatialPackNCHW(2, 2, 8, 1, 1),
SpatialPackNCHW(1, 4, 8, 4, 1),
SpatialPackNCHW(1, 4, 8, 1, 16),
SpatialPackNCHW(1, 4, 8, 4, 8),
SpatialPackNCHW(1, 7, 8, 3, 8),
SpatialPackNCHW(1, 2, 8, 1, 8),
SpatialPackNCHW(2, 1, 8, 1, 4),
SpatialPackNCHW(1, 7, 8, 1, 1),
SpatialPackNCHW(1, 1, 8, 1, 16),
SpatialPackNCHW(1, 1, 8, 1, 8),
SpatialPackNCHW(1, 1, 8, 1, 16),
SpatialPackNCHW(3, 3, 16, 3, 16),
SpatialPackNCHW(1, 1, 16, 2, 16),
SpatialPackNCHW(1, 1, 8, 1, 16),
SpatialPackNCHW(1, 1, 8, 1, 16),
]
_QUANTIZED_SCHEDULES_NHWC = [
# resnet
SpatialPackNHWC(2, 2, 8, 1, 1),
SpatialPackNHWC(1, 4, 8, 4, 1),
SpatialPackNHWC(1, 4, 8, 1, 16),
SpatialPackNHWC(1, 4, 8, 4, 8),
SpatialPackNHWC(1, 7, 8, 3, 8),
SpatialPackNHWC(1, 2, 8, 1, 8),
SpatialPackNHWC(2, 1, 8, 1, 4),
SpatialPackNHWC(1, 7, 8, 1, 1),
SpatialPackNHWC(1, 1, 8, 1, 16),
SpatialPackNHWC(1, 1, 8, 1, 8),
SpatialPackNHWC(1, 1, 8, 1, 16),
]
@_get_schedule.register("cpu")
def _get_schedule_bitserial_conv2d(wkl, layout):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
idx = _WORKLOADS.index(wkl)
if layout == "NCHW":
sch = _QUANTIZED_SCHEDULES_NCHW[idx]
elif layout == "NHWC":
sch = _QUANTIZED_SCHEDULES_NHWC[idx]
return sch
@bitserial_conv2d.register("cpu")
def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits,
layout='NCHW', pack_dtype=None, out_dtype=None, dorefa=False):
if out_dtype is None:
out_dtype = data.dtype
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
assert layout == "NCHW" or layout == "NHWC", "only support layouts NCHW and NHWC"
wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout)
sch = _get_schedule(wkl, layout)
return _SCH_TO_DECL_FUNC_QUANT[type(sch)](data, kernel, stride, padding, activation_bits,
weight_bits, pack_dtype, out_dtype, dorefa)
@generic.schedule_bitserial_conv2d_nchw.register(["cpu"])
@generic.schedule_bitserial_conv2d_nhwc.register(["cpu"])
def schedule_bitserial_conv2d(outs):
"""CPU schedule for bitserial convolutions NCHW and NHWC"""
s = tvm.create_schedule([x.op for x in outs])
def traverse(op):
"""Traverse operators from computation graph"""
output = op.output(0)
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag) or 'elemwise' in op.tag:
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif 'spatial_bitserial_conv_nchw' in op.tag or 'spatial_bitserial_conv_nhwc' in op.tag:
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel_q = kernel_vec.op.input_tensors[0]
kernel = kernel_q.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[0]
data_q = data_vec.op.input_tensors[0]
data = data_q.op.input_tensors[0]
data_pad = None
if isinstance(data_q.op, tvm.tensor.ComputeOp) and "pad" in data_q.op.tag:
data_pad = data_q
data_q = data
data = data_q.op.input_tensors[0]
if "QuantizeInput" in kernel.op.name:
# Need to go up 1 further, from the combine in bitpack
kernel = kernel.op.input_tensors[0]
if "QuantizeInput" in data.op.name:
# Need to go up 1 further, from the combine in bitpack
data = data.op.input_tensors[0]
if 'spatial_bitserial_conv_nchw' in op.tag:
_schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, outs[0])
elif 'spatial_bitserial_conv_nhwc' in op.tag:
_schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, outs[0])
traverse(outs[0].op)
return s
def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, last):
IB, _, CI, IH, IW = data_q.shape
KB, CO, _, KH, KW = kernel_q.shape
_, _, OH, OW = output.shape
# Infer padding and stride
if data_pad is None:
padding = (0, 0)
TH, TW = IH, IW
else:
_, _, _, TH, TW = data_pad.shape
hpad = get_const_int((TH - IH) // 2)
wpad = get_const_int((TW - IW) // 2)
padding = (hpad, wpad)
hstride = get_const_int((TH - KH) // (OH - 1))
wstride = get_const_int((TW - KW) // (OW - 1))
stride = (hstride, wstride)
wkl = _get_workload(data, kernel, stride, padding, output.dtype, "NCHW")
sch = _get_schedule(wkl, "NCHW")
VH = sch.vh
VW = sch.vw
VC = sch.vc
ba = sch.ba
bc = sch.bc
CC = s.cache_write(conv_out, "global")
n, co, oh, ow, vh, vw, vc = s[conv_out].op.axis
s[conv_out].vectorize(vc)
s[CC].compute_at(s[conv_out], ow)
n, co, oh, ow, vh, vw, vc = s[CC].op.axis
ci, dh, dw, b1, b2 = s[CC].op.reduce_axis
s[CC].reorder(ci, dh, vh, dw, vw, b1, b2, vc)
s[CC].unroll(b1)
s[CC].unroll(b2)
s[CC].vectorize(vc)
##### Schedule A
if data_pad is not None:
s[data_pad].compute_inline()
_, h, _, _, _, _, vw = s[data_vec].op.axis
s[data_vec].vectorize(vw)
if ba == 1:
oaxis = h
paxis = h
else:
oh, ih = s[data_vec].split(h, ba)
oaxis = oh
paxis = ih
s[data_vec].parallel(paxis)
s[data_vec].pragma(oaxis, "parallel_launch_point")
s[data_vec].pragma(paxis, "parallel_stride_pattern")
s[data_vec].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule B
co, _, _, _, _, vc = s[kernel_vec].op.axis
s[kernel_vec].vectorize(vc)
if bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[kernel_vec].split(co, bc)
oaxis = oco
paxis = ico
s[kernel_vec].parallel(paxis)
s[kernel_vec].pragma(oaxis, "parallel_launch_point")
s[kernel_vec].pragma(paxis, "parallel_stride_pattern")
s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule C
n, co, h, w = s[last].op.axis
co, vc = s[last].split(co, VC)
oh, ow, vh, vw = s[last].tile(h, w, VH, VW)
s[last].reorder(n, co, oh, ow, vh, vw, vc)
if last != output:
s[output].compute_inline()
s[conv_out].compute_at(s[last], ow)
if bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[last].split(co, bc)
oaxis = oco
paxis = ico
s[last].parallel(paxis)
s[last].pragma(oaxis, "parallel_launch_point")
s[last].pragma(paxis, "parallel_stride_pattern")
s[last].pragma(oaxis, "parallel_barrier_when_finish")
return s
def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, last):
# no stride and padding info here
_, IH, IW, CI, IB = data_q.shape
KH, KW, _, CO, KB = kernel_q.shape
_, OH, OW, _ = output.shape
# Infer padding and stride
if data_pad is None:
padding = (0, 0)
TH, TW = IH, IW
else:
_, TH, TW, _, _ = data_pad.shape
hpad = get_const_int((TH - IH) // 2)
wpad = get_const_int((TW - IW) // 2)
padding = (hpad, wpad)
hstride = get_const_int((TH - KH) // (OH - 1))
wstride = get_const_int((TW - KW) // (OW - 1))
stride = (hstride, wstride)
wkl = _get_workload(data, kernel, stride, padding, last.dtype, "NHWC")
sch = _get_schedule(wkl, "NHWC")
VH = sch.vh
VW = sch.vw
VC = sch.vc
ba = sch.ba
bc = sch.bc
##### Schedule data packing
if data_pad is not None:
s[data_pad].compute_inline()
_, h, _, _, _, _, _ = s[data_vec].op.axis
if ba == 1:
oaxis = h
paxis = h
else:
oh, ih = s[data_vec].split(h, ba)
oaxis = oh
paxis = ih
s[data_vec].parallel(paxis)
s[data_vec].pragma(oaxis, "parallel_launch_point")
s[data_vec].pragma(paxis, "parallel_stride_pattern")
s[data_vec].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule kernel packing
co, _, _, _, _, _ = s[kernel_vec].op.axis
if bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[kernel_vec].split(co, bc)
oaxis = oco
paxis = ico
s[kernel_vec].parallel(paxis)
s[kernel_vec].pragma(oaxis, "parallel_launch_point")
s[kernel_vec].pragma(paxis, "parallel_stride_pattern")
s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule Convolution
n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis
dh, dw, ci, b1, b2 = s[conv_out].op.reduce_axis
s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
s[conv_out].unroll(b1)
s[conv_out].unroll(b2)
s[conv_out].vectorize(vc)
# # Schedule output
n, h, w, co = s[last].op.axis
co, vc = s[last].split(co, VC)
oh, ow, vh, vw = s[last].tile(h, w, VH, VW)
s[last].reorder(n, oh, ow, co, vh, vw, vc)
s[last].vectorize(vc)
if last != output:
s[output].compute_inline()
s[conv_out].compute_at(s[last], ow)
if bc == 1:
oaxis = oh
paxis = oh
else:
oho, iho = s[last].split(oh, bc)
oaxis = oho
paxis = iho
s[last].parallel(paxis)
s[last].pragma(oaxis, "parallel_launch_point")
s[last].pragma(paxis, "parallel_stride_pattern")
s[last].pragma(oaxis, "parallel_barrier_when_finish")
return s
import os
import numpy as np
import tvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from tvm.contrib import util
from tvm.contrib.pickle_memoize import memoize
def generate_quantized_np(shape, bits, out_dtype):
min_val = 0
max_val = 1 << bits
return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa):
in_height = in_width = in_size
input_type='uint32'
out_dtype='int32'
with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_type, name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits,
out_dtype=out_dtype, layout="NCHW", dorefa=dorefa)
s = topi.generic.schedule_bitserial_conv2d_nchw([B])
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype
def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type)
w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type)
if dorefa:
w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1
b_np = topi.testing.conv2d_nchw_python(a_np.astype(out_dtype), w_, stride, padding)
else:
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
return a_np, w_np, b_np
a_np, w_np, b_np = get_ref_data()
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, W, B], "llvm")
func(a, w, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa):
in_height = in_width = in_size
input_type='uint32'
out_dtype='int32'
with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype,
layout="NHWC", dorefa=dorefa)
s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype
def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type)
w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type)
if dorefa:
w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1
b_np = topi.testing.conv2d_nhwc_python(a_np, w_, stride, padding).astype(out_dtype)
else:
b_np = topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding).astype(out_dtype)
return a_np, w_np, b_np
a_np, w_np, b_np = get_ref_data()
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, W, B], 'llvm')
func(a, w, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def test_bitserial_conv2d():
in_size = 56
ic, oc = 64, 64
k = 3
stride = 1
pad = 1
verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 1, 1, True)
verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 2, 1, True)
verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 1, 1, False)
verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 2, 1, False)
verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 2, 2, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, True)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, True)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 2, False)
if __name__ == "__main__":
test_bitserial_conv2d()
\ No newline at end of file
import os
import re
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
from tvm.contrib import util
target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'
def generate_quantized_np(shape, bits, out_dtype):
np.random.seed(0)
min_val = 0
max_val = 1 << bits
return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
# Verify that certain special instructions from the tensorize pass exist
def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa):
in_height = in_width = in_size
input_type='uint32'
out_dtype='int32'
with tvm.target.rasp():
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype,
layout="NHWC", dorefa=dorefa)
s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
func = tvm.build(s, [A, W, B], target)
assembly = func.get_source('asm')
matches = re.findall("vpadal", assembly)
assert (len(matches) > 0)
matches = re.findall("vcnt", assembly)
assert (len(matches) > 0)
matches = re.findall("vpadd", assembly)
assert (len(matches) > 0)
def test_bitserial_conv2d():
in_size = 56
ic, oc = 64, 64
k = 3
stride = 1
pad = 1
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False)
if __name__ == "__main__":
test_bitserial_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment