Commit b405f68b by Meghan Cowan Committed by Tianqi Chen

[TOPI] Bitserial dense operators for CPU (#3051)

parent d8f547c7
......@@ -87,6 +87,7 @@ class TaskExtractEnv:
topi.nn.dense: "topi_nn_dense",
topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
topi.nn.bitserial_dense: "topi_nn_bitserial_dense",
topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
}
......@@ -101,6 +102,7 @@ class TaskExtractEnv:
topi.nn.dense: [topi.generic.schedule_dense],
topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense],
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
}
......@@ -200,18 +202,25 @@ class TaskExtractEnv:
args = deserialize_args(args)
C = topi.nn.bitserial_conv2d_nhwc(*args, **kwargs)
s = topi.generic.nn.schedule_bitserial_conv2d_nhwc([C])
data = args[0]
kernel = args[1]
return s, [data, kernel, C]
A, W = args[:2]
return s, [A, W, C]
@register("topi_nn_bitserial_conv2d_nchw")
def _topi_bitserial_conv2d_nchw(*args, **kwargs):
args = deserialize_args(args)
C = topi.nn.bitserial_conv2d_nchw(*args, **kwargs)
s = topi.generic.nn.schedule_bitserial_conv2d_nchw([C])
data = args[0]
kernel = args[1]
return s, [data, kernel, C]
A, W = args[:2]
return s, [A, W, C]
@register("topi_nn_bitserial_dense")
def _topi_nn_bitserial_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.bitserial_dense(*args, **kwargs)
s = topi.generic.schedule_bitserial_dense([C])
return s, [A, W, C]
@register("topi_nn_deformable_conv2d_nchw")
def _topi_nn_deformable_conv2d_nchw(*args, **kwargs):
......
......@@ -4,4 +4,5 @@ from . import conv2d
from . import depthwise_conv2d
from . import conv2d_transpose
from . import bitserial_conv2d
from . import bitserial_dense
from . import injective
......@@ -21,7 +21,8 @@ import tvm
from tvm import autotvm
from .. import tag
from ..nn.pad import pad
from ..nn.bitserial_conv2d import bitpack, bitserial_conv2d_nhwc
from ..nn.bitserial_conv2d import bitserial_conv2d_nhwc
from ..nn.bitserial_util import bitpack, binary_op_multiplier
from ..nn.util import get_pad_tuple
from ..util import get_const_int, get_const_tuple
from .. import generic
......@@ -93,7 +94,8 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
policy='candidate', candidate=[
[n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i],
[n, oh, ow, co, vh, vw, kw, kh, ci_o, kb, ib, vc, ci_i],])
cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops
# binary ops
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
# ====================
VC = cfg["tile_co"].size[-1]
......@@ -310,7 +312,6 @@ def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
s[conv_out].compute_at(s[last], co)
s[last].parallel(oh)
s = s.normalize()
return s
@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nhwc, 'arm_cpu', 'direct')
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, invalid-name, too-many-locals, too-many-arguments
"""Schedule for bitserial dense operator."""
from __future__ import absolute_import as _abs
import tvm
from tvm import autotvm
from topi.util import get_const_tuple
from .. import tag
from .. import generic
from .bitserial_conv2d import _intrin_popcount
from ..nn.pad import pad
from ..nn.bitserial_dense import bitserial_dense
from ..nn.bitserial_util import bitpack, binary_op_multiplier
@autotvm.register_topi_compute(bitserial_dense, ['arm_cpu'], 'direct')
def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_dtype,
unipolar):
"""The default implementation of bitserial dense in topi.
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
if len(weight.shape) == 2:
weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
else:
weight_packed = weight
batch, DB, in_dim = get_const_tuple(data_packed.shape)
out_dim, WB, in_dim = get_const_tuple(weight_packed.shape)
# Pad Inputs so that microkernel can be used
# out_dim and in_dim need to be multiples of 8
if out_dim % 8 != 0:
out_dim_pad = out_dim % 8
data_packed = pad(data_packed, [0, 0, 0], [out_dim_pad, 0, 0], name='PaddedInput')
out_dim += out_dim_pad
######## Search space
x, y = cfg.axis(batch), cfg.axis(out_dim)
db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(in_dim)
ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2,
filter=lambda xx: xx.size[-1] == 8 or xx.size[-1] == 16)
xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2,
filter=lambda xx: xx.size[-1] == 8)
cfg.define_reorder('reorder_0', [yo, xo, ko, xi, wb, db, yi, ki],
policy='candidate', candidate=[
[yo, xo, ko, xi, wb, db, yi, ki],
[yo, xo, xi, ko, wb, db, yi, ki],
[yo, xo, ko, xi, wb, db, yi, ki]])
###### Compute rule
VY = cfg['tile_y'].size[-1]
VK = cfg['tile_k'].size[-1]
wvshape = (out_dim//VY, in_dim//VK, WB, VY, VK)
oshape = (batch, out_dim)
k = tvm.reduce_axis((0, in_dim), name='k')
db = tvm.reduce_axis((0, DB), name='db')
wb = tvm.reduce_axis((0, WB), name='wb')
# Tile data and weights
weight_vec = tvm.compute(wvshape, lambda yo, ko, wb, vy, vk:
weight_packed[yo*VY+vy][wb][ko*VK+vk], name='weight_vec')
matmul_unipolar = tvm.compute(oshape, lambda x, y: tvm.sum(
(tvm.popcount(weight_vec[y//VY, k//VK, wb, y%VY, k%VK].astype(out_dtype) &
data_packed[x, db, k].astype(out_dtype)) -
tvm.popcount(~weight_vec[y//VY, k//VK, wb, y%VY, k%VK].astype(out_dtype) &
data_packed[x, db, k].astype(out_dtype)))
<< (wb+db).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')
matmul = tvm.compute(oshape, lambda x, y: tvm.sum(
tvm.popcount(weight_vec[y//VY, k//VK, wb, y%VY, k%VK].astype(out_dtype) &
data_packed[x, db, k].astype(out_dtype))
<< (wb+db).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')
cfg.add_flop(batch * out_dim * in_dim * binary_op_multiplier(pack_dtype))
if unipolar:
return matmul_unipolar
return matmul
@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_dense, ['arm_cpu'], 'direct')
def schedule_bitserial_dense(cfg, outs):
"""Schedule for binary_dense.
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial dense operator.
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for bitserial_dense.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(cfg, s, data_vec, weight_vec, output, unipolar):
z, k, _, y, x = s[weight_vec].op.axis
s[weight_vec].parallel(z)
s[weight_vec].vectorize(x)
x, y = s[output].op.axis
wb, db, k = s[output].op.reduce_axis
_, DB, _ = get_const_tuple(data_vec.shape)
_, _, WB, _, _ = get_const_tuple(weight_vec.shape)
yo, yi = cfg["tile_y"].apply(s, output, y)
xo, xi = cfg["tile_x"].apply(s, output, x)
ko, ki = cfg["tile_k"].apply(s, output, k)
cfg["reorder_0"].apply(s, output, [yo, xo, ko, xi, wb, db, yi, ki])
fused = s[output].fuse(xo, yo)
s[output].parallel(fused)
nfactor = cfg['tile_y'].size[-1]
kfactor = cfg['tile_k'].size[-1]
if nfactor % 8 == 0:
pc = _intrin_popcount(nfactor, kfactor, WB, DB, unipolar)
s[output].tensorize(wb, pc)
return s
def traverse(op):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag) or 'elemwise' in op.tag:
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar':
output = op.output(0)
weight_vec = op.input_tensors[0]
data_vec = op.input_tensors[1]
data = data_vec.op.input_tensors[0]
if "QuantizeInput" in data.op.name:
data = data.op.input_tensors[0]
unipolar = (output.op.tag == 'bitserial_dense_unipolar')
_schedule(cfg, s, data_vec, weight_vec, output, unipolar)
else:
raise RuntimeError("Unsupported operator: %s" % op.tag)
traverse(outs[0].op)
return s
......@@ -312,6 +312,22 @@ def schedule_bitserial_conv2d_nhwc(outs):
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_bitserial_dense(outs):
"""Schedule for bitserial_dense
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial_dense
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.override_native_generic_func("schedule_reduce")
def schedule_reduce(outs):
"""Schedule for reduction
......
......@@ -17,5 +17,6 @@ from .bnn import *
from .upsampling import *
from .local_response_norm import *
from .bitserial_conv2d import *
from .bitserial_dense import *
from .l2_normalize import *
from .batch_matmul import *
......@@ -14,16 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, too-many-locals, too-many-arguments, unused-argument
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Bitserial Conv2D operators"""
from __future__ import absolute_import as _abs
import numpy as np
import tvm
from tvm import autotvm
from topi.transform import concatenate
from .pad import pad
from .util import get_pad_tuple
from ..util import get_const_tuple, get_const_int
from .bitserial_util import bitpack, binary_op_multiplier
from ..util import get_const_tuple
@tvm.target.generic_func
def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight_bits,
......@@ -68,7 +67,7 @@ def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight
Input_q = bitpack(data, activation_bits, pack_axis=1, bit_axis=2, pack_type=pack_dtype)
Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype)
batch, in_channel, activation_bits, in_height, in_width = Input_q.shape
num_filter, channel, kernel_h, kernel_w, weight_bits = Filter_q.shape
num_filter, _, kernel_h, kernel_w, weight_bits = Filter_q.shape
if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
......@@ -259,10 +258,11 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
filter=lambda x: max(x.size[1:]) <= 16)
cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
re_axes = cfg.define_reorder("reorder_0",
[n, co, oh, ow, vc, vh, vw, kh, kw, kb, ib, ci],
policy='interval_all', interval=(6, 11))
cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops
cfg.define_reorder("reorder_0",
[n, co, oh, ow, vc, vh, vw, kh, kw, kb, ib, ci],
policy='interval_all', interval=(6, 11))
# binary ops
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
# ====================
VC = cfg["tile_co"].size[-1]
......@@ -275,7 +275,7 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
oshape = (1, CO, OH, OW)
if (TPAD != 0 and RPAD != 0):
data_pad = pad(data_q, (0, 0, 0, TPAD, LPAD), (0, 0, 0, DPAD, RPAD), name="data_pad")
data_pad = pad(data_q, pad_before, pad_after, name="data_pad")
else:
data_pad = data_q
......@@ -361,10 +361,11 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
re_axes = cfg.define_reorder("reorder_0",
[n, oh, ow, co, vh, vw, kh, kw, kb, ib, vc, ci],
policy='interval_all', interval=(3, 7))
cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops
cfg.define_reorder("reorder_0",
[n, oh, ow, co, vh, vw, kh, kw, kb, ib, vc, ci],
policy='interval_all', interval=(3, 7))
# binary ops
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
# ====================
VC = cfg["tile_co"].size[-1]
......@@ -377,7 +378,7 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
oshape = (1, OH, OW, CO)
if (DPAD != 0 and RPAD != 0):
data_pad = pad(data_q, (0, TPAD, LPAD, 0, 0), (0, DPAD, RPAD, 0, 0), name="data_pad")
data_pad = pad(data_q, pad_before, pad_after, name="data_pad")
else:
data_pad = data_q
......@@ -413,66 +414,3 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
return tvm.compute(oshape, lambda n, h, w, co:
conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC],
name='output_unpack', tag='spatial_bitserial_conv_nhwc')
def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
"""Packs data into format necessary for bitserial computation
pack_axis : int
index of the axis to pack in data
bit_axis : int
index of axis to place bit axis in resulting packed data"""
ishape = data.shape
n = len(ishape)
if pack_type == 'uint8':
data_width = 8
elif pack_type == 'uint16':
data_width = 16
elif pack_type == 'uint32':
data_width = 32
elif pack_type == 'uint64':
data_width = 64
# Data must be in multiples of the data_width
assert get_const_int(ishape[pack_axis]) % data_width == 0, "Not a multiple of word size"
shape_vec = list(ishape)
shape_vec[pack_axis] = (shape_vec[pack_axis] // data_width)
shape_vec.insert(bit_axis, 1)
bitserial_oshape = tuple(shape_vec)
masks = np.array([0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80])
# pack axis shifts if bit axis comes before
if bit_axis <= pack_axis:
pack_axis += 1
def _bitpack(*indices):
packed_data = [tvm.const(0, pack_type)] * bits
for k in range(data_width):
# Translate indices for packed data back to original
idx = [0] * n
j = 0
for i in range(n+1):
if i == bit_axis:
continue
elif i == pack_axis:
idx[j] = indices[i] * data_width + k
else:
idx[j] = indices[i]
j += 1
element = data(*idx)
for b in range(bits):
extracted_bit = ((element & tvm.const(masks[b], "int32")) >> b).astype(pack_type)
packed_data[b] = (packed_data[b] | extracted_bit)
if k < data_width - 1:
packed_data[b] = packed_data[b] << 1
if k == data_width - 1:
return tuple(packed_data)
return tuple(packed_data)
output_tuple = tvm.compute(bitserial_oshape, _bitpack, name=name, tag='bitpack')
if bits > 1:
return concatenate(output_tuple, axis=bit_axis)
return output_tuple
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Bitserial Dense operator."""
from __future__ import absolute_import
import tvm
from tvm import autotvm
from topi.util import get_const_tuple
from .bitserial_util import bitpack, binary_op_multiplier
@tvm.target.generic_func
def bitserial_dense(data, weight, data_bits, weight_bits, pack_dtype='uint32',
out_dtype='int16', unipolar=True):
"""The default implementation of bitserial dense in topi.
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim] or
3-D with shape [out_dim, weight_bits, in_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
if len(weight.shape) == 2:
weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
else:
weight_packed = weight
Y, DB, K = get_const_tuple(data_packed.shape)
X, WB, _ = get_const_tuple(weight_packed.shape)
oshape = (Y, X)
k = tvm.reduce_axis((0, K), name='k')
db = tvm.reduce_axis((0, DB), name='db')
wb = tvm.reduce_axis((0, WB), name='wb')
matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum(
(tvm.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]) -
tvm.popcount(~weight_packed[j, wb, k] & data_packed[i, db, k])).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]),
tag='bitserial_dense_unipolar')
matmul = tvm.compute(oshape, lambda i, j: tvm.sum(
tvm.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')
if unipolar:
return matmul_unipolar
return matmul
@autotvm.register_topi_compute(bitserial_dense, ['cpu'], 'direct')
def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint32',
out_dtype='int16', unipolar=True):
"""Bitserial dense implementation. TODO: Why are these separate
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim] or
3-D with shape [out_dim, weight_bits, in_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
if len(weight.shape) == 2:
weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
else:
weight_packed = weight
Y, DB, K = get_const_tuple(data_packed.shape)
X, WB, _ = get_const_tuple(weight_packed.shape)
######## Search space
x, y = cfg.axis(X), cfg.axis(Y)
db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K)
ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2)
xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2)
cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi],
policy='candidate', candidate=[
[yo, xo, ko, yi, wb, db, ki, xi],
[yo, xo, yi, ko, wb, db, ki, xi]])
cfg.define_annotate('ann_reduce', [db, wb], policy='try_unroll')
cfg.define_annotate('ann_spatial', [yi, xi], policy='try_unroll_vec')
###### Compute rule
VX = cfg['tile_x'].size[-1]
wvshape = (X//VX, WB, VX, K)
oshape = (Y, X)
k = tvm.reduce_axis((0, K), name='k')
db = tvm.reduce_axis((0, DB), name='db')
wb = tvm.reduce_axis((0, WB), name='wb')
# Tile data and weights
weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k:
weight_packed[xo*VX+vx][wb][k], name='weight_vec')
matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum(
(tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]) -
tvm.popcount(~weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k])).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')
matmul = tvm.compute(oshape, lambda i, j: tvm.sum(
tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')
# binary ops
cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype))
if unipolar:
return matmul_unipolar
return matmul
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Utility functions for bitserial operators"""
import numpy as np
import tvm
from topi.transform import concatenate
from ..util import get_const_int
def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
"""Packs data into format necessary for bitserial computation
pack_axis : int
index of the axis to pack in data
bit_axis : int
index of axis to place bit axis in resulting packed data"""
ishape = data.shape
n = len(ishape)
if pack_type == 'uint8':
data_width = 8
elif pack_type == 'uint16':
data_width = 16
elif pack_type == 'uint32':
data_width = 32
elif pack_type == 'uint64':
data_width = 64
# Data must be in multiples of the data_width
assert get_const_int(ishape[pack_axis]) % data_width == 0, "Not a multiple of word size"
shape_vec = list(ishape)
shape_vec[pack_axis] = (shape_vec[pack_axis] // data_width)
shape_vec.insert(bit_axis, 1)
bitserial_oshape = tuple(shape_vec)
masks = np.array([0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80])
# pack axis shifts if bit axis comes before
if bit_axis <= pack_axis:
pack_axis += 1
def _bitpack(*indices):
packed_data = [tvm.const(0, pack_type)] * bits
for k in range(data_width):
# Translate indices for packed data back to original
idx = [0] * n
j = 0
for i in range(n+1):
if i == bit_axis:
continue
elif i == pack_axis:
idx[j] = indices[i] * data_width + k
else:
idx[j] = indices[i]
j += 1
element = data(*idx)
for b in range(bits):
extracted_bit = ((element & tvm.const(masks[b], "int32")) >> b).astype(pack_type)
packed_data[b] = (packed_data[b] | extracted_bit)
if k < data_width - 1:
packed_data[b] = packed_data[b] << 1
if k == data_width - 1:
return tuple(packed_data)
return tuple(packed_data)
output_tuple = tvm.compute(bitserial_oshape, _bitpack, name=name, tag='bitpack')
if bits > 1:
return concatenate(output_tuple, axis=bit_axis)
return output_tuple
def binary_op_multiplier(pack_dtype):
""""Returns number of bits packed into
pack_dtype: string
pack type for the operator (must be a uint)"""
return int(pack_dtype[4:])
\ No newline at end of file
......@@ -9,6 +9,7 @@ from .nn import *
from .injective import *
from .pooling import schedule_pool, schedule_global_pool
from .bitserial_conv2d import schedule_bitserial_conv2d
from .bitserial_dense import schedule_bitserial_dense
from .depthwise_conv2d import schedule_depthwise_conv2d_NCHWc
from .dense import _schedule_dense, _schedule_dense_pack, _schedule_dense_nopack
from .batch_matmul import schedule_batch_matmul
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Schedule for bitserial dense operator."""
from __future__ import absolute_import as _abs
import tvm
from tvm import autotvm
from topi.util import get_const_int
from .. import tag
from .. import generic
@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_dense, ['cpu'], 'direct')
def schedule_bitserial_dense(cfg, outs):
"""Schedule for bitserial_dense.
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial dense operator.
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for bitserial_dense.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(cfg, s, data_vec, weight_vec, output):
s[data_vec].parallel(s[data_vec].op.axis[0])
s[weight_vec].parallel(s[weight_vec].op.axis[0])
y, x = s[output].op.axis
wb, db, k = s[output].op.reduce_axis
yo, yi = cfg["tile_y"].apply(s, output, y)
xo, xi = cfg["tile_x"].apply(s, output, x)
ko, ki = cfg["tile_k"].apply(s, output, k)
cfg["reorder_0"].apply(s, output, [yo, xo, ko, yi, wb, db, ki, xi])
cfg["ann_reduce"].apply(s, output, [db, wb],
axis_lens=[get_const_int(db.dom.extent),
get_const_int(wb.dom.extent)],
max_unroll=8,
cfg=cfg)
cfg["ann_spatial"].apply(s, output, [yi, xi],
axis_lens=[cfg['tile_y'].size[-1],
cfg['tile_x'].size[-1]],
max_unroll=8,
cfg=cfg)
s[output].vectorize(xi)
s[output].parallel(yo)
return s
def traverse(op):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag) or 'elemwise' in op.tag:
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar':
output = op.output(0)
weight_vec = op.input_tensors[0]
data_vec = op.input_tensors[1]
data = data_vec.op.input_tensors[0]
if "QuantizeInput" in data.op.name:
data = data.op.input_tensors[0]
_schedule(cfg, s, data_vec, weight_vec, output)
else:
raise RuntimeError("Unsupported operator: %s" % op.tag)
traverse(outs[0].op)
return s
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for bitserial_dense operator"""
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize
def generate_quantized_np(shape, bits, out_dtype):
min_val = 0
max_val = 1 << bits
return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
def verify_bitserial_dense(batch, in_dim, out_dim, activation_bits, weight_bits, unipolar):
input_dtype = 'uint32'
out_dtype = 'int16'
with tvm.target.create('llvm'):
A = tvm.placeholder((batch, in_dim), dtype=input_dtype, name='A')
B = tvm.placeholder((out_dim, in_dim), dtype=input_dtype, name='B')
C = topi.nn.bitserial_dense(A, B, activation_bits, weight_bits, out_dtype=out_dtype,
unipolar=unipolar)
s = topi.generic.schedule_bitserial_dense([C])
a_shape = get_const_tuple(A.shape)
b_shape = get_const_tuple(B.shape)
@memoize("topi.tests.test_topi_bitseral_dense")
def get_ref_data():
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype)
b_np = generate_quantized_np(get_const_tuple(b_shape), weight_bits, input_dtype)
if unipolar:
b_ = np.copy(b_np).astype(out_dtype)
for x in np.nditer(b_, op_flags=['readwrite']):
x[...] = 1 if x == 1 else -1
c_np = np.dot(a_np, b_.T)
else:
c_np = np.dot(a_np, b_np.T)
return a_np, b_np, c_np
a_np, b_np, c_np = get_ref_data()
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
func = tvm.build(s, [A, B, C], "llvm")
func(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
def test_bitserial_dense():
verify_bitserial_dense(1, 1024, 1000, 1, 1, True)
verify_bitserial_dense(1, 1024, 1000, 2, 1, True)
verify_bitserial_dense(1, 1024, 1000, 1, 1, False)
verify_bitserial_dense(1, 1024, 1000, 2, 1, False)
if __name__ == "__main__":
test_bitserial_dense()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment