Commit b405f68b by Meghan Cowan Committed by Tianqi Chen

[TOPI] Bitserial dense operators for CPU (#3051)

parent d8f547c7
...@@ -87,6 +87,7 @@ class TaskExtractEnv: ...@@ -87,6 +87,7 @@ class TaskExtractEnv:
topi.nn.dense: "topi_nn_dense", topi.nn.dense: "topi_nn_dense",
topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw", topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc", topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
topi.nn.bitserial_dense: "topi_nn_bitserial_dense",
topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw", topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
} }
...@@ -101,6 +102,7 @@ class TaskExtractEnv: ...@@ -101,6 +102,7 @@ class TaskExtractEnv:
topi.nn.dense: [topi.generic.schedule_dense], topi.nn.dense: [topi.generic.schedule_dense],
topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw], topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc], topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense],
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw], topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
} }
...@@ -200,18 +202,25 @@ class TaskExtractEnv: ...@@ -200,18 +202,25 @@ class TaskExtractEnv:
args = deserialize_args(args) args = deserialize_args(args)
C = topi.nn.bitserial_conv2d_nhwc(*args, **kwargs) C = topi.nn.bitserial_conv2d_nhwc(*args, **kwargs)
s = topi.generic.nn.schedule_bitserial_conv2d_nhwc([C]) s = topi.generic.nn.schedule_bitserial_conv2d_nhwc([C])
data = args[0] A, W = args[:2]
kernel = args[1] return s, [A, W, C]
return s, [data, kernel, C]
@register("topi_nn_bitserial_conv2d_nchw") @register("topi_nn_bitserial_conv2d_nchw")
def _topi_bitserial_conv2d_nchw(*args, **kwargs): def _topi_bitserial_conv2d_nchw(*args, **kwargs):
args = deserialize_args(args) args = deserialize_args(args)
C = topi.nn.bitserial_conv2d_nchw(*args, **kwargs) C = topi.nn.bitserial_conv2d_nchw(*args, **kwargs)
s = topi.generic.nn.schedule_bitserial_conv2d_nchw([C]) s = topi.generic.nn.schedule_bitserial_conv2d_nchw([C])
data = args[0] A, W = args[:2]
kernel = args[1] return s, [A, W, C]
return s, [data, kernel, C]
@register("topi_nn_bitserial_dense")
def _topi_nn_bitserial_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.bitserial_dense(*args, **kwargs)
s = topi.generic.schedule_bitserial_dense([C])
return s, [A, W, C]
@register("topi_nn_deformable_conv2d_nchw") @register("topi_nn_deformable_conv2d_nchw")
def _topi_nn_deformable_conv2d_nchw(*args, **kwargs): def _topi_nn_deformable_conv2d_nchw(*args, **kwargs):
......
...@@ -4,4 +4,5 @@ from . import conv2d ...@@ -4,4 +4,5 @@ from . import conv2d
from . import depthwise_conv2d from . import depthwise_conv2d
from . import conv2d_transpose from . import conv2d_transpose
from . import bitserial_conv2d from . import bitserial_conv2d
from . import bitserial_dense
from . import injective from . import injective
...@@ -21,7 +21,8 @@ import tvm ...@@ -21,7 +21,8 @@ import tvm
from tvm import autotvm from tvm import autotvm
from .. import tag from .. import tag
from ..nn.pad import pad from ..nn.pad import pad
from ..nn.bitserial_conv2d import bitpack, bitserial_conv2d_nhwc from ..nn.bitserial_conv2d import bitserial_conv2d_nhwc
from ..nn.bitserial_util import bitpack, binary_op_multiplier
from ..nn.util import get_pad_tuple from ..nn.util import get_pad_tuple
from ..util import get_const_int, get_const_tuple from ..util import get_const_int, get_const_tuple
from .. import generic from .. import generic
...@@ -93,7 +94,8 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh ...@@ -93,7 +94,8 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
policy='candidate', candidate=[ policy='candidate', candidate=[
[n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i], [n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i],
[n, oh, ow, co, vh, vw, kw, kh, ci_o, kb, ib, vc, ci_i],]) [n, oh, ow, co, vh, vw, kw, kh, ci_o, kb, ib, vc, ci_i],])
cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops # binary ops
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
# ==================== # ====================
VC = cfg["tile_co"].size[-1] VC = cfg["tile_co"].size[-1]
...@@ -310,7 +312,6 @@ def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec, ...@@ -310,7 +312,6 @@ def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
s[conv_out].compute_at(s[last], co) s[conv_out].compute_at(s[last], co)
s[last].parallel(oh) s[last].parallel(oh)
s = s.normalize()
return s return s
@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nhwc, 'arm_cpu', 'direct') @autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nhwc, 'arm_cpu', 'direct')
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, invalid-name, too-many-locals, too-many-arguments
"""Schedule for bitserial dense operator."""
from __future__ import absolute_import as _abs
import tvm
from tvm import autotvm
from topi.util import get_const_tuple
from .. import tag
from .. import generic
from .bitserial_conv2d import _intrin_popcount
from ..nn.pad import pad
from ..nn.bitserial_dense import bitserial_dense
from ..nn.bitserial_util import bitpack, binary_op_multiplier
@autotvm.register_topi_compute(bitserial_dense, ['arm_cpu'], 'direct')
def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_dtype,
unipolar):
"""The default implementation of bitserial dense in topi.
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
if len(weight.shape) == 2:
weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
else:
weight_packed = weight
batch, DB, in_dim = get_const_tuple(data_packed.shape)
out_dim, WB, in_dim = get_const_tuple(weight_packed.shape)
# Pad Inputs so that microkernel can be used
# out_dim and in_dim need to be multiples of 8
if out_dim % 8 != 0:
out_dim_pad = out_dim % 8
data_packed = pad(data_packed, [0, 0, 0], [out_dim_pad, 0, 0], name='PaddedInput')
out_dim += out_dim_pad
######## Search space
x, y = cfg.axis(batch), cfg.axis(out_dim)
db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(in_dim)
ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2,
filter=lambda xx: xx.size[-1] == 8 or xx.size[-1] == 16)
xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2,
filter=lambda xx: xx.size[-1] == 8)
cfg.define_reorder('reorder_0', [yo, xo, ko, xi, wb, db, yi, ki],
policy='candidate', candidate=[
[yo, xo, ko, xi, wb, db, yi, ki],
[yo, xo, xi, ko, wb, db, yi, ki],
[yo, xo, ko, xi, wb, db, yi, ki]])
###### Compute rule
VY = cfg['tile_y'].size[-1]
VK = cfg['tile_k'].size[-1]
wvshape = (out_dim//VY, in_dim//VK, WB, VY, VK)
oshape = (batch, out_dim)
k = tvm.reduce_axis((0, in_dim), name='k')
db = tvm.reduce_axis((0, DB), name='db')
wb = tvm.reduce_axis((0, WB), name='wb')
# Tile data and weights
weight_vec = tvm.compute(wvshape, lambda yo, ko, wb, vy, vk:
weight_packed[yo*VY+vy][wb][ko*VK+vk], name='weight_vec')
matmul_unipolar = tvm.compute(oshape, lambda x, y: tvm.sum(
(tvm.popcount(weight_vec[y//VY, k//VK, wb, y%VY, k%VK].astype(out_dtype) &
data_packed[x, db, k].astype(out_dtype)) -
tvm.popcount(~weight_vec[y//VY, k//VK, wb, y%VY, k%VK].astype(out_dtype) &
data_packed[x, db, k].astype(out_dtype)))
<< (wb+db).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')
matmul = tvm.compute(oshape, lambda x, y: tvm.sum(
tvm.popcount(weight_vec[y//VY, k//VK, wb, y%VY, k%VK].astype(out_dtype) &
data_packed[x, db, k].astype(out_dtype))
<< (wb+db).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')
cfg.add_flop(batch * out_dim * in_dim * binary_op_multiplier(pack_dtype))
if unipolar:
return matmul_unipolar
return matmul
@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_dense, ['arm_cpu'], 'direct')
def schedule_bitserial_dense(cfg, outs):
"""Schedule for binary_dense.
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial dense operator.
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for bitserial_dense.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(cfg, s, data_vec, weight_vec, output, unipolar):
z, k, _, y, x = s[weight_vec].op.axis
s[weight_vec].parallel(z)
s[weight_vec].vectorize(x)
x, y = s[output].op.axis
wb, db, k = s[output].op.reduce_axis
_, DB, _ = get_const_tuple(data_vec.shape)
_, _, WB, _, _ = get_const_tuple(weight_vec.shape)
yo, yi = cfg["tile_y"].apply(s, output, y)
xo, xi = cfg["tile_x"].apply(s, output, x)
ko, ki = cfg["tile_k"].apply(s, output, k)
cfg["reorder_0"].apply(s, output, [yo, xo, ko, xi, wb, db, yi, ki])
fused = s[output].fuse(xo, yo)
s[output].parallel(fused)
nfactor = cfg['tile_y'].size[-1]
kfactor = cfg['tile_k'].size[-1]
if nfactor % 8 == 0:
pc = _intrin_popcount(nfactor, kfactor, WB, DB, unipolar)
s[output].tensorize(wb, pc)
return s
def traverse(op):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag) or 'elemwise' in op.tag:
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar':
output = op.output(0)
weight_vec = op.input_tensors[0]
data_vec = op.input_tensors[1]
data = data_vec.op.input_tensors[0]
if "QuantizeInput" in data.op.name:
data = data.op.input_tensors[0]
unipolar = (output.op.tag == 'bitserial_dense_unipolar')
_schedule(cfg, s, data_vec, weight_vec, output, unipolar)
else:
raise RuntimeError("Unsupported operator: %s" % op.tag)
traverse(outs[0].op)
return s
...@@ -312,6 +312,22 @@ def schedule_bitserial_conv2d_nhwc(outs): ...@@ -312,6 +312,22 @@ def schedule_bitserial_conv2d_nhwc(outs):
return _default_schedule(outs, False) return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_bitserial_dense(outs):
"""Schedule for bitserial_dense
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial_dense
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.override_native_generic_func("schedule_reduce") @tvm.target.override_native_generic_func("schedule_reduce")
def schedule_reduce(outs): def schedule_reduce(outs):
"""Schedule for reduction """Schedule for reduction
......
...@@ -17,5 +17,6 @@ from .bnn import * ...@@ -17,5 +17,6 @@ from .bnn import *
from .upsampling import * from .upsampling import *
from .local_response_norm import * from .local_response_norm import *
from .bitserial_conv2d import * from .bitserial_conv2d import *
from .bitserial_dense import *
from .l2_normalize import * from .l2_normalize import *
from .batch_matmul import * from .batch_matmul import *
...@@ -14,16 +14,15 @@ ...@@ -14,16 +14,15 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, unused-variable, too-many-locals, too-many-arguments, unused-argument # pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Bitserial Conv2D operators""" """Bitserial Conv2D operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import numpy as np
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from topi.transform import concatenate
from .pad import pad from .pad import pad
from .util import get_pad_tuple from .util import get_pad_tuple
from ..util import get_const_tuple, get_const_int from .bitserial_util import bitpack, binary_op_multiplier
from ..util import get_const_tuple
@tvm.target.generic_func @tvm.target.generic_func
def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight_bits, def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight_bits,
...@@ -68,7 +67,7 @@ def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight ...@@ -68,7 +67,7 @@ def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight
Input_q = bitpack(data, activation_bits, pack_axis=1, bit_axis=2, pack_type=pack_dtype) Input_q = bitpack(data, activation_bits, pack_axis=1, bit_axis=2, pack_type=pack_dtype)
Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype) Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype)
batch, in_channel, activation_bits, in_height, in_width = Input_q.shape batch, in_channel, activation_bits, in_height, in_width = Input_q.shape
num_filter, channel, kernel_h, kernel_w, weight_bits = Filter_q.shape num_filter, _, kernel_h, kernel_w, weight_bits = Filter_q.shape
if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2): if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel) TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
...@@ -259,10 +258,11 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits, ...@@ -259,10 +258,11 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
filter=lambda x: max(x.size[1:]) <= 16) filter=lambda x: max(x.size[1:]) <= 16)
cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll') cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
re_axes = cfg.define_reorder("reorder_0", cfg.define_reorder("reorder_0",
[n, co, oh, ow, vc, vh, vw, kh, kw, kb, ib, ci], [n, co, oh, ow, vc, vh, vw, kh, kw, kb, ib, ci],
policy='interval_all', interval=(6, 11)) policy='interval_all', interval=(6, 11))
cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops # binary ops
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
# ==================== # ====================
VC = cfg["tile_co"].size[-1] VC = cfg["tile_co"].size[-1]
...@@ -275,7 +275,7 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits, ...@@ -275,7 +275,7 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
oshape = (1, CO, OH, OW) oshape = (1, CO, OH, OW)
if (TPAD != 0 and RPAD != 0): if (TPAD != 0 and RPAD != 0):
data_pad = pad(data_q, (0, 0, 0, TPAD, LPAD), (0, 0, 0, DPAD, RPAD), name="data_pad") data_pad = pad(data_q, pad_before, pad_after, name="data_pad")
else: else:
data_pad = data_q data_pad = data_q
...@@ -361,10 +361,11 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits, ...@@ -361,10 +361,11 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2, ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16) filter=lambda x: max(x.size[1:]) <= 16)
cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll') cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
re_axes = cfg.define_reorder("reorder_0", cfg.define_reorder("reorder_0",
[n, oh, ow, co, vh, vw, kh, kw, kb, ib, vc, ci], [n, oh, ow, co, vh, vw, kh, kw, kb, ib, vc, ci],
policy='interval_all', interval=(3, 7)) policy='interval_all', interval=(3, 7))
cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops # binary ops
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
# ==================== # ====================
VC = cfg["tile_co"].size[-1] VC = cfg["tile_co"].size[-1]
...@@ -377,7 +378,7 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits, ...@@ -377,7 +378,7 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
oshape = (1, OH, OW, CO) oshape = (1, OH, OW, CO)
if (DPAD != 0 and RPAD != 0): if (DPAD != 0 and RPAD != 0):
data_pad = pad(data_q, (0, TPAD, LPAD, 0, 0), (0, DPAD, RPAD, 0, 0), name="data_pad") data_pad = pad(data_q, pad_before, pad_after, name="data_pad")
else: else:
data_pad = data_q data_pad = data_q
...@@ -413,66 +414,3 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits, ...@@ -413,66 +414,3 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
return tvm.compute(oshape, lambda n, h, w, co: return tvm.compute(oshape, lambda n, h, w, co:
conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC], conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC],
name='output_unpack', tag='spatial_bitserial_conv_nhwc') name='output_unpack', tag='spatial_bitserial_conv_nhwc')
def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
"""Packs data into format necessary for bitserial computation
pack_axis : int
index of the axis to pack in data
bit_axis : int
index of axis to place bit axis in resulting packed data"""
ishape = data.shape
n = len(ishape)
if pack_type == 'uint8':
data_width = 8
elif pack_type == 'uint16':
data_width = 16
elif pack_type == 'uint32':
data_width = 32
elif pack_type == 'uint64':
data_width = 64
# Data must be in multiples of the data_width
assert get_const_int(ishape[pack_axis]) % data_width == 0, "Not a multiple of word size"
shape_vec = list(ishape)
shape_vec[pack_axis] = (shape_vec[pack_axis] // data_width)
shape_vec.insert(bit_axis, 1)
bitserial_oshape = tuple(shape_vec)
masks = np.array([0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80])
# pack axis shifts if bit axis comes before
if bit_axis <= pack_axis:
pack_axis += 1
def _bitpack(*indices):
packed_data = [tvm.const(0, pack_type)] * bits
for k in range(data_width):
# Translate indices for packed data back to original
idx = [0] * n
j = 0
for i in range(n+1):
if i == bit_axis:
continue
elif i == pack_axis:
idx[j] = indices[i] * data_width + k
else:
idx[j] = indices[i]
j += 1
element = data(*idx)
for b in range(bits):
extracted_bit = ((element & tvm.const(masks[b], "int32")) >> b).astype(pack_type)
packed_data[b] = (packed_data[b] | extracted_bit)
if k < data_width - 1:
packed_data[b] = packed_data[b] << 1
if k == data_width - 1:
return tuple(packed_data)
return tuple(packed_data)
output_tuple = tvm.compute(bitserial_oshape, _bitpack, name=name, tag='bitpack')
if bits > 1:
return concatenate(output_tuple, axis=bit_axis)
return output_tuple
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Bitserial Dense operator."""
from __future__ import absolute_import
import tvm
from tvm import autotvm
from topi.util import get_const_tuple
from .bitserial_util import bitpack, binary_op_multiplier
@tvm.target.generic_func
def bitserial_dense(data, weight, data_bits, weight_bits, pack_dtype='uint32',
out_dtype='int16', unipolar=True):
"""The default implementation of bitserial dense in topi.
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim] or
3-D with shape [out_dim, weight_bits, in_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
if len(weight.shape) == 2:
weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
else:
weight_packed = weight
Y, DB, K = get_const_tuple(data_packed.shape)
X, WB, _ = get_const_tuple(weight_packed.shape)
oshape = (Y, X)
k = tvm.reduce_axis((0, K), name='k')
db = tvm.reduce_axis((0, DB), name='db')
wb = tvm.reduce_axis((0, WB), name='wb')
matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum(
(tvm.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]) -
tvm.popcount(~weight_packed[j, wb, k] & data_packed[i, db, k])).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]),
tag='bitserial_dense_unipolar')
matmul = tvm.compute(oshape, lambda i, j: tvm.sum(
tvm.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')
if unipolar:
return matmul_unipolar
return matmul
@autotvm.register_topi_compute(bitserial_dense, ['cpu'], 'direct')
def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint32',
out_dtype='int16', unipolar=True):
"""Bitserial dense implementation. TODO: Why are these separate
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim] or
3-D with shape [out_dim, weight_bits, in_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
if len(weight.shape) == 2:
weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
else:
weight_packed = weight
Y, DB, K = get_const_tuple(data_packed.shape)
X, WB, _ = get_const_tuple(weight_packed.shape)
######## Search space
x, y = cfg.axis(X), cfg.axis(Y)
db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K)
ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2)
xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2)
cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi],
policy='candidate', candidate=[
[yo, xo, ko, yi, wb, db, ki, xi],
[yo, xo, yi, ko, wb, db, ki, xi]])
cfg.define_annotate('ann_reduce', [db, wb], policy='try_unroll')
cfg.define_annotate('ann_spatial', [yi, xi], policy='try_unroll_vec')
###### Compute rule
VX = cfg['tile_x'].size[-1]
wvshape = (X//VX, WB, VX, K)
oshape = (Y, X)
k = tvm.reduce_axis((0, K), name='k')
db = tvm.reduce_axis((0, DB), name='db')
wb = tvm.reduce_axis((0, WB), name='wb')
# Tile data and weights
weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k:
weight_packed[xo*VX+vx][wb][k], name='weight_vec')
matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum(
(tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]) -
tvm.popcount(~weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k])).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')
matmul = tvm.compute(oshape, lambda i, j: tvm.sum(
tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')
# binary ops
cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype))
if unipolar:
return matmul_unipolar
return matmul
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Utility functions for bitserial operators"""
import numpy as np
import tvm
from topi.transform import concatenate
from ..util import get_const_int
def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
"""Packs data into format necessary for bitserial computation
pack_axis : int
index of the axis to pack in data
bit_axis : int
index of axis to place bit axis in resulting packed data"""
ishape = data.shape
n = len(ishape)
if pack_type == 'uint8':
data_width = 8
elif pack_type == 'uint16':
data_width = 16
elif pack_type == 'uint32':
data_width = 32
elif pack_type == 'uint64':
data_width = 64
# Data must be in multiples of the data_width
assert get_const_int(ishape[pack_axis]) % data_width == 0, "Not a multiple of word size"
shape_vec = list(ishape)
shape_vec[pack_axis] = (shape_vec[pack_axis] // data_width)
shape_vec.insert(bit_axis, 1)
bitserial_oshape = tuple(shape_vec)
masks = np.array([0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80])
# pack axis shifts if bit axis comes before
if bit_axis <= pack_axis:
pack_axis += 1
def _bitpack(*indices):
packed_data = [tvm.const(0, pack_type)] * bits
for k in range(data_width):
# Translate indices for packed data back to original
idx = [0] * n
j = 0
for i in range(n+1):
if i == bit_axis:
continue
elif i == pack_axis:
idx[j] = indices[i] * data_width + k
else:
idx[j] = indices[i]
j += 1
element = data(*idx)
for b in range(bits):
extracted_bit = ((element & tvm.const(masks[b], "int32")) >> b).astype(pack_type)
packed_data[b] = (packed_data[b] | extracted_bit)
if k < data_width - 1:
packed_data[b] = packed_data[b] << 1
if k == data_width - 1:
return tuple(packed_data)
return tuple(packed_data)
output_tuple = tvm.compute(bitserial_oshape, _bitpack, name=name, tag='bitpack')
if bits > 1:
return concatenate(output_tuple, axis=bit_axis)
return output_tuple
def binary_op_multiplier(pack_dtype):
""""Returns number of bits packed into
pack_dtype: string
pack type for the operator (must be a uint)"""
return int(pack_dtype[4:])
\ No newline at end of file
...@@ -9,6 +9,7 @@ from .nn import * ...@@ -9,6 +9,7 @@ from .nn import *
from .injective import * from .injective import *
from .pooling import schedule_pool, schedule_global_pool from .pooling import schedule_pool, schedule_global_pool
from .bitserial_conv2d import schedule_bitserial_conv2d from .bitserial_conv2d import schedule_bitserial_conv2d
from .bitserial_dense import schedule_bitserial_dense
from .depthwise_conv2d import schedule_depthwise_conv2d_NCHWc from .depthwise_conv2d import schedule_depthwise_conv2d_NCHWc
from .dense import _schedule_dense, _schedule_dense_pack, _schedule_dense_nopack from .dense import _schedule_dense, _schedule_dense_pack, _schedule_dense_nopack
from .batch_matmul import schedule_batch_matmul from .batch_matmul import schedule_batch_matmul
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Schedule for bitserial dense operator."""
from __future__ import absolute_import as _abs
import tvm
from tvm import autotvm
from topi.util import get_const_int
from .. import tag
from .. import generic
@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_dense, ['cpu'], 'direct')
def schedule_bitserial_dense(cfg, outs):
"""Schedule for bitserial_dense.
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial dense operator.
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for bitserial_dense.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(cfg, s, data_vec, weight_vec, output):
s[data_vec].parallel(s[data_vec].op.axis[0])
s[weight_vec].parallel(s[weight_vec].op.axis[0])
y, x = s[output].op.axis
wb, db, k = s[output].op.reduce_axis
yo, yi = cfg["tile_y"].apply(s, output, y)
xo, xi = cfg["tile_x"].apply(s, output, x)
ko, ki = cfg["tile_k"].apply(s, output, k)
cfg["reorder_0"].apply(s, output, [yo, xo, ko, yi, wb, db, ki, xi])
cfg["ann_reduce"].apply(s, output, [db, wb],
axis_lens=[get_const_int(db.dom.extent),
get_const_int(wb.dom.extent)],
max_unroll=8,
cfg=cfg)
cfg["ann_spatial"].apply(s, output, [yi, xi],
axis_lens=[cfg['tile_y'].size[-1],
cfg['tile_x'].size[-1]],
max_unroll=8,
cfg=cfg)
s[output].vectorize(xi)
s[output].parallel(yo)
return s
def traverse(op):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag) or 'elemwise' in op.tag:
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar':
output = op.output(0)
weight_vec = op.input_tensors[0]
data_vec = op.input_tensors[1]
data = data_vec.op.input_tensors[0]
if "QuantizeInput" in data.op.name:
data = data.op.input_tensors[0]
_schedule(cfg, s, data_vec, weight_vec, output)
else:
raise RuntimeError("Unsupported operator: %s" % op.tag)
traverse(outs[0].op)
return s
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for bitserial_dense operator"""
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize
def generate_quantized_np(shape, bits, out_dtype):
min_val = 0
max_val = 1 << bits
return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
def verify_bitserial_dense(batch, in_dim, out_dim, activation_bits, weight_bits, unipolar):
input_dtype = 'uint32'
out_dtype = 'int16'
with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_dim), dtype=input_dtype, name='A')
B = tvm.placeholder((out_dim, in_dim), dtype=input_dtype, name='B')
C = topi.nn.bitserial_dense(A, B, activation_bits, weight_bits, out_dtype=out_dtype,
unipolar=unipolar)
s = topi.generic.schedule_bitserial_dense([C])
a_shape = get_const_tuple(A.shape)
b_shape = get_const_tuple(B.shape)
@memoize("topi.tests.test_topi_bitseral_dense")
def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype)
b_np = generate_quantized_np(get_const_tuple(b_shape), weight_bits, input_dtype)
if unipolar:
b_ = np.copy(b_np).astype(out_dtype)
for x in np.nditer(b_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1
c_np = np.dot(a_np, b_.T)
else:
c_np = np.dot(a_np, b_np.T)
return a_np, b_np, c_np
a_np, b_np, c_np = get_ref_data()
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
func = tvm.build(s, [A, B, C], "llvm")
func(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
def test_bitserial_dense():
verify_bitserial_dense(1, 1024, 1000, 1, 1, True)
verify_bitserial_dense(1, 1024, 1000, 2, 1, True)
verify_bitserial_dense(1, 1024, 1000, 1, 1, False)
verify_bitserial_dense(1, 1024, 1000, 2, 1, False)
if __name__ == "__main__":
test_bitserial_dense()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment