Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
17351875
Commit
17351875
authored
Apr 03, 2019
by
Meghan Cowan
Committed by
Tianqi Chen
Apr 03, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI] bitserial_conv2d move to autotvm template and updates (#2819)
parent
cefe07e2
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
609 additions
and
572 deletions
+609
-572
python/tvm/autotvm/task/task.py
+1
-1
python/tvm/autotvm/task/topi_integration.py
+22
-0
topi/python/topi/arm_cpu/bitserial_conv2d.py
+193
-226
topi/python/topi/nn/bitserial_conv2d.py
+259
-138
topi/python/topi/x86/bitserial_conv2d.py
+78
-183
topi/tests/python/test_topi_bitserial_conv2d.py
+18
-18
topi/tests/python/test_topi_bitserial_conv2d_rasp.py
+38
-6
No files found.
python/tvm/autotvm/task/task.py
View file @
17351875
...
...
@@ -205,7 +205,7 @@ def args_to_workload(x, topi_compute_func=None):
workload
=
tuple
([
args_to_workload
(
a
)
for
a
in
x
])
elif
isinstance
(
x
,
(
str
,
int
,
float
,
np
.
int
,
np
.
float
)):
workload
=
x
elif
isinstance
(
x
,
(
expr
.
StringImm
,
expr
.
IntImm
,
expr
.
FloatImm
)):
elif
isinstance
(
x
,
(
expr
.
StringImm
,
expr
.
UIntImm
,
expr
.
IntImm
,
expr
.
FloatImm
)):
workload
=
x
.
value
elif
x
is
None
:
workload
=
0
...
...
python/tvm/autotvm/task/topi_integration.py
View file @
17351875
...
...
@@ -68,6 +68,8 @@ class TaskExtractEnv:
topi
.
nn
.
group_conv2d_nchw
:
"topi_nn_group_conv2d_nchw"
,
topi
.
nn
.
conv2d_transpose_nchw
:
"topi_nn_conv2d_transpose_nchw"
,
topi
.
nn
.
dense
:
"topi_nn_dense"
,
topi
.
nn
.
bitserial_conv2d_nchw
:
"topi_nn_bitserial_conv2d_nchw"
,
topi
.
nn
.
bitserial_conv2d_nhwc
:
"topi_nn_bitserial_conv2d_nhwc"
,
topi
.
nn
.
deformable_conv2d_nchw
:
"topi_nn_deformable_conv2d_nchw"
,
}
...
...
@@ -79,6 +81,8 @@ class TaskExtractEnv:
topi
.
nn
.
group_conv2d_nchw
:
[
topi
.
generic
.
schedule_group_conv2d_nchw
],
topi
.
nn
.
conv2d_transpose_nchw
:
[
topi
.
generic
.
schedule_conv2d_transpose_nchw
],
topi
.
nn
.
dense
:
[
topi
.
generic
.
schedule_dense
],
topi
.
nn
.
bitserial_conv2d_nchw
:
[
topi
.
generic
.
schedule_bitserial_conv2d_nchw
],
topi
.
nn
.
bitserial_conv2d_nhwc
:
[
topi
.
generic
.
schedule_bitserial_conv2d_nhwc
],
topi
.
nn
.
deformable_conv2d_nchw
:
[
topi
.
generic
.
schedule_deformable_conv2d_nchw
],
}
...
...
@@ -174,6 +178,24 @@ class TaskExtractEnv:
return
s
,
[
data
,
weight
,
bias
,
C
]
return
s
,
[
data
,
weight
,
C
]
@register
(
"topi_nn_bitserial_conv2d_nhwc"
)
def
_topi_bitserial_conv2d_nhwc
(
*
args
,
**
kwargs
):
args
=
deserialize_args
(
args
)
C
=
topi
.
nn
.
bitserial_conv2d_nhwc
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
nn
.
schedule_bitserial_conv2d_nhwc
([
C
])
data
=
args
[
0
]
kernel
=
args
[
1
]
return
s
,
[
data
,
kernel
,
C
]
@register
(
"topi_nn_bitserial_conv2d_nchw"
)
def
_topi_bitserial_conv2d_nchw
(
*
args
,
**
kwargs
):
args
=
deserialize_args
(
args
)
C
=
topi
.
nn
.
bitserial_conv2d_nchw
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
nn
.
schedule_bitserial_conv2d_nchw
([
C
])
data
=
args
[
0
]
kernel
=
args
[
1
]
return
s
,
[
data
,
kernel
,
C
]
@register
(
"topi_nn_deformable_conv2d_nchw"
)
def
_topi_nn_deformable_conv2d_nchw
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
...
...
topi/python/topi/arm_cpu/bitserial_conv2d.py
View file @
17351875
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Bitserial conv2d schedule on
raspberry pi
"""
"""Bitserial conv2d schedule on
arm cpu
"""
from
__future__
import
absolute_import
as
_abs
from
collections
import
namedtuple
import
tvm
from
tvm
import
autotvm
from
..
import
tag
from
..nn.pad
import
pad
from
..nn.bitserial_conv2d
import
bitserial_conv2d
,
_get_schedule
,
_get_workload
,
bitpack
from
..nn.bitserial_conv2d
import
SpatialPackNCHW
,
_WORKLOADS
,
spatial_pack_nchw
from
..nn.bitserial_conv2d
import
bitpack
,
bitserial_conv2d_nhwc
from
..nn.util
import
get_pad_tuple
from
..util
import
get_const_int
from
..util
import
get_const_int
,
get_const_tuple
from
..
import
generic
RaspSpatialPack
=
namedtuple
(
'SpatialPack'
,
[
'vh'
,
'vw'
,
'vc'
,
'ba'
,
'bc'
,
'split_ci'
,
'kfactor'
])
_QUANTIZED_SCHEDULES_NHWC
=
[
RaspSpatialPack
(
2
,
2
,
8
,
1
,
1
,
False
,
8
),
RaspSpatialPack
(
1
,
4
,
8
,
4
,
1
,
False
,
8
),
RaspSpatialPack
(
1
,
4
,
8
,
1
,
16
,
False
,
8
),
RaspSpatialPack
(
1
,
4
,
8
,
4
,
8
,
False
,
8
),
RaspSpatialPack
(
1
,
7
,
8
,
3
,
8
,
False
,
16
),
RaspSpatialPack
(
1
,
2
,
8
,
1
,
8
,
False
,
16
),
RaspSpatialPack
(
2
,
1
,
8
,
1
,
4
,
False
,
16
),
RaspSpatialPack
(
1
,
7
,
8
,
1
,
1
,
True
,
16
),
RaspSpatialPack
(
1
,
1
,
8
,
1
,
16
,
True
,
16
),
RaspSpatialPack
(
1
,
1
,
8
,
1
,
8
,
True
,
16
),
RaspSpatialPack
(
1
,
1
,
8
,
1
,
16
,
True
,
16
),
]
_QUANTIZED_SCHEDULES_NCHW
=
[
# resnet
SpatialPackNCHW
(
2
,
2
,
8
,
1
,
1
),
SpatialPackNCHW
(
1
,
4
,
8
,
4
,
1
),
SpatialPackNCHW
(
1
,
4
,
8
,
1
,
16
),
SpatialPackNCHW
(
1
,
4
,
8
,
4
,
8
),
SpatialPackNCHW
(
1
,
7
,
8
,
3
,
8
),
SpatialPackNCHW
(
1
,
2
,
8
,
1
,
8
),
SpatialPackNCHW
(
2
,
1
,
8
,
1
,
4
),
SpatialPackNCHW
(
1
,
7
,
8
,
1
,
1
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
16
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
8
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
16
),
]
@_get_schedule.register
(
"arm_cpu"
)
def
_get_schedule_bitserial_conv2d
(
wkl
,
layout
):
if
wkl
not
in
_WORKLOADS
:
raise
ValueError
(
"no schedule for such workload: {}"
.
format
(
wkl
))
idx
=
_WORKLOADS
.
index
(
wkl
)
if
layout
==
"NCHW"
:
sch
=
_QUANTIZED_SCHEDULES_NCHW
[
idx
]
elif
layout
==
"NHWC"
:
sch
=
_QUANTIZED_SCHEDULES_NHWC
[
idx
]
return
sch
@bitserial_conv2d.register
(
"arm_cpu"
)
def
_declaration_bitserial_conv2d
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
layout
=
'NCHW'
,
pack_dtype
=
None
,
out_dtype
=
None
,
dorefa
=
False
):
if
out_dtype
is
None
:
out_dtype
=
data
.
dtype
assert
data
.
shape
[
0
]
.
value
==
1
,
"only support batch size=1 convolution on rasp"
assert
layout
in
(
"NCHW"
,
"NHWC"
),
"only support layouts NCHW and NHWC"
if
dorefa
:
assert
layout
==
"NCHW"
,
"Cannot support dorea with NHWC layout yet"
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
layout
)
sch
=
_get_schedule
(
wkl
,
layout
)
if
layout
==
"NCHW"
:
return
spatial_pack_nchw
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
pack_dtype
=
pack_dtype
,
out_dtype
=
out_dtype
,
dorefa
=
dorefa
)
return
_spatial_pack_nhwc
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
out_dtype
)
def
_kernel_vec_spatial_pack_nhwc
(
kernel
,
kernel_bits
,
VC
):
kernel_q
=
bitpack
(
kernel
,
kernel_bits
,
pack_axis
=
2
,
bit_axis
=
2
,
pack_type
=
'uint8'
)
def
_kernel_vec_spatial_pack_nhwc
(
kernel
,
kernel_bits
,
VC
,
use_bitpack
=
True
):
if
use_bitpack
:
kernel_q
=
bitpack
(
kernel
,
kernel_bits
,
pack_axis
=
2
,
bit_axis
=
2
,
pack_type
=
'uint8'
)
else
:
kernel_q
=
kernel
KH
,
KW
,
KB
,
CI
,
CO
=
kernel_q
.
shape
kvshape
=
(
CO
//
VC
,
KH
,
KW
,
KB
,
VC
,
CI
)
return
tvm
.
compute
(
kvshape
,
lambda
co
,
dh
,
dw
,
b
,
vc
,
ci
:
\
kernel_q
[
dh
][
dw
][
b
][
ci
][
co
*
VC
+
vc
],
name
=
'kernel_vec'
)
def
_spatial_pack_nhwc
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
out_dtype
):
@autotvm.register_topi_compute
(
bitserial_conv2d_nhwc
,
'arm_cpu'
,
'direct'
)
def
spatial_pack_nhwc
(
cfg
,
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
pack_dtype
,
out_dtype
,
unipolar
):
""" Compute convolution with pack on spatial axes. """
assert
data
.
shape
[
0
]
.
value
==
1
,
"spatial pack convolution only support batch size=1"
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
"NHWC"
)
sch
=
_get_schedule
(
wkl
,
"NHWC"
)
VH
=
sch
.
vh
VW
=
sch
.
vw
VC
=
sch
.
vc
assert
pack_dtype
==
'uint8'
,
"only support packing into uint8 bits"
assert
out_dtype
==
'int16'
,
"only support output type of int16"
data_q
=
bitpack
(
data
,
activation_bits
,
pack_axis
=
3
,
bit_axis
=
3
,
pack_type
=
'uint8'
)
kernel_vec
=
_kernel_vec_spatial_pack_nhwc
(
kernel
,
weight_bits
,
VC
)
N
,
H
,
W
,
IB
,
CI
=
data_q
.
shape
OCO
,
KH
,
KW
,
KB
,
VC
,
_
=
kernel_vec
.
shape
N
,
H
,
W
,
CI
=
get_const_tuple
(
data
.
shape
)
if
len
(
kernel
.
shape
)
==
4
:
KH
,
KW
,
_
,
CO
=
get_const_tuple
(
kernel
.
shape
)
CI_packed
=
CI
//
8
else
:
KH
,
KW
,
KB
,
CI_packed
,
CO
=
get_const_tuple
(
kernel
.
shape
)
CO
=
OCO
*
VC
HPAD
,
WPAD
,
_
,
_
=
get_pad_tuple
(
padding
,
kernel
)
if
isinstance
(
padding
,
int
)
or
(
isinstance
(
padding
,
(
tuple
,
list
))
and
len
(
padding
)
==
2
):
TPAD
,
LPAD
,
DPAD
,
RPAD
=
get_pad_tuple
(
padding
,
kernel
)
else
:
TPAD
,
LPAD
,
DPAD
,
RPAD
=
padding
if
isinstance
(
stride
,
(
tuple
,
list
)):
HSTR
,
WSTR
=
stride
...
...
@@ -102,75 +46,151 @@ def _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bi
HSTR
,
WSTR
=
stride
,
stride
HCAT
,
WCAT
=
KH
-
1
,
KW
-
1
PAD_H
=
H
+
2
*
HPAD
PAD_W
=
W
+
2
*
WPAD
OH
=
(
H
+
2
*
HPAD
-
KH
)
//
HSTR
+
1
OW
=
(
W
+
2
*
WPAD
-
KW
)
//
WSTR
+
1
PAD_H
=
H
+
(
TPAD
+
DPAD
)
PAD_W
=
W
+
(
LPAD
+
RPAD
)
OH
=
(
PAD_H
-
KH
)
//
HSTR
+
1
OW
=
(
PAD_W
-
KW
)
//
WSTR
+
1
oshape
=
(
1
,
OH
,
OW
,
CO
)
# Pad input channels of weights and data when it is not a multiple of 8
if
CI_packed
%
8
!=
0
:
CI_PAD
=
CI_packed
%
8
CI_packed
+=
CI_PAD
else
:
CI_PAD
=
0
# ==================== define configuration space ====================
n
,
oh
,
ow
,
co
=
cfg
.
axis
(
N
),
cfg
.
axis
(
OH
),
cfg
.
axis
(
OW
),
cfg
.
axis
(
CO
)
ci
,
kh
,
kw
=
cfg
.
reduce_axis
(
CI_packed
),
cfg
.
reduce_axis
(
KH
),
cfg
.
reduce_axis
(
KW
)
ib
,
kb
=
cfg
.
reduce_axis
(
activation_bits
),
cfg
.
reduce_axis
(
weight_bits
)
co
,
vc
=
cfg
.
define_split
(
'tile_co'
,
co
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
x
.
size
[
-
1
]
==
8
)
oh
,
vh
=
cfg
.
define_split
(
'tile_oh'
,
oh
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
x
.
size
[
-
1
]
>=
2
)
ow
,
vw
=
cfg
.
define_split
(
'tile_ow'
,
ow
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
x
.
size
[
-
1
]
>=
2
)
ci_o
,
ci_i
=
cfg
.
define_split
(
"tile_ci"
,
ci
,
num_outputs
=
2
,
filter
=
lambda
x
:
x
.
size
[
-
1
]
==
8
or
x
.
size
[
-
1
]
==
16
)
re_axes
=
cfg
.
define_reorder
(
"reorder_0"
,
[
n
,
oh
,
ow
,
co
,
vh
,
vw
,
kh
,
kw
,
ci_o
,
kb
,
ib
,
vc
,
ci_i
],
policy
=
'candidate'
,
candidate
=
[
[
n
,
oh
,
ow
,
co
,
vh
,
vw
,
kh
,
kw
,
ci_o
,
kb
,
ib
,
vc
,
ci_i
],
[
n
,
oh
,
ow
,
co
,
vh
,
vw
,
kw
,
kh
,
ci_o
,
kb
,
ib
,
vc
,
ci_i
],])
cfg
.
add_flop
(
2
*
N
*
OH
*
OW
*
CO
*
CI
*
8
*
KH
*
KW
)
# these are actually binary ops
# ====================
VC
=
cfg
[
"tile_co"
]
.
size
[
-
1
]
VH
=
cfg
[
"tile_oh"
]
.
size
[
-
1
]
VW
=
cfg
[
"tile_ow"
]
.
size
[
-
1
]
data_q
=
bitpack
(
data
,
activation_bits
,
pack_axis
=
3
,
bit_axis
=
3
,
pack_type
=
'uint8'
)
kernel_vec
=
_kernel_vec_spatial_pack_nhwc
(
kernel
,
weight_bits
,
VC
,
len
(
kernel
.
shape
)
==
4
)
if
kernel_vec
.
shape
[
-
1
]
%
8
!=
0
and
CI_PAD
!=
0
:
kernel_vec
=
pad
(
kernel_vec
,
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
CI_PAD
])
N
,
H
,
W
,
IB
,
CI
=
data_q
.
shape
OCO
,
KH
,
KW
,
KB
,
VC
,
CI
=
kernel_vec
.
shape
dvshape
=
(
N
,
PAD_H
//
(
VH
*
HSTR
),
PAD_W
//
(
VW
*
WSTR
),
VH
*
HSTR
+
HCAT
,
VW
*
WSTR
+
WCAT
,
IB
,
CI
)
ovshape
=
(
1
,
OH
//
VH
,
OW
//
VW
,
CO
//
VC
,
VH
,
VW
,
VC
)
oshape
=
(
1
,
OH
,
OW
,
CO
)
if
(
HPAD
!=
0
and
WPAD
!=
0
):
data_pad
=
pad
(
data_q
,
(
0
,
HPAD
,
WPAD
,
0
,
0
),
name
=
"data_pad"
)
if
(
TPAD
!=
0
and
RPAD
!=
0
):
data_pad
=
pad
(
data_q
,
(
0
,
TPAD
,
LPAD
,
0
,
0
),
(
0
,
DPAD
,
RPAD
,
0
,
CI_PAD
),
name
=
"data_pad"
)
elif
CI_PAD
!=
0
:
data_pad
=
pad
(
data_q
,
(
0
,
0
,
0
,
0
,
0
),
(
0
,
0
,
0
,
0
,
CI_PAD
),
name
=
"data_pad"
)
else
:
data_pad
=
data_q
data_vec
=
tvm
.
compute
(
dvshape
,
lambda
n
,
h
,
w
,
vh
,
vw
,
b
,
ci
:
\
data_pad
[
n
][
h
*
VH
*
HSTR
+
vh
][
w
*
VW
*
WSTR
+
vw
][
b
][
ci
],
name
=
'data_vec'
)
ci
=
tvm
.
reduce_axis
((
0
,
CI
),
name
=
'ci'
)
dh
=
tvm
.
reduce_axis
((
0
,
KH
),
name
=
'dh'
)
dw
=
tvm
.
reduce_axis
((
0
,
KW
),
name
=
'dw'
)
ib
=
tvm
.
reduce_axis
((
0
,
IB
),
name
=
'ib'
)
kb
=
tvm
.
reduce_axis
((
0
,
KB
),
name
=
'kb'
)
def
_conv
(
n
,
h
,
w
,
co
,
vh
,
vw
,
vc
):
def
_
bipolar_
conv
(
n
,
h
,
w
,
co
,
vh
,
vw
,
vc
):
return
tvm
.
sum
((
tvm
.
popcount
(
kernel_vec
[
co
,
dh
,
dw
,
kb
,
vc
,
ci
]
.
astype
(
'uint16'
)
&
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ib
,
ci
]
.
astype
(
'uint16'
))
<<
(
kb
+
ib
)
.
astype
(
'uint16'
)),
axis
=
[
dh
,
dw
,
kb
,
ib
,
ci
])
def
_unipolar_conv
(
n
,
h
,
w
,
co
,
vh
,
vw
,
vc
):
return
tvm
.
sum
(
((
tvm
.
popcount
(
kernel_vec
[
co
,
dh
,
dw
,
kb
,
vc
,
ci
]
.
astype
(
'int16'
)
&
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ib
,
ci
]
.
astype
(
'int16'
))
-
tvm
.
popcount
(
~
kernel_vec
[
co
,
dh
,
dw
,
kb
,
vc
,
ci
]
.
astype
(
'int16'
)
&
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ib
,
ci
])
.
astype
(
'int16'
))
<<
(
kb
+
ib
)
.
astype
(
'int16'
)),
axis
=
[
dh
,
dw
,
kb
,
ib
,
ci
])
if
unipolar
:
conv_vec
=
tvm
.
compute
(
ovshape
,
_unipolar_conv
,
name
=
'conv_vec'
,
tag
=
'unipolar'
)
else
:
conv_vec
=
tvm
.
compute
(
ovshape
,
_bipolar_conv
,
name
=
'conv_vec'
,
tag
=
'bipolar'
)
conv
=
tvm
.
compute
(
ovshape
,
_conv
,
name
=
'conv'
)
conv
=
tvm
.
compute
(
oshape
,
lambda
n
,
h
,
w
,
co
:
conv_vec
[
n
][
h
//
VH
][
w
//
VW
][
co
//
VC
][
h
%
VH
][
w
%
VW
][
co
%
VC
]
.
astype
(
out_dtype
),
name
=
'conv'
,
tag
=
'spatial_bitserial_conv_nhwc'
)
return
tvm
.
compute
(
oshape
,
lambda
n
,
h
,
w
,
co
:
conv
[
n
][
h
//
VH
][
w
//
VW
][
co
//
VC
][
h
%
VH
][
w
%
VW
][
co
%
VC
]
.
astype
(
out_dtype
),
name
=
'output_vec'
,
tag
=
'spatial_bitserial_conv_nhwc'
)
return
conv
def
_intrin_popcount
(
m
,
k_i
,
w_b
,
x_b
):
dtype
=
'uint8'
w
=
tvm
.
placeholder
((
w_b
,
m
,
k_i
),
dtype
=
dtype
,
name
=
'w'
)
x
=
tvm
.
placeholder
((
x_b
,
k_i
,),
dtype
=
dtype
,
name
=
'x'
)
def
_intrin_popcount
(
m
,
k_i
,
w_b
,
x_b
,
unipolar
):
pack_
dtype
=
'uint8'
w
=
tvm
.
placeholder
((
w_b
,
m
,
k_i
),
dtype
=
pack_
dtype
,
name
=
'w'
)
x
=
tvm
.
placeholder
((
x_b
,
k_i
,),
dtype
=
pack_
dtype
,
name
=
'x'
)
k
=
tvm
.
reduce_axis
((
0
,
k_i
),
name
=
'k'
)
bw
=
tvm
.
reduce_axis
((
0
,
w_b
),
name
=
'bw'
)
bx
=
tvm
.
reduce_axis
((
0
,
x_b
),
name
=
'bx'
)
z
=
tvm
.
compute
((
m
,),
lambda
i
:
tvm
.
sum
(
tvm
.
popcount
(
w
[
bw
,
i
,
k
]
.
astype
(
'uint16'
)
&
x
[
bx
,
k
]
.
astype
(
'uint16'
))
<<
(
bw
+
bx
)
.
astype
(
'uint16'
),
axis
=
[
bw
,
bx
,
k
]),
name
=
'z'
)
if
unipolar
:
dtype
=
'int16'
z
=
tvm
.
compute
((
m
,),
lambda
i
:
tvm
.
sum
((
tvm
.
popcount
(
w
[
bw
,
i
,
k
]
.
astype
(
dtype
)
&
x
[
bx
,
k
]
.
astype
(
dtype
))
-
tvm
.
popcount
(
~
w
[
bw
,
i
,
k
]
.
astype
(
dtype
)
&
x
[
bx
,
k
]
.
astype
(
dtype
)))
<<
(
bw
+
bx
)
.
astype
(
dtype
),
axis
=
[
bw
,
bx
,
k
]),
name
=
'z'
)
else
:
dtype
=
'uint16'
z
=
tvm
.
compute
((
m
,),
lambda
i
:
tvm
.
sum
(
tvm
.
popcount
(
w
[
bw
,
i
,
k
]
.
astype
(
dtype
)
&
x
[
bx
,
k
]
.
astype
(
dtype
))
<<
(
bw
+
bx
)
.
astype
(
dtype
),
axis
=
[
bw
,
bx
,
k
]),
name
=
'z'
)
Wb
=
tvm
.
decl_buffer
(
w
.
shape
,
w
.
dtype
,
name
=
"W"
,
offset_factor
=
k_i
,
strides
=
[
tvm
.
var
(
'ldw'
),
tvm
.
var
(
'ldw'
),
1
])
strides
=
[
tvm
.
var
(
'ldw'
),
tvm
.
var
(
'ldw'
),
1
])
# stride can be inferred
Xb
=
tvm
.
decl_buffer
(
x
.
shape
,
x
.
dtype
,
name
=
"X"
,
offset_factor
=
k_i
,
strides
=
[
tvm
.
var
(
'ldw'
),
1
])
Zb
=
tvm
.
decl_buffer
(
z
.
shape
,
z
.
dtype
,
name
=
"Z"
,
offset_factor
=
1
,
strides
=
[
1
])
def
_intrin_func
(
ins
,
outs
):
ww
,
xx
=
ins
zz
=
outs
[
0
]
vpadd
=
"llvm.arm.neon.vpadd.v8u8"
vpadalu
=
"llvm.arm.neon.vpadalu.v16u8.v8u16"
args_1
=
tvm
.
const
(
1
,
'uint32'
)
args_2
=
tvm
.
const
(
2
,
'uint32'
)
if
unipolar
:
vpadd
=
"llvm.arm.neon.vpadd.v8i8"
vpadalu
=
"llvm.arm.neon.vpadals.v16i8.v8i16"
full_dtype
=
'int8x16'
half_dtype
=
'int8x8'
return_dtype
=
'int16x8'
else
:
vpadd
=
"llvm.arm.neon.vpadd.v8u8"
vpadalu
=
"llvm.arm.neon.vpadalu.v16u8.v8u16"
full_dtype
=
'uint8x16'
half_dtype
=
'uint8x8'
return_dtype
=
'uint16x8'
def
_instr
(
index
):
irb
=
tvm
.
ir_builder
.
create
()
if
index
==
1
:
irb
.
emit
(
zz
.
vstore
(
0
,
tvm
.
const
(
0
,
'uint16x8'
)))
if
index
==
1
:
# reduce reset
irb
.
emit
(
zz
.
vstore
(
0
,
tvm
.
const
(
0
,
return_dtype
)))
return
irb
.
get
()
# body and reduce update
cnts8
=
[
None
]
*
8
cnts4
=
[
None
]
*
4
cnts2
=
[
None
]
*
2
...
...
@@ -178,154 +198,108 @@ def _intrin_popcount(m, k_i, w_b, x_b):
for
bx
in
range
(
x_b
):
if
k_i
==
16
:
for
i
in
range
(
m
):
ands
=
ww
.
vload
([
bw
,
i
,
0
],
'uint8x16'
)
&
xx
.
vload
([
bx
,
0
],
'uint8x16'
)
cnts
=
tvm
.
popcount
(
ands
)
upper_half
=
tvm
.
call_pure_intrin
(
'uint8x8'
,
'vectorhigh'
,
cnts
)
lower_half
=
tvm
.
call_pure_intrin
(
'uint8x8'
,
'vectorlow'
,
cnts
)
w_
=
ww
.
vload
([
bw
,
i
,
0
],
'uint8x16'
)
.
astype
(
full_dtype
)
x_
=
xx
.
vload
([
bx
,
0
],
'uint8x16'
)
.
astype
(
full_dtype
)
if
unipolar
:
cnts
=
tvm
.
popcount
(
w_
&
x_
)
-
tvm
.
popcount
(
~
w_
&
x_
)
else
:
cnts
=
tvm
.
popcount
(
w_
&
x_
)
upper_half
=
tvm
.
call_pure_intrin
(
half_dtype
,
'vectorhigh'
,
cnts
)
lower_half
=
tvm
.
call_pure_intrin
(
half_dtype
,
'vectorlow'
,
cnts
)
cnts8
[
i
]
=
upper_half
+
lower_half
for
i
in
range
(
m
//
2
):
cnts4
[
i
]
=
tvm
.
call_llvm_intrin
(
'uint8x8'
,
vpadd
,
cnts4
[
i
]
=
tvm
.
call_llvm_intrin
(
half_dtype
,
vpadd
,
args_1
,
cnts8
[
i
*
2
],
cnts8
[
i
*
2
+
1
])
for
i
in
range
(
m
//
4
):
cnts2
[
i
]
=
tvm
.
call_llvm_intrin
(
'uint8x8'
,
vpadd
,
cnts2
[
i
]
=
tvm
.
call_llvm_intrin
(
half_dtype
,
vpadd
,
args_1
,
cnts4
[
i
*
2
],
cnts4
[
i
*
2
+
1
])
cnts
=
tvm
.
call_pure_intrin
(
'uint8x16'
,
'vectorcombine'
,
cnts2
[
0
],
cnts2
[
1
])
shifted_cnts
=
cnts
<<
tvm
.
const
(
bw
+
bx
,
dtype
)
out
=
tvm
.
call_llvm_intrin
(
'uint16x8'
,
vpadalu
,
args_2
,
zz
.
vload
(
0
,
'uint16x8'
),
shifted_cnts
)
cnts
=
tvm
.
call_pure_intrin
(
full_dtype
,
'vectorcombine'
,
cnts2
[
0
],
cnts2
[
1
])
shifted_cnts
=
cnts
<<
tvm
.
const
(
bw
+
bx
,
pack_
dtype
)
out
=
tvm
.
call_llvm_intrin
(
return_dtype
,
vpadalu
,
args_2
,
zz
.
vload
(
0
,
return_dtype
),
shifted_cnts
)
else
:
# ki == 8
for
i
in
range
(
m
):
ands
=
ww
.
vload
([
bw
,
i
,
0
],
'uint8x8'
)
&
xx
.
vload
([
bx
,
0
],
'uint8x8'
)
cnts8
[
i
]
=
tvm
.
popcount
(
ands
)
w_
=
ww
.
vload
([
bw
,
i
,
0
],
'uint8x8'
)
.
astype
(
half_dtype
)
x_
=
xx
.
vload
([
bx
,
0
],
'uint8x8'
)
.
astype
(
half_dtype
)
if
unipolar
:
cnts8
[
i
]
=
tvm
.
popcount
(
w_
&
x_
)
-
tvm
.
popcount
(
~
w_
&
x_
)
else
:
cnts8
[
i
]
=
tvm
.
popcount
(
w_
&
x_
)
for
i
in
range
(
m
//
2
):
cnts4
[
i
]
=
tvm
.
call_llvm_intrin
(
'uint8x8'
,
vpadd
,
cnts4
[
i
]
=
tvm
.
call_llvm_intrin
(
half_dtype
,
vpadd
,
args_1
,
cnts8
[
i
*
2
],
cnts8
[
i
*
2
+
1
])
for
i
in
range
(
m
//
4
):
cnts2
[
i
]
=
tvm
.
call_llvm_intrin
(
'uint8x8'
,
vpadd
,
cnts2
[
i
]
=
tvm
.
call_llvm_intrin
(
half_dtype
,
vpadd
,
args_1
,
cnts4
[
i
*
2
],
cnts4
[
i
*
2
+
1
])
cnts
=
tvm
.
call_pure_intrin
(
'uint8x16'
,
'vectorcombine'
,
cnts2
[
0
],
cnts2
[
1
])
shifted_cnts
=
cnts
<<
tvm
.
const
(
bw
+
bx
,
dtype
)
out
=
tvm
.
call_llvm_intrin
(
'uint16x8'
,
vpadalu
,
args_2
,
zz
.
vload
(
0
,
'uint16x8'
),
shifted_cnts
)
cnts
=
tvm
.
call_pure_intrin
(
full_dtype
,
'vectorcombine'
,
cnts2
[
0
],
cnts2
[
1
])
shifted_cnts
=
cnts
<<
tvm
.
const
(
bw
+
bx
,
pack_
dtype
)
out
=
tvm
.
call_llvm_intrin
(
return_dtype
,
vpadalu
,
args_2
,
zz
.
vload
(
0
,
return_dtype
),
shifted_cnts
)
irb
.
emit
(
zz
.
vstore
(
0
,
out
))
return
irb
.
get
()
# body, reset, update
return
_instr
(
0
),
_instr
(
1
),
_instr
(
2
)
with
tvm
.
build_config
(
offset_factor
=
1
,
partition_const_loop
=
True
):
return
tvm
.
decl_tensor_intrin
(
z
.
op
,
_intrin_func
,
binds
=
{
w
:
Wb
,
x
:
Xb
})
return
tvm
.
decl_tensor_intrin
(
z
.
op
,
_intrin_func
,
binds
=
{
w
:
Wb
,
x
:
Xb
,
z
:
Zb
})
# ARM specific schedule that using custom microkernel
def
_schedule_spatial_conv2d_nhwc
(
s
,
data
,
data_q
,
data_pad
,
data_vec
,
kernel
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
last
):
# no stride and padding info here
_
,
H
,
W
,
IB
,
CI
=
data_q
.
shape
KH
,
KW
,
KB
,
_
,
CO
=
kernel_q
.
shape
def
_schedule_spatial_conv2d_nhwc
(
cfg
,
s
,
data_pad
,
data_vec
,
kernel_vec
,
conv_out
,
output
,
last
,
unipolar
):
_
,
_
,
_
,
_
,
_
,
IB
,
CI
=
data_vec
.
shape
_
,
KH
,
KW
,
KB
,
_
,
_
=
kernel_vec
.
shape
KB
=
get_const_int
(
KB
)
IB
=
get_const_int
(
IB
)
if
data_pad
is
None
:
padding
=
(
0
,
0
)
_
,
in_h
,
in_w
,
_
,
_
=
data_q
.
shape
kern_h
,
kern_w
,
_
,
_
=
kernel
.
shape
_
,
out_h
,
out_w
,
_
=
output
.
shape
hstride
=
(
in_h
-
kern_h
)
//
(
out_h
-
1
)
wstride
=
(
in_w
-
kern_w
)
//
(
out_w
-
1
)
stride
=
get_const_int
(
hstride
),
get_const_int
(
wstride
)
else
:
_
,
in_h
,
in_w
,
_
,
_
=
data_q
.
shape
_
,
pad_h
,
pad_w
,
_
,
_
=
data_pad
.
shape
hpad
=
(
pad_h
-
in_h
)
//
2
wpad
=
(
pad_w
-
in_w
)
//
2
padding
=
get_const_int
(
hpad
),
get_const_int
(
wpad
)
_
,
in_h
,
in_w
,
_
,
_
=
data_pad
.
shape
kern_h
,
kern_w
,
_
,
_
=
kernel
.
shape
_
,
out_h
,
out_w
,
_
=
output
.
shape
hstride
=
(
in_h
-
kern_h
)
//
(
out_h
-
1
)
wstride
=
(
in_w
-
kern_w
)
//
(
out_w
-
1
)
stride
=
get_const_int
(
hstride
),
get_const_int
(
wstride
)
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
output
.
dtype
,
"NHWC"
)
sch
=
_get_schedule
(
wkl
,
"NHWC"
)
VH
=
sch
.
vh
VW
=
sch
.
vw
VC
=
sch
.
vc
ba
=
sch
.
ba
bc
=
sch
.
bc
##### Schedule data packing
VC
=
cfg
[
"tile_co"
]
.
size
[
-
1
]
VH
=
cfg
[
"tile_oh"
]
.
size
[
-
1
]
VW
=
cfg
[
"tile_ow"
]
.
size
[
-
1
]
##### Schedule data padding and packing
if
data_pad
is
not
None
:
s
[
data_pad
]
.
compute_inline
()
_
,
h
,
_
,
_
,
_
,
_
,
_
=
s
[
data_vec
]
.
op
.
axis
if
ba
==
1
:
oaxis
=
h
paxis
=
h
else
:
oh
,
ih
=
s
[
data_vec
]
.
split
(
h
,
ba
)
oaxis
=
oh
paxis
=
ih
s
[
data_vec
]
.
parallel
(
paxis
)
s
[
data_vec
]
.
pragma
(
oaxis
,
"parallel_launch_point"
)
s
[
data_vec
]
.
pragma
(
paxis
,
"parallel_stride_pattern"
)
s
[
data_vec
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
cfg
.
define_split
(
"tile_ah"
,
cfg
.
axis
(
h
),
policy
=
"all"
,
num_outputs
=
2
,
max_factor
=
32
)
oh
,
ih
=
cfg
[
"tile_ah"
]
.
apply
(
s
,
data_vec
,
h
)
s
[
data_vec
]
.
parallel
(
oh
)
####
#
Schedule kernel packing
#### Schedule kernel packing
co
,
_
,
_
,
_
,
_
,
_
=
s
[
kernel_vec
]
.
op
.
axis
if
bc
==
1
:
oaxis
=
co
paxis
=
co
else
:
oco
,
ico
=
s
[
kernel_vec
]
.
split
(
co
,
bc
)
oaxis
=
oco
paxis
=
ico
s
[
kernel_vec
]
.
parallel
(
paxis
)
s
[
kernel_vec
]
.
pragma
(
oaxis
,
"parallel_launch_point"
)
s
[
kernel_vec
]
.
pragma
(
paxis
,
"parallel_stride_pattern"
)
s
[
kernel_vec
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
cfg
.
define_split
(
"tile_bco"
,
cfg
.
axis
(
co
),
policy
=
"all"
,
num_outputs
=
2
,
max_factor
=
32
)
oco
,
ico
=
cfg
[
"tile_bco"
]
.
apply
(
s
,
kernel_vec
,
co
)
s
[
kernel_vec
]
.
parallel
(
oco
)
##### Schedule Convolution
n
,
oh
,
ow
,
co
,
vh
,
vw
,
vc
=
s
[
conv_out
]
.
op
.
axis
dh
,
d
w
,
kb
,
ib
,
ci
=
s
[
conv_out
]
.
op
.
reduce_axis
kh
,
k
w
,
kb
,
ib
,
ci
=
s
[
conv_out
]
.
op
.
reduce_axis
kfactor
=
sch
.
kfactor
if
sch
.
split_ci
:
oci
,
ici
=
s
[
conv_out
]
.
split
(
ci
,
kfactor
)
s
[
conv_out
]
.
reorder
(
n
,
oh
,
ow
,
co
,
vh
,
vw
,
dh
,
dw
,
oci
,
kb
,
ib
,
vc
,
ici
)
else
:
s
[
conv_out
]
.
reorder
(
n
,
oh
,
ow
,
co
,
vh
,
vw
,
dh
,
dw
,
kb
,
ib
,
vc
,
ci
)
ci_o
,
ci_i
=
cfg
[
'tile_ci'
]
.
apply
(
s
,
conv_out
,
ci
)
re_axes
=
cfg
[
"reorder_0"
]
.
apply
(
s
,
conv_out
,
[
n
,
oh
,
ow
,
co
,
vh
,
vw
,
kh
,
kw
,
ci_o
,
kb
,
ib
,
vc
,
ci_i
])
pc
=
_intrin_popcount
(
8
,
kfactor
,
KB
,
IB
)
s
[
conv_out
]
.
tensorize
(
kb
,
pc
)
# Use microkernel
kfactor
=
cfg
[
'tile_ci'
]
.
size
[
1
]
if
kfactor
%
8
==
0
:
pc
=
_intrin_popcount
(
VC
,
kfactor
,
KB
,
IB
,
unipolar
)
s
[
conv_out
]
.
tensorize
(
kb
,
pc
)
n
,
h
,
w
,
co
=
s
[
last
]
.
op
.
axis
co
,
vc
=
s
[
last
]
.
split
(
co
,
VC
)
oh
,
ow
,
vh
,
vw
=
s
[
last
]
.
tile
(
h
,
w
,
VH
,
VW
)
s
[
last
]
.
reorder
(
n
,
oh
,
ow
,
co
,
vc
,
vh
,
vw
)
s
[
last
]
.
vectorize
(
vw
)
co
,
vc
=
cfg
[
'tile_co'
]
.
apply
(
s
,
last
,
co
)
oh
,
vh
=
cfg
[
'tile_oh'
]
.
apply
(
s
,
last
,
h
)
ow
,
vw
=
cfg
[
'tile_ow'
]
.
apply
(
s
,
last
,
w
)
s
[
last
]
.
reorder
(
n
,
oh
,
ow
,
co
,
vh
,
vw
,
vc
)
s
[
last
]
.
vectorize
(
vc
)
if
last
!=
output
:
s
[
last
]
.
compute_inline
()
s
[
conv_out
]
.
compute_at
(
s
[
last
],
ow
)
if
co
==
1
:
oaxis
=
oh
paxis
=
oh
else
:
oho
,
iho
=
s
[
last
]
.
split
(
oh
,
bc
)
oaxis
=
oho
paxis
=
iho
s
[
last
]
.
parallel
(
paxis
)
s
[
conv_out
]
.
compute_at
(
s
[
last
],
co
)
s
[
last
]
.
parallel
(
oh
)
s
=
s
.
normalize
()
return
s
@
generic.schedule_bitserial_conv2d_nhwc.register
([
"arm_cpu"
]
)
def
schedule_bitserial_conv2d_nhwc
(
outs
):
"""
Raspverry pi
schedule for bitserial conv2d"""
@
autotvm.register_topi_schedule
(
generic
.
nn
.
schedule_bitserial_conv2d_nhwc
,
'arm_cpu'
,
'direct'
)
def
schedule_bitserial_conv2d_nhwc
(
cfg
,
outs
):
"""
Arm cpu
schedule for bitserial conv2d"""
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
...
...
@@ -344,10 +318,6 @@ def schedule_bitserial_conv2d_nhwc(outs):
conv_out
=
op
.
input_tensors
[
0
]
kernel_vec
=
conv_out
.
op
.
input_tensors
[
0
]
kernel_q
=
kernel_vec
.
op
.
input_tensors
[
0
]
kernel
=
kernel_q
.
op
.
input_tensors
[
0
]
if
"QuantizeInput"
in
kernel
.
op
.
name
:
# Need to go up 1 further, from the combine in bitpack
kernel
=
kernel
.
op
.
input_tensors
[
0
]
data_vec
=
conv_out
.
op
.
input_tensors
[
1
]
data_q
=
data_vec
.
op
.
input_tensors
[
0
]
data
=
data_q
.
op
.
input_tensors
[
0
]
...
...
@@ -355,13 +325,10 @@ def schedule_bitserial_conv2d_nhwc(outs):
if
isinstance
(
data_q
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
"pad"
in
data_q
.
op
.
tag
:
data_pad
=
data_q
data_q
=
data
data
=
data_q
.
op
.
input_tensors
[
0
]
if
"QuantizeInput"
in
data
.
op
.
name
:
# Need to go up 1 further, from the combine in bitpack
data
=
data
.
op
.
input_tensors
[
0
]
_schedule_spatial_conv2d_nhwc
(
s
,
data
,
data_q
,
data_pad
,
data
_vec
,
kernel
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
outs
[
0
]
)
unipolar
=
"unipolar"
in
conv_out
.
op
.
tag
_schedule_spatial_conv2d_nhwc
(
cfg
,
s
,
data_pad
,
data_vec
,
kernel
_vec
,
conv_out
,
output
,
outs
[
0
],
unipolar
)
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
...
...
topi/python/topi/nn/bitserial_conv2d.py
View file @
17351875
# pylint: disable=invalid-name, unused-variable, too-many-locals, too-many-arguments, unused-argument
"""Bitserial Conv2D operators"""
from
__future__
import
absolute_import
as
_abs
from
collections
import
namedtuple
import
numpy
as
np
import
tvm
from
tvm
import
autotvm
from
topi.transform
import
concatenate
from
.pad
import
pad
from
.util
import
get_pad_tuple
from
..util
import
get_const_tuple
,
get_const_int
# workload description of conv2d
Workload
=
namedtuple
(
'Workload'
,
[
'in_dtype'
,
'out_dtype'
,
'height'
,
'width'
,
'in_filter'
,
'out_filter'
,
'hkernel'
,
'wkernel'
,
'hpad'
,
'wpad'
,
'hstride'
,
'wstride'
])
SpatialPackNCHW
=
namedtuple
(
'SpatialPack'
,
[
'vh'
,
'vw'
,
'vc'
,
'ba'
,
'bc'
])
SpatialPackNHWC
=
namedtuple
(
'SpatialPack'
,
[
'vh'
,
'vw'
,
'vc'
,
'ba'
,
'bc'
])
_WORKLOADS
=
[
# workloads of resnet18 on imagenet
# input_size, input_size, ic, oc, kh, kw, pad, pad, stride, stride
Workload
(
'uint32'
,
'int32'
,
56
,
56
,
64
,
64
,
3
,
3
,
1
,
1
,
1
,
1
),
Workload
(
'uint32'
,
'int32'
,
56
,
56
,
64
,
64
,
1
,
1
,
0
,
0
,
1
,
1
),
Workload
(
'uint32'
,
'int32'
,
56
,
56
,
64
,
128
,
3
,
3
,
1
,
1
,
2
,
2
),
Workload
(
'uint32'
,
'int32'
,
56
,
56
,
64
,
128
,
1
,
1
,
0
,
0
,
2
,
2
),
Workload
(
'uint32'
,
'int32'
,
28
,
28
,
128
,
128
,
3
,
3
,
1
,
1
,
1
,
1
),
Workload
(
'uint32'
,
'int32'
,
28
,
28
,
128
,
256
,
3
,
3
,
1
,
1
,
2
,
2
),
Workload
(
'uint32'
,
'int32'
,
28
,
28
,
128
,
256
,
1
,
1
,
0
,
0
,
2
,
2
),
Workload
(
'uint32'
,
'int32'
,
14
,
14
,
256
,
256
,
3
,
3
,
1
,
1
,
1
,
1
),
Workload
(
'uint32'
,
'int32'
,
14
,
14
,
256
,
512
,
3
,
3
,
1
,
1
,
2
,
2
),
Workload
(
'uint32'
,
'int32'
,
14
,
14
,
256
,
512
,
1
,
1
,
0
,
0
,
2
,
2
),
Workload
(
'uint32'
,
'int32'
,
7
,
7
,
512
,
512
,
3
,
3
,
1
,
1
,
1
,
1
),
# workload of alexnet on cifar10
Workload
(
'int32'
,
'int32'
,
27
,
27
,
96
,
192
,
5
,
5
,
2
,
2
,
1
,
1
),
Workload
(
'int32'
,
'int32'
,
13
,
13
,
192
,
384
,
3
,
3
,
1
,
1
,
1
,
1
),
Workload
(
'int32'
,
'int32'
,
13
,
13
,
384
,
384
,
3
,
3
,
1
,
1
,
1
,
1
),
Workload
(
'int32'
,
'int32'
,
13
,
13
,
384
,
256
,
3
,
3
,
1
,
1
,
1
,
1
),
]
@tvm.target.generic_func
def
bitserial_conv2d
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
layout
=
'NCHW'
,
pack_dtype
=
'uint32'
,
out_dtype
=
'int32'
,
dorefa
=
True
):
def
bitserial_conv2d
_nchw
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
pack_dtype
=
'uint32'
,
out_dtype
=
'int16'
,
unipolar
=
True
):
"""Bitserial Conv2D operator.
Parameters
----------
input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width] or
[batch, in_height, in_width, in_channel]
4-D with shape [batch, in_channel, in_height, in_width]
filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width] or
[filter_height, filter_width, in_channel, num_filter]
4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
layout : str
layout of data
padding : int or a list/tuple of two or four ints
padding size, [pad_height, pad_width], [pad_top, pad_left, pad_down, pad_right]
activation_bits: int
number of bits used for activations/input elements
...
...
@@ -78,63 +40,184 @@ def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits
pack_dtype: str
bit packing type
dorefa
: bool
preform the bitserial dot-product using 2 popcounts (required for DoReFa-Net)
unipolar
: bool
if binarization style is in unipolar 1/0 format, instead of bipolar -1/+1 format
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width] or
[batch, out_height, out_width, out_channel]
4-D with shape [batch, out_channel, out_height, out_width]
"""
# search platform specific declaration first
# default declaration
if
layout
==
'NCHW'
:
return
spatial_pack_nchw
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
pack_dtype
=
pack_dtype
,
out_dtype
=
out_dtype
,
dorefa
=
dorefa
)
if
layout
==
'NHWC'
:
return
spatial_pack_nhwc
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
pack_dtype
=
pack_dtype
,
out_dtype
=
out_dtype
,
dorefa
=
dorefa
)
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
def
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
layout
):
""" Get the workload structure. """
assert
layout
in
(
"NCHW"
,
"NHWC"
),
\
"Only support layouts NCHW and NHWC"
if
layout
==
"NCHW"
:
_
,
CI
,
IH
,
IW
=
[
x
.
value
for
x
in
data
.
shape
]
CO
,
_
,
KH
,
KW
=
[
x
.
value
for
x
in
kernel
.
shape
]
else
:
# NHWC
IH
,
IW
=
data
.
shape
[
1
]
.
value
,
data
.
shape
[
2
]
.
value
KH
,
KW
,
CI
,
CO
=
[
x
for
x
in
get_const_tuple
(
kernel
.
shape
)]
HPAD
,
WPAD
,
_
,
_
=
get_pad_tuple
(
padding
,
kernel
)
if
isinstance
(
stride
,
(
tuple
,
list
)):
HSTR
,
WSTR
=
stride
assert
isinstance
(
stride
,
int
)
or
len
(
stride
)
==
2
Input_q
=
bitpack
(
data
,
activation_bits
,
pack_axis
=
1
,
bit_axis
=
2
,
pack_type
=
pack_dtype
)
Filter_q
=
bitpack
(
filter
,
weight_bits
,
pack_axis
=
1
,
bit_axis
=
4
,
pack_type
=
pack_dtype
)
batch
,
in_channel
,
activation_bits
,
in_height
,
in_width
=
Input_q
.
shape
num_filter
,
channel
,
kernel_h
,
kernel_w
,
weight_bits
=
Filter_q
.
shape
if
isinstance
(
padding
,
int
)
or
(
isinstance
(
padding
,
(
tuple
,
list
))
and
len
(
padding
)
==
2
):
TPAD
,
LPAD
,
DPAD
,
RPAD
=
get_pad_tuple
(
padding
,
kernel
)
else
:
HSTR
,
WSTR
=
stride
,
stride
TPAD
,
LPAD
,
DPAD
,
RPAD
=
padding
pad_before
=
[
0
,
0
,
0
,
TPAD
,
LPAD
]
pad_after
=
[
0
,
0
,
0
,
DPAD
,
RPAD
]
PadInput_q
=
pad
(
Input_q
,
pad_before
,
pad_after
,
name
=
"pad_temp"
)
# compute the output shape
if
isinstance
(
stride
,
int
):
stride_h
=
stride_w
=
stride
else
:
stride_h
,
stride_w
=
stride
out_channel
=
num_filter
out_height
=
(
in_height
-
kernel_h
+
TPAD
+
DPAD
)
//
stride_h
+
1
out_width
=
(
in_width
-
kernel_w
+
LPAD
+
RPAD
)
//
stride_w
+
1
rc
=
tvm
.
reduce_axis
((
0
,
in_channel
),
name
=
'rc'
)
ry
=
tvm
.
reduce_axis
((
0
,
kernel_h
),
name
=
'ry'
)
rx
=
tvm
.
reduce_axis
((
0
,
kernel_w
),
name
=
'rx'
)
b1
=
tvm
.
reduce_axis
((
0
,
activation_bits
),
name
=
'b1'
)
b2
=
tvm
.
reduce_axis
((
0
,
weight_bits
),
name
=
'b2'
)
if
unipolar
:
def
_conv
(
nn
,
ff
,
yy
,
xx
):
b1b2
=
(
b1
+
b2
)
.
astype
(
out_dtype
)
return
tvm
.
sum
(
((
tvm
.
popcount
(
PadInput_q
[
nn
,
rc
,
b1
,
yy
*
stride_h
+
ry
,
xx
*
stride_w
+
rx
]
&
Filter_q
[
ff
,
rc
,
ry
,
rx
,
b2
])
-
tvm
.
popcount
(
PadInput_q
[
nn
,
rc
,
b1
,
yy
*
stride_h
+
ry
,
xx
*
stride_w
+
rx
]
&
~
Filter_q
[
ff
,
rc
,
ry
,
rx
,
b2
]))
<<
(
b1b2
))
.
astype
(
out_dtype
),
axis
=
[
rc
,
ry
,
rx
,
b2
,
b1
])
.
astype
(
out_dtype
)
else
:
def
_conv
(
nn
,
ff
,
yy
,
xx
):
b1b2
=
(
b1
+
b2
)
.
astype
(
out_dtype
)
return
tvm
.
sum
((
tvm
.
popcount
(
PadInput_q
[
nn
,
rc
,
b1
,
yy
*
stride_h
+
ry
,
xx
*
stride_w
+
rx
]
&
Filter_q
[
ff
,
rc
,
ry
,
rx
,
b2
])
<<
(
b1b2
))
.
astype
(
out_dtype
),
axis
=
[
rc
,
ry
,
rx
,
b2
,
b1
])
.
astype
(
out_dtype
)
return
Workload
(
data
.
dtype
,
out_dtype
,
IH
,
IW
,
CI
,
CO
,
KH
,
KW
,
HPAD
,
WPAD
,
HSTR
,
WSTR
)
return
tvm
.
compute
((
batch
,
out_channel
,
out_height
,
out_width
),
_conv
,
name
=
"Conv2dOutput"
,
tag
=
"bitserial_conv2d_nchw"
)
@tvm.target.generic_func
def
_get_schedule
(
wkl
,
layout
):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target
=
tvm
.
target
.
current_target
()
raise
RuntimeError
(
"No schedule for current target:{}"
.
format
(
target
))
# This return has no use, merely to supress pylint warning
return
wkl
def
spatial_pack_nchw
(
data
,
kernel
,
stride
,
padding
,
in_bits
,
weight_bits
,
pack_dtype
,
out_dtype
,
dorefa
=
False
):
def
bitserial_conv2d_nhwc
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
pack_dtype
=
'uint32'
,
out_dtype
=
'int16'
,
unipolar
=
True
):
"""Bitserial Conv2D operator.
Parameters
----------
input : tvm.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
filter : tvm.Tensor
4-D with shape [filter_height, filter_width, in_channel, num_filter]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two or four ints
padding size, [pad_height, pad_width], [pad_top, pad_left, pad_down, pad_right]
activation_bits: int
number of bits used for activations/input elements
weight_bits: int
number of bits used for weight elements
out_dtype: str
return type of convolution
pack_dtype: str
bit packing type
unipolar: bool
if binarization style is in unipolar 1/0 format, instead of bipolar -1/+1 format
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
assert
isinstance
(
stride
,
int
)
or
len
(
stride
)
==
2
Input_q
=
bitpack
(
data
,
activation_bits
,
pack_axis
=
3
,
bit_axis
=
4
,
pack_type
=
pack_dtype
)
if
len
(
kernel
.
shape
)
==
4
:
Filter_q
=
bitpack
(
kernel
,
weight_bits
,
pack_axis
=
2
,
bit_axis
=
4
,
pack_type
=
pack_dtype
)
kernel_h
,
kernel_w
,
_
,
num_filter
,
_
=
get_const_tuple
(
Filter_q
.
shape
)
else
:
Filter_q
=
kernel
kernel_h
,
kernel_w
,
_
,
_
,
num_filter
=
get_const_tuple
(
Filter_q
.
shape
)
batch
,
in_height
,
in_width
,
in_channel_q
,
_
=
get_const_tuple
(
Input_q
.
shape
)
if
isinstance
(
padding
,
int
)
or
(
isinstance
(
padding
,
(
tuple
,
list
))
and
len
(
padding
)
==
2
):
TPAD
,
LPAD
,
DPAD
,
RPAD
=
get_pad_tuple
(
padding
,
kernel
)
else
:
TPAD
,
LPAD
,
DPAD
,
RPAD
=
padding
pad_before
=
[
0
,
TPAD
,
LPAD
,
0
,
0
]
pad_after
=
[
0
,
DPAD
,
RPAD
,
0
,
0
]
# compute the output shape
if
isinstance
(
stride
,
int
):
stride_h
=
stride_w
=
stride
else
:
stride_h
,
stride_w
=
stride
out_channel
=
num_filter
out_height
=
(
in_height
-
kernel_h
+
TPAD
+
DPAD
)
//
stride_h
+
1
out_width
=
(
in_width
-
kernel_w
+
LPAD
+
RPAD
)
//
stride_w
+
1
PadInput_q
=
pad
(
Input_q
,
pad_before
,
pad_after
,
name
=
"PaddedInput"
)
rc
=
tvm
.
reduce_axis
((
0
,
in_channel_q
),
name
=
'rc'
)
ry
=
tvm
.
reduce_axis
((
0
,
kernel_h
),
name
=
'ry'
)
rx
=
tvm
.
reduce_axis
((
0
,
kernel_w
),
name
=
'rx'
)
b1
=
tvm
.
reduce_axis
((
0
,
activation_bits
),
name
=
'b1'
)
b2
=
tvm
.
reduce_axis
((
0
,
weight_bits
),
name
=
'b2'
)
if
unipolar
:
def
_conv
(
nn
,
yy
,
xx
,
ff
):
b1b2
=
(
b1
+
b2
)
.
astype
(
out_dtype
)
return
tvm
.
sum
(
((
tvm
.
popcount
(
PadInput_q
[
nn
,
yy
*
stride_h
+
ry
,
xx
*
stride_w
+
rx
,
rc
,
b1
]
&
Filter_q
[
ry
,
rx
,
rc
,
ff
,
b2
])
-
tvm
.
popcount
(
PadInput_q
[
nn
,
yy
*
stride_h
+
ry
,
xx
*
stride_w
+
rx
,
rc
,
b1
]
&
~
Filter_q
[
ry
,
rx
,
rc
,
ff
,
b2
]))
<<
b1b2
)
.
astype
(
out_dtype
),
axis
=
[
rc
,
ry
,
rx
,
b2
,
b1
])
else
:
def
_conv
(
nn
,
yy
,
xx
,
ff
):
b1b2
=
(
b1
+
b2
)
.
astype
(
out_dtype
)
return
tvm
.
sum
((
tvm
.
popcount
(
PadInput_q
[
nn
,
yy
*
stride_h
+
ry
,
xx
*
stride_w
+
rx
,
rc
,
b1
]
&
Filter_q
[
ry
,
rx
,
rc
,
ff
,
b2
])
<<
b1b2
)
.
astype
(
out_dtype
),
axis
=
[
rc
,
ry
,
rx
,
b2
,
b1
])
conv
=
tvm
.
compute
((
batch
,
out_height
,
out_width
,
out_channel
),
_conv
,
name
=
"Conv2dOutput"
,
tag
=
"bitserial_conv2d_nhwc"
)
return
conv
@autotvm.register_topi_compute
(
bitserial_conv2d_nchw
,
[
'cpu'
,
'arm_cpu'
],
'direct'
)
def
spatial_pack_nchw
(
cfg
,
data
,
kernel
,
stride
,
padding
,
in_bits
,
weight_bits
,
pack_dtype
=
'uint32'
,
out_dtype
=
'int16'
,
unipolar
=
True
):
""" Compute convolution with pack on spatial axes. """
assert
data
.
shape
[
0
]
.
value
==
1
,
"spatial pack convolution only support batch size=1"
data_q
=
bitpack
(
data
,
in_bits
,
pack_axis
=
1
,
bit_axis
=
0
,
pack_type
=
pack_dtype
)
kernel_q
=
bitpack
(
kernel
,
weight_bits
,
pack_axis
=
1
,
bit_axis
=
0
,
pack_type
=
pack_dtype
)
IB
,
_
,
CI
,
H
,
W
=
data_q
.
shape
KB
,
CO
,
_
,
KH
,
KW
=
kernel_q
.
shape
HPAD
,
WPAD
,
_
,
_
=
get_pad_tuple
(
padding
,
kernel
)
# Check if kernel is already bitpacked
if
len
(
kernel
.
shape
)
==
4
:
kernel_q
=
bitpack
(
kernel
,
weight_bits
,
pack_axis
=
1
,
bit_axis
=
0
,
pack_type
=
pack_dtype
)
KB
,
CO
,
_
,
KH
,
KW
=
get_const_tuple
(
kernel_q
.
shape
)
else
:
kernel_vec
=
kernel
OCO
,
_
,
KH
,
KW
,
KB
,
VC
=
get_const_tuple
(
kernel_vec
.
shape
)
CO
=
OCO
*
VC
IB
,
N
,
CI
,
H
,
W
=
get_const_tuple
(
data_q
.
shape
)
KB
,
CO
,
_
,
KH
,
KW
=
get_const_tuple
(
kernel_q
.
shape
)
if
isinstance
(
padding
,
int
)
or
(
isinstance
(
padding
,
(
tuple
,
list
))
and
len
(
padding
)
==
2
):
TPAD
,
LPAD
,
DPAD
,
RPAD
=
get_pad_tuple
(
padding
,
kernel
)
else
:
TPAD
,
LPAD
,
DPAD
,
RPAD
=
padding
pad_before
=
[
0
,
0
,
0
,
TPAD
,
LPAD
]
pad_after
=
[
0
,
0
,
0
,
DPAD
,
RPAD
]
if
isinstance
(
stride
,
(
tuple
,
list
)):
HSTR
,
WSTR
=
stride
...
...
@@ -142,38 +225,50 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
HSTR
,
WSTR
=
stride
,
stride
HCAT
,
WCAT
=
KH
-
1
,
KW
-
1
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
"NCHW"
)
sch
=
_get_schedule
(
wkl
,
"NCHW"
)
VH
=
sch
.
vh
VW
=
sch
.
vw
VC
=
sch
.
vc
TH
=
H
+
2
*
HPAD
TW
=
W
+
2
*
WPAD
OH
=
(
H
+
2
*
HPAD
-
KH
)
//
HSTR
+
1
OW
=
(
W
+
2
*
WPAD
-
KW
)
//
WSTR
+
1
TH
=
H
+
TPAD
+
DPAD
TW
=
W
+
LPAD
+
RPAD
OH
=
(
H
+
TPAD
+
DPAD
-
KH
)
//
HSTR
+
1
OW
=
(
W
+
LPAD
+
RPAD
-
KW
)
//
WSTR
+
1
# ==================== define configuration space ====================
n
,
co
,
oh
,
ow
=
cfg
.
axis
(
N
),
cfg
.
axis
(
CO
),
cfg
.
axis
(
OH
),
cfg
.
axis
(
OW
)
ci
,
kh
,
kw
=
cfg
.
reduce_axis
(
CI
),
cfg
.
reduce_axis
(
KH
),
cfg
.
reduce_axis
(
KW
)
ib
,
kb
=
cfg
.
reduce_axis
(
in_bits
),
cfg
.
reduce_axis
(
weight_bits
)
co
,
vc
=
cfg
.
define_split
(
'tile_co'
,
co
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
max
(
x
.
size
[
1
:])
<=
16
)
oh
,
vh
=
cfg
.
define_split
(
'tile_oh'
,
oh
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
max
(
x
.
size
[
1
:])
<=
16
)
ow
,
vw
=
cfg
.
define_split
(
'tile_ow'
,
ow
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
max
(
x
.
size
[
1
:])
<=
16
)
cfg
.
define_annotate
(
'ann_reduce'
,
[
ib
,
kb
,
kh
,
kw
],
policy
=
'try_unroll'
)
re_axes
=
cfg
.
define_reorder
(
"reorder_0"
,
[
n
,
co
,
oh
,
ow
,
vc
,
vh
,
vw
,
kh
,
kw
,
kb
,
ib
,
ci
],
policy
=
'interval_all'
,
interval
=
(
6
,
11
))
cfg
.
add_flop
(
2
*
N
*
OH
*
OW
*
CO
*
CI
*
8
*
KH
*
KW
)
# these are actually binary ops
# ====================
VC
=
cfg
[
"tile_co"
]
.
size
[
-
1
]
VH
=
cfg
[
"tile_oh"
]
.
size
[
-
1
]
VW
=
cfg
[
"tile_ow"
]
.
size
[
-
1
]
dshape
=
(
IB
,
1
,
CI
,
H
,
W
)
dpshape
=
(
IB
,
1
,
CI
,
TH
,
TW
)
dvshape
=
(
1
,
TH
//
(
VH
*
HSTR
),
TW
//
(
VW
*
WSTR
),
CI
,
VH
*
HSTR
+
HCAT
,
VW
*
WSTR
+
WCAT
,
IB
)
kshape
=
(
KB
,
CO
,
CI
,
KH
,
KW
)
kvshape
=
(
CO
//
VC
,
CI
,
KH
,
KW
,
KB
,
VC
)
ovshape
=
(
1
,
CO
//
VC
,
OH
//
VH
,
OW
//
VW
,
VH
,
VW
,
VC
)
oshape
=
(
1
,
CO
,
OH
,
OW
)
DOPAD
=
(
HPAD
!=
0
and
WPAD
!=
0
)
if
DOPAD
:
data_pad
=
pad
(
data_q
,
(
0
,
0
,
0
,
HPAD
,
WPAD
),
name
=
"data_pad"
)
if
(
TPAD
!=
0
and
RPAD
!=
0
):
data_pad
=
pad
(
data_q
,
(
0
,
0
,
0
,
TPAD
,
LPAD
),
(
0
,
0
,
0
,
DPAD
,
RPAD
),
name
=
"data_pad"
)
else
:
data_pad
=
data_q
data_vec
=
tvm
.
compute
(
dvshape
,
lambda
n
,
h
,
w
,
ci
,
vh
,
vw
,
b
:
\
data_pad
[
b
][
n
][
ci
][
h
*
VH
*
HSTR
+
vh
][
w
*
VW
*
WSTR
+
vw
],
name
=
'data_vec'
)
kernel_vec
=
tvm
.
compute
(
kvshape
,
lambda
co
,
ci
,
dh
,
dw
,
b
,
vc
:
\
kernel_q
[
b
][
co
*
VC
+
vc
][
ci
][
dh
][
dw
],
name
=
'kernel_vec'
)
if
len
(
kernel
.
shape
)
==
4
:
kernel_vec
=
tvm
.
compute
(
kvshape
,
lambda
co
,
ci
,
dh
,
dw
,
b
,
vc
:
\
kernel_q
[
b
][
co
*
VC
+
vc
][
ci
][
dh
][
dw
],
name
=
'kernel_vec'
)
ci
=
tvm
.
reduce_axis
((
0
,
CI
),
name
=
'ci'
)
dh
=
tvm
.
reduce_axis
((
0
,
KH
),
name
=
'dh'
)
...
...
@@ -183,7 +278,7 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
def
_conv
(
n
,
co
,
h
,
w
,
vh
,
vw
,
vc
):
b1b2
=
(
b1
+
b2
)
.
astype
(
out_dtype
)
if
dorefa
:
if
unipolar
:
return
tvm
.
sum
((
tvm
.
popcount
(
data_vec
[
n
,
h
,
w
,
ci
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
b1
]
.
astype
(
out_dtype
)
&
kernel_vec
[
co
,
ci
,
dh
,
dw
,
b2
,
vc
]
.
astype
(
out_dtype
))
-
...
...
@@ -203,15 +298,28 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
conv
[
n
][
co
//
VC
][
h
//
VH
][
w
//
VW
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
name
=
'conv_vec'
,
tag
=
'spatial_bitserial_conv_nchw'
)
def
spatial_pack_nhwc
(
data
,
kernel
,
stride
,
padding
,
in_bits
,
weight_bits
,
pack_dtype
,
out_dtype
,
dorefa
=
False
):
@autotvm.register_topi_compute
(
bitserial_conv2d_nhwc
,
'cpu'
,
'direct'
)
def
spatial_pack_nhwc
(
cfg
,
data
,
kernel
,
stride
,
padding
,
in_bits
,
weight_bits
,
pack_dtype
=
'uint32'
,
out_dtype
=
'int16'
,
unipolar
=
True
):
""" Compute convolution with pack on spatial axes. """
assert
data
.
shape
[
0
]
.
value
==
1
,
"spatial pack convolution only support batch size=1"
data_q
=
bitpack
(
data
,
in_bits
,
pack_axis
=
3
,
bit_axis
=
4
,
pack_type
=
pack_dtype
)
kernel_q
=
bitpack
(
kernel
,
weight_bits
,
pack_axis
=
2
,
bit_axis
=
4
,
pack_type
=
pack_dtype
)
_
,
H
,
W
,
CI
,
IB
=
data_q
.
shape
KH
,
KW
,
_
,
CO
,
KB
=
kernel_q
.
shape
HPAD
,
WPAD
,
_
,
_
=
get_pad_tuple
(
padding
,
kernel
)
pack_kernel
=
len
(
kernel
.
shape
)
==
4
if
pack_kernel
:
kernel_q
=
bitpack
(
kernel
,
weight_bits
,
pack_axis
=
2
,
bit_axis
=
4
,
pack_type
=
pack_dtype
)
else
:
kernel_q
=
kernel
KH
,
KW
,
_
,
CO
,
KB
=
get_const_tuple
(
kernel_q
.
shape
)
N
,
H
,
W
,
CI
,
IB
=
get_const_tuple
(
data_q
.
shape
)
if
isinstance
(
padding
,
int
)
or
(
isinstance
(
padding
,
(
tuple
,
list
))
and
len
(
padding
)
==
2
):
TPAD
,
LPAD
,
DPAD
,
RPAD
=
get_pad_tuple
(
padding
,
kernel
)
else
:
TPAD
,
LPAD
,
DPAD
,
RPAD
=
padding
pad_before
=
[
0
,
TPAD
,
LPAD
,
0
,
0
]
pad_after
=
[
0
,
DPAD
,
RPAD
,
0
,
0
]
if
isinstance
(
stride
,
(
tuple
,
list
)):
HSTR
,
WSTR
=
stride
...
...
@@ -219,24 +327,41 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
HSTR
,
WSTR
=
stride
,
stride
HCAT
,
WCAT
=
KH
-
1
,
KW
-
1
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
"NHWC"
)
sch
=
_get_schedule
(
wkl
,
"NHWC"
)
VH
=
sch
.
vh
VW
=
sch
.
vw
VC
=
sch
.
vc
PAD_H
=
H
+
(
TPAD
+
DPAD
)
PAD_W
=
W
+
(
LPAD
+
RPAD
)
OH
=
(
PAD_H
-
KH
)
//
HSTR
+
1
OW
=
(
PAD_W
-
KW
)
//
WSTR
+
1
oshape
=
(
1
,
OH
,
OW
,
CO
)
PAD_H
=
H
+
2
*
HPAD
PAD_W
=
W
+
2
*
WPAD
OH
=
(
H
+
2
*
HPAD
-
KH
)
//
HSTR
+
1
OW
=
(
W
+
2
*
WPAD
-
KW
)
//
WSTR
+
1
# ==================== define configuration space ====================
n
,
oh
,
ow
,
co
=
cfg
.
axis
(
N
),
cfg
.
axis
(
OH
),
cfg
.
axis
(
OW
),
cfg
.
axis
(
CO
)
ci
,
kh
,
kw
=
cfg
.
reduce_axis
(
CI
),
cfg
.
reduce_axis
(
KH
),
cfg
.
reduce_axis
(
KW
)
ib
,
kb
=
cfg
.
reduce_axis
(
in_bits
),
cfg
.
reduce_axis
(
weight_bits
)
co
,
vc
=
cfg
.
define_split
(
'tile_co'
,
co
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
max
(
x
.
size
[
1
:])
<=
16
)
oh
,
vh
=
cfg
.
define_split
(
'tile_oh'
,
oh
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
max
(
x
.
size
[
1
:])
<=
16
)
ow
,
vw
=
cfg
.
define_split
(
'tile_ow'
,
ow
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
max
(
x
.
size
[
1
:])
<=
16
)
cfg
.
define_annotate
(
'ann_reduce'
,
[
ib
,
kb
,
kh
,
kw
],
policy
=
'try_unroll'
)
re_axes
=
cfg
.
define_reorder
(
"reorder_0"
,
[
n
,
oh
,
ow
,
co
,
vh
,
vw
,
kh
,
kw
,
kb
,
ib
,
vc
,
ci
],
policy
=
'interval_all'
,
interval
=
(
3
,
7
))
cfg
.
add_flop
(
2
*
N
*
OH
*
OW
*
CO
*
CI
*
8
*
KH
*
KW
)
# these are actually binary ops
# ====================
VC
=
cfg
[
"tile_co"
]
.
size
[
-
1
]
VH
=
cfg
[
"tile_oh"
]
.
size
[
-
1
]
VW
=
cfg
[
"tile_ow"
]
.
size
[
-
1
]
dvshape
=
(
1
,
PAD_H
//
(
VH
*
HSTR
),
PAD_W
//
(
VW
*
WSTR
),
VH
*
HSTR
+
HCAT
,
VW
*
WSTR
+
WCAT
,
CI
,
IB
)
kvshape
=
(
CO
,
KH
,
KW
,
CI
,
VC
,
KB
)
ovshape
=
(
1
,
OH
,
OW
,
CO
,
VH
,
VW
,
VC
)
oshape
=
(
1
,
OH
,
OW
,
CO
)
if
(
HPAD
!=
0
and
W
PAD
!=
0
):
data_pad
=
pad
(
data_q
,
(
0
,
HPAD
,
W
PAD
,
0
,
0
),
name
=
"data_pad"
)
if
(
DPAD
!=
0
and
R
PAD
!=
0
):
data_pad
=
pad
(
data_q
,
(
0
,
TPAD
,
LPAD
,
0
,
0
),
(
0
,
DPAD
,
R
PAD
,
0
,
0
),
name
=
"data_pad"
)
else
:
data_pad
=
data_q
...
...
@@ -254,12 +379,12 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
def
_conv
(
n
,
h
,
w
,
co
,
vh
,
vw
,
vc
):
b1b2
=
(
b1
+
b2
)
.
astype
(
out_dtype
)
if
dorefa
:
if
unipolar
:
return
tvm
.
sum
(
(
tvm
.
popcount
(
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ci
,
b1
]
.
astype
(
out_dtype
)
&
kernel_vec
[
co
,
dh
,
dw
,
ci
,
vc
,
b2
]
.
astype
(
out_dtype
)
)
-
tvm
.
popcount
(
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ci
,
b1
]
.
astype
(
out_dtype
)
&
~
kernel_vec
[
co
,
dh
,
dw
,
ci
,
vc
,
b2
])
.
astype
(
out_dtype
))
<<
b1b2
,
(
(
tvm
.
popcount
(
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ci
,
b1
]
&
kernel_vec
[
co
,
dh
,
dw
,
ci
,
vc
,
b2
])
.
astype
(
out_dtype
)
-
tvm
.
popcount
(
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ci
,
b1
]
&
~
kernel_vec
[
co
,
dh
,
dw
,
ci
,
vc
,
b2
])
.
astype
(
out_dtype
))
<<
b1b2
)
,
axis
=
[
dh
,
dw
,
ci
,
b1
,
b2
])
return
tvm
.
sum
(
tvm
.
popcount
(
...
...
@@ -273,6 +398,7 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
conv
[
n
][
h
//
VH
][
w
//
VW
][
co
//
VC
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
name
=
'output_unpack'
,
tag
=
'spatial_bitserial_conv_nhwc'
)
def
bitpack
(
data
,
bits
,
pack_axis
,
bit_axis
,
pack_type
,
name
=
"QuantizeInput"
):
"""Packs data into format necessary for bitserial computation
pack_axis : int
...
...
@@ -334,8 +460,3 @@ def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
if
bits
>
1
:
return
concatenate
(
output_tuple
,
axis
=
bit_axis
)
return
output_tuple
_SCH_TO_DECL_FUNC_QUANT
=
{
SpatialPackNCHW
:
spatial_pack_nchw
,
SpatialPackNHWC
:
spatial_pack_nhwc
,
}
topi/python/topi/x86/bitserial_conv2d.py
View file @
17351875
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Bitserial conv2d schedule on x86"""
import
tvm
from
tvm
import
autotvm
from
topi.util
import
get_const_int
from
..
import
generic
,
tag
from
..nn.bitserial_conv2d
import
bitserial_conv2d
,
_get_schedule
,
_get_workload
from
..nn.bitserial_conv2d
import
SpatialPackNCHW
,
SpatialPackNHWC
from
..nn.bitserial_conv2d
import
_WORKLOADS
,
_SCH_TO_DECL_FUNC_QUANT
_QUANTIZED_SCHEDULES_NCHW
=
[
# resnet
SpatialPackNCHW
(
2
,
2
,
8
,
1
,
1
),
SpatialPackNCHW
(
1
,
4
,
8
,
4
,
1
),
SpatialPackNCHW
(
1
,
4
,
8
,
1
,
16
),
SpatialPackNCHW
(
1
,
4
,
8
,
4
,
8
),
SpatialPackNCHW
(
1
,
7
,
8
,
3
,
8
),
SpatialPackNCHW
(
1
,
2
,
8
,
1
,
8
),
SpatialPackNCHW
(
2
,
1
,
8
,
1
,
4
),
SpatialPackNCHW
(
1
,
7
,
8
,
1
,
1
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
16
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
8
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
16
),
SpatialPackNCHW
(
3
,
3
,
16
,
3
,
16
),
SpatialPackNCHW
(
1
,
1
,
16
,
2
,
16
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
16
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
16
),
]
_QUANTIZED_SCHEDULES_NHWC
=
[
# resnet
SpatialPackNHWC
(
2
,
2
,
8
,
1
,
1
),
SpatialPackNHWC
(
1
,
4
,
8
,
4
,
1
),
SpatialPackNHWC
(
1
,
4
,
8
,
1
,
16
),
SpatialPackNHWC
(
1
,
4
,
8
,
4
,
8
),
SpatialPackNHWC
(
1
,
7
,
8
,
3
,
8
),
SpatialPackNHWC
(
1
,
2
,
8
,
1
,
8
),
SpatialPackNHWC
(
2
,
1
,
8
,
1
,
4
),
SpatialPackNHWC
(
1
,
7
,
8
,
1
,
1
),
SpatialPackNHWC
(
1
,
1
,
8
,
1
,
16
),
SpatialPackNHWC
(
1
,
1
,
8
,
1
,
8
),
SpatialPackNHWC
(
1
,
1
,
8
,
1
,
16
),
]
@_get_schedule.register
(
"cpu"
)
def
_get_schedule_bitserial_conv2d
(
wkl
,
layout
):
if
wkl
not
in
_WORKLOADS
:
raise
ValueError
(
"no schedule for such workload: {}"
.
format
(
wkl
))
idx
=
_WORKLOADS
.
index
(
wkl
)
if
layout
==
"NCHW"
:
sch
=
_QUANTIZED_SCHEDULES_NCHW
[
idx
]
elif
layout
==
"NHWC"
:
sch
=
_QUANTIZED_SCHEDULES_NHWC
[
idx
]
return
sch
@bitserial_conv2d.register
(
"cpu"
)
def
_declaration_bitserial_conv2d
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
layout
=
'NCHW'
,
pack_dtype
=
None
,
out_dtype
=
None
,
dorefa
=
False
):
if
out_dtype
is
None
:
out_dtype
=
data
.
dtype
assert
data
.
shape
[
0
]
.
value
==
1
,
"only support batch size=1 convolution on rasp"
assert
layout
in
(
"NCHW"
,
"NHWC"
),
"only support layouts NCHW and NHWC"
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
layout
)
sch
=
_get_schedule
(
wkl
,
layout
)
return
_SCH_TO_DECL_FUNC_QUANT
[
type
(
sch
)](
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
pack_dtype
,
out_dtype
,
dorefa
)
@generic.schedule_bitserial_conv2d_nchw.register
([
"cpu"
])
@generic.schedule_bitserial_conv2d_nhwc.register
([
"cpu"
])
def
schedule_bitserial_conv2d
(
outs
):
@autotvm.register_topi_schedule
(
generic
.
nn
.
schedule_bitserial_conv2d_nchw
,
[
'cpu'
],
'direct'
)
@autotvm.register_topi_schedule
(
generic
.
nn
.
schedule_bitserial_conv2d_nhwc
,
[
'cpu'
],
'direct'
)
def
schedule_bitserial_conv2d
(
cfg
,
outs
):
"""CPU schedule for bitserial convolutions NCHW and NHWC"""
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
...
...
@@ -88,7 +27,6 @@ def schedule_bitserial_conv2d(outs):
conv_out
=
op
.
input_tensors
[
0
]
kernel_vec
=
conv_out
.
op
.
input_tensors
[
1
]
kernel_q
=
kernel_vec
.
op
.
input_tensors
[
0
]
kernel
=
kernel_q
.
op
.
input_tensors
[
0
]
data_vec
=
conv_out
.
op
.
input_tensors
[
0
]
data_q
=
data_vec
.
op
.
input_tensors
[
0
]
data
=
data_q
.
op
.
input_tensors
[
0
]
...
...
@@ -97,29 +35,27 @@ def schedule_bitserial_conv2d(outs):
data_pad
=
data_q
data_q
=
data
data
=
data_q
.
op
.
input_tensors
[
0
]
if
"QuantizeInput"
in
kernel
.
op
.
name
:
# Need to go up 1 further, from the combine in bitpack
kernel
=
kernel
.
op
.
input_tensors
[
0
]
if
"QuantizeInput"
in
data
.
op
.
name
:
# Need to go up 1 further, from the combine in bitpack
data
=
data
.
op
.
input_tensors
[
0
]
if
'spatial_bitserial_conv_nchw'
in
op
.
tag
:
_schedule_
spatial_conv2d_nchw
(
s
,
data
,
data_q
,
data_pad
,
data_vec
,
kernel
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
outs
[
0
])
_schedule_
bitserial_conv2d_nchw
(
cfg
,
s
,
data_q
,
data_pad
,
data_vec
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
outs
[
0
])
elif
'spatial_bitserial_conv_nhwc'
in
op
.
tag
:
_schedule_
spatial_conv2d_nhwc
(
s
,
data
,
data_q
,
data_pad
,
data_vec
,
kernel
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
outs
[
0
])
_schedule_
bitserial_conv2d_nhwc
(
cfg
,
s
,
data_q
,
data_pad
,
data_vec
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
outs
[
0
])
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
def
_schedule_
spatial_conv2d_nchw
(
s
,
data
,
data_q
,
data_pad
,
data_vec
,
kernel
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
last
):
def
_schedule_
bitserial_conv2d_nchw
(
cfg
,
s
,
data_q
,
data_pad
,
data_vec
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
last
):
IB
,
_
,
CI
,
IH
,
IW
=
data_q
.
shape
KB
,
CO
,
_
,
KH
,
KW
=
kernel_q
.
shape
_
,
_
,
OH
,
OW
=
output
.
shape
...
...
@@ -138,37 +74,21 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
wstride
=
get_const_int
((
TW
-
KW
)
//
(
OW
-
1
))
stride
=
(
hstride
,
wstride
)
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
output
.
dtype
,
"NCHW"
)
sch
=
_get_schedule
(
wkl
,
"NCHW"
)
VH
=
sch
.
vh
VW
=
sch
.
vw
VC
=
sch
.
vc
ba
=
sch
.
ba
bc
=
sch
.
bc
CC
=
s
.
cache_write
(
conv_out
,
"global"
)
n
,
co
,
oh
,
ow
,
vh
,
vw
,
vc
=
s
[
conv_out
]
.
op
.
axis
s
[
conv_out
]
.
vectorize
(
vc
)
s
[
CC
]
.
compute_at
(
s
[
conv_out
],
ow
)
n
,
co
,
oh
,
ow
,
vh
,
vw
,
vc
=
s
[
CC
]
.
op
.
axis
ci
,
dh
,
dw
,
b1
,
b2
=
s
[
CC
]
.
op
.
reduce_axis
s
[
CC
]
.
reorder
(
ci
,
dh
,
vh
,
dw
,
vw
,
b1
,
b2
,
vc
)
s
[
CC
]
.
unroll
(
b1
)
s
[
CC
]
.
unroll
(
b2
)
s
[
CC
]
.
vectorize
(
vc
)
VC
=
cfg
[
"tile_co"
]
.
size
[
-
1
]
VH
=
cfg
[
"tile_oh"
]
.
size
[
-
1
]
VW
=
cfg
[
"tile_ow"
]
.
size
[
-
1
]
##### Schedule A
##### Schedule Data padding, and bitpacking
if
data_pad
is
not
None
:
s
[
data_pad
]
.
compute_inline
()
_
,
h
,
_
,
_
,
_
,
_
,
vw
=
s
[
data_vec
]
.
op
.
axis
s
[
data_vec
]
.
vectorize
(
vw
)
if
ba
==
1
:
oaxis
=
h
paxis
=
h
_
,
_
,
h
,
_
,
_
,
_
,
_
=
s
[
data_vec
]
.
op
.
axis
cfg
.
define_split
(
"tile_ah"
,
cfg
.
axis
(
h
),
policy
=
"all"
,
num_outputs
=
2
,
max_factor
=
32
)
oh
,
ih
=
cfg
[
"tile_ah"
]
.
apply
(
s
,
data_vec
,
h
)
if
cfg
[
"tile_ah"
]
.
size
[
1
]
==
1
:
oaxis
=
oh
paxis
=
oh
else
:
oh
,
ih
=
s
[
data_vec
]
.
split
(
h
,
ba
)
oaxis
=
oh
paxis
=
ih
...
...
@@ -178,14 +98,14 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s
[
data_vec
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
##### Schedule B
co
,
_
,
_
,
_
,
_
,
vc
=
s
[
kernel_vec
]
.
op
.
axis
s
[
kernel_vec
]
.
vectorize
(
vc
)
if
bc
==
1
:
oaxis
=
co
paxis
=
co
##### Schedule Kenerl bitpacking
co
,
_
,
_
,
_
,
_
,
_
=
s
[
kernel_vec
]
.
op
.
axis
cfg
.
define_split
(
"tile_bco"
,
cfg
.
axis
(
co
),
policy
=
"all"
,
num_outputs
=
2
,
max_factor
=
32
)
oco
,
ico
=
cfg
[
"tile_bco"
]
.
apply
(
s
,
kernel_vec
,
co
)
if
cfg
[
"tile_bco"
]
.
size
[
1
]
==
1
:
oaxis
=
oco
paxis
=
oco
else
:
oco
,
ico
=
s
[
kernel_vec
]
.
split
(
co
,
bc
)
oaxis
=
oco
paxis
=
ico
...
...
@@ -195,7 +115,23 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s
[
kernel_vec
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
##### Schedule C
##### Schedule Convolution
n
,
co
,
oh
,
ow
,
vh
,
vw
,
vc
=
s
[
conv_out
]
.
op
.
axis
ci
,
dh
,
dw
,
ib
,
kb
=
s
[
conv_out
]
.
op
.
reduce_axis
# s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
cfg
[
"reorder_0"
]
.
apply
(
s
,
conv_out
,
[
n
,
co
,
oh
,
ow
,
vc
,
vh
,
vw
,
dh
,
dw
,
kb
,
ib
,
ci
])
cfg
[
"ann_reduce"
]
.
apply
(
s
,
conv_out
,
[
kb
,
ib
,
dh
,
dw
],
axis_lens
=
[
get_const_int
(
kb
.
dom
.
extent
),
get_const_int
(
ib
.
dom
.
extent
),
get_const_int
(
dh
.
dom
.
extent
),
get_const_int
(
dw
.
dom
.
extent
)],
max_unroll
=
16
,
cfg
=
cfg
)
s
[
conv_out
]
.
vectorize
(
vc
)
# # Schedule output
n
,
co
,
h
,
w
=
s
[
last
]
.
op
.
axis
co
,
vc
=
s
[
last
]
.
split
(
co
,
VC
)
oh
,
ow
,
vh
,
vw
=
s
[
last
]
.
tile
(
h
,
w
,
VH
,
VW
)
...
...
@@ -204,89 +140,58 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s
[
output
]
.
compute_inline
()
s
[
conv_out
]
.
compute_at
(
s
[
last
],
ow
)
if
bc
==
1
:
oaxis
=
co
paxis
=
co
oco
,
ico
=
cfg
[
"tile_oh"
]
.
apply
(
s
,
last
,
co
)
if
cfg
[
"tile_oh"
]
.
size
[
1
]
==
1
:
oaxis
=
oco
paxis
=
oco
else
:
oco
,
ico
=
s
[
last
]
.
split
(
co
,
bc
)
oaxis
=
oco
paxis
=
ico
s
[
last
]
.
parallel
(
paxis
)
s
[
last
]
.
pragma
(
oaxis
,
"parallel_launch_point"
)
s
[
last
]
.
pragma
(
paxis
,
"parallel_stride_pattern"
)
s
[
last
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
s
[
last
]
.
parallel
(
oco
)
return
s
def
_schedule_
spatial_conv2d_nhwc
(
s
,
data
,
data_q
,
data_pad
,
data_vec
,
kernel
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
last
):
def
_schedule_
bitserial_conv2d_nhwc
(
cfg
,
s
,
data_q
,
data_pad
,
data_vec
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
last
):
# no stride and padding info here
_
,
IH
,
IW
,
CI
,
IB
=
data_q
.
shape
KH
,
KW
,
_
,
CO
,
KB
=
kernel_q
.
shape
_
,
OH
,
OW
,
_
=
output
.
shape
# Infer padding and stride
if
data_pad
is
None
:
padding
=
(
0
,
0
)
TH
,
TW
=
IH
,
IW
else
:
_
,
TH
,
TW
,
_
,
_
=
data_pad
.
shape
hpad
=
get_const_int
((
TH
-
IH
)
//
2
)
wpad
=
get_const_int
((
TW
-
IW
)
//
2
)
padding
=
(
hpad
,
wpad
)
hstride
=
get_const_int
((
TH
-
KH
)
//
(
OH
-
1
))
wstride
=
get_const_int
((
TW
-
KW
)
//
(
OW
-
1
))
stride
=
(
hstride
,
wstride
)
VC
=
cfg
[
"tile_co"
]
.
size
[
-
1
]
VH
=
cfg
[
"tile_oh"
]
.
size
[
-
1
]
VW
=
cfg
[
"tile_ow"
]
.
size
[
-
1
]
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
last
.
dtype
,
"NHWC"
)
sch
=
_get_schedule
(
wkl
,
"NHWC"
)
VH
=
sch
.
vh
VW
=
sch
.
vw
VC
=
sch
.
vc
ba
=
sch
.
ba
bc
=
sch
.
bc
##### Schedule data packing
##### Schedule data padding and packing
if
data_pad
is
not
None
:
s
[
data_pad
]
.
compute_inline
()
_
,
h
,
_
,
_
,
_
,
_
,
_
=
s
[
data_vec
]
.
op
.
axis
if
ba
==
1
:
oaxis
=
h
paxis
=
h
else
:
oh
,
ih
=
s
[
data_vec
]
.
split
(
h
,
ba
)
oaxis
=
oh
paxis
=
ih
s
[
data_vec
]
.
parallel
(
paxis
)
s
[
data_vec
]
.
pragma
(
oaxis
,
"parallel_launch_point"
)
s
[
data_vec
]
.
pragma
(
paxis
,
"parallel_stride_pattern"
)
s
[
data_vec
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
cfg
.
define_split
(
"tile_ah"
,
cfg
.
axis
(
h
),
policy
=
"all"
,
num_outputs
=
2
,
max_factor
=
32
)
oh
,
ih
=
cfg
[
"tile_ah"
]
.
apply
(
s
,
data_vec
,
h
)
s
[
data_vec
]
.
parallel
(
oh
)
##### Schedule kernel packing
co
,
_
,
_
,
_
,
_
,
_
=
s
[
kernel_vec
]
.
op
.
axis
if
bc
==
1
:
oaxis
=
co
paxis
=
co
else
:
oco
,
ico
=
s
[
kernel_vec
]
.
split
(
co
,
bc
)
oaxis
=
oco
paxis
=
ico
s
[
kernel_vec
]
.
parallel
(
paxis
)
s
[
kernel_vec
]
.
pragma
(
oaxis
,
"parallel_launch_point"
)
s
[
kernel_vec
]
.
pragma
(
paxis
,
"parallel_stride_pattern"
)
s
[
kernel_vec
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
cfg
.
define_split
(
"tile_bco"
,
cfg
.
axis
(
co
),
policy
=
"all"
,
num_outputs
=
2
,
max_factor
=
32
)
oco
,
ico
=
cfg
[
"tile_bco"
]
.
apply
(
s
,
kernel_vec
,
co
)
s
[
kernel_vec
]
.
parallel
(
oco
)
##### Schedule Convolution
n
,
oh
,
ow
,
co
,
vh
,
vw
,
vc
=
s
[
conv_out
]
.
op
.
axis
dh
,
dw
,
ci
,
b1
,
b2
=
s
[
conv_out
]
.
op
.
reduce_axis
s
[
conv_out
]
.
reorder
(
n
,
oh
,
ow
,
co
,
vh
,
vw
,
dh
,
dw
,
ci
,
vc
,
b1
,
b2
)
# s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
cfg
[
"reorder_0"
]
.
apply
(
s
,
conv_out
,
[
n
,
oh
,
ow
,
co
,
vh
,
vw
,
dh
,
dw
,
ci
,
vc
,
b1
,
b2
])
cfg
[
"ann_reduce"
]
.
apply
(
s
,
conv_out
,
[
b1
,
b2
,
dh
,
dw
],
axis_lens
=
[
get_const_int
(
b1
.
dom
.
extent
),
get_const_int
(
b2
.
dom
.
extent
),
get_const_int
(
dh
.
dom
.
extent
),
get_const_int
(
dw
.
dom
.
extent
)],
max_unroll
=
16
,
cfg
=
cfg
)
s
[
conv_out
]
.
unroll
(
b1
)
s
[
conv_out
]
.
unroll
(
b2
)
...
...
@@ -302,17 +207,7 @@ def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
s
[
output
]
.
compute_inline
()
s
[
conv_out
]
.
compute_at
(
s
[
last
],
ow
)
if
bc
==
1
:
oaxis
=
oh
paxis
=
oh
else
:
oho
,
iho
=
s
[
last
]
.
split
(
oh
,
bc
)
oaxis
=
oho
paxis
=
iho
s
[
last
]
.
parallel
(
paxis
)
s
[
last
]
.
pragma
(
oaxis
,
"parallel_launch_point"
)
s
[
last
]
.
pragma
(
paxis
,
"parallel_stride_pattern"
)
s
[
last
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
oho
,
iho
=
cfg
[
"tile_oh"
]
.
apply
(
s
,
last
,
oh
)
# reuse parameter
s
[
last
]
.
parallel
(
oho
)
return
s
topi/tests/python/test_topi_bitserial_conv2d.py
View file @
17351875
...
...
@@ -11,16 +11,16 @@ def generate_quantized_np(shape, bits, out_dtype):
return
np
.
random
.
randint
(
min_val
,
max_val
,
size
=
shape
)
.
astype
(
out_dtype
)
def
verify_bitserial_conv2d_nchw
(
batch
,
in_size
,
in_channel
,
num_filter
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
dorefa
):
activation_bits
,
weight_bits
,
unipolar
):
in_height
=
in_width
=
in_size
input_type
=
'uint32'
input_
d
type
=
'uint32'
out_dtype
=
'int32'
with
tvm
.
target
.
create
(
'llvm'
):
A
=
tvm
.
placeholder
((
batch
,
in_channel
,
in_height
,
in_width
),
dtype
=
input_type
,
name
=
'A'
)
W
=
tvm
.
placeholder
((
num_filter
,
in_channel
,
kernel
,
kernel
),
dtype
=
input_type
,
name
=
'W'
)
B
=
topi
.
nn
.
bitserial_conv2d
(
A
,
W
,
stride
,
padding
,
activation_bits
,
weight_bits
,
out_dtype
=
out_dtype
,
layout
=
"NCHW"
,
dorefa
=
dorefa
)
A
=
tvm
.
placeholder
((
batch
,
in_channel
,
in_height
,
in_width
),
dtype
=
input_
d
type
,
name
=
'A'
)
W
=
tvm
.
placeholder
((
num_filter
,
in_channel
,
kernel
,
kernel
),
dtype
=
input_
d
type
,
name
=
'W'
)
B
=
topi
.
nn
.
bitserial_conv2d
_nchw
(
A
,
W
,
stride
,
padding
,
activation_bits
,
weight_bits
,
out_dtype
=
out_dtype
,
unipolar
=
unipolar
)
s
=
topi
.
generic
.
schedule_bitserial_conv2d_nchw
([
B
])
a_shape
=
get_const_tuple
(
A
.
shape
)
...
...
@@ -28,9 +28,9 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
@memoize
(
"topi.tests.test_topi_bitseral_conv2d_nchw"
)
def
get_ref_data
():
a_np
=
generate_quantized_np
(
get_const_tuple
(
a_shape
),
activation_bits
,
input_type
)
w_np
=
generate_quantized_np
(
get_const_tuple
(
w_shape
),
weight_bits
,
input_type
)
if
dorefa
:
a_np
=
generate_quantized_np
(
get_const_tuple
(
a_shape
),
activation_bits
,
input_
d
type
)
w_np
=
generate_quantized_np
(
get_const_tuple
(
w_shape
),
weight_bits
,
input_
d
type
)
if
unipolar
:
w_
=
np
.
copy
(
w_np
)
.
astype
(
out_dtype
)
for
x
in
np
.
nditer
(
w_
,
op_flags
=
[
'readwrite'
]):
x
[
...
]
=
1
if
x
==
1
else
-
1
...
...
@@ -49,16 +49,16 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
tvm
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
)
def
verify_bitserial_conv2d_nhwc
(
batch
,
in_size
,
in_channel
,
num_filter
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
dorefa
):
activation_bits
,
weight_bits
,
unipolar
):
in_height
=
in_width
=
in_size
input_type
=
'uint32'
input_
d
type
=
'uint32'
out_dtype
=
'int32'
with
tvm
.
target
.
create
(
'llvm'
):
A
=
tvm
.
placeholder
((
batch
,
in_height
,
in_width
,
in_channel
),
dtype
=
input_type
,
name
=
'A'
)
W
=
tvm
.
placeholder
((
kernel
,
kernel
,
in_channel
,
num_filter
),
dtype
=
input_type
,
name
=
'W'
)
B
=
topi
.
nn
.
bitserial_conv2d
(
A
,
W
,
stride
,
padding
,
activation_bits
,
weight_bits
,
out_dtype
=
out_dtype
,
layout
=
"NHWC"
,
dorefa
=
dorefa
)
A
=
tvm
.
placeholder
((
batch
,
in_height
,
in_width
,
in_channel
),
dtype
=
input_
d
type
,
name
=
'A'
)
W
=
tvm
.
placeholder
((
kernel
,
kernel
,
in_channel
,
num_filter
),
dtype
=
input_
d
type
,
name
=
'W'
)
B
=
topi
.
nn
.
bitserial_conv2d
_nhwc
(
A
,
W
,
stride
,
padding
,
activation_bits
,
weight_bits
,
out_dtype
=
out_dtype
,
unipolar
=
unipolar
)
s
=
topi
.
generic
.
schedule_bitserial_conv2d_nhwc
([
B
])
a_shape
=
get_const_tuple
(
A
.
shape
)
...
...
@@ -66,9 +66,9 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
@memoize
(
"topi.tests.test_topi_bitseral_conv2d_nhwc"
)
def
get_ref_data
():
a_np
=
generate_quantized_np
(
get_const_tuple
(
a_shape
),
activation_bits
,
input_type
)
w_np
=
generate_quantized_np
(
get_const_tuple
(
w_shape
),
weight_bits
,
input_type
)
if
dorefa
:
a_np
=
generate_quantized_np
(
get_const_tuple
(
a_shape
),
activation_bits
,
input_
d
type
)
w_np
=
generate_quantized_np
(
get_const_tuple
(
w_shape
),
weight_bits
,
input_
d
type
)
if
unipolar
:
w_
=
np
.
copy
(
w_np
)
.
astype
(
out_dtype
)
for
x
in
np
.
nditer
(
w_
,
op_flags
=
[
'readwrite'
]):
x
[
...
]
=
1
if
x
==
1
else
-
1
...
...
topi/tests/python/test_topi_bitserial_conv2d_rasp.py
View file @
17351875
...
...
@@ -4,6 +4,7 @@ import numpy as np
import
tvm
import
topi
import
topi.testing
from
topi.util
import
get_const_tuple
def
generate_quantized_np
(
shape
,
bits
,
out_dtype
):
np
.
random
.
seed
(
0
)
...
...
@@ -13,19 +14,20 @@ def generate_quantized_np(shape, bits, out_dtype):
# Verify that certain special instructions from the tensorize pass exist
def
verify_bitserial_conv2d_nhwc
(
batch
,
in_size
,
in_channel
,
num_filter
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
dorefa
):
activation_bits
,
weight_bits
,
unipolar
):
in_height
=
in_width
=
in_size
input_type
=
'uint32'
out_dtype
=
'int
32
'
out_dtype
=
'int
16
'
with
tvm
.
target
.
arm_cpu
(
'rasp3b'
):
device
=
'llvm -device=arm_cpu -model=bcm2837 -target=armv7l-linux-gnueabihf -mattr=+neon'
with
tvm
.
target
.
create
(
device
):
A
=
tvm
.
placeholder
((
batch
,
in_height
,
in_width
,
in_channel
),
dtype
=
input_type
,
name
=
'A'
)
W
=
tvm
.
placeholder
((
kernel
,
kernel
,
in_channel
,
num_filter
),
dtype
=
input_type
,
name
=
'W'
)
B
=
topi
.
nn
.
bitserial_conv2d
(
A
,
W
,
stride
,
padding
,
activation_bits
,
weight_bits
,
out_dtype
=
out_dtype
,
layout
=
"NHWC"
,
dorefa
=
dorefa
)
B
=
topi
.
nn
.
bitserial_conv2d
_nhwc
(
A
,
W
,
stride
,
padding
,
activation_bits
,
weight_bits
,
pack_dtype
=
'uint8'
,
out_dtype
=
'int16'
,
unipolar
=
unipolar
)
s
=
topi
.
generic
.
schedule_bitserial_conv2d_nhwc
([
B
])
func
=
tvm
.
build
(
s
,
[
A
,
W
,
B
],
tvm
.
target
.
arm_cpu
(
'rasp3b'
)
)
func
=
tvm
.
build
(
s
,
[
A
,
W
,
B
],
device
)
assembly
=
func
.
get_source
(
'asm'
)
matches
=
re
.
findall
(
"vpadal"
,
assembly
)
...
...
@@ -35,6 +37,33 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
matches
=
re
.
findall
(
"vpadd"
,
assembly
)
assert
(
len
(
matches
)
>
0
)
ctx
=
tvm
.
context
(
device
,
0
)
if
'arm'
not
in
os
.
uname
()[
4
]:
print
(
"Skipped running code, not an arm device"
)
return
print
(
"Running on target:
%
s"
%
device
)
def
get_ref_data
():
a_np
=
generate_quantized_np
(
get_const_tuple
(
A
.
shape
),
activation_bits
,
input_type
)
w_np
=
generate_quantized_np
(
get_const_tuple
(
W
.
shape
),
weight_bits
,
input_type
)
if
unipolar
:
w_
=
np
.
copy
(
w_np
)
.
astype
(
out_dtype
)
for
x
in
np
.
nditer
(
w_
,
op_flags
=
[
'readwrite'
]):
x
[
...
]
=
1
if
x
==
1
else
-
1
b_np
=
topi
.
testing
.
conv2d_nhwc_python
(
a_np
,
w_
,
stride
,
padding
)
.
astype
(
out_dtype
)
else
:
b_np
=
topi
.
testing
.
conv2d_nhwc_python
(
a_np
,
w_np
,
stride
,
padding
)
.
astype
(
out_dtype
)
return
a_np
,
w_np
,
b_np
a_np
,
w_np
,
b_np
=
get_ref_data
()
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
w
=
tvm
.
nd
.
array
(
w_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
B
.
shape
),
dtype
=
B
.
dtype
),
ctx
)
func
=
tvm
.
build
(
s
,
[
A
,
W
,
B
],
device
)
func
(
a
,
w
,
b
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
)
def
test_bitserial_conv2d
():
in_size
=
56
ic
,
oc
=
64
,
64
...
...
@@ -45,6 +74,9 @@ def test_bitserial_conv2d():
verify_bitserial_conv2d_nhwc
(
1
,
in_size
,
ic
,
oc
,
k
,
stride
,
pad
,
1
,
1
,
False
)
verify_bitserial_conv2d_nhwc
(
1
,
in_size
,
ic
,
oc
,
k
,
stride
,
pad
,
2
,
1
,
False
)
verify_bitserial_conv2d_nhwc
(
1
,
in_size
,
ic
,
oc
,
k
,
stride
,
pad
,
1
,
1
,
True
)
verify_bitserial_conv2d_nhwc
(
1
,
in_size
,
ic
,
oc
,
k
,
stride
,
pad
,
2
,
1
,
True
)
if
__name__
==
"__main__"
:
test_bitserial_conv2d
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment