Commit 17351875 by Meghan Cowan Committed by Tianqi Chen

[TOPI] bitserial_conv2d move to autotvm template and updates (#2819)

parent cefe07e2
......@@ -205,7 +205,7 @@ def args_to_workload(x, topi_compute_func=None):
workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float)):
workload = x
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
workload = x.value
elif x is None:
workload = 0
......
......@@ -68,6 +68,8 @@ class TaskExtractEnv:
topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
topi.nn.dense: "topi_nn_dense",
topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
}
......@@ -79,6 +81,8 @@ class TaskExtractEnv:
topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
topi.nn.dense: [topi.generic.schedule_dense],
topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
}
......@@ -174,6 +178,24 @@ class TaskExtractEnv:
return s, [data, weight, bias, C]
return s, [data, weight, C]
@register("topi_nn_bitserial_conv2d_nhwc")
def _topi_bitserial_conv2d_nhwc(*args, **kwargs):
args = deserialize_args(args)
C = topi.nn.bitserial_conv2d_nhwc(*args, **kwargs)
s = topi.generic.nn.schedule_bitserial_conv2d_nhwc([C])
data = args[0]
kernel = args[1]
return s, [data, kernel, C]
@register("topi_nn_bitserial_conv2d_nchw")
def _topi_bitserial_conv2d_nchw(*args, **kwargs):
args = deserialize_args(args)
C = topi.nn.bitserial_conv2d_nchw(*args, **kwargs)
s = topi.generic.nn.schedule_bitserial_conv2d_nchw([C])
data = args[0]
kernel = args[1]
return s, [data, kernel, C]
@register("topi_nn_deformable_conv2d_nchw")
def _topi_nn_deformable_conv2d_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
......
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Bitserial conv2d schedule on raspberry pi"""
"""Bitserial conv2d schedule on arm cpu"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from tvm import autotvm
from .. import tag
from ..nn.pad import pad
from ..nn.bitserial_conv2d import bitserial_conv2d, _get_schedule, _get_workload, bitpack
from ..nn.bitserial_conv2d import SpatialPackNCHW, _WORKLOADS, spatial_pack_nchw
from ..nn.bitserial_conv2d import bitpack, bitserial_conv2d_nhwc
from ..nn.util import get_pad_tuple
from ..util import get_const_int
from ..util import get_const_int, get_const_tuple
from .. import generic
RaspSpatialPack = namedtuple('SpatialPack',
['vh', 'vw', 'vc', 'ba', 'bc', 'split_ci', 'kfactor'])
_QUANTIZED_SCHEDULES_NHWC = [
RaspSpatialPack(2, 2, 8, 1, 1, False, 8),
RaspSpatialPack(1, 4, 8, 4, 1, False, 8),
RaspSpatialPack(1, 4, 8, 1, 16, False, 8),
RaspSpatialPack(1, 4, 8, 4, 8, False, 8),
RaspSpatialPack(1, 7, 8, 3, 8, False, 16),
RaspSpatialPack(1, 2, 8, 1, 8, False, 16),
RaspSpatialPack(2, 1, 8, 1, 4, False, 16),
RaspSpatialPack(1, 7, 8, 1, 1, True, 16),
RaspSpatialPack(1, 1, 8, 1, 16, True, 16),
RaspSpatialPack(1, 1, 8, 1, 8, True, 16),
RaspSpatialPack(1, 1, 8, 1, 16, True, 16),
]
_QUANTIZED_SCHEDULES_NCHW = [
# resnet
SpatialPackNCHW(2, 2, 8, 1, 1),
SpatialPackNCHW(1, 4, 8, 4, 1),
SpatialPackNCHW(1, 4, 8, 1, 16),
SpatialPackNCHW(1, 4, 8, 4, 8),
SpatialPackNCHW(1, 7, 8, 3, 8),
SpatialPackNCHW(1, 2, 8, 1, 8),
SpatialPackNCHW(2, 1, 8, 1, 4),
SpatialPackNCHW(1, 7, 8, 1, 1),
SpatialPackNCHW(1, 1, 8, 1, 16),
SpatialPackNCHW(1, 1, 8, 1, 8),
SpatialPackNCHW(1, 1, 8, 1, 16),
]
@_get_schedule.register("arm_cpu")
def _get_schedule_bitserial_conv2d(wkl, layout):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
idx = _WORKLOADS.index(wkl)
if layout == "NCHW":
sch = _QUANTIZED_SCHEDULES_NCHW[idx]
elif layout == "NHWC":
sch = _QUANTIZED_SCHEDULES_NHWC[idx]
return sch
@bitserial_conv2d.register("arm_cpu")
def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits,
layout='NCHW', pack_dtype=None, out_dtype=None, dorefa=False):
if out_dtype is None:
out_dtype = data.dtype
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
assert layout in ("NCHW", "NHWC"), "only support layouts NCHW and NHWC"
if dorefa:
assert layout == "NCHW", "Cannot support dorea with NHWC layout yet"
wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout)
sch = _get_schedule(wkl, layout)
if layout == "NCHW":
return spatial_pack_nchw(data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa)
return _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits,
weight_bits, out_dtype)
def _kernel_vec_spatial_pack_nhwc(kernel, kernel_bits, VC):
kernel_q = bitpack(kernel, kernel_bits, pack_axis=2, bit_axis=2, pack_type='uint8')
def _kernel_vec_spatial_pack_nhwc(kernel, kernel_bits, VC, use_bitpack=True):
if use_bitpack:
kernel_q = bitpack(kernel, kernel_bits, pack_axis=2, bit_axis=2, pack_type='uint8')
else:
kernel_q = kernel
KH, KW, KB, CI, CO = kernel_q.shape
kvshape = (CO//VC, KH, KW, KB, VC, CI)
return tvm.compute(kvshape, lambda co, dh, dw, b, vc, ci: \
kernel_q[dh][dw][b][ci][co*VC+vc], name='kernel_vec')
def _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bits, out_dtype):
@autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'arm_cpu', 'direct')
def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype, out_dtype, unipolar):
""" Compute convolution with pack on spatial axes. """
assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NHWC")
sch = _get_schedule(wkl, "NHWC")
VH = sch.vh
VW = sch.vw
VC = sch.vc
assert pack_dtype == 'uint8', "only support packing into uint8 bits"
assert out_dtype == 'int16', "only support output type of int16"
data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8')
kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC)
N, H, W, IB, CI = data_q.shape
OCO, KH, KW, KB, VC, _ = kernel_vec.shape
N, H, W, CI = get_const_tuple(data.shape)
if len(kernel.shape) == 4:
KH, KW, _, CO = get_const_tuple(kernel.shape)
CI_packed = CI // 8
else:
KH, KW, KB, CI_packed, CO = get_const_tuple(kernel.shape)
CO = OCO * VC
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
else:
TPAD, LPAD, DPAD, RPAD = padding
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
......@@ -102,75 +46,151 @@ def _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bi
HSTR, WSTR = stride, stride
HCAT, WCAT = KH-1, KW-1
PAD_H = H + 2*HPAD
PAD_W = W + 2*WPAD
OH = (H + 2*HPAD - KH) // HSTR + 1
OW = (W + 2*WPAD - KW) // WSTR + 1
PAD_H = H + (TPAD + DPAD)
PAD_W = W + (LPAD + RPAD)
OH = (PAD_H - KH) // HSTR + 1
OW = (PAD_W - KW) // WSTR + 1
oshape = (1, OH, OW, CO)
# Pad input channels of weights and data when it is not a multiple of 8
if CI_packed % 8 != 0:
CI_PAD = CI_packed % 8
CI_packed += CI_PAD
else:
CI_PAD = 0
# ==================== define configuration space ====================
n, oh, ow, co = cfg.axis(N), cfg.axis(OH), cfg.axis(OW), cfg.axis(CO)
ci, kh, kw = cfg.reduce_axis(CI_packed), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
ib, kb = cfg.reduce_axis(activation_bits), cfg.reduce_axis(weight_bits)
co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2,
filter=lambda x: x.size[-1] == 8)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2,
filter=lambda x: x.size[-1] >= 2)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2,
filter=lambda x: x.size[-1] >= 2)
ci_o, ci_i = cfg.define_split("tile_ci", ci, num_outputs=2,
filter=lambda x: x.size[-1] == 8 or x.size[-1] == 16)
re_axes = cfg.define_reorder("reorder_0",
[n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i],
policy='candidate', candidate=[
[n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i],
[n, oh, ow, co, vh, vw, kw, kh, ci_o, kb, ib, vc, ci_i],])
cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops
# ====================
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]
data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8')
kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC, len(kernel.shape) == 4)
if kernel_vec.shape[-1] % 8 != 0 and CI_PAD != 0:
kernel_vec = pad(kernel_vec, [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, CI_PAD])
N, H, W, IB, CI = data_q.shape
OCO, KH, KW, KB, VC, CI = kernel_vec.shape
dvshape = (N, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, IB, CI)
ovshape = (1, OH // VH, OW // VW, CO // VC, VH, VW, VC)
oshape = (1, OH, OW, CO)
if (HPAD != 0 and WPAD != 0):
data_pad = pad(data_q, (0, HPAD, WPAD, 0, 0), name="data_pad")
if (TPAD != 0 and RPAD != 0):
data_pad = pad(data_q, (0, TPAD, LPAD, 0, 0), (0, DPAD, RPAD, 0, CI_PAD), name="data_pad")
elif CI_PAD != 0:
data_pad = pad(data_q, (0, 0, 0, 0, 0), (0, 0, 0, 0, CI_PAD), name="data_pad")
else:
data_pad = data_q
data_vec = tvm.compute(dvshape, lambda n, h, w, vh, vw, b, ci: \
data_pad[n][h*VH*HSTR+vh][w*VW*WSTR+vw][b][ci], name='data_vec')
ci = tvm.reduce_axis((0, CI), name='ci')
dh = tvm.reduce_axis((0, KH), name='dh')
dw = tvm.reduce_axis((0, KW), name='dw')
ib = tvm.reduce_axis((0, IB), name='ib')
kb = tvm.reduce_axis((0, KB), name='kb')
def _conv(n, h, w, co, vh, vw, vc):
def _bipolar_conv(n, h, w, co, vh, vw, vc):
return tvm.sum((tvm.popcount(
kernel_vec[co, dh, dw, kb, vc, ci].astype('uint16') &
data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('uint16'))
<< (kb + ib).astype('uint16')), axis=[dh, dw, kb, ib, ci])
def _unipolar_conv(n, h, w, co, vh, vw, vc):
return tvm.sum(
((tvm.popcount(kernel_vec[co, dh, dw, kb, vc, ci].astype('int16') &
data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('int16')) -
tvm.popcount(~kernel_vec[co, dh, dw, kb, vc, ci].astype('int16') &
data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci]).astype('int16'))
<< (kb + ib).astype('int16')), axis=[dh, dw, kb, ib, ci])
if unipolar:
conv_vec = tvm.compute(ovshape, _unipolar_conv, name='conv_vec', tag='unipolar')
else:
conv_vec = tvm.compute(ovshape, _bipolar_conv, name='conv_vec', tag='bipolar')
conv = tvm.compute(ovshape, _conv, name='conv')
conv = tvm.compute(oshape, lambda n, h, w, co:
conv_vec[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC].astype(out_dtype),
name='conv', tag='spatial_bitserial_conv_nhwc')
return tvm.compute(oshape, lambda n, h, w, co:
conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC].astype(out_dtype),
name='output_vec', tag='spatial_bitserial_conv_nhwc')
return conv
def _intrin_popcount(m, k_i, w_b, x_b):
dtype = 'uint8'
w = tvm.placeholder((w_b, m, k_i), dtype=dtype, name='w')
x = tvm.placeholder((x_b, k_i,), dtype=dtype, name='x')
def _intrin_popcount(m, k_i, w_b, x_b, unipolar):
pack_dtype = 'uint8'
w = tvm.placeholder((w_b, m, k_i), dtype=pack_dtype, name='w')
x = tvm.placeholder((x_b, k_i,), dtype=pack_dtype, name='x')
k = tvm.reduce_axis((0, k_i), name='k')
bw = tvm.reduce_axis((0, w_b), name='bw')
bx = tvm.reduce_axis((0, x_b), name='bx')
z = tvm.compute((m,), lambda i:
tvm.sum(tvm.popcount(w[bw, i, k].astype('uint16') &
x[bx, k].astype('uint16'))
<< (bw+bx).astype('uint16'), axis=[bw, bx, k]), name='z')
if unipolar:
dtype = 'int16'
z = tvm.compute((m,), lambda i:
tvm.sum((tvm.popcount(w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype)) -
tvm.popcount(~w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype)))
<< (bw+bx).astype(dtype), axis=[bw, bx, k]), name='z')
else:
dtype = 'uint16'
z = tvm.compute((m,), lambda i:
tvm.sum(tvm.popcount(w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype))
<< (bw+bx).astype(dtype), axis=[bw, bx, k]), name='z')
Wb = tvm.decl_buffer(w.shape, w.dtype,
name="W",
offset_factor=k_i,
strides=[tvm.var('ldw'), tvm.var('ldw'), 1])
strides=[tvm.var('ldw'), tvm.var('ldw'), 1]) # stride can be inferred
Xb = tvm.decl_buffer(x.shape, x.dtype,
name="X",
offset_factor=k_i,
strides=[tvm.var('ldw'), 1])
Zb = tvm.decl_buffer(z.shape, z.dtype,
name="Z",
offset_factor=1,
strides=[1])
def _intrin_func(ins, outs):
ww, xx = ins
zz = outs[0]
vpadd = "llvm.arm.neon.vpadd.v8u8"
vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16"
args_1 = tvm.const(1, 'uint32')
args_2 = tvm.const(2, 'uint32')
if unipolar:
vpadd = "llvm.arm.neon.vpadd.v8i8"
vpadalu = "llvm.arm.neon.vpadals.v16i8.v8i16"
full_dtype = 'int8x16'
half_dtype = 'int8x8'
return_dtype = 'int16x8'
else:
vpadd = "llvm.arm.neon.vpadd.v8u8"
vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16"
full_dtype = 'uint8x16'
half_dtype = 'uint8x8'
return_dtype = 'uint16x8'
def _instr(index):
irb = tvm.ir_builder.create()
if index == 1:
irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8')))
if index == 1: # reduce reset
irb.emit(zz.vstore(0, tvm.const(0, return_dtype)))
return irb.get()
# body and reduce update
cnts8 = [None] * 8
cnts4 = [None] * 4
cnts2 = [None] * 2
......@@ -178,154 +198,108 @@ def _intrin_popcount(m, k_i, w_b, x_b):
for bx in range(x_b):
if k_i == 16:
for i in range(m):
ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload([bx, 0], 'uint8x16')
cnts = tvm.popcount(ands)
upper_half = tvm.call_pure_intrin('uint8x8', 'vectorhigh', cnts)
lower_half = tvm.call_pure_intrin('uint8x8', 'vectorlow', cnts)
w_ = ww.vload([bw, i, 0], 'uint8x16').astype(full_dtype)
x_ = xx.vload([bx, 0], 'uint8x16').astype(full_dtype)
if unipolar:
cnts = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_)
else:
cnts = tvm.popcount(w_ & x_)
upper_half = tvm.call_pure_intrin(half_dtype, 'vectorhigh', cnts)
lower_half = tvm.call_pure_intrin(half_dtype, 'vectorlow', cnts)
cnts8[i] = upper_half + lower_half
for i in range(m//2):
cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
args_1, cnts8[i*2], cnts8[i*2+1])
for i in range(m//4):
cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
args_1, cnts4[i*2], cnts4[i*2+1])
cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.const(bw+bx, dtype)
out = tvm.call_llvm_intrin('uint16x8', vpadalu,
args_2, zz.vload(0, 'uint16x8'), shifted_cnts)
cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype)
out = tvm.call_llvm_intrin(return_dtype, vpadalu,
args_2, zz.vload(0, return_dtype), shifted_cnts)
else: # ki == 8
for i in range(m):
ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload([bx, 0], 'uint8x8')
cnts8[i] = tvm.popcount(ands)
w_ = ww.vload([bw, i, 0], 'uint8x8').astype(half_dtype)
x_ = xx.vload([bx, 0], 'uint8x8').astype(half_dtype)
if unipolar:
cnts8[i] = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_)
else:
cnts8[i] = tvm.popcount(w_ & x_)
for i in range(m//2):
cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
args_1, cnts8[i*2], cnts8[i*2+1])
for i in range(m//4):
cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd,
cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd,
args_1, cnts4[i*2], cnts4[i*2+1])
cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.const(bw+bx, dtype)
out = tvm.call_llvm_intrin('uint16x8', vpadalu,
args_2, zz.vload(0, 'uint16x8'), shifted_cnts)
cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype)
out = tvm.call_llvm_intrin(return_dtype, vpadalu,
args_2, zz.vload(0, return_dtype), shifted_cnts)
irb.emit(zz.vstore(0, out))
return irb.get()
# body, reset, update
return _instr(0), _instr(1), _instr(2)
with tvm.build_config(offset_factor=1, partition_const_loop=True):
return tvm.decl_tensor_intrin(z.op, _intrin_func, binds={w: Wb, x:Xb})
return tvm.decl_tensor_intrin(z.op, _intrin_func, binds={w: Wb, x:Xb, z:Zb})
# ARM specific schedule that using custom microkernel
def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, last):
# no stride and padding info here
_, H, W, IB, CI = data_q.shape
KH, KW, KB, _, CO = kernel_q.shape
def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
conv_out, output, last, unipolar):
_, _, _, _, _, IB, CI = data_vec.shape
_, KH, KW, KB, _, _ = kernel_vec.shape
KB = get_const_int(KB)
IB = get_const_int(IB)
if data_pad is None:
padding = (0, 0)
_, in_h, in_w, _, _ = data_q.shape
kern_h, kern_w, _, _ = kernel.shape
_, out_h, out_w, _ = output.shape
hstride = (in_h - kern_h) // (out_h - 1)
wstride = (in_w - kern_w) // (out_w - 1)
stride = get_const_int(hstride), get_const_int(wstride)
else:
_, in_h, in_w, _, _ = data_q.shape
_, pad_h, pad_w, _, _ = data_pad.shape
hpad = (pad_h - in_h) // 2
wpad = (pad_w - in_w) // 2
padding = get_const_int(hpad), get_const_int(wpad)
_, in_h, in_w, _, _ = data_pad.shape
kern_h, kern_w, _, _ = kernel.shape
_, out_h, out_w, _ = output.shape
hstride = (in_h - kern_h) // (out_h - 1)
wstride = (in_w - kern_w) // (out_w - 1)
stride = get_const_int(hstride), get_const_int(wstride)
wkl = _get_workload(data, kernel, stride, padding, output.dtype, "NHWC")
sch = _get_schedule(wkl, "NHWC")
VH = sch.vh
VW = sch.vw
VC = sch.vc
ba = sch.ba
bc = sch.bc
##### Schedule data packing
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]
##### Schedule data padding and packing
if data_pad is not None:
s[data_pad].compute_inline()
_, h, _, _, _, _, _ = s[data_vec].op.axis
if ba == 1:
oaxis = h
paxis = h
else:
oh, ih = s[data_vec].split(h, ba)
oaxis = oh
paxis = ih
s[data_vec].parallel(paxis)
s[data_vec].pragma(oaxis, "parallel_launch_point")
s[data_vec].pragma(paxis, "parallel_stride_pattern")
s[data_vec].pragma(oaxis, "parallel_barrier_when_finish")
cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)
oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
s[data_vec].parallel(oh)
##### Schedule kernel packing
#### Schedule kernel packing
co, _, _, _, _, _ = s[kernel_vec].op.axis
if bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[kernel_vec].split(co, bc)
oaxis = oco
paxis = ico
s[kernel_vec].parallel(paxis)
s[kernel_vec].pragma(oaxis, "parallel_launch_point")
s[kernel_vec].pragma(paxis, "parallel_stride_pattern")
s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish")
cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)
oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
s[kernel_vec].parallel(oco)
##### Schedule Convolution
n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis
dh, dw, kb, ib, ci = s[conv_out].op.reduce_axis
kh, kw, kb, ib, ci = s[conv_out].op.reduce_axis
kfactor = sch.kfactor
if sch.split_ci:
oci, ici = s[conv_out].split(ci, kfactor)
s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, oci, kb, ib, vc, ici)
else:
s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, kb, ib, vc, ci)
ci_o, ci_i = cfg['tile_ci'].apply(s, conv_out, ci)
re_axes = cfg["reorder_0"].apply(s, conv_out,
[n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i])
pc = _intrin_popcount(8, kfactor, KB, IB)
s[conv_out].tensorize(kb, pc)
# Use microkernel
kfactor = cfg['tile_ci'].size[1]
if kfactor % 8 == 0:
pc = _intrin_popcount(VC, kfactor, KB, IB, unipolar)
s[conv_out].tensorize(kb, pc)
n, h, w, co = s[last].op.axis
co, vc = s[last].split(co, VC)
oh, ow, vh, vw = s[last].tile(h, w, VH, VW)
s[last].reorder(n, oh, ow, co, vc, vh, vw)
s[last].vectorize(vw)
co, vc = cfg['tile_co'].apply(s, last, co)
oh, vh = cfg['tile_oh'].apply(s, last, h)
ow, vw = cfg['tile_ow'].apply(s, last, w)
s[last].reorder(n, oh, ow, co, vh, vw, vc)
s[last].vectorize(vc)
if last != output:
s[last].compute_inline()
s[conv_out].compute_at(s[last], ow)
if co == 1:
oaxis = oh
paxis = oh
else:
oho, iho = s[last].split(oh, bc)
oaxis = oho
paxis = iho
s[last].parallel(paxis)
s[conv_out].compute_at(s[last], co)
s[last].parallel(oh)
s = s.normalize()
return s
@generic.schedule_bitserial_conv2d_nhwc.register(["arm_cpu"])
def schedule_bitserial_conv2d_nhwc(outs):
"""Raspverry pi schedule for bitserial conv2d"""
@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nhwc, 'arm_cpu', 'direct')
def schedule_bitserial_conv2d_nhwc(cfg, outs):
"""Arm cpu schedule for bitserial conv2d"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
......@@ -344,10 +318,6 @@ def schedule_bitserial_conv2d_nhwc(outs):
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[0]
kernel_q = kernel_vec.op.input_tensors[0]
kernel = kernel_q.op.input_tensors[0]
if "QuantizeInput" in kernel.op.name:
# Need to go up 1 further, from the combine in bitpack
kernel = kernel.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[1]
data_q = data_vec.op.input_tensors[0]
data = data_q.op.input_tensors[0]
......@@ -355,13 +325,10 @@ def schedule_bitserial_conv2d_nhwc(outs):
if isinstance(data_q.op, tvm.tensor.ComputeOp) and "pad" in data_q.op.tag:
data_pad = data_q
data_q = data
data = data_q.op.input_tensors[0]
if "QuantizeInput" in data.op.name:
# Need to go up 1 further, from the combine in bitpack
data = data.op.input_tensors[0]
_schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec, conv_out, output, outs[0])
unipolar = "unipolar" in conv_out.op.tag
_schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
conv_out, output, outs[0], unipolar)
scheduled_ops.append(op)
traverse(outs[0].op)
......
# pylint: disable=invalid-name, unused-variable, too-many-locals, too-many-arguments, unused-argument
"""Bitserial Conv2D operators"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import numpy as np
import tvm
from tvm import autotvm
from topi.transform import concatenate
from .pad import pad
from .util import get_pad_tuple
from ..util import get_const_tuple, get_const_int
# workload description of conv2d
Workload = namedtuple('Workload',
['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
SpatialPackNCHW = namedtuple('SpatialPack',
['vh', 'vw', 'vc', 'ba', 'bc'])
SpatialPackNHWC = namedtuple('SpatialPack',
['vh', 'vw', 'vc', 'ba', 'bc'])
_WORKLOADS = [
# workloads of resnet18 on imagenet
# input_size, input_size, ic, oc, kh, kw, pad, pad, stride, stride
Workload('uint32', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('uint32', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('uint32', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('uint32', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('uint32', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('uint32', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('uint32', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('uint32', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('uint32', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('uint32', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('uint32', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
# workload of alexnet on cifar10
Workload('int32', 'int32', 27, 27, 96, 192, 5, 5, 2, 2, 1, 1),
Workload('int32', 'int32', 13, 13, 192, 384, 3, 3, 1, 1, 1, 1),
Workload('int32', 'int32', 13, 13, 384, 384, 3, 3, 1, 1, 1, 1),
Workload('int32', 'int32', 13, 13, 384, 256, 3, 3, 1, 1, 1, 1),
]
@tvm.target.generic_func
def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits,
layout='NCHW', pack_dtype='uint32', out_dtype='int32', dorefa=True):
def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype='uint32', out_dtype='int16', unipolar=True):
"""Bitserial Conv2D operator.
Parameters
----------
input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width] or
[batch, in_height, in_width, in_channel]
4-D with shape [batch, in_channel, in_height, in_width]
filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width] or
[filter_height, filter_width, in_channel, num_filter]
4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
layout : str
layout of data
padding : int or a list/tuple of two or four ints
padding size, [pad_height, pad_width], [pad_top, pad_left, pad_down, pad_right]
activation_bits: int
number of bits used for activations/input elements
......@@ -78,63 +40,184 @@ def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits
pack_dtype: str
bit packing type
dorefa: bool
preform the bitserial dot-product using 2 popcounts (required for DoReFa-Net)
unipolar: bool
if binarization style is in unipolar 1/0 format, instead of bipolar -1/+1 format
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width] or
[batch, out_height, out_width, out_channel]
4-D with shape [batch, out_channel, out_height, out_width]
"""
# search platform specific declaration first
# default declaration
if layout == 'NCHW':
return spatial_pack_nchw(data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa)
if layout == 'NHWC':
return spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa)
raise ValueError("not support this layout {} yet".format(layout))
def _get_workload(data, kernel, stride, padding, out_dtype, layout):
""" Get the workload structure. """
assert layout in ("NCHW", "NHWC"), \
"Only support layouts NCHW and NHWC"
if layout == "NCHW":
_, CI, IH, IW = [x.value for x in data.shape]
CO, _, KH, KW = [x.value for x in kernel.shape]
else: # NHWC
IH, IW = data.shape[1].value, data.shape[2].value
KH, KW, CI, CO = [x for x in get_const_tuple(kernel.shape)]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
assert isinstance(stride, int) or len(stride) == 2
Input_q = bitpack(data, activation_bits, pack_axis=1, bit_axis=2, pack_type=pack_dtype)
Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype)
batch, in_channel, activation_bits, in_height, in_width = Input_q.shape
num_filter, channel, kernel_h, kernel_w, weight_bits = Filter_q.shape
if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
else:
HSTR, WSTR = stride, stride
TPAD, LPAD, DPAD, RPAD = padding
pad_before = [0, 0, 0, TPAD, LPAD]
pad_after = [0, 0, 0, DPAD, RPAD]
PadInput_q = pad(Input_q, pad_before, pad_after, name="pad_temp")
# compute the output shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
out_channel = num_filter
out_height = (in_height - kernel_h + TPAD + DPAD) // stride_h + 1
out_width = (in_width - kernel_w + LPAD + RPAD) // stride_w + 1
rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
b1 = tvm.reduce_axis((0, activation_bits), name='b1')
b2 = tvm.reduce_axis((0, weight_bits), name='b2')
if unipolar:
def _conv(nn, ff, yy, xx):
b1b2 = (b1+b2).astype(out_dtype)
return tvm.sum(
((tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
Filter_q[ff, rc, ry, rx, b2]) -
tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
~Filter_q[ff, rc, ry, rx, b2]))
<< (b1b2)).astype(out_dtype),
axis=[rc, ry, rx, b2, b1]).astype(out_dtype)
else:
def _conv(nn, ff, yy, xx):
b1b2 = (b1+b2).astype(out_dtype)
return tvm.sum((tvm.popcount(
PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] &
Filter_q[ff, rc, ry, rx, b2])<< (b1b2)).astype(out_dtype),
axis=[rc, ry, rx, b2, b1]).astype(out_dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
return tvm.compute((batch, out_channel, out_height, out_width), _conv,
name="Conv2dOutput", tag="bitserial_conv2d_nchw")
@tvm.target.generic_func
def _get_schedule(wkl, layout):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
pack_dtype, out_dtype, dorefa=False):
def bitserial_conv2d_nhwc(data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype='uint32', out_dtype='int16', unipolar=True):
"""Bitserial Conv2D operator.
Parameters
----------
input : tvm.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
filter : tvm.Tensor
4-D with shape [filter_height, filter_width, in_channel, num_filter]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two or four ints
padding size, [pad_height, pad_width], [pad_top, pad_left, pad_down, pad_right]
activation_bits: int
number of bits used for activations/input elements
weight_bits: int
number of bits used for weight elements
out_dtype: str
return type of convolution
pack_dtype: str
bit packing type
unipolar: bool
if binarization style is in unipolar 1/0 format, instead of bipolar -1/+1 format
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
assert isinstance(stride, int) or len(stride) == 2
Input_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype)
if len(kernel.shape) == 4:
Filter_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype)
kernel_h, kernel_w, _, num_filter, _ = get_const_tuple(Filter_q.shape)
else:
Filter_q = kernel
kernel_h, kernel_w, _, _, num_filter = get_const_tuple(Filter_q.shape)
batch, in_height, in_width, in_channel_q, _ = get_const_tuple(Input_q.shape)
if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
else:
TPAD, LPAD, DPAD, RPAD = padding
pad_before = [0, TPAD, LPAD, 0, 0]
pad_after = [0, DPAD, RPAD, 0, 0]
# compute the output shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
out_channel = num_filter
out_height = (in_height - kernel_h + TPAD + DPAD) // stride_h + 1
out_width = (in_width - kernel_w + LPAD + RPAD) // stride_w + 1
PadInput_q = pad(Input_q, pad_before, pad_after, name="PaddedInput")
rc = tvm.reduce_axis((0, in_channel_q), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
b1 = tvm.reduce_axis((0, activation_bits), name='b1')
b2 = tvm.reduce_axis((0, weight_bits), name='b2')
if unipolar:
def _conv(nn, yy, xx, ff):
b1b2 = (b1+b2).astype(out_dtype)
return tvm.sum(
((tvm.popcount(PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] &
Filter_q[ry, rx, rc, ff, b2]) -
tvm.popcount(PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] &
~Filter_q[ry, rx, rc, ff, b2]))
<< b1b2).astype(out_dtype),
axis=[rc, ry, rx, b2, b1])
else:
def _conv(nn, yy, xx, ff):
b1b2 = (b1+b2).astype(out_dtype)
return tvm.sum((tvm.popcount(
PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] &
Filter_q[ry, rx, rc, ff, b2]) << b1b2).astype(out_dtype),
axis=[rc, ry, rx, b2, b1])
conv = tvm.compute((batch, out_height, out_width, out_channel), _conv,
name="Conv2dOutput", tag="bitserial_conv2d_nhwc")
return conv
@autotvm.register_topi_compute(bitserial_conv2d_nchw, ['cpu', 'arm_cpu'], 'direct')
def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
pack_dtype='uint32', out_dtype='int16', unipolar=True):
""" Compute convolution with pack on spatial axes. """
assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
data_q = bitpack(data, in_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype)
kernel_q = bitpack(kernel, weight_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype)
IB, _, CI, H, W = data_q.shape
KB, CO, _, KH, KW = kernel_q.shape
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
# Check if kernel is already bitpacked
if len(kernel.shape) == 4:
kernel_q = bitpack(kernel, weight_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype)
KB, CO, _, KH, KW = get_const_tuple(kernel_q.shape)
else:
kernel_vec = kernel
OCO, _, KH, KW, KB, VC = get_const_tuple(kernel_vec.shape)
CO = OCO * VC
IB, N, CI, H, W = get_const_tuple(data_q.shape)
KB, CO, _, KH, KW = get_const_tuple(kernel_q.shape)
if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
else:
TPAD, LPAD, DPAD, RPAD = padding
pad_before = [0, 0, 0, TPAD, LPAD]
pad_after = [0, 0, 0, DPAD, RPAD]
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
......@@ -142,38 +225,50 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
HSTR, WSTR = stride, stride
HCAT, WCAT = KH-1, KW-1
wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NCHW")
sch = _get_schedule(wkl, "NCHW")
VH = sch.vh
VW = sch.vw
VC = sch.vc
TH = H + 2*HPAD
TW = W + 2*WPAD
OH = (H + 2*HPAD - KH) // HSTR + 1
OW = (W + 2*WPAD - KW) // WSTR + 1
TH = H + TPAD + DPAD
TW = W + LPAD + RPAD
OH = (H + TPAD + DPAD - KH) // HSTR + 1
OW = (W + LPAD + RPAD - KW) // WSTR + 1
# ==================== define configuration space ====================
n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
re_axes = cfg.define_reorder("reorder_0",
[n, co, oh, ow, vc, vh, vw, kh, kw, kb, ib, ci],
policy='interval_all', interval=(6, 11))
cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops
# ====================
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]
dshape = (IB, 1, CI, H, W)
dpshape = (IB, 1, CI, TH, TW)
dvshape = (1, TH//(VH*HSTR), TW//(VW*WSTR), CI, VH*HSTR+HCAT, VW*WSTR+WCAT, IB)
kshape = (KB, CO, CI, KH, KW)
kvshape = (CO//VC, CI, KH, KW, KB, VC)
ovshape = (1, CO//VC, OH//VH, OW//VW, VH, VW, VC)
oshape = (1, CO, OH, OW)
DOPAD = (HPAD != 0 and WPAD != 0)
if DOPAD:
data_pad = pad(data_q, (0, 0, 0, HPAD, WPAD), name="data_pad")
if (TPAD != 0 and RPAD != 0):
data_pad = pad(data_q, (0, 0, 0, TPAD, LPAD), (0, 0, 0, DPAD, RPAD), name="data_pad")
else:
data_pad = data_q
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw, b: \
data_pad[b][n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec')
kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, b, vc: \
kernel_q[b][co*VC+vc][ci][dh][dw], name='kernel_vec')
if len(kernel.shape) == 4:
kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, b, vc: \
kernel_q[b][co*VC+vc][ci][dh][dw], name='kernel_vec')
ci = tvm.reduce_axis((0, CI), name='ci')
dh = tvm.reduce_axis((0, KH), name='dh')
......@@ -183,7 +278,7 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
def _conv(n, co, h, w, vh, vw, vc):
b1b2 = (b1+b2).astype(out_dtype)
if dorefa:
if unipolar:
return tvm.sum((tvm.popcount(
data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) &
kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype)) -
......@@ -203,15 +298,28 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
name='conv_vec', tag='spatial_bitserial_conv_nchw')
def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
pack_dtype, out_dtype, dorefa=False):
@autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct')
def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
pack_dtype='uint32', out_dtype='int16', unipolar=True):
""" Compute convolution with pack on spatial axes. """
assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
data_q = bitpack(data, in_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype)
kernel_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype)
_, H, W, CI, IB = data_q.shape
KH, KW, _, CO, KB = kernel_q.shape
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
pack_kernel = len(kernel.shape) == 4
if pack_kernel:
kernel_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype)
else:
kernel_q = kernel
KH, KW, _, CO, KB = get_const_tuple(kernel_q.shape)
N, H, W, CI, IB = get_const_tuple(data_q.shape)
if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
else:
TPAD, LPAD, DPAD, RPAD = padding
pad_before = [0, TPAD, LPAD, 0, 0]
pad_after = [0, DPAD, RPAD, 0, 0]
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
......@@ -219,24 +327,41 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
HSTR, WSTR = stride, stride
HCAT, WCAT = KH-1, KW-1
wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NHWC")
sch = _get_schedule(wkl, "NHWC")
VH = sch.vh
VW = sch.vw
VC = sch.vc
PAD_H = H + (TPAD + DPAD)
PAD_W = W + (LPAD + RPAD)
OH = (PAD_H - KH) // HSTR + 1
OW = (PAD_W - KW) // WSTR + 1
oshape = (1, OH, OW, CO)
PAD_H = H + 2*HPAD
PAD_W = W + 2*WPAD
OH = (H + 2*HPAD - KH) // HSTR + 1
OW = (W + 2*WPAD - KW) // WSTR + 1
# ==================== define configuration space ====================
n, oh, ow, co = cfg.axis(N), cfg.axis(OH), cfg.axis(OW), cfg.axis(CO)
ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
re_axes = cfg.define_reorder("reorder_0",
[n, oh, ow, co, vh, vw, kh, kw, kb, ib, vc, ci],
policy='interval_all', interval=(3, 7))
cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops
# ====================
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]
dvshape = (1, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, CI, IB)
kvshape = (CO, KH, KW, CI, VC, KB)
ovshape = (1, OH, OW, CO, VH, VW, VC)
oshape = (1, OH, OW, CO)
if (HPAD != 0 and WPAD != 0):
data_pad = pad(data_q, (0, HPAD, WPAD, 0, 0), name="data_pad")
if (DPAD != 0 and RPAD != 0):
data_pad = pad(data_q, (0, TPAD, LPAD, 0, 0), (0, DPAD, RPAD, 0, 0), name="data_pad")
else:
data_pad = data_q
......@@ -254,12 +379,12 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
def _conv(n, h, w, co, vh, vw, vc):
b1b2 = (b1+b2).astype(out_dtype)
if dorefa:
if unipolar:
return tvm.sum(
(tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1].astype(out_dtype) &
kernel_vec[co, dh, dw, ci, vc, b2].astype(out_dtype)) -
tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1].astype(out_dtype) &
~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2,
((tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] &
kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) -
tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1]&
~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2),
axis=[dh, dw, ci, b1, b2])
return tvm.sum(tvm.popcount(
......@@ -273,6 +398,7 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC],
name='output_unpack', tag='spatial_bitserial_conv_nhwc')
def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
"""Packs data into format necessary for bitserial computation
pack_axis : int
......@@ -334,8 +460,3 @@ def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
if bits > 1:
return concatenate(output_tuple, axis=bit_axis)
return output_tuple
_SCH_TO_DECL_FUNC_QUANT = {
SpatialPackNCHW: spatial_pack_nchw,
SpatialPackNHWC: spatial_pack_nhwc,
}
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Bitserial conv2d schedule on x86"""
import tvm
from tvm import autotvm
from topi.util import get_const_int
from .. import generic, tag
from ..nn.bitserial_conv2d import bitserial_conv2d, _get_schedule, _get_workload
from ..nn.bitserial_conv2d import SpatialPackNCHW, SpatialPackNHWC
from ..nn.bitserial_conv2d import _WORKLOADS, _SCH_TO_DECL_FUNC_QUANT
_QUANTIZED_SCHEDULES_NCHW = [
# resnet
SpatialPackNCHW(2, 2, 8, 1, 1),
SpatialPackNCHW(1, 4, 8, 4, 1),
SpatialPackNCHW(1, 4, 8, 1, 16),
SpatialPackNCHW(1, 4, 8, 4, 8),
SpatialPackNCHW(1, 7, 8, 3, 8),
SpatialPackNCHW(1, 2, 8, 1, 8),
SpatialPackNCHW(2, 1, 8, 1, 4),
SpatialPackNCHW(1, 7, 8, 1, 1),
SpatialPackNCHW(1, 1, 8, 1, 16),
SpatialPackNCHW(1, 1, 8, 1, 8),
SpatialPackNCHW(1, 1, 8, 1, 16),
SpatialPackNCHW(3, 3, 16, 3, 16),
SpatialPackNCHW(1, 1, 16, 2, 16),
SpatialPackNCHW(1, 1, 8, 1, 16),
SpatialPackNCHW(1, 1, 8, 1, 16),
]
_QUANTIZED_SCHEDULES_NHWC = [
# resnet
SpatialPackNHWC(2, 2, 8, 1, 1),
SpatialPackNHWC(1, 4, 8, 4, 1),
SpatialPackNHWC(1, 4, 8, 1, 16),
SpatialPackNHWC(1, 4, 8, 4, 8),
SpatialPackNHWC(1, 7, 8, 3, 8),
SpatialPackNHWC(1, 2, 8, 1, 8),
SpatialPackNHWC(2, 1, 8, 1, 4),
SpatialPackNHWC(1, 7, 8, 1, 1),
SpatialPackNHWC(1, 1, 8, 1, 16),
SpatialPackNHWC(1, 1, 8, 1, 8),
SpatialPackNHWC(1, 1, 8, 1, 16),
]
@_get_schedule.register("cpu")
def _get_schedule_bitserial_conv2d(wkl, layout):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
idx = _WORKLOADS.index(wkl)
if layout == "NCHW":
sch = _QUANTIZED_SCHEDULES_NCHW[idx]
elif layout == "NHWC":
sch = _QUANTIZED_SCHEDULES_NHWC[idx]
return sch
@bitserial_conv2d.register("cpu")
def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits,
layout='NCHW', pack_dtype=None, out_dtype=None, dorefa=False):
if out_dtype is None:
out_dtype = data.dtype
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
assert layout in ("NCHW", "NHWC"), "only support layouts NCHW and NHWC"
wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout)
sch = _get_schedule(wkl, layout)
return _SCH_TO_DECL_FUNC_QUANT[type(sch)](data, kernel, stride, padding, activation_bits,
weight_bits, pack_dtype, out_dtype, dorefa)
@generic.schedule_bitserial_conv2d_nchw.register(["cpu"])
@generic.schedule_bitserial_conv2d_nhwc.register(["cpu"])
def schedule_bitserial_conv2d(outs):
@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nchw, ['cpu'], 'direct')
@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nhwc, ['cpu'], 'direct')
def schedule_bitserial_conv2d(cfg, outs):
"""CPU schedule for bitserial convolutions NCHW and NHWC"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
......@@ -88,7 +27,6 @@ def schedule_bitserial_conv2d(outs):
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel_q = kernel_vec.op.input_tensors[0]
kernel = kernel_q.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[0]
data_q = data_vec.op.input_tensors[0]
data = data_q.op.input_tensors[0]
......@@ -97,29 +35,27 @@ def schedule_bitserial_conv2d(outs):
data_pad = data_q
data_q = data
data = data_q.op.input_tensors[0]
if "QuantizeInput" in kernel.op.name:
# Need to go up 1 further, from the combine in bitpack
kernel = kernel.op.input_tensors[0]
if "QuantizeInput" in data.op.name:
# Need to go up 1 further, from the combine in bitpack
data = data.op.input_tensors[0]
if 'spatial_bitserial_conv_nchw' in op.tag:
_schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, outs[0])
_schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
kernel_q, kernel_vec,
conv_out, output, outs[0])
elif 'spatial_bitserial_conv_nhwc' in op.tag:
_schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, outs[0])
_schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec,
kernel_q, kernel_vec,
conv_out, output, outs[0])
scheduled_ops.append(op)
traverse(outs[0].op)
return s
def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, last):
def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
kernel_q, kernel_vec,
conv_out, output, last):
IB, _, CI, IH, IW = data_q.shape
KB, CO, _, KH, KW = kernel_q.shape
_, _, OH, OW = output.shape
......@@ -138,37 +74,21 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
wstride = get_const_int((TW - KW) // (OW - 1))
stride = (hstride, wstride)
wkl = _get_workload(data, kernel, stride, padding, output.dtype, "NCHW")
sch = _get_schedule(wkl, "NCHW")
VH = sch.vh
VW = sch.vw
VC = sch.vc
ba = sch.ba
bc = sch.bc
CC = s.cache_write(conv_out, "global")
n, co, oh, ow, vh, vw, vc = s[conv_out].op.axis
s[conv_out].vectorize(vc)
s[CC].compute_at(s[conv_out], ow)
n, co, oh, ow, vh, vw, vc = s[CC].op.axis
ci, dh, dw, b1, b2 = s[CC].op.reduce_axis
s[CC].reorder(ci, dh, vh, dw, vw, b1, b2, vc)
s[CC].unroll(b1)
s[CC].unroll(b2)
s[CC].vectorize(vc)
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]
##### Schedule A
##### Schedule Data padding, and bitpacking
if data_pad is not None:
s[data_pad].compute_inline()
_, h, _, _, _, _, vw = s[data_vec].op.axis
s[data_vec].vectorize(vw)
if ba == 1:
oaxis = h
paxis = h
_, _, h, _, _, _, _ = s[data_vec].op.axis
cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)
oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
if cfg["tile_ah"].size[1] == 1:
oaxis = oh
paxis = oh
else:
oh, ih = s[data_vec].split(h, ba)
oaxis = oh
paxis = ih
......@@ -178,14 +98,14 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s[data_vec].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule B
co, _, _, _, _, vc = s[kernel_vec].op.axis
s[kernel_vec].vectorize(vc)
if bc == 1:
oaxis = co
paxis = co
##### Schedule Kenerl bitpacking
co, _, _, _, _, _ = s[kernel_vec].op.axis
cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)
oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
if cfg["tile_bco"].size[1] == 1:
oaxis = oco
paxis = oco
else:
oco, ico = s[kernel_vec].split(co, bc)
oaxis = oco
paxis = ico
......@@ -195,7 +115,23 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule C
##### Schedule Convolution
n, co, oh, ow, vh, vw, vc = s[conv_out].op.axis
ci, dh, dw, ib, kb = s[conv_out].op.reduce_axis
# s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
cfg["reorder_0"].apply(s, conv_out, [n, co, oh, ow, vc, vh, vw, dh, dw, kb, ib, ci])
cfg["ann_reduce"].apply(s, conv_out, [kb, ib, dh, dw],
axis_lens=[get_const_int(kb.dom.extent),
get_const_int(ib.dom.extent),
get_const_int(dh.dom.extent),
get_const_int(dw.dom.extent)],
max_unroll=16,
cfg=cfg)
s[conv_out].vectorize(vc)
# # Schedule output
n, co, h, w = s[last].op.axis
co, vc = s[last].split(co, VC)
oh, ow, vh, vw = s[last].tile(h, w, VH, VW)
......@@ -204,89 +140,58 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s[output].compute_inline()
s[conv_out].compute_at(s[last], ow)
if bc == 1:
oaxis = co
paxis = co
oco, ico = cfg["tile_oh"].apply(s, last, co)
if cfg["tile_oh"].size[1] == 1:
oaxis = oco
paxis = oco
else:
oco, ico = s[last].split(co, bc)
oaxis = oco
paxis = ico
s[last].parallel(paxis)
s[last].pragma(oaxis, "parallel_launch_point")
s[last].pragma(paxis, "parallel_stride_pattern")
s[last].pragma(oaxis, "parallel_barrier_when_finish")
s[last].parallel(oco)
return s
def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, last):
def _schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec,
kernel_q, kernel_vec,
conv_out, output, last):
# no stride and padding info here
_, IH, IW, CI, IB = data_q.shape
KH, KW, _, CO, KB = kernel_q.shape
_, OH, OW, _ = output.shape
# Infer padding and stride
if data_pad is None:
padding = (0, 0)
TH, TW = IH, IW
else:
_, TH, TW, _, _ = data_pad.shape
hpad = get_const_int((TH - IH) // 2)
wpad = get_const_int((TW - IW) // 2)
padding = (hpad, wpad)
hstride = get_const_int((TH - KH) // (OH - 1))
wstride = get_const_int((TW - KW) // (OW - 1))
stride = (hstride, wstride)
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]
wkl = _get_workload(data, kernel, stride, padding, last.dtype, "NHWC")
sch = _get_schedule(wkl, "NHWC")
VH = sch.vh
VW = sch.vw
VC = sch.vc
ba = sch.ba
bc = sch.bc
##### Schedule data packing
##### Schedule data padding and packing
if data_pad is not None:
s[data_pad].compute_inline()
_, h, _, _, _, _, _ = s[data_vec].op.axis
if ba == 1:
oaxis = h
paxis = h
else:
oh, ih = s[data_vec].split(h, ba)
oaxis = oh
paxis = ih
s[data_vec].parallel(paxis)
s[data_vec].pragma(oaxis, "parallel_launch_point")
s[data_vec].pragma(paxis, "parallel_stride_pattern")
s[data_vec].pragma(oaxis, "parallel_barrier_when_finish")
cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)
oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
s[data_vec].parallel(oh)
##### Schedule kernel packing
co, _, _, _, _, _ = s[kernel_vec].op.axis
if bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[kernel_vec].split(co, bc)
oaxis = oco
paxis = ico
s[kernel_vec].parallel(paxis)
s[kernel_vec].pragma(oaxis, "parallel_launch_point")
s[kernel_vec].pragma(paxis, "parallel_stride_pattern")
s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish")
cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)
oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
s[kernel_vec].parallel(oco)
##### Schedule Convolution
n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis
dh, dw, ci, b1, b2 = s[conv_out].op.reduce_axis
s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
# s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
cfg["reorder_0"].apply(s, conv_out, [n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2])
cfg["ann_reduce"].apply(s, conv_out, [b1, b2, dh, dw],
axis_lens=[get_const_int(b1.dom.extent),
get_const_int(b2.dom.extent),
get_const_int(dh.dom.extent),
get_const_int(dw.dom.extent)],
max_unroll=16,
cfg=cfg)
s[conv_out].unroll(b1)
s[conv_out].unroll(b2)
......@@ -302,17 +207,7 @@ def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
s[output].compute_inline()
s[conv_out].compute_at(s[last], ow)
if bc == 1:
oaxis = oh
paxis = oh
else:
oho, iho = s[last].split(oh, bc)
oaxis = oho
paxis = iho
s[last].parallel(paxis)
s[last].pragma(oaxis, "parallel_launch_point")
s[last].pragma(paxis, "parallel_stride_pattern")
s[last].pragma(oaxis, "parallel_barrier_when_finish")
oho, iho = cfg["tile_oh"].apply(s, last, oh) # reuse parameter
s[last].parallel(oho)
return s
......@@ -11,16 +11,16 @@ def generate_quantized_np(shape, bits, out_dtype):
return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa):
activation_bits, weight_bits, unipolar):
in_height = in_width = in_size
input_type = 'uint32'
input_dtype = 'uint32'
out_dtype = 'int32'
with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_type, name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits,
out_dtype=out_dtype, layout="NCHW", dorefa=dorefa)
A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_dtype, name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_dtype, name='W')
B = topi.nn.bitserial_conv2d_nchw(A, W, stride, padding, activation_bits, weight_bits,
out_dtype=out_dtype, unipolar=unipolar)
s = topi.generic.schedule_bitserial_conv2d_nchw([B])
a_shape = get_const_tuple(A.shape)
......@@ -28,9 +28,9 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
@memoize("topi.tests.test_topi_bitseral_conv2d_nchw")
def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type)
w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type)
if dorefa:
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype)
w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_dtype)
if unipolar:
w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1
......@@ -49,16 +49,16 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa):
activation_bits, weight_bits, unipolar):
in_height = in_width = in_size
input_type='uint32'
input_dtype='uint32'
out_dtype='int32'
with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype,
layout="NHWC", dorefa=dorefa)
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_dtype, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_dtype, name='W')
B = topi.nn.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
out_dtype=out_dtype, unipolar=unipolar)
s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
a_shape = get_const_tuple(A.shape)
......@@ -66,9 +66,9 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
@memoize("topi.tests.test_topi_bitseral_conv2d_nhwc")
def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type)
w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type)
if dorefa:
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype)
w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_dtype)
if unipolar:
w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1
......
......@@ -4,6 +4,7 @@ import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
def generate_quantized_np(shape, bits, out_dtype):
np.random.seed(0)
......@@ -13,19 +14,20 @@ def generate_quantized_np(shape, bits, out_dtype):
# Verify that certain special instructions from the tensorize pass exist
def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding,
activation_bits, weight_bits, dorefa):
activation_bits, weight_bits, unipolar):
in_height = in_width = in_size
input_type = 'uint32'
out_dtype = 'int32'
out_dtype = 'int16'
with tvm.target.arm_cpu('rasp3b'):
device = 'llvm -device=arm_cpu -model=bcm2837 -target=armv7l-linux-gnueabihf -mattr=+neon'
with tvm.target.create(device):
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype,
layout="NHWC", dorefa=dorefa)
B = topi.nn.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
pack_dtype='uint8', out_dtype='int16', unipolar=unipolar)
s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
func = tvm.build(s, [A, W, B], tvm.target.arm_cpu('rasp3b'))
func = tvm.build(s, [A, W, B], device)
assembly = func.get_source('asm')
matches = re.findall("vpadal", assembly)
......@@ -35,6 +37,33 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
matches = re.findall("vpadd", assembly)
assert (len(matches) > 0)
ctx = tvm.context(device, 0)
if 'arm' not in os.uname()[4]:
print ("Skipped running code, not an arm device")
return
print("Running on target: %s" % device)
def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type)
w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type)
if unipolar:
w_ = np.copy(w_np).astype(out_dtype)
for x in np.nditer(w_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1
b_np = topi.testing.conv2d_nhwc_python(a_np, w_, stride, padding).astype(out_dtype)
else:
b_np = topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding).astype(out_dtype)
return a_np, w_np, b_np
a_np, w_np, b_np = get_ref_data()
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, W, B], device)
func(a, w, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def test_bitserial_conv2d():
in_size = 56
ic, oc = 64, 64
......@@ -45,6 +74,9 @@ def test_bitserial_conv2d():
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, True)
verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, True)
if __name__ == "__main__":
test_bitserial_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment