Commit 17351875 by Meghan Cowan Committed by Tianqi Chen

[TOPI] bitserial_conv2d move to autotvm template and updates (#2819)

parent cefe07e2
...@@ -205,7 +205,7 @@ def args_to_workload(x, topi_compute_func=None): ...@@ -205,7 +205,7 @@ def args_to_workload(x, topi_compute_func=None):
workload = tuple([args_to_workload(a) for a in x]) workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float)): elif isinstance(x, (str, int, float, np.int, np.float)):
workload = x workload = x
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
workload = x.value workload = x.value
elif x is None: elif x is None:
workload = 0 workload = 0
......
...@@ -68,6 +68,8 @@ class TaskExtractEnv: ...@@ -68,6 +68,8 @@ class TaskExtractEnv:
topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw", topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw", topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
topi.nn.dense: "topi_nn_dense", topi.nn.dense: "topi_nn_dense",
topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw", topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
} }
...@@ -79,6 +81,8 @@ class TaskExtractEnv: ...@@ -79,6 +81,8 @@ class TaskExtractEnv:
topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw], topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw], topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
topi.nn.dense: [topi.generic.schedule_dense], topi.nn.dense: [topi.generic.schedule_dense],
topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw], topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
} }
...@@ -174,6 +178,24 @@ class TaskExtractEnv: ...@@ -174,6 +178,24 @@ class TaskExtractEnv:
return s, [data, weight, bias, C] return s, [data, weight, bias, C]
return s, [data, weight, C] return s, [data, weight, C]
@register("topi_nn_bitserial_conv2d_nhwc")
def _topi_bitserial_conv2d_nhwc(*args, **kwargs):
args = deserialize_args(args)
C = topi.nn.bitserial_conv2d_nhwc(*args, **kwargs)
s = topi.generic.nn.schedule_bitserial_conv2d_nhwc([C])
data = args[0]
kernel = args[1]
return s, [data, kernel, C]
@register("topi_nn_bitserial_conv2d_nchw")
def _topi_bitserial_conv2d_nchw(*args, **kwargs):
args = deserialize_args(args)
C = topi.nn.bitserial_conv2d_nchw(*args, **kwargs)
s = topi.generic.nn.schedule_bitserial_conv2d_nchw([C])
data = args[0]
kernel = args[1]
return s, [data, kernel, C]
@register("topi_nn_deformable_conv2d_nchw") @register("topi_nn_deformable_conv2d_nchw")
def _topi_nn_deformable_conv2d_nchw(*args, **kwargs): def _topi_nn_deformable_conv2d_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call" assert not kwargs, "Do not support kwargs in template function call"
......
# pylint: disable=invalid-name,unused-variable,invalid-name # pylint: disable=invalid-name,unused-variable,invalid-name
"""Bitserial conv2d schedule on raspberry pi""" """Bitserial conv2d schedule on arm cpu"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm import tvm
from tvm import autotvm
from .. import tag from .. import tag
from ..nn.pad import pad from ..nn.pad import pad
from ..nn.bitserial_conv2d import bitserial_conv2d, _get_schedule, _get_workload, bitpack from ..nn.bitserial_conv2d import bitpack, bitserial_conv2d_nhwc
from ..nn.bitserial_conv2d import SpatialPackNCHW, _WORKLOADS, spatial_pack_nchw
from ..nn.util import get_pad_tuple from ..nn.util import get_pad_tuple
from ..util import get_const_int from ..util import get_const_int, get_const_tuple
from .. import generic from .. import generic
RaspSpatialPack = namedtuple('SpatialPack', def _kernel_vec_spatial_pack_nhwc(kernel, kernel_bits, VC, use_bitpack=True):
['vh', 'vw', 'vc', 'ba', 'bc', 'split_ci', 'kfactor']) if use_bitpack:
_QUANTIZED_SCHEDULES_NHWC = [
RaspSpatialPack(2, 2, 8, 1, 1, False, 8),
RaspSpatialPack(1, 4, 8, 4, 1, False, 8),
RaspSpatialPack(1, 4, 8, 1, 16, False, 8),
RaspSpatialPack(1, 4, 8, 4, 8, False, 8),
RaspSpatialPack(1, 7, 8, 3, 8, False, 16),
RaspSpatialPack(1, 2, 8, 1, 8, False, 16),
RaspSpatialPack(2, 1, 8, 1, 4, False, 16),
RaspSpatialPack(1, 7, 8, 1, 1, True, 16),
RaspSpatialPack(1, 1, 8, 1, 16, True, 16),
RaspSpatialPack(1, 1, 8, 1, 8, True, 16),
RaspSpatialPack(1, 1, 8, 1, 16, True, 16),
]
_QUANTIZED_SCHEDULES_NCHW = [
# resnet
SpatialPackNCHW(2, 2, 8, 1, 1),
SpatialPackNCHW(1, 4, 8, 4, 1),
SpatialPackNCHW(1, 4, 8, 1, 16),
SpatialPackNCHW(1, 4, 8, 4, 8),
SpatialPackNCHW(1, 7, 8, 3, 8),
SpatialPackNCHW(1, 2, 8, 1, 8),
SpatialPackNCHW(2, 1, 8, 1, 4),
SpatialPackNCHW(1, 7, 8, 1, 1),
SpatialPackNCHW(1, 1, 8, 1, 16),
SpatialPackNCHW(1, 1, 8, 1, 8),
SpatialPackNCHW(1, 1, 8, 1, 16),
]
@_get_schedule.register("arm_cpu")
def _get_schedule_bitserial_conv2d(wkl, layout):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
idx = _WORKLOADS.index(wkl)
if layout == "NCHW":
sch = _QUANTIZED_SCHEDULES_NCHW[idx]
elif layout == "NHWC":
sch = _QUANTIZED_SCHEDULES_NHWC[idx]
return sch
@bitserial_conv2d.register("arm_cpu")
def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits,
layout='NCHW', pack_dtype=None, out_dtype=None, dorefa=False):
if out_dtype is None:
out_dtype = data.dtype
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
assert layout in ("NCHW", "NHWC"), "only support layouts NCHW and NHWC"
if dorefa:
assert layout == "NCHW", "Cannot support dorea with NHWC layout yet"
wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout)
sch = _get_schedule(wkl, layout)
if layout == "NCHW":
return spatial_pack_nchw(data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa)
return _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits,
weight_bits, out_dtype)
def _kernel_vec_spatial_pack_nhwc(kernel, kernel_bits, VC):
kernel_q = bitpack(kernel, kernel_bits, pack_axis=2, bit_axis=2, pack_type='uint8') kernel_q = bitpack(kernel, kernel_bits, pack_axis=2, bit_axis=2, pack_type='uint8')
else:
kernel_q = kernel
KH, KW, KB, CI, CO = kernel_q.shape KH, KW, KB, CI, CO = kernel_q.shape
kvshape = (CO//VC, KH, KW, KB, VC, CI) kvshape = (CO//VC, KH, KW, KB, VC, CI)
return tvm.compute(kvshape, lambda co, dh, dw, b, vc, ci: \ return tvm.compute(kvshape, lambda co, dh, dw, b, vc, ci: \
kernel_q[dh][dw][b][ci][co*VC+vc], name='kernel_vec') kernel_q[dh][dw][b][ci][co*VC+vc], name='kernel_vec')
def _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bits, out_dtype): @autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'arm_cpu', 'direct')
def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype, out_dtype, unipolar):
""" Compute convolution with pack on spatial axes. """ """ Compute convolution with pack on spatial axes. """
assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NHWC") assert pack_dtype == 'uint8', "only support packing into uint8 bits"
sch = _get_schedule(wkl, "NHWC") assert out_dtype == 'int16', "only support output type of int16"
VH = sch.vh
VW = sch.vw
VC = sch.vc
data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8') N, H, W, CI = get_const_tuple(data.shape)
kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC) if len(kernel.shape) == 4:
N, H, W, IB, CI = data_q.shape KH, KW, _, CO = get_const_tuple(kernel.shape)
OCO, KH, KW, KB, VC, _ = kernel_vec.shape CI_packed = CI // 8
else:
KH, KW, KB, CI_packed, CO = get_const_tuple(kernel.shape)
CO = OCO * VC if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
else:
TPAD, LPAD, DPAD, RPAD = padding
if isinstance(stride, (tuple, list)): if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride HSTR, WSTR = stride
...@@ -102,75 +46,151 @@ def _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bi ...@@ -102,75 +46,151 @@ def _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bi
HSTR, WSTR = stride, stride HSTR, WSTR = stride, stride
HCAT, WCAT = KH-1, KW-1 HCAT, WCAT = KH-1, KW-1
PAD_H = H + 2*HPAD PAD_H = H + (TPAD + DPAD)
PAD_W = W + 2*WPAD PAD_W = W + (LPAD + RPAD)
OH = (H + 2*HPAD - KH) // HSTR + 1 OH = (PAD_H - KH) // HSTR + 1
OW = (W + 2*WPAD - KW) // WSTR + 1 OW = (PAD_W - KW) // WSTR + 1
oshape = (1, OH, OW, CO)
# Pad input channels of weights and data when it is not a multiple of 8
if CI_packed % 8 != 0:
CI_PAD = CI_packed % 8
CI_packed += CI_PAD
else:
CI_PAD = 0
# ==================== define configuration space ====================
n, oh, ow, co = cfg.axis(N), cfg.axis(OH), cfg.axis(OW), cfg.axis(CO)
ci, kh, kw = cfg.reduce_axis(CI_packed), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
ib, kb = cfg.reduce_axis(activation_bits), cfg.reduce_axis(weight_bits)
co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2,
filter=lambda x: x.size[-1] == 8)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2,
filter=lambda x: x.size[-1] >= 2)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2,
filter=lambda x: x.size[-1] >= 2)
ci_o, ci_i = cfg.define_split("tile_ci", ci, num_outputs=2,
filter=lambda x: x.size[-1] == 8 or x.size[-1] == 16)
re_axes = cfg.define_reorder("reorder_0",
[n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i],
policy='candidate', candidate=[
[n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i],
[n, oh, ow, co, vh, vw, kw, kh, ci_o, kb, ib, vc, ci_i],])
cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops
# ====================
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]
data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8')
kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC, len(kernel.shape) == 4)
if kernel_vec.shape[-1] % 8 != 0 and CI_PAD != 0:
kernel_vec = pad(kernel_vec, [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, CI_PAD])
N, H, W, IB, CI = data_q.shape
OCO, KH, KW, KB, VC, CI = kernel_vec.shape
dvshape = (N, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, IB, CI) dvshape = (N, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, IB, CI)
ovshape = (1, OH // VH, OW // VW, CO // VC, VH, VW, VC) ovshape = (1, OH // VH, OW // VW, CO // VC, VH, VW, VC)
oshape = (1, OH, OW, CO)
if (HPAD != 0 and WPAD != 0): if (TPAD != 0 and RPAD != 0):
data_pad = pad(data_q, (0, HPAD, WPAD, 0, 0), name="data_pad") data_pad = pad(data_q, (0, TPAD, LPAD, 0, 0), (0, DPAD, RPAD, 0, CI_PAD), name="data_pad")
elif CI_PAD != 0:
data_pad = pad(data_q, (0, 0, 0, 0, 0), (0, 0, 0, 0, CI_PAD), name="data_pad")
else: else:
data_pad = data_q data_pad = data_q
data_vec = tvm.compute(dvshape, lambda n, h, w, vh, vw, b, ci: \ data_vec = tvm.compute(dvshape, lambda n, h, w, vh, vw, b, ci: \
data_pad[n][h*VH*HSTR+vh][w*VW*WSTR+vw][b][ci], name='data_vec') data_pad[n][h*VH*HSTR+vh][w*VW*WSTR+vw][b][ci], name='data_vec')
ci = tvm.reduce_axis((0, CI), name='ci') ci = tvm.reduce_axis((0, CI), name='ci')
dh = tvm.reduce_axis((0, KH), name='dh') dh = tvm.reduce_axis((0, KH), name='dh')
dw = tvm.reduce_axis((0, KW), name='dw') dw = tvm.reduce_axis((0, KW), name='dw')
ib = tvm.reduce_axis((0, IB), name='ib') ib = tvm.reduce_axis((0, IB), name='ib')
kb = tvm.reduce_axis((0, KB), name='kb') kb = tvm.reduce_axis((0, KB), name='kb')
def _conv(n, h, w, co, vh, vw, vc): def _bipolar_conv(n, h, w, co, vh, vw, vc):
return tvm.sum((tvm.popcount( return tvm.sum((tvm.popcount(
kernel_vec[co, dh, dw, kb, vc, ci].astype('uint16') & kernel_vec[co, dh, dw, kb, vc, ci].astype('uint16') &
data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('uint16')) data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('uint16'))
<< (kb + ib).astype('uint16')), axis=[dh, dw, kb, ib, ci]) << (kb + ib).astype('uint16')), axis=[dh, dw, kb, ib, ci])
def _unipolar_conv(n, h, w, co, vh, vw, vc):
return tvm.sum(
((tvm.popcount(kernel_vec[co, dh, dw, kb, vc, ci].astype('int16') &
data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('int16')) -
tvm.popcount(~kernel_vec[co, dh, dw, kb, vc, ci].astype('int16') &
data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci]).astype('int16'))
<< (kb + ib).astype('int16')), axis=[dh, dw, kb, ib, ci])
if unipolar:
conv_vec = tvm.compute(ovshape, _unipolar_conv, name='conv_vec', tag='unipolar')
else:
conv_vec = tvm.compute(ovshape, _bipolar_conv, name='conv_vec', tag='bipolar')
conv = tvm.compute(ovshape, _conv, name='conv') conv = tvm.compute(oshape, lambda n, h, w, co:
conv_vec[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC].astype(out_dtype),
name='conv', tag='spatial_bitserial_conv_nhwc')
return tvm.compute(oshape, lambda n, h, w, co: return conv
conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC].astype(out_dtype),
name='output_vec', tag='spatial_bitserial_conv_nhwc')
def _intrin_popcount(m, k_i, w_b, x_b): def _intrin_popcount(m, k_i, w_b, x_b, unipolar):
dtype = 'uint8' pack_dtype = 'uint8'
w = tvm.placeholder((w_b, m, k_i), dtype=dtype, name='w') w = tvm.placeholder((w_b, m, k_i), dtype=pack_dtype, name='w')
x = tvm.placeholder((x_b, k_i,), dtype=dtype, name='x') x = tvm.placeholder((x_b, k_i,), dtype=pack_dtype, name='x')
k = tvm.reduce_axis((0, k_i), name='k') k = tvm.reduce_axis((0, k_i), name='k')
bw = tvm.reduce_axis((0, w_b), name='bw') bw = tvm.reduce_axis((0, w_b), name='bw')
bx = tvm.reduce_axis((0, x_b), name='bx') bx = tvm.reduce_axis((0, x_b), name='bx')
if unipolar:
dtype = 'int16'
z = tvm.compute((m,), lambda i: z = tvm.compute((m,), lambda i:
tvm.sum(tvm.popcount(w[bw, i, k].astype('uint16') & tvm.sum((tvm.popcount(w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype)) -
x[bx, k].astype('uint16')) tvm.popcount(~w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype)))
<< (bw+bx).astype('uint16'), axis=[bw, bx, k]), name='z') << (bw+bx).astype(dtype), axis=[bw, bx, k]), name='z')
else:
dtype = 'uint16'
z = tvm.compute((m,), lambda i:
tvm.sum(tvm.popcount(w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype))
<< (bw+bx).astype(dtype), axis=[bw, bx, k]), name='z')
Wb = tvm.decl_buffer(w.shape, w.dtype, Wb = tvm.decl_buffer(w.shape, w.dtype,
name="W", name="W",
offset_factor=k_i, offset_factor=k_i,
strides=[tvm.var('ldw'), tvm.var('ldw'), 1]) strides=[tvm.var('ldw'), tvm.var('ldw'), 1]) # stride can be inferred
Xb = tvm.decl_buffer(x.shape, x.dtype, Xb = tvm.decl_buffer(x.shape, x.dtype,
name="X", name="X",
offset_factor=k_i, offset_factor=k_i,
strides=[tvm.var('ldw'), 1]) strides=[tvm.var('ldw'), 1])
Zb = tvm.decl_buffer(z.shape, z.dtype,
name="Z",
offset_factor=1,
strides=[1])
def _intrin_func(ins, outs): def _intrin_func(ins, outs):
ww, xx = ins ww, xx = ins
zz = outs[0] zz = outs[0]
vpadd = "llvm.arm.neon.vpadd.v8u8"
vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16"
args_1 = tvm.const(1, 'uint32') args_1 = tvm.const(1, 'uint32')
args_2 = tvm.const(2, 'uint32') args_2 = tvm.const(2, 'uint32')
if unipolar:
vpadd = "llvm.arm.neon.vpadd.v8i8"
vpadalu = "llvm.arm.neon.vpadals.v16i8.v8i16"
full_dtype = 'int8x16'
half_dtype = 'int8x8'
return_dtype = 'int16x8'
else:
vpadd = "llvm.arm.neon.vpadd.v8u8"
vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16"
full_dtype = 'uint8x16'
half_dtype = 'uint8x8'
return_dtype = 'uint16x8'
def _instr(index): def _instr(index):
irb = tvm.ir_builder.create() irb = tvm.ir_builder.create()
if index == 1: if index == 1: # reduce reset
irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8'))) irb.emit(zz.vstore(0, tvm.const(0, return_dtype)))
return irb.get() return irb.get()
# body and reduce update
cnts8 = [None] * 8 cnts8 = [None] * 8
cnts4 = [None] * 4 cnts4 = [None] * 4
cnts2 = [None] * 2 cnts2 = [None] * 2
...@@ -178,154 +198,108 @@ def _intrin_popcount(m, k_i, w_b, x_b): ...@@ -178,154 +198,108 @@ def _intrin_popcount(m, k_i, w_b, x_b):
for bx in range(x_b): for bx in range(x_b):
if k_i == 16: if k_i == 16:
for i in range(m): for i in range(m):
ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload([bx, 0], 'uint8x16') w_ = ww.vload([bw, i, 0], 'uint8x16').astype(full_dtype)
cnts = tvm.popcount(ands) x_ = xx.vload([bx, 0], 'uint8x16').astype(full_dtype)
upper_half = tvm.call_pure_intrin('uint8x8', 'vectorhigh', cnts) if unipolar:
lower_half = tvm.call_pure_intrin('uint8x8', 'vectorlow', cnts) cnts = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_)
else:
cnts = tvm.popcount(w_ & x_)
upper_half = tvm.call_pure_intrin(half_dtype, 'vectorhigh', cnts)
lower_half = tvm.call_pure_intrin(half_dtype, 'vectorlow', cnts)
cnts8[i] = upper_half + lower_half cnts8[i] = upper_half + lower_half
for i in range(m//2): for i in range(m//2):
cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
args_1, cnts8[i*2], cnts8[i*2+1]) args_1, cnts8[i*2], cnts8[i*2+1])
for i in range(m//4): for i in range(m//4):
cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
args_1, cnts4[i*2], cnts4[i*2+1]) args_1, cnts4[i*2], cnts4[i*2+1])
cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.const(bw+bx, dtype) shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype)
out = tvm.call_llvm_intrin('uint16x8', vpadalu, out = tvm.call_llvm_intrin(return_dtype, vpadalu,
args_2, zz.vload(0, 'uint16x8'), shifted_cnts) args_2, zz.vload(0, return_dtype), shifted_cnts)
else: # ki == 8 else: # ki == 8
for i in range(m): for i in range(m):
ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload([bx, 0], 'uint8x8') w_ = ww.vload([bw, i, 0], 'uint8x8').astype(half_dtype)
cnts8[i] = tvm.popcount(ands) x_ = xx.vload([bx, 0], 'uint8x8').astype(half_dtype)
if unipolar:
cnts8[i] = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_)
else:
cnts8[i] = tvm.popcount(w_ & x_)
for i in range(m//2): for i in range(m//2):
cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
args_1, cnts8[i*2], cnts8[i*2+1]) args_1, cnts8[i*2], cnts8[i*2+1])
for i in range(m//4): for i in range(m//4):
cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
args_1, cnts4[i*2], cnts4[i*2+1]) args_1, cnts4[i*2], cnts4[i*2+1])
cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.const(bw+bx, dtype) shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype)
out = tvm.call_llvm_intrin('uint16x8', vpadalu, out = tvm.call_llvm_intrin(return_dtype, vpadalu,
args_2, zz.vload(0, 'uint16x8'), shifted_cnts) args_2, zz.vload(0, return_dtype), shifted_cnts)
irb.emit(zz.vstore(0, out)) irb.emit(zz.vstore(0, out))
return irb.get() return irb.get()
# body, reset, update # body, reset, update
return _instr(0), _instr(1), _instr(2) return _instr(0), _instr(1), _instr(2)
with tvm.build_config(offset_factor=1, partition_const_loop=True): with tvm.build_config(offset_factor=1, partition_const_loop=True):
return tvm.decl_tensor_intrin(z.op, _intrin_func, binds={w: Wb, x:Xb}) return tvm.decl_tensor_intrin(z.op, _intrin_func, binds={w: Wb, x:Xb, z:Zb})
# ARM specific schedule that using custom microkernel # ARM specific schedule that using custom microkernel
def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
kernel, kernel_q, kernel_vec, conv_out, output, last, unipolar):
conv_out, output, last): _, _, _, _, _, IB, CI = data_vec.shape
# no stride and padding info here _, KH, KW, KB, _, _ = kernel_vec.shape
_, H, W, IB, CI = data_q.shape
KH, KW, KB, _, CO = kernel_q.shape
KB = get_const_int(KB) KB = get_const_int(KB)
IB = get_const_int(IB) IB = get_const_int(IB)
if data_pad is None: VC = cfg["tile_co"].size[-1]
padding = (0, 0) VH = cfg["tile_oh"].size[-1]
_, in_h, in_w, _, _ = data_q.shape VW = cfg["tile_ow"].size[-1]
kern_h, kern_w, _, _ = kernel.shape
_, out_h, out_w, _ = output.shape ##### Schedule data padding and packing
hstride = (in_h - kern_h) // (out_h - 1)
wstride = (in_w - kern_w) // (out_w - 1)
stride = get_const_int(hstride), get_const_int(wstride)
else:
_, in_h, in_w, _, _ = data_q.shape
_, pad_h, pad_w, _, _ = data_pad.shape
hpad = (pad_h - in_h) // 2
wpad = (pad_w - in_w) // 2
padding = get_const_int(hpad), get_const_int(wpad)
_, in_h, in_w, _, _ = data_pad.shape
kern_h, kern_w, _, _ = kernel.shape
_, out_h, out_w, _ = output.shape
hstride = (in_h - kern_h) // (out_h - 1)
wstride = (in_w - kern_w) // (out_w - 1)
stride = get_const_int(hstride), get_const_int(wstride)
wkl = _get_workload(data, kernel, stride, padding, output.dtype, "NHWC")
sch = _get_schedule(wkl, "NHWC")
VH = sch.vh
VW = sch.vw
VC = sch.vc
ba = sch.ba
bc = sch.bc
##### Schedule data packing
if data_pad is not None: if data_pad is not None:
s[data_pad].compute_inline() s[data_pad].compute_inline()
_, h, _, _, _, _, _ = s[data_vec].op.axis _, h, _, _, _, _, _ = s[data_vec].op.axis
if ba == 1: cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)
oaxis = h oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
paxis = h s[data_vec].parallel(oh)
else:
oh, ih = s[data_vec].split(h, ba)
oaxis = oh
paxis = ih
s[data_vec].parallel(paxis)
s[data_vec].pragma(oaxis, "parallel_launch_point")
s[data_vec].pragma(paxis, "parallel_stride_pattern")
s[data_vec].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule kernel packing #### Schedule kernel packing
co, _, _, _, _, _ = s[kernel_vec].op.axis co, _, _, _, _, _ = s[kernel_vec].op.axis
if bc == 1: cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)
oaxis = co oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
paxis = co s[kernel_vec].parallel(oco)
else:
oco, ico = s[kernel_vec].split(co, bc)
oaxis = oco
paxis = ico
s[kernel_vec].parallel(paxis)
s[kernel_vec].pragma(oaxis, "parallel_launch_point")
s[kernel_vec].pragma(paxis, "parallel_stride_pattern")
s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule Convolution ##### Schedule Convolution
n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis
dh, dw, kb, ib, ci = s[conv_out].op.reduce_axis kh, kw, kb, ib, ci = s[conv_out].op.reduce_axis
kfactor = sch.kfactor ci_o, ci_i = cfg['tile_ci'].apply(s, conv_out, ci)
if sch.split_ci: re_axes = cfg["reorder_0"].apply(s, conv_out,
oci, ici = s[conv_out].split(ci, kfactor) [n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i])
s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, oci, kb, ib, vc, ici)
else:
s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, kb, ib, vc, ci)
pc = _intrin_popcount(8, kfactor, KB, IB) # Use microkernel
kfactor = cfg['tile_ci'].size[1]
if kfactor % 8 == 0:
pc = _intrin_popcount(VC, kfactor, KB, IB, unipolar)
s[conv_out].tensorize(kb, pc) s[conv_out].tensorize(kb, pc)
n, h, w, co = s[last].op.axis n, h, w, co = s[last].op.axis
co, vc = s[last].split(co, VC) co, vc = cfg['tile_co'].apply(s, last, co)
oh, ow, vh, vw = s[last].tile(h, w, VH, VW) oh, vh = cfg['tile_oh'].apply(s, last, h)
s[last].reorder(n, oh, ow, co, vc, vh, vw) ow, vw = cfg['tile_ow'].apply(s, last, w)
s[last].vectorize(vw) s[last].reorder(n, oh, ow, co, vh, vw, vc)
s[last].vectorize(vc)
if last != output: if last != output:
s[last].compute_inline() s[last].compute_inline()
s[conv_out].compute_at(s[last], ow) s[conv_out].compute_at(s[last], co)
if co == 1: s[last].parallel(oh)
oaxis = oh
paxis = oh
else:
oho, iho = s[last].split(oh, bc)
oaxis = oho
paxis = iho
s[last].parallel(paxis)
s = s.normalize() s = s.normalize()
return s return s
@generic.schedule_bitserial_conv2d_nhwc.register(["arm_cpu"]) @autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nhwc, 'arm_cpu', 'direct')
def schedule_bitserial_conv2d_nhwc(outs): def schedule_bitserial_conv2d_nhwc(cfg, outs):
"""Raspverry pi schedule for bitserial conv2d""" """Arm cpu schedule for bitserial conv2d"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = [] scheduled_ops = []
...@@ -344,10 +318,6 @@ def schedule_bitserial_conv2d_nhwc(outs): ...@@ -344,10 +318,6 @@ def schedule_bitserial_conv2d_nhwc(outs):
conv_out = op.input_tensors[0] conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[0] kernel_vec = conv_out.op.input_tensors[0]
kernel_q = kernel_vec.op.input_tensors[0] kernel_q = kernel_vec.op.input_tensors[0]
kernel = kernel_q.op.input_tensors[0]
if "QuantizeInput" in kernel.op.name:
# Need to go up 1 further, from the combine in bitpack
kernel = kernel.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[1] data_vec = conv_out.op.input_tensors[1]
data_q = data_vec.op.input_tensors[0] data_q = data_vec.op.input_tensors[0]
data = data_q.op.input_tensors[0] data = data_q.op.input_tensors[0]
...@@ -355,13 +325,10 @@ def schedule_bitserial_conv2d_nhwc(outs): ...@@ -355,13 +325,10 @@ def schedule_bitserial_conv2d_nhwc(outs):
if isinstance(data_q.op, tvm.tensor.ComputeOp) and "pad" in data_q.op.tag: if isinstance(data_q.op, tvm.tensor.ComputeOp) and "pad" in data_q.op.tag:
data_pad = data_q data_pad = data_q
data_q = data data_q = data
data = data_q.op.input_tensors[0]
if "QuantizeInput" in data.op.name:
# Need to go up 1 further, from the combine in bitpack
data = data.op.input_tensors[0] data = data.op.input_tensors[0]
unipolar = "unipolar" in conv_out.op.tag
_schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
kernel, kernel_q, kernel_vec, conv_out, output, outs[0]) conv_out, output, outs[0], unipolar)
scheduled_ops.append(op) scheduled_ops.append(op)
traverse(outs[0].op) traverse(outs[0].op)
......
# pylint: disable=invalid-name, unused-variable, too-many-locals, too-many-arguments, unused-argument # pylint: disable=invalid-name, unused-variable, too-many-locals, too-many-arguments, unused-argument
"""Bitserial Conv2D operators""" """Bitserial Conv2D operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple
import numpy as np import numpy as np
import tvm import tvm
from tvm import autotvm
from topi.transform import concatenate from topi.transform import concatenate
from .pad import pad from .pad import pad
from .util import get_pad_tuple from .util import get_pad_tuple
from ..util import get_const_tuple, get_const_int from ..util import get_const_tuple, get_const_int
# workload description of conv2d
Workload = namedtuple('Workload',
['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
SpatialPackNCHW = namedtuple('SpatialPack',
['vh', 'vw', 'vc', 'ba', 'bc'])
SpatialPackNHWC = namedtuple('SpatialPack',
['vh', 'vw', 'vc', 'ba', 'bc'])
_WORKLOADS = [
# workloads of resnet18 on imagenet
# input_size, input_size, ic, oc, kh, kw, pad, pad, stride, stride
Workload('uint32', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('uint32', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('uint32', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('uint32', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('uint32', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('uint32', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('uint32', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('uint32', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('uint32', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('uint32', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('uint32', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
# workload of alexnet on cifar10
Workload('int32', 'int32', 27, 27, 96, 192, 5, 5, 2, 2, 1, 1),
Workload('int32', 'int32', 13, 13, 192, 384, 3, 3, 1, 1, 1, 1),
Workload('int32', 'int32', 13, 13, 384, 384, 3, 3, 1, 1, 1, 1),
Workload('int32', 'int32', 13, 13, 384, 256, 3, 3, 1, 1, 1, 1),
]
@tvm.target.generic_func @tvm.target.generic_func
def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits, def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight_bits,
layout='NCHW', pack_dtype='uint32', out_dtype='int32', dorefa=True): pack_dtype='uint32', out_dtype='int16', unipolar=True):
"""Bitserial Conv2D operator. """Bitserial Conv2D operator.
Parameters Parameters
---------- ----------
input : tvm.Tensor input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width] or 4-D with shape [batch, in_channel, in_height, in_width]
[batch, in_height, in_width, in_channel]
filter : tvm.Tensor filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width] or 4-D with shape [num_filter, in_channel, filter_height, filter_width]
[filter_height, filter_width, in_channel, num_filter]
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width] stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints padding : int or a list/tuple of two or four ints
padding size, or [pad_height, pad_width] padding size, [pad_height, pad_width], [pad_top, pad_left, pad_down, pad_right]
layout : str
layout of data
activation_bits: int activation_bits: int
number of bits used for activations/input elements number of bits used for activations/input elements
...@@ -78,63 +40,184 @@ def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits ...@@ -78,63 +40,184 @@ def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits
pack_dtype: str pack_dtype: str
bit packing type bit packing type
dorefa: bool unipolar: bool
preform the bitserial dot-product using 2 popcounts (required for DoReFa-Net) if binarization style is in unipolar 1/0 format, instead of bipolar -1/+1 format
Returns Returns
------- -------
output : tvm.Tensor output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width] or 4-D with shape [batch, out_channel, out_height, out_width]
[batch, out_height, out_width, out_channel]
""" """
# search platform specific declaration first assert isinstance(stride, int) or len(stride) == 2
# default declaration Input_q = bitpack(data, activation_bits, pack_axis=1, bit_axis=2, pack_type=pack_dtype)
if layout == 'NCHW': Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype)
return spatial_pack_nchw(data, kernel, stride, padding, activation_bits, weight_bits, batch, in_channel, activation_bits, in_height, in_width = Input_q.shape
pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa) num_filter, channel, kernel_h, kernel_w, weight_bits = Filter_q.shape
if layout == 'NHWC':
return spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bits, if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa) TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
raise ValueError("not support this layout {} yet".format(layout))
def _get_workload(data, kernel, stride, padding, out_dtype, layout):
""" Get the workload structure. """
assert layout in ("NCHW", "NHWC"), \
"Only support layouts NCHW and NHWC"
if layout == "NCHW":
_, CI, IH, IW = [x.value for x in data.shape]
CO, _, KH, KW = [x.value for x in kernel.shape]
else: # NHWC
IH, IW = data.shape[1].value, data.shape[2].value
KH, KW, CI, CO = [x for x in get_const_tuple(kernel.shape)]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else: else:
HSTR, WSTR = stride, stride TPAD, LPAD, DPAD, RPAD = padding
pad_before = [0, 0, 0, TPAD, LPAD]
pad_after = [0, 0, 0, DPAD, RPAD]
PadInput_q = pad(Input_q, pad_before, pad_after, name="pad_temp")
# compute the output shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
out_channel = num_filter
out_height = (in_height - kernel_h + TPAD + DPAD) // stride_h + 1
out_width = (in_width - kernel_w + LPAD + RPAD) // stride_w + 1
rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
b1 = tvm.reduce_axis((0, activation_bits), name='b1')
b2 = tvm.reduce_axis((0, weight_bits), name='b2')
if unipolar:
def _conv(nn, ff, yy, xx):
b1b2 = (b1+b2).astype(out_dtype)
return tvm.sum(
((tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
Filter_q[ff, rc, ry, rx, b2]) -
tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
~Filter_q[ff, rc, ry, rx, b2]))
<< (b1b2)).astype(out_dtype),
axis=[rc, ry, rx, b2, b1]).astype(out_dtype)
else:
def _conv(nn, ff, yy, xx):
b1b2 = (b1+b2).astype(out_dtype)
return tvm.sum((tvm.popcount(
PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
Filter_q[ff, rc, ry, rx, b2])<< (b1b2)).astype(out_dtype),
axis=[rc, ry, rx, b2, b1]).astype(out_dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) return tvm.compute((batch, out_channel, out_height, out_width), _conv,
name="Conv2dOutput", tag="bitserial_conv2d_nchw")
@tvm.target.generic_func @tvm.target.generic_func
def _get_schedule(wkl, layout): def bitserial_conv2d_nhwc(data, kernel, stride, padding, activation_bits, weight_bits,
# pylint: disable=unreachable pack_dtype='uint32', out_dtype='int16', unipolar=True):
""" Get the platform specific schedule. """ """Bitserial Conv2D operator.
target = tvm.target.current_target()
raise RuntimeError( Parameters
"No schedule for current target:{}".format(target)) ----------
# This return has no use, merely to supress pylint warning input : tvm.Tensor
return wkl 4-D with shape [batch, in_height, in_width, in_channel]
def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits, filter : tvm.Tensor
pack_dtype, out_dtype, dorefa=False): 4-D with shape [filter_height, filter_width, in_channel, num_filter]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two or four ints
padding size, [pad_height, pad_width], [pad_top, pad_left, pad_down, pad_right]
activation_bits: int
number of bits used for activations/input elements
weight_bits: int
number of bits used for weight elements
out_dtype: str
return type of convolution
pack_dtype: str
bit packing type
unipolar: bool
if binarization style is in unipolar 1/0 format, instead of bipolar -1/+1 format
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
assert isinstance(stride, int) or len(stride) == 2
Input_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype)
if len(kernel.shape) == 4:
Filter_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype)
kernel_h, kernel_w, _, num_filter, _ = get_const_tuple(Filter_q.shape)
else:
Filter_q = kernel
kernel_h, kernel_w, _, _, num_filter = get_const_tuple(Filter_q.shape)
batch, in_height, in_width, in_channel_q, _ = get_const_tuple(Input_q.shape)
if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
else:
TPAD, LPAD, DPAD, RPAD = padding
pad_before = [0, TPAD, LPAD, 0, 0]
pad_after = [0, DPAD, RPAD, 0, 0]
# compute the output shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
out_channel = num_filter
out_height = (in_height - kernel_h + TPAD + DPAD) // stride_h + 1
out_width = (in_width - kernel_w + LPAD + RPAD) // stride_w + 1
PadInput_q = pad(Input_q, pad_before, pad_after, name="PaddedInput")
rc = tvm.reduce_axis((0, in_channel_q), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
b1 = tvm.reduce_axis((0, activation_bits), name='b1')
b2 = tvm.reduce_axis((0, weight_bits), name='b2')
if unipolar:
def _conv(nn, yy, xx, ff):
b1b2 = (b1+b2).astype(out_dtype)
return tvm.sum(
((tvm.popcount(PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] &
Filter_q[ry, rx, rc, ff, b2]) -
tvm.popcount(PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] &
~Filter_q[ry, rx, rc, ff, b2]))
<< b1b2).astype(out_dtype),
axis=[rc, ry, rx, b2, b1])
else:
def _conv(nn, yy, xx, ff):
b1b2 = (b1+b2).astype(out_dtype)
return tvm.sum((tvm.popcount(
PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] &
Filter_q[ry, rx, rc, ff, b2]) << b1b2).astype(out_dtype),
axis=[rc, ry, rx, b2, b1])
conv = tvm.compute((batch, out_height, out_width, out_channel), _conv,
name="Conv2dOutput", tag="bitserial_conv2d_nhwc")
return conv
@autotvm.register_topi_compute(bitserial_conv2d_nchw, ['cpu', 'arm_cpu'], 'direct')
def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
pack_dtype='uint32', out_dtype='int16', unipolar=True):
""" Compute convolution with pack on spatial axes. """ """ Compute convolution with pack on spatial axes. """
assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
data_q = bitpack(data, in_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype) data_q = bitpack(data, in_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype)
# Check if kernel is already bitpacked
if len(kernel.shape) == 4:
kernel_q = bitpack(kernel, weight_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype) kernel_q = bitpack(kernel, weight_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype)
IB, _, CI, H, W = data_q.shape KB, CO, _, KH, KW = get_const_tuple(kernel_q.shape)
KB, CO, _, KH, KW = kernel_q.shape else:
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) kernel_vec = kernel
OCO, _, KH, KW, KB, VC = get_const_tuple(kernel_vec.shape)
CO = OCO * VC
IB, N, CI, H, W = get_const_tuple(data_q.shape)
KB, CO, _, KH, KW = get_const_tuple(kernel_q.shape)
if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
else:
TPAD, LPAD, DPAD, RPAD = padding
pad_before = [0, 0, 0, TPAD, LPAD]
pad_after = [0, 0, 0, DPAD, RPAD]
if isinstance(stride, (tuple, list)): if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride HSTR, WSTR = stride
...@@ -142,36 +225,48 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits, ...@@ -142,36 +225,48 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
HSTR, WSTR = stride, stride HSTR, WSTR = stride, stride
HCAT, WCAT = KH-1, KW-1 HCAT, WCAT = KH-1, KW-1
wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NCHW") TH = H + TPAD + DPAD
sch = _get_schedule(wkl, "NCHW") TW = W + LPAD + RPAD
VH = sch.vh OH = (H + TPAD + DPAD - KH) // HSTR + 1
VW = sch.vw OW = (W + LPAD + RPAD - KW) // WSTR + 1
VC = sch.vc
# ==================== define configuration space ====================
TH = H + 2*HPAD n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
TW = W + 2*WPAD ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
OH = (H + 2*HPAD - KH) // HSTR + 1 ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
OW = (W + 2*WPAD - KW) // WSTR + 1
co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
re_axes = cfg.define_reorder("reorder_0",
[n, co, oh, ow, vc, vh, vw, kh, kw, kb, ib, ci],
policy='interval_all', interval=(6, 11))
cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops
# ====================
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]
dshape = (IB, 1, CI, H, W)
dpshape = (IB, 1, CI, TH, TW)
dvshape = (1, TH//(VH*HSTR), TW//(VW*WSTR), CI, VH*HSTR+HCAT, VW*WSTR+WCAT, IB) dvshape = (1, TH//(VH*HSTR), TW//(VW*WSTR), CI, VH*HSTR+HCAT, VW*WSTR+WCAT, IB)
kshape = (KB, CO, CI, KH, KW)
kvshape = (CO//VC, CI, KH, KW, KB, VC) kvshape = (CO//VC, CI, KH, KW, KB, VC)
ovshape = (1, CO//VC, OH//VH, OW//VW, VH, VW, VC) ovshape = (1, CO//VC, OH//VH, OW//VW, VH, VW, VC)
oshape = (1, CO, OH, OW) oshape = (1, CO, OH, OW)
DOPAD = (HPAD != 0 and WPAD != 0) if (TPAD != 0 and RPAD != 0):
if DOPAD: data_pad = pad(data_q, (0, 0, 0, TPAD, LPAD), (0, 0, 0, DPAD, RPAD), name="data_pad")
data_pad = pad(data_q, (0, 0, 0, HPAD, WPAD), name="data_pad")
else: else:
data_pad = data_q data_pad = data_q
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw, b: \ data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw, b: \
data_pad[b][n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec') data_pad[b][n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec')
if len(kernel.shape) == 4:
kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, b, vc: \ kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, b, vc: \
kernel_q[b][co*VC+vc][ci][dh][dw], name='kernel_vec') kernel_q[b][co*VC+vc][ci][dh][dw], name='kernel_vec')
...@@ -183,7 +278,7 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits, ...@@ -183,7 +278,7 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
def _conv(n, co, h, w, vh, vw, vc): def _conv(n, co, h, w, vh, vw, vc):
b1b2 = (b1+b2).astype(out_dtype) b1b2 = (b1+b2).astype(out_dtype)
if dorefa: if unipolar:
return tvm.sum((tvm.popcount( return tvm.sum((tvm.popcount(
data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) & data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) &
kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype)) - kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype)) -
...@@ -203,15 +298,28 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits, ...@@ -203,15 +298,28 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC], conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
name='conv_vec', tag='spatial_bitserial_conv_nchw') name='conv_vec', tag='spatial_bitserial_conv_nchw')
def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits, @autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct')
pack_dtype, out_dtype, dorefa=False): def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
pack_dtype='uint32', out_dtype='int16', unipolar=True):
""" Compute convolution with pack on spatial axes. """ """ Compute convolution with pack on spatial axes. """
assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
data_q = bitpack(data, in_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype) data_q = bitpack(data, in_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype)
pack_kernel = len(kernel.shape) == 4
if pack_kernel:
kernel_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype) kernel_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype)
_, H, W, CI, IB = data_q.shape else:
KH, KW, _, CO, KB = kernel_q.shape kernel_q = kernel
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
KH, KW, _, CO, KB = get_const_tuple(kernel_q.shape)
N, H, W, CI, IB = get_const_tuple(data_q.shape)
if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
else:
TPAD, LPAD, DPAD, RPAD = padding
pad_before = [0, TPAD, LPAD, 0, 0]
pad_after = [0, DPAD, RPAD, 0, 0]
if isinstance(stride, (tuple, list)): if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride HSTR, WSTR = stride
...@@ -219,24 +327,41 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits, ...@@ -219,24 +327,41 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
HSTR, WSTR = stride, stride HSTR, WSTR = stride, stride
HCAT, WCAT = KH-1, KW-1 HCAT, WCAT = KH-1, KW-1
wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NHWC") PAD_H = H + (TPAD + DPAD)
sch = _get_schedule(wkl, "NHWC") PAD_W = W + (LPAD + RPAD)
VH = sch.vh OH = (PAD_H - KH) // HSTR + 1
VW = sch.vw OW = (PAD_W - KW) // WSTR + 1
VC = sch.vc oshape = (1, OH, OW, CO)
PAD_H = H + 2*HPAD # ==================== define configuration space ====================
PAD_W = W + 2*WPAD n, oh, ow, co = cfg.axis(N), cfg.axis(OH), cfg.axis(OW), cfg.axis(CO)
OH = (H + 2*HPAD - KH) // HSTR + 1 ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
OW = (W + 2*WPAD - KW) // WSTR + 1 ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
re_axes = cfg.define_reorder("reorder_0",
[n, oh, ow, co, vh, vw, kh, kw, kb, ib, vc, ci],
policy='interval_all', interval=(3, 7))
cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops
# ====================
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]
dvshape = (1, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, CI, IB) dvshape = (1, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, CI, IB)
kvshape = (CO, KH, KW, CI, VC, KB) kvshape = (CO, KH, KW, CI, VC, KB)
ovshape = (1, OH, OW, CO, VH, VW, VC) ovshape = (1, OH, OW, CO, VH, VW, VC)
oshape = (1, OH, OW, CO) oshape = (1, OH, OW, CO)
if (HPAD != 0 and WPAD != 0): if (DPAD != 0 and RPAD != 0):
data_pad = pad(data_q, (0, HPAD, WPAD, 0, 0), name="data_pad") data_pad = pad(data_q, (0, TPAD, LPAD, 0, 0), (0, DPAD, RPAD, 0, 0), name="data_pad")
else: else:
data_pad = data_q data_pad = data_q
...@@ -254,12 +379,12 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits, ...@@ -254,12 +379,12 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
def _conv(n, h, w, co, vh, vw, vc): def _conv(n, h, w, co, vh, vw, vc):
b1b2 = (b1+b2).astype(out_dtype) b1b2 = (b1+b2).astype(out_dtype)
if dorefa: if unipolar:
return tvm.sum( return tvm.sum(
(tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1].astype(out_dtype) & ((tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] &
kernel_vec[co, dh, dw, ci, vc, b2].astype(out_dtype)) - kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) -
tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1].astype(out_dtype) & tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1]&
~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2, ~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2),
axis=[dh, dw, ci, b1, b2]) axis=[dh, dw, ci, b1, b2])
return tvm.sum(tvm.popcount( return tvm.sum(tvm.popcount(
...@@ -273,6 +398,7 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits, ...@@ -273,6 +398,7 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC], conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC],
name='output_unpack', tag='spatial_bitserial_conv_nhwc') name='output_unpack', tag='spatial_bitserial_conv_nhwc')
def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"): def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
"""Packs data into format necessary for bitserial computation """Packs data into format necessary for bitserial computation
pack_axis : int pack_axis : int
...@@ -334,8 +460,3 @@ def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"): ...@@ -334,8 +460,3 @@ def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
if bits > 1: if bits > 1:
return concatenate(output_tuple, axis=bit_axis) return concatenate(output_tuple, axis=bit_axis)
return output_tuple return output_tuple
_SCH_TO_DECL_FUNC_QUANT = {
SpatialPackNCHW: spatial_pack_nchw,
SpatialPackNHWC: spatial_pack_nhwc,
}
# pylint: disable=invalid-name,unused-variable,invalid-name # pylint: disable=invalid-name,unused-variable,invalid-name
"""Bitserial conv2d schedule on x86""" """Bitserial conv2d schedule on x86"""
import tvm import tvm
from tvm import autotvm
from topi.util import get_const_int from topi.util import get_const_int
from .. import generic, tag from .. import generic, tag
from ..nn.bitserial_conv2d import bitserial_conv2d, _get_schedule, _get_workload
from ..nn.bitserial_conv2d import SpatialPackNCHW, SpatialPackNHWC @autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nchw, ['cpu'], 'direct')
from ..nn.bitserial_conv2d import _WORKLOADS, _SCH_TO_DECL_FUNC_QUANT @autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nhwc, ['cpu'], 'direct')
def schedule_bitserial_conv2d(cfg, outs):
_QUANTIZED_SCHEDULES_NCHW = [
# resnet
SpatialPackNCHW(2, 2, 8, 1, 1),
SpatialPackNCHW(1, 4, 8, 4, 1),
SpatialPackNCHW(1, 4, 8, 1, 16),
SpatialPackNCHW(1, 4, 8, 4, 8),
SpatialPackNCHW(1, 7, 8, 3, 8),
SpatialPackNCHW(1, 2, 8, 1, 8),
SpatialPackNCHW(2, 1, 8, 1, 4),
SpatialPackNCHW(1, 7, 8, 1, 1),
SpatialPackNCHW(1, 1, 8, 1, 16),
SpatialPackNCHW(1, 1, 8, 1, 8),
SpatialPackNCHW(1, 1, 8, 1, 16),
SpatialPackNCHW(3, 3, 16, 3, 16),
SpatialPackNCHW(1, 1, 16, 2, 16),
SpatialPackNCHW(1, 1, 8, 1, 16),
SpatialPackNCHW(1, 1, 8, 1, 16),
]
_QUANTIZED_SCHEDULES_NHWC = [
# resnet
SpatialPackNHWC(2, 2, 8, 1, 1),
SpatialPackNHWC(1, 4, 8, 4, 1),
SpatialPackNHWC(1, 4, 8, 1, 16),
SpatialPackNHWC(1, 4, 8, 4, 8),
SpatialPackNHWC(1, 7, 8, 3, 8),
SpatialPackNHWC(1, 2, 8, 1, 8),
SpatialPackNHWC(2, 1, 8, 1, 4),
SpatialPackNHWC(1, 7, 8, 1, 1),
SpatialPackNHWC(1, 1, 8, 1, 16),
SpatialPackNHWC(1, 1, 8, 1, 8),
SpatialPackNHWC(1, 1, 8, 1, 16),
]
@_get_schedule.register("cpu")
def _get_schedule_bitserial_conv2d(wkl, layout):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
idx = _WORKLOADS.index(wkl)
if layout == "NCHW":
sch = _QUANTIZED_SCHEDULES_NCHW[idx]
elif layout == "NHWC":
sch = _QUANTIZED_SCHEDULES_NHWC[idx]
return sch
@bitserial_conv2d.register("cpu")
def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits,
layout='NCHW', pack_dtype=None, out_dtype=None, dorefa=False):
if out_dtype is None:
out_dtype = data.dtype
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
assert layout in ("NCHW", "NHWC"), "only support layouts NCHW and NHWC"
wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout)
sch = _get_schedule(wkl, layout)
return _SCH_TO_DECL_FUNC_QUANT[type(sch)](data, kernel, stride, padding, activation_bits,
weight_bits, pack_dtype, out_dtype, dorefa)
@generic.schedule_bitserial_conv2d_nchw.register(["cpu"])
@generic.schedule_bitserial_conv2d_nhwc.register(["cpu"])
def schedule_bitserial_conv2d(outs):
"""CPU schedule for bitserial convolutions NCHW and NHWC""" """CPU schedule for bitserial convolutions NCHW and NHWC"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = [] scheduled_ops = []
...@@ -88,7 +27,6 @@ def schedule_bitserial_conv2d(outs): ...@@ -88,7 +27,6 @@ def schedule_bitserial_conv2d(outs):
conv_out = op.input_tensors[0] conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1] kernel_vec = conv_out.op.input_tensors[1]
kernel_q = kernel_vec.op.input_tensors[0] kernel_q = kernel_vec.op.input_tensors[0]
kernel = kernel_q.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[0] data_vec = conv_out.op.input_tensors[0]
data_q = data_vec.op.input_tensors[0] data_q = data_vec.op.input_tensors[0]
data = data_q.op.input_tensors[0] data = data_q.op.input_tensors[0]
...@@ -97,28 +35,26 @@ def schedule_bitserial_conv2d(outs): ...@@ -97,28 +35,26 @@ def schedule_bitserial_conv2d(outs):
data_pad = data_q data_pad = data_q
data_q = data data_q = data
data = data_q.op.input_tensors[0] data = data_q.op.input_tensors[0]
if "QuantizeInput" in kernel.op.name:
# Need to go up 1 further, from the combine in bitpack
kernel = kernel.op.input_tensors[0]
if "QuantizeInput" in data.op.name: if "QuantizeInput" in data.op.name:
# Need to go up 1 further, from the combine in bitpack # Need to go up 1 further, from the combine in bitpack
data = data.op.input_tensors[0] data = data.op.input_tensors[0]
if 'spatial_bitserial_conv_nchw' in op.tag: if 'spatial_bitserial_conv_nchw' in op.tag:
_schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec, _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec, kernel_q, kernel_vec,
conv_out, output, outs[0]) conv_out, output, outs[0])
elif 'spatial_bitserial_conv_nhwc' in op.tag: elif 'spatial_bitserial_conv_nhwc' in op.tag:
_schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, _schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec, kernel_q, kernel_vec,
conv_out, output, outs[0]) conv_out, output, outs[0])
scheduled_ops.append(op) scheduled_ops.append(op)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec, def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec, kernel_q, kernel_vec,
conv_out, output, last): conv_out, output, last):
IB, _, CI, IH, IW = data_q.shape IB, _, CI, IH, IW = data_q.shape
KB, CO, _, KH, KW = kernel_q.shape KB, CO, _, KH, KW = kernel_q.shape
...@@ -138,37 +74,21 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec, ...@@ -138,37 +74,21 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
wstride = get_const_int((TW - KW) // (OW - 1)) wstride = get_const_int((TW - KW) // (OW - 1))
stride = (hstride, wstride) stride = (hstride, wstride)
wkl = _get_workload(data, kernel, stride, padding, output.dtype, "NCHW") VC = cfg["tile_co"].size[-1]
sch = _get_schedule(wkl, "NCHW") VH = cfg["tile_oh"].size[-1]
VH = sch.vh VW = cfg["tile_ow"].size[-1]
VW = sch.vw
VC = sch.vc
ba = sch.ba
bc = sch.bc
CC = s.cache_write(conv_out, "global")
n, co, oh, ow, vh, vw, vc = s[conv_out].op.axis
s[conv_out].vectorize(vc)
s[CC].compute_at(s[conv_out], ow)
n, co, oh, ow, vh, vw, vc = s[CC].op.axis
ci, dh, dw, b1, b2 = s[CC].op.reduce_axis
s[CC].reorder(ci, dh, vh, dw, vw, b1, b2, vc)
s[CC].unroll(b1)
s[CC].unroll(b2)
s[CC].vectorize(vc)
##### Schedule A ##### Schedule Data padding, and bitpacking
if data_pad is not None: if data_pad is not None:
s[data_pad].compute_inline() s[data_pad].compute_inline()
_, h, _, _, _, _, vw = s[data_vec].op.axis _, _, h, _, _, _, _ = s[data_vec].op.axis
s[data_vec].vectorize(vw) cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)
if ba == 1: oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
oaxis = h if cfg["tile_ah"].size[1] == 1:
paxis = h oaxis = oh
paxis = oh
else: else:
oh, ih = s[data_vec].split(h, ba)
oaxis = oh oaxis = oh
paxis = ih paxis = ih
...@@ -178,14 +98,14 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec, ...@@ -178,14 +98,14 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s[data_vec].pragma(oaxis, "parallel_barrier_when_finish") s[data_vec].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule B ##### Schedule Kenerl bitpacking
co, _, _, _, _, vc = s[kernel_vec].op.axis co, _, _, _, _, _ = s[kernel_vec].op.axis
s[kernel_vec].vectorize(vc) cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)
if bc == 1: oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
oaxis = co if cfg["tile_bco"].size[1] == 1:
paxis = co oaxis = oco
paxis = oco
else: else:
oco, ico = s[kernel_vec].split(co, bc)
oaxis = oco oaxis = oco
paxis = ico paxis = ico
...@@ -195,7 +115,23 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec, ...@@ -195,7 +115,23 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish") s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule C ##### Schedule Convolution
n, co, oh, ow, vh, vw, vc = s[conv_out].op.axis
ci, dh, dw, ib, kb = s[conv_out].op.reduce_axis
# s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
cfg["reorder_0"].apply(s, conv_out, [n, co, oh, ow, vc, vh, vw, dh, dw, kb, ib, ci])
cfg["ann_reduce"].apply(s, conv_out, [kb, ib, dh, dw],
axis_lens=[get_const_int(kb.dom.extent),
get_const_int(ib.dom.extent),
get_const_int(dh.dom.extent),
get_const_int(dw.dom.extent)],
max_unroll=16,
cfg=cfg)
s[conv_out].vectorize(vc)
# # Schedule output
n, co, h, w = s[last].op.axis n, co, h, w = s[last].op.axis
co, vc = s[last].split(co, VC) co, vc = s[last].split(co, VC)
oh, ow, vh, vw = s[last].tile(h, w, VH, VW) oh, ow, vh, vw = s[last].tile(h, w, VH, VW)
...@@ -204,89 +140,58 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec, ...@@ -204,89 +140,58 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s[output].compute_inline() s[output].compute_inline()
s[conv_out].compute_at(s[last], ow) s[conv_out].compute_at(s[last], ow)
if bc == 1: oco, ico = cfg["tile_oh"].apply(s, last, co)
oaxis = co if cfg["tile_oh"].size[1] == 1:
paxis = co oaxis = oco
paxis = oco
else: else:
oco, ico = s[last].split(co, bc) oco, ico = s[last].split(co, bc)
oaxis = oco oaxis = oco
paxis = ico paxis = ico
s[last].parallel(paxis) s[last].parallel(oco)
s[last].pragma(oaxis, "parallel_launch_point")
s[last].pragma(paxis, "parallel_stride_pattern")
s[last].pragma(oaxis, "parallel_barrier_when_finish")
return s return s
def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, def _schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec, kernel_q, kernel_vec,
conv_out, output, last): conv_out, output, last):
# no stride and padding info here # no stride and padding info here
_, IH, IW, CI, IB = data_q.shape _, IH, IW, CI, IB = data_q.shape
KH, KW, _, CO, KB = kernel_q.shape KH, KW, _, CO, KB = kernel_q.shape
_, OH, OW, _ = output.shape _, OH, OW, _ = output.shape
# Infer padding and stride
if data_pad is None:
padding = (0, 0)
TH, TW = IH, IW
else:
_, TH, TW, _, _ = data_pad.shape
hpad = get_const_int((TH - IH) // 2)
wpad = get_const_int((TW - IW) // 2)
padding = (hpad, wpad)
hstride = get_const_int((TH - KH) // (OH - 1)) VC = cfg["tile_co"].size[-1]
wstride = get_const_int((TW - KW) // (OW - 1)) VH = cfg["tile_oh"].size[-1]
stride = (hstride, wstride) VW = cfg["tile_ow"].size[-1]
wkl = _get_workload(data, kernel, stride, padding, last.dtype, "NHWC") ##### Schedule data padding and packing
sch = _get_schedule(wkl, "NHWC")
VH = sch.vh
VW = sch.vw
VC = sch.vc
ba = sch.ba
bc = sch.bc
##### Schedule data packing
if data_pad is not None: if data_pad is not None:
s[data_pad].compute_inline() s[data_pad].compute_inline()
_, h, _, _, _, _, _ = s[data_vec].op.axis _, h, _, _, _, _, _ = s[data_vec].op.axis
if ba == 1: cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)
oaxis = h oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
paxis = h s[data_vec].parallel(oh)
else:
oh, ih = s[data_vec].split(h, ba)
oaxis = oh
paxis = ih
s[data_vec].parallel(paxis)
s[data_vec].pragma(oaxis, "parallel_launch_point")
s[data_vec].pragma(paxis, "parallel_stride_pattern")
s[data_vec].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule kernel packing ##### Schedule kernel packing
co, _, _, _, _, _ = s[kernel_vec].op.axis co, _, _, _, _, _ = s[kernel_vec].op.axis
if bc == 1: cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)
oaxis = co oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
paxis = co s[kernel_vec].parallel(oco)
else:
oco, ico = s[kernel_vec].split(co, bc)
oaxis = oco
paxis = ico
s[kernel_vec].parallel(paxis)
s[kernel_vec].pragma(oaxis, "parallel_launch_point")
s[kernel_vec].pragma(paxis, "parallel_stride_pattern")
s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule Convolution ##### Schedule Convolution
n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis
dh, dw, ci, b1, b2 = s[conv_out].op.reduce_axis dh, dw, ci, b1, b2 = s[conv_out].op.reduce_axis
s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2) # s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
cfg["reorder_0"].apply(s, conv_out, [n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2])
cfg["ann_reduce"].apply(s, conv_out, [b1, b2, dh, dw],
axis_lens=[get_const_int(b1.dom.extent),
get_const_int(b2.dom.extent),
get_const_int(dh.dom.extent),
get_const_int(dw.dom.extent)],
max_unroll=16,
cfg=cfg)
s[conv_out].unroll(b1) s[conv_out].unroll(b1)
s[conv_out].unroll(b2) s[conv_out].unroll(b2)
...@@ -302,17 +207,7 @@ def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, ...@@ -302,17 +207,7 @@ def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
s[output].compute_inline() s[output].compute_inline()
s[conv_out].compute_at(s[last], ow) s[conv_out].compute_at(s[last], ow)
if bc == 1: oho, iho = cfg["tile_oh"].apply(s, last, oh) # reuse parameter
oaxis = oh s[last].parallel(oho)
paxis = oh
else:
oho, iho = s[last].split(oh, bc)
oaxis = oho
paxis = iho
s[last].parallel(paxis)
s[last].pragma(oaxis, "parallel_launch_point")
s[last].pragma(paxis, "parallel_stride_pattern")
s[last].pragma(oaxis, "parallel_barrier_when_finish")
return s return s
...@@ -11,16 +11,16 @@ def generate_quantized_np(shape, bits, out_dtype): ...@@ -11,16 +11,16 @@ def generate_quantized_np(shape, bits, out_dtype):
return np.random.randint(min_val, max_val, size=shape).astype(out_dtype) return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding, def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa): activation_bits, weight_bits, unipolar):
in_height = in_width = in_size in_height = in_width = in_size
input_type = 'uint32' input_dtype = 'uint32'
out_dtype = 'int32' out_dtype = 'int32'
with tvm.target.create('llvm'): with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_type, name='A') A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_dtype, name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_type, name='W') W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_dtype, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, B = topi.nn.bitserial_conv2d_nchw(A, W, stride, padding, activation_bits, weight_bits,
out_dtype=out_dtype, layout="NCHW", dorefa=dorefa) out_dtype=out_dtype, unipolar=unipolar)
s = topi.generic.schedule_bitserial_conv2d_nchw([B]) s = topi.generic.schedule_bitserial_conv2d_nchw([B])
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
...@@ -28,9 +28,9 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, ...@@ -28,9 +28,9 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
@memoize("topi.tests.test_topi_bitseral_conv2d_nchw") @memoize("topi.tests.test_topi_bitseral_conv2d_nchw")
def get_ref_data(): def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type) a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype)
w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type) w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_dtype)
if dorefa: if unipolar:
w_ = np.copy(w_np).astype(out_dtype) w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']): for x in np.nditer(w_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1 x[...] = 1 if x == 1 else -1
...@@ -49,16 +49,16 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, ...@@ -49,16 +49,16 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding, def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa): activation_bits, weight_bits, unipolar):
in_height = in_width = in_size in_height = in_width = in_size
input_type='uint32' input_dtype='uint32'
out_dtype='int32' out_dtype='int32'
with tvm.target.create('llvm'): with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A') A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_dtype, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W') W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_dtype, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype, B = topi.nn.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
layout="NHWC", dorefa=dorefa) out_dtype=out_dtype, unipolar=unipolar)
s = topi.generic.schedule_bitserial_conv2d_nhwc([B]) s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
...@@ -66,9 +66,9 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, ...@@ -66,9 +66,9 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
@memoize("topi.tests.test_topi_bitseral_conv2d_nhwc") @memoize("topi.tests.test_topi_bitseral_conv2d_nhwc")
def get_ref_data(): def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type) a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype)
w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type) w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_dtype)
if dorefa: if unipolar:
w_ = np.copy(w_np).astype(out_dtype) w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']): for x in np.nditer(w_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1 x[...] = 1 if x == 1 else -1
......
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
import tvm import tvm
import topi import topi
import topi.testing import topi.testing
from topi.util import get_const_tuple
def generate_quantized_np(shape, bits, out_dtype): def generate_quantized_np(shape, bits, out_dtype):
np.random.seed(0) np.random.seed(0)
...@@ -13,19 +14,20 @@ def generate_quantized_np(shape, bits, out_dtype): ...@@ -13,19 +14,20 @@ def generate_quantized_np(shape, bits, out_dtype):
# Verify that certain special instructions from the tensorize pass exist # Verify that certain special instructions from the tensorize pass exist
def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding, def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa): activation_bits, weight_bits, unipolar):
in_height = in_width = in_size in_height = in_width = in_size
input_type = 'uint32' input_type = 'uint32'
out_dtype = 'int32' out_dtype = 'int16'
with tvm.target.arm_cpu('rasp3b'): device = 'llvm -device=arm_cpu -model=bcm2837 -target=armv7l-linux-gnueabihf -mattr=+neon'
with tvm.target.create(device):
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A') A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W') W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype, B = topi.nn.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
layout="NHWC", dorefa=dorefa) pack_dtype='uint8', out_dtype='int16', unipolar=unipolar)
s = topi.generic.schedule_bitserial_conv2d_nhwc([B]) s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
func = tvm.build(s, [A, W, B], tvm.target.arm_cpu('rasp3b')) func = tvm.build(s, [A, W, B], device)
assembly = func.get_source('asm') assembly = func.get_source('asm')
matches = re.findall("vpadal", assembly) matches = re.findall("vpadal", assembly)
...@@ -35,6 +37,33 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, ...@@ -35,6 +37,33 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
matches = re.findall("vpadd", assembly) matches = re.findall("vpadd", assembly)
assert (len(matches) > 0) assert (len(matches) > 0)
ctx = tvm.context(device, 0)
if 'arm' not in os.uname()[4]:
print ("Skipped running code, not an arm device")
return
print("Running on target: %s" % device)
def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type)
w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type)
if unipolar:
w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1
b_np = topi.testing.conv2d_nhwc_python(a_np, w_, stride, padding).astype(out_dtype)
else:
b_np = topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding).astype(out_dtype)
return a_np, w_np, b_np
a_np, w_np, b_np = get_ref_data()
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, W, B], device)
func(a, w, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def test_bitserial_conv2d(): def test_bitserial_conv2d():
in_size = 56 in_size = 56
ic, oc = 64, 64 ic, oc = 64, 64
...@@ -45,6 +74,9 @@ def test_bitserial_conv2d(): ...@@ -45,6 +74,9 @@ def test_bitserial_conv2d():
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False) verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False) verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, True)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, True)
if __name__ == "__main__": if __name__ == "__main__":
test_bitserial_conv2d() test_bitserial_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment