Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
17351875
Commit
17351875
authored
Apr 03, 2019
by
Meghan Cowan
Committed by
Tianqi Chen
Apr 03, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI] bitserial_conv2d move to autotvm template and updates (#2819)
parent
cefe07e2
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
598 additions
and
561 deletions
+598
-561
python/tvm/autotvm/task/task.py
+1
-1
python/tvm/autotvm/task/topi_integration.py
+22
-0
topi/python/topi/arm_cpu/bitserial_conv2d.py
+190
-223
topi/python/topi/nn/bitserial_conv2d.py
+255
-134
topi/python/topi/x86/bitserial_conv2d.py
+74
-179
topi/tests/python/test_topi_bitserial_conv2d.py
+18
-18
topi/tests/python/test_topi_bitserial_conv2d_rasp.py
+38
-6
No files found.
python/tvm/autotvm/task/task.py
View file @
17351875
...
@@ -205,7 +205,7 @@ def args_to_workload(x, topi_compute_func=None):
...
@@ -205,7 +205,7 @@ def args_to_workload(x, topi_compute_func=None):
workload
=
tuple
([
args_to_workload
(
a
)
for
a
in
x
])
workload
=
tuple
([
args_to_workload
(
a
)
for
a
in
x
])
elif
isinstance
(
x
,
(
str
,
int
,
float
,
np
.
int
,
np
.
float
)):
elif
isinstance
(
x
,
(
str
,
int
,
float
,
np
.
int
,
np
.
float
)):
workload
=
x
workload
=
x
elif
isinstance
(
x
,
(
expr
.
StringImm
,
expr
.
IntImm
,
expr
.
FloatImm
)):
elif
isinstance
(
x
,
(
expr
.
StringImm
,
expr
.
UIntImm
,
expr
.
IntImm
,
expr
.
FloatImm
)):
workload
=
x
.
value
workload
=
x
.
value
elif
x
is
None
:
elif
x
is
None
:
workload
=
0
workload
=
0
...
...
python/tvm/autotvm/task/topi_integration.py
View file @
17351875
...
@@ -68,6 +68,8 @@ class TaskExtractEnv:
...
@@ -68,6 +68,8 @@ class TaskExtractEnv:
topi
.
nn
.
group_conv2d_nchw
:
"topi_nn_group_conv2d_nchw"
,
topi
.
nn
.
group_conv2d_nchw
:
"topi_nn_group_conv2d_nchw"
,
topi
.
nn
.
conv2d_transpose_nchw
:
"topi_nn_conv2d_transpose_nchw"
,
topi
.
nn
.
conv2d_transpose_nchw
:
"topi_nn_conv2d_transpose_nchw"
,
topi
.
nn
.
dense
:
"topi_nn_dense"
,
topi
.
nn
.
dense
:
"topi_nn_dense"
,
topi
.
nn
.
bitserial_conv2d_nchw
:
"topi_nn_bitserial_conv2d_nchw"
,
topi
.
nn
.
bitserial_conv2d_nhwc
:
"topi_nn_bitserial_conv2d_nhwc"
,
topi
.
nn
.
deformable_conv2d_nchw
:
"topi_nn_deformable_conv2d_nchw"
,
topi
.
nn
.
deformable_conv2d_nchw
:
"topi_nn_deformable_conv2d_nchw"
,
}
}
...
@@ -79,6 +81,8 @@ class TaskExtractEnv:
...
@@ -79,6 +81,8 @@ class TaskExtractEnv:
topi
.
nn
.
group_conv2d_nchw
:
[
topi
.
generic
.
schedule_group_conv2d_nchw
],
topi
.
nn
.
group_conv2d_nchw
:
[
topi
.
generic
.
schedule_group_conv2d_nchw
],
topi
.
nn
.
conv2d_transpose_nchw
:
[
topi
.
generic
.
schedule_conv2d_transpose_nchw
],
topi
.
nn
.
conv2d_transpose_nchw
:
[
topi
.
generic
.
schedule_conv2d_transpose_nchw
],
topi
.
nn
.
dense
:
[
topi
.
generic
.
schedule_dense
],
topi
.
nn
.
dense
:
[
topi
.
generic
.
schedule_dense
],
topi
.
nn
.
bitserial_conv2d_nchw
:
[
topi
.
generic
.
schedule_bitserial_conv2d_nchw
],
topi
.
nn
.
bitserial_conv2d_nhwc
:
[
topi
.
generic
.
schedule_bitserial_conv2d_nhwc
],
topi
.
nn
.
deformable_conv2d_nchw
:
[
topi
.
generic
.
schedule_deformable_conv2d_nchw
],
topi
.
nn
.
deformable_conv2d_nchw
:
[
topi
.
generic
.
schedule_deformable_conv2d_nchw
],
}
}
...
@@ -174,6 +178,24 @@ class TaskExtractEnv:
...
@@ -174,6 +178,24 @@ class TaskExtractEnv:
return
s
,
[
data
,
weight
,
bias
,
C
]
return
s
,
[
data
,
weight
,
bias
,
C
]
return
s
,
[
data
,
weight
,
C
]
return
s
,
[
data
,
weight
,
C
]
@register
(
"topi_nn_bitserial_conv2d_nhwc"
)
def
_topi_bitserial_conv2d_nhwc
(
*
args
,
**
kwargs
):
args
=
deserialize_args
(
args
)
C
=
topi
.
nn
.
bitserial_conv2d_nhwc
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
nn
.
schedule_bitserial_conv2d_nhwc
([
C
])
data
=
args
[
0
]
kernel
=
args
[
1
]
return
s
,
[
data
,
kernel
,
C
]
@register
(
"topi_nn_bitserial_conv2d_nchw"
)
def
_topi_bitserial_conv2d_nchw
(
*
args
,
**
kwargs
):
args
=
deserialize_args
(
args
)
C
=
topi
.
nn
.
bitserial_conv2d_nchw
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
nn
.
schedule_bitserial_conv2d_nchw
([
C
])
data
=
args
[
0
]
kernel
=
args
[
1
]
return
s
,
[
data
,
kernel
,
C
]
@register
(
"topi_nn_deformable_conv2d_nchw"
)
@register
(
"topi_nn_deformable_conv2d_nchw"
)
def
_topi_nn_deformable_conv2d_nchw
(
*
args
,
**
kwargs
):
def
_topi_nn_deformable_conv2d_nchw
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
assert
not
kwargs
,
"Do not support kwargs in template function call"
...
...
topi/python/topi/arm_cpu/bitserial_conv2d.py
View file @
17351875
# pylint: disable=invalid-name,unused-variable,invalid-name
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Bitserial conv2d schedule on
raspberry pi
"""
"""Bitserial conv2d schedule on
arm cpu
"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
collections
import
namedtuple
import
tvm
import
tvm
from
tvm
import
autotvm
from
..
import
tag
from
..
import
tag
from
..nn.pad
import
pad
from
..nn.pad
import
pad
from
..nn.bitserial_conv2d
import
bitserial_conv2d
,
_get_schedule
,
_get_workload
,
bitpack
from
..nn.bitserial_conv2d
import
bitpack
,
bitserial_conv2d_nhwc
from
..nn.bitserial_conv2d
import
SpatialPackNCHW
,
_WORKLOADS
,
spatial_pack_nchw
from
..nn.util
import
get_pad_tuple
from
..nn.util
import
get_pad_tuple
from
..util
import
get_const_int
from
..util
import
get_const_int
,
get_const_tuple
from
..
import
generic
from
..
import
generic
RaspSpatialPack
=
namedtuple
(
'SpatialPack'
,
def
_kernel_vec_spatial_pack_nhwc
(
kernel
,
kernel_bits
,
VC
,
use_bitpack
=
True
):
[
'vh'
,
'vw'
,
'vc'
,
'ba'
,
'bc'
,
'split_ci'
,
'kfactor'
])
if
use_bitpack
:
_QUANTIZED_SCHEDULES_NHWC
=
[
RaspSpatialPack
(
2
,
2
,
8
,
1
,
1
,
False
,
8
),
RaspSpatialPack
(
1
,
4
,
8
,
4
,
1
,
False
,
8
),
RaspSpatialPack
(
1
,
4
,
8
,
1
,
16
,
False
,
8
),
RaspSpatialPack
(
1
,
4
,
8
,
4
,
8
,
False
,
8
),
RaspSpatialPack
(
1
,
7
,
8
,
3
,
8
,
False
,
16
),
RaspSpatialPack
(
1
,
2
,
8
,
1
,
8
,
False
,
16
),
RaspSpatialPack
(
2
,
1
,
8
,
1
,
4
,
False
,
16
),
RaspSpatialPack
(
1
,
7
,
8
,
1
,
1
,
True
,
16
),
RaspSpatialPack
(
1
,
1
,
8
,
1
,
16
,
True
,
16
),
RaspSpatialPack
(
1
,
1
,
8
,
1
,
8
,
True
,
16
),
RaspSpatialPack
(
1
,
1
,
8
,
1
,
16
,
True
,
16
),
]
_QUANTIZED_SCHEDULES_NCHW
=
[
# resnet
SpatialPackNCHW
(
2
,
2
,
8
,
1
,
1
),
SpatialPackNCHW
(
1
,
4
,
8
,
4
,
1
),
SpatialPackNCHW
(
1
,
4
,
8
,
1
,
16
),
SpatialPackNCHW
(
1
,
4
,
8
,
4
,
8
),
SpatialPackNCHW
(
1
,
7
,
8
,
3
,
8
),
SpatialPackNCHW
(
1
,
2
,
8
,
1
,
8
),
SpatialPackNCHW
(
2
,
1
,
8
,
1
,
4
),
SpatialPackNCHW
(
1
,
7
,
8
,
1
,
1
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
16
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
8
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
16
),
]
@_get_schedule.register
(
"arm_cpu"
)
def
_get_schedule_bitserial_conv2d
(
wkl
,
layout
):
if
wkl
not
in
_WORKLOADS
:
raise
ValueError
(
"no schedule for such workload: {}"
.
format
(
wkl
))
idx
=
_WORKLOADS
.
index
(
wkl
)
if
layout
==
"NCHW"
:
sch
=
_QUANTIZED_SCHEDULES_NCHW
[
idx
]
elif
layout
==
"NHWC"
:
sch
=
_QUANTIZED_SCHEDULES_NHWC
[
idx
]
return
sch
@bitserial_conv2d.register
(
"arm_cpu"
)
def
_declaration_bitserial_conv2d
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
layout
=
'NCHW'
,
pack_dtype
=
None
,
out_dtype
=
None
,
dorefa
=
False
):
if
out_dtype
is
None
:
out_dtype
=
data
.
dtype
assert
data
.
shape
[
0
]
.
value
==
1
,
"only support batch size=1 convolution on rasp"
assert
layout
in
(
"NCHW"
,
"NHWC"
),
"only support layouts NCHW and NHWC"
if
dorefa
:
assert
layout
==
"NCHW"
,
"Cannot support dorea with NHWC layout yet"
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
layout
)
sch
=
_get_schedule
(
wkl
,
layout
)
if
layout
==
"NCHW"
:
return
spatial_pack_nchw
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
pack_dtype
=
pack_dtype
,
out_dtype
=
out_dtype
,
dorefa
=
dorefa
)
return
_spatial_pack_nhwc
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
out_dtype
)
def
_kernel_vec_spatial_pack_nhwc
(
kernel
,
kernel_bits
,
VC
):
kernel_q
=
bitpack
(
kernel
,
kernel_bits
,
pack_axis
=
2
,
bit_axis
=
2
,
pack_type
=
'uint8'
)
kernel_q
=
bitpack
(
kernel
,
kernel_bits
,
pack_axis
=
2
,
bit_axis
=
2
,
pack_type
=
'uint8'
)
else
:
kernel_q
=
kernel
KH
,
KW
,
KB
,
CI
,
CO
=
kernel_q
.
shape
KH
,
KW
,
KB
,
CI
,
CO
=
kernel_q
.
shape
kvshape
=
(
CO
//
VC
,
KH
,
KW
,
KB
,
VC
,
CI
)
kvshape
=
(
CO
//
VC
,
KH
,
KW
,
KB
,
VC
,
CI
)
return
tvm
.
compute
(
kvshape
,
lambda
co
,
dh
,
dw
,
b
,
vc
,
ci
:
\
return
tvm
.
compute
(
kvshape
,
lambda
co
,
dh
,
dw
,
b
,
vc
,
ci
:
\
kernel_q
[
dh
][
dw
][
b
][
ci
][
co
*
VC
+
vc
],
name
=
'kernel_vec'
)
kernel_q
[
dh
][
dw
][
b
][
ci
][
co
*
VC
+
vc
],
name
=
'kernel_vec'
)
def
_spatial_pack_nhwc
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
out_dtype
):
@autotvm.register_topi_compute
(
bitserial_conv2d_nhwc
,
'arm_cpu'
,
'direct'
)
def
spatial_pack_nhwc
(
cfg
,
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
pack_dtype
,
out_dtype
,
unipolar
):
""" Compute convolution with pack on spatial axes. """
""" Compute convolution with pack on spatial axes. """
assert
data
.
shape
[
0
]
.
value
==
1
,
"spatial pack convolution only support batch size=1"
assert
data
.
shape
[
0
]
.
value
==
1
,
"spatial pack convolution only support batch size=1"
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
"NHWC"
)
assert
pack_dtype
==
'uint8'
,
"only support packing into uint8 bits"
sch
=
_get_schedule
(
wkl
,
"NHWC"
)
assert
out_dtype
==
'int16'
,
"only support output type of int16"
VH
=
sch
.
vh
VW
=
sch
.
vw
VC
=
sch
.
vc
data_q
=
bitpack
(
data
,
activation_bits
,
pack_axis
=
3
,
bit_axis
=
3
,
pack_type
=
'uint8'
)
N
,
H
,
W
,
CI
=
get_const_tuple
(
data
.
shape
)
kernel_vec
=
_kernel_vec_spatial_pack_nhwc
(
kernel
,
weight_bits
,
VC
)
if
len
(
kernel
.
shape
)
==
4
:
N
,
H
,
W
,
IB
,
CI
=
data_q
.
shape
KH
,
KW
,
_
,
CO
=
get_const_tuple
(
kernel
.
shape
)
OCO
,
KH
,
KW
,
KB
,
VC
,
_
=
kernel_vec
.
shape
CI_packed
=
CI
//
8
else
:
KH
,
KW
,
KB
,
CI_packed
,
CO
=
get_const_tuple
(
kernel
.
shape
)
CO
=
OCO
*
VC
if
isinstance
(
padding
,
int
)
or
(
isinstance
(
padding
,
(
tuple
,
list
))
and
len
(
padding
)
==
2
):
HPAD
,
WPAD
,
_
,
_
=
get_pad_tuple
(
padding
,
kernel
)
TPAD
,
LPAD
,
DPAD
,
RPAD
=
get_pad_tuple
(
padding
,
kernel
)
else
:
TPAD
,
LPAD
,
DPAD
,
RPAD
=
padding
if
isinstance
(
stride
,
(
tuple
,
list
)):
if
isinstance
(
stride
,
(
tuple
,
list
)):
HSTR
,
WSTR
=
stride
HSTR
,
WSTR
=
stride
...
@@ -102,75 +46,151 @@ def _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bi
...
@@ -102,75 +46,151 @@ def _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bi
HSTR
,
WSTR
=
stride
,
stride
HSTR
,
WSTR
=
stride
,
stride
HCAT
,
WCAT
=
KH
-
1
,
KW
-
1
HCAT
,
WCAT
=
KH
-
1
,
KW
-
1
PAD_H
=
H
+
2
*
HPAD
PAD_H
=
H
+
(
TPAD
+
DPAD
)
PAD_W
=
W
+
2
*
WPAD
PAD_W
=
W
+
(
LPAD
+
RPAD
)
OH
=
(
H
+
2
*
HPAD
-
KH
)
//
HSTR
+
1
OH
=
(
PAD_H
-
KH
)
//
HSTR
+
1
OW
=
(
W
+
2
*
WPAD
-
KW
)
//
WSTR
+
1
OW
=
(
PAD_W
-
KW
)
//
WSTR
+
1
oshape
=
(
1
,
OH
,
OW
,
CO
)
# Pad input channels of weights and data when it is not a multiple of 8
if
CI_packed
%
8
!=
0
:
CI_PAD
=
CI_packed
%
8
CI_packed
+=
CI_PAD
else
:
CI_PAD
=
0
# ==================== define configuration space ====================
n
,
oh
,
ow
,
co
=
cfg
.
axis
(
N
),
cfg
.
axis
(
OH
),
cfg
.
axis
(
OW
),
cfg
.
axis
(
CO
)
ci
,
kh
,
kw
=
cfg
.
reduce_axis
(
CI_packed
),
cfg
.
reduce_axis
(
KH
),
cfg
.
reduce_axis
(
KW
)
ib
,
kb
=
cfg
.
reduce_axis
(
activation_bits
),
cfg
.
reduce_axis
(
weight_bits
)
co
,
vc
=
cfg
.
define_split
(
'tile_co'
,
co
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
x
.
size
[
-
1
]
==
8
)
oh
,
vh
=
cfg
.
define_split
(
'tile_oh'
,
oh
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
x
.
size
[
-
1
]
>=
2
)
ow
,
vw
=
cfg
.
define_split
(
'tile_ow'
,
ow
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
x
.
size
[
-
1
]
>=
2
)
ci_o
,
ci_i
=
cfg
.
define_split
(
"tile_ci"
,
ci
,
num_outputs
=
2
,
filter
=
lambda
x
:
x
.
size
[
-
1
]
==
8
or
x
.
size
[
-
1
]
==
16
)
re_axes
=
cfg
.
define_reorder
(
"reorder_0"
,
[
n
,
oh
,
ow
,
co
,
vh
,
vw
,
kh
,
kw
,
ci_o
,
kb
,
ib
,
vc
,
ci_i
],
policy
=
'candidate'
,
candidate
=
[
[
n
,
oh
,
ow
,
co
,
vh
,
vw
,
kh
,
kw
,
ci_o
,
kb
,
ib
,
vc
,
ci_i
],
[
n
,
oh
,
ow
,
co
,
vh
,
vw
,
kw
,
kh
,
ci_o
,
kb
,
ib
,
vc
,
ci_i
],])
cfg
.
add_flop
(
2
*
N
*
OH
*
OW
*
CO
*
CI
*
8
*
KH
*
KW
)
# these are actually binary ops
# ====================
VC
=
cfg
[
"tile_co"
]
.
size
[
-
1
]
VH
=
cfg
[
"tile_oh"
]
.
size
[
-
1
]
VW
=
cfg
[
"tile_ow"
]
.
size
[
-
1
]
data_q
=
bitpack
(
data
,
activation_bits
,
pack_axis
=
3
,
bit_axis
=
3
,
pack_type
=
'uint8'
)
kernel_vec
=
_kernel_vec_spatial_pack_nhwc
(
kernel
,
weight_bits
,
VC
,
len
(
kernel
.
shape
)
==
4
)
if
kernel_vec
.
shape
[
-
1
]
%
8
!=
0
and
CI_PAD
!=
0
:
kernel_vec
=
pad
(
kernel_vec
,
[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
CI_PAD
])
N
,
H
,
W
,
IB
,
CI
=
data_q
.
shape
OCO
,
KH
,
KW
,
KB
,
VC
,
CI
=
kernel_vec
.
shape
dvshape
=
(
N
,
PAD_H
//
(
VH
*
HSTR
),
PAD_W
//
(
VW
*
WSTR
),
VH
*
HSTR
+
HCAT
,
VW
*
WSTR
+
WCAT
,
IB
,
CI
)
dvshape
=
(
N
,
PAD_H
//
(
VH
*
HSTR
),
PAD_W
//
(
VW
*
WSTR
),
VH
*
HSTR
+
HCAT
,
VW
*
WSTR
+
WCAT
,
IB
,
CI
)
ovshape
=
(
1
,
OH
//
VH
,
OW
//
VW
,
CO
//
VC
,
VH
,
VW
,
VC
)
ovshape
=
(
1
,
OH
//
VH
,
OW
//
VW
,
CO
//
VC
,
VH
,
VW
,
VC
)
oshape
=
(
1
,
OH
,
OW
,
CO
)
if
(
HPAD
!=
0
and
WPAD
!=
0
):
if
(
TPAD
!=
0
and
RPAD
!=
0
):
data_pad
=
pad
(
data_q
,
(
0
,
HPAD
,
WPAD
,
0
,
0
),
name
=
"data_pad"
)
data_pad
=
pad
(
data_q
,
(
0
,
TPAD
,
LPAD
,
0
,
0
),
(
0
,
DPAD
,
RPAD
,
0
,
CI_PAD
),
name
=
"data_pad"
)
elif
CI_PAD
!=
0
:
data_pad
=
pad
(
data_q
,
(
0
,
0
,
0
,
0
,
0
),
(
0
,
0
,
0
,
0
,
CI_PAD
),
name
=
"data_pad"
)
else
:
else
:
data_pad
=
data_q
data_pad
=
data_q
data_vec
=
tvm
.
compute
(
dvshape
,
lambda
n
,
h
,
w
,
vh
,
vw
,
b
,
ci
:
\
data_vec
=
tvm
.
compute
(
dvshape
,
lambda
n
,
h
,
w
,
vh
,
vw
,
b
,
ci
:
\
data_pad
[
n
][
h
*
VH
*
HSTR
+
vh
][
w
*
VW
*
WSTR
+
vw
][
b
][
ci
],
name
=
'data_vec'
)
data_pad
[
n
][
h
*
VH
*
HSTR
+
vh
][
w
*
VW
*
WSTR
+
vw
][
b
][
ci
],
name
=
'data_vec'
)
ci
=
tvm
.
reduce_axis
((
0
,
CI
),
name
=
'ci'
)
ci
=
tvm
.
reduce_axis
((
0
,
CI
),
name
=
'ci'
)
dh
=
tvm
.
reduce_axis
((
0
,
KH
),
name
=
'dh'
)
dh
=
tvm
.
reduce_axis
((
0
,
KH
),
name
=
'dh'
)
dw
=
tvm
.
reduce_axis
((
0
,
KW
),
name
=
'dw'
)
dw
=
tvm
.
reduce_axis
((
0
,
KW
),
name
=
'dw'
)
ib
=
tvm
.
reduce_axis
((
0
,
IB
),
name
=
'ib'
)
ib
=
tvm
.
reduce_axis
((
0
,
IB
),
name
=
'ib'
)
kb
=
tvm
.
reduce_axis
((
0
,
KB
),
name
=
'kb'
)
kb
=
tvm
.
reduce_axis
((
0
,
KB
),
name
=
'kb'
)
def
_conv
(
n
,
h
,
w
,
co
,
vh
,
vw
,
vc
):
def
_
bipolar_
conv
(
n
,
h
,
w
,
co
,
vh
,
vw
,
vc
):
return
tvm
.
sum
((
tvm
.
popcount
(
return
tvm
.
sum
((
tvm
.
popcount
(
kernel_vec
[
co
,
dh
,
dw
,
kb
,
vc
,
ci
]
.
astype
(
'uint16'
)
&
kernel_vec
[
co
,
dh
,
dw
,
kb
,
vc
,
ci
]
.
astype
(
'uint16'
)
&
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ib
,
ci
]
.
astype
(
'uint16'
))
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ib
,
ci
]
.
astype
(
'uint16'
))
<<
(
kb
+
ib
)
.
astype
(
'uint16'
)),
axis
=
[
dh
,
dw
,
kb
,
ib
,
ci
])
<<
(
kb
+
ib
)
.
astype
(
'uint16'
)),
axis
=
[
dh
,
dw
,
kb
,
ib
,
ci
])
def
_unipolar_conv
(
n
,
h
,
w
,
co
,
vh
,
vw
,
vc
):
return
tvm
.
sum
(
((
tvm
.
popcount
(
kernel_vec
[
co
,
dh
,
dw
,
kb
,
vc
,
ci
]
.
astype
(
'int16'
)
&
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ib
,
ci
]
.
astype
(
'int16'
))
-
tvm
.
popcount
(
~
kernel_vec
[
co
,
dh
,
dw
,
kb
,
vc
,
ci
]
.
astype
(
'int16'
)
&
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ib
,
ci
])
.
astype
(
'int16'
))
<<
(
kb
+
ib
)
.
astype
(
'int16'
)),
axis
=
[
dh
,
dw
,
kb
,
ib
,
ci
])
if
unipolar
:
conv_vec
=
tvm
.
compute
(
ovshape
,
_unipolar_conv
,
name
=
'conv_vec'
,
tag
=
'unipolar'
)
else
:
conv_vec
=
tvm
.
compute
(
ovshape
,
_bipolar_conv
,
name
=
'conv_vec'
,
tag
=
'bipolar'
)
conv
=
tvm
.
compute
(
ovshape
,
_conv
,
name
=
'conv'
)
conv
=
tvm
.
compute
(
oshape
,
lambda
n
,
h
,
w
,
co
:
conv_vec
[
n
][
h
//
VH
][
w
//
VW
][
co
//
VC
][
h
%
VH
][
w
%
VW
][
co
%
VC
]
.
astype
(
out_dtype
),
name
=
'conv'
,
tag
=
'spatial_bitserial_conv_nhwc'
)
return
tvm
.
compute
(
oshape
,
lambda
n
,
h
,
w
,
co
:
return
conv
conv
[
n
][
h
//
VH
][
w
//
VW
][
co
//
VC
][
h
%
VH
][
w
%
VW
][
co
%
VC
]
.
astype
(
out_dtype
),
name
=
'output_vec'
,
tag
=
'spatial_bitserial_conv_nhwc'
)
def
_intrin_popcount
(
m
,
k_i
,
w_b
,
x_b
):
def
_intrin_popcount
(
m
,
k_i
,
w_b
,
x_b
,
unipolar
):
dtype
=
'uint8'
pack_
dtype
=
'uint8'
w
=
tvm
.
placeholder
((
w_b
,
m
,
k_i
),
dtype
=
dtype
,
name
=
'w'
)
w
=
tvm
.
placeholder
((
w_b
,
m
,
k_i
),
dtype
=
pack_
dtype
,
name
=
'w'
)
x
=
tvm
.
placeholder
((
x_b
,
k_i
,),
dtype
=
dtype
,
name
=
'x'
)
x
=
tvm
.
placeholder
((
x_b
,
k_i
,),
dtype
=
pack_
dtype
,
name
=
'x'
)
k
=
tvm
.
reduce_axis
((
0
,
k_i
),
name
=
'k'
)
k
=
tvm
.
reduce_axis
((
0
,
k_i
),
name
=
'k'
)
bw
=
tvm
.
reduce_axis
((
0
,
w_b
),
name
=
'bw'
)
bw
=
tvm
.
reduce_axis
((
0
,
w_b
),
name
=
'bw'
)
bx
=
tvm
.
reduce_axis
((
0
,
x_b
),
name
=
'bx'
)
bx
=
tvm
.
reduce_axis
((
0
,
x_b
),
name
=
'bx'
)
if
unipolar
:
dtype
=
'int16'
z
=
tvm
.
compute
((
m
,),
lambda
i
:
z
=
tvm
.
compute
((
m
,),
lambda
i
:
tvm
.
sum
(
tvm
.
popcount
(
w
[
bw
,
i
,
k
]
.
astype
(
'uint16'
)
&
tvm
.
sum
((
tvm
.
popcount
(
w
[
bw
,
i
,
k
]
.
astype
(
dtype
)
&
x
[
bx
,
k
]
.
astype
(
dtype
))
-
x
[
bx
,
k
]
.
astype
(
'uint16'
))
tvm
.
popcount
(
~
w
[
bw
,
i
,
k
]
.
astype
(
dtype
)
&
x
[
bx
,
k
]
.
astype
(
dtype
)))
<<
(
bw
+
bx
)
.
astype
(
'uint16'
),
axis
=
[
bw
,
bx
,
k
]),
name
=
'z'
)
<<
(
bw
+
bx
)
.
astype
(
dtype
),
axis
=
[
bw
,
bx
,
k
]),
name
=
'z'
)
else
:
dtype
=
'uint16'
z
=
tvm
.
compute
((
m
,),
lambda
i
:
tvm
.
sum
(
tvm
.
popcount
(
w
[
bw
,
i
,
k
]
.
astype
(
dtype
)
&
x
[
bx
,
k
]
.
astype
(
dtype
))
<<
(
bw
+
bx
)
.
astype
(
dtype
),
axis
=
[
bw
,
bx
,
k
]),
name
=
'z'
)
Wb
=
tvm
.
decl_buffer
(
w
.
shape
,
w
.
dtype
,
Wb
=
tvm
.
decl_buffer
(
w
.
shape
,
w
.
dtype
,
name
=
"W"
,
name
=
"W"
,
offset_factor
=
k_i
,
offset_factor
=
k_i
,
strides
=
[
tvm
.
var
(
'ldw'
),
tvm
.
var
(
'ldw'
),
1
])
strides
=
[
tvm
.
var
(
'ldw'
),
tvm
.
var
(
'ldw'
),
1
])
# stride can be inferred
Xb
=
tvm
.
decl_buffer
(
x
.
shape
,
x
.
dtype
,
Xb
=
tvm
.
decl_buffer
(
x
.
shape
,
x
.
dtype
,
name
=
"X"
,
name
=
"X"
,
offset_factor
=
k_i
,
offset_factor
=
k_i
,
strides
=
[
tvm
.
var
(
'ldw'
),
1
])
strides
=
[
tvm
.
var
(
'ldw'
),
1
])
Zb
=
tvm
.
decl_buffer
(
z
.
shape
,
z
.
dtype
,
name
=
"Z"
,
offset_factor
=
1
,
strides
=
[
1
])
def
_intrin_func
(
ins
,
outs
):
def
_intrin_func
(
ins
,
outs
):
ww
,
xx
=
ins
ww
,
xx
=
ins
zz
=
outs
[
0
]
zz
=
outs
[
0
]
vpadd
=
"llvm.arm.neon.vpadd.v8u8"
vpadalu
=
"llvm.arm.neon.vpadalu.v16u8.v8u16"
args_1
=
tvm
.
const
(
1
,
'uint32'
)
args_1
=
tvm
.
const
(
1
,
'uint32'
)
args_2
=
tvm
.
const
(
2
,
'uint32'
)
args_2
=
tvm
.
const
(
2
,
'uint32'
)
if
unipolar
:
vpadd
=
"llvm.arm.neon.vpadd.v8i8"
vpadalu
=
"llvm.arm.neon.vpadals.v16i8.v8i16"
full_dtype
=
'int8x16'
half_dtype
=
'int8x8'
return_dtype
=
'int16x8'
else
:
vpadd
=
"llvm.arm.neon.vpadd.v8u8"
vpadalu
=
"llvm.arm.neon.vpadalu.v16u8.v8u16"
full_dtype
=
'uint8x16'
half_dtype
=
'uint8x8'
return_dtype
=
'uint16x8'
def
_instr
(
index
):
def
_instr
(
index
):
irb
=
tvm
.
ir_builder
.
create
()
irb
=
tvm
.
ir_builder
.
create
()
if
index
==
1
:
if
index
==
1
:
# reduce reset
irb
.
emit
(
zz
.
vstore
(
0
,
tvm
.
const
(
0
,
'uint16x8'
)))
irb
.
emit
(
zz
.
vstore
(
0
,
tvm
.
const
(
0
,
return_dtype
)))
return
irb
.
get
()
return
irb
.
get
()
# body and reduce update
cnts8
=
[
None
]
*
8
cnts8
=
[
None
]
*
8
cnts4
=
[
None
]
*
4
cnts4
=
[
None
]
*
4
cnts2
=
[
None
]
*
2
cnts2
=
[
None
]
*
2
...
@@ -178,154 +198,108 @@ def _intrin_popcount(m, k_i, w_b, x_b):
...
@@ -178,154 +198,108 @@ def _intrin_popcount(m, k_i, w_b, x_b):
for
bx
in
range
(
x_b
):
for
bx
in
range
(
x_b
):
if
k_i
==
16
:
if
k_i
==
16
:
for
i
in
range
(
m
):
for
i
in
range
(
m
):
ands
=
ww
.
vload
([
bw
,
i
,
0
],
'uint8x16'
)
&
xx
.
vload
([
bx
,
0
],
'uint8x16'
)
w_
=
ww
.
vload
([
bw
,
i
,
0
],
'uint8x16'
)
.
astype
(
full_dtype
)
cnts
=
tvm
.
popcount
(
ands
)
x_
=
xx
.
vload
([
bx
,
0
],
'uint8x16'
)
.
astype
(
full_dtype
)
upper_half
=
tvm
.
call_pure_intrin
(
'uint8x8'
,
'vectorhigh'
,
cnts
)
if
unipolar
:
lower_half
=
tvm
.
call_pure_intrin
(
'uint8x8'
,
'vectorlow'
,
cnts
)
cnts
=
tvm
.
popcount
(
w_
&
x_
)
-
tvm
.
popcount
(
~
w_
&
x_
)
else
:
cnts
=
tvm
.
popcount
(
w_
&
x_
)
upper_half
=
tvm
.
call_pure_intrin
(
half_dtype
,
'vectorhigh'
,
cnts
)
lower_half
=
tvm
.
call_pure_intrin
(
half_dtype
,
'vectorlow'
,
cnts
)
cnts8
[
i
]
=
upper_half
+
lower_half
cnts8
[
i
]
=
upper_half
+
lower_half
for
i
in
range
(
m
//
2
):
for
i
in
range
(
m
//
2
):
cnts4
[
i
]
=
tvm
.
call_llvm_intrin
(
'uint8x8'
,
vpadd
,
cnts4
[
i
]
=
tvm
.
call_llvm_intrin
(
half_dtype
,
vpadd
,
args_1
,
cnts8
[
i
*
2
],
cnts8
[
i
*
2
+
1
])
args_1
,
cnts8
[
i
*
2
],
cnts8
[
i
*
2
+
1
])
for
i
in
range
(
m
//
4
):
for
i
in
range
(
m
//
4
):
cnts2
[
i
]
=
tvm
.
call_llvm_intrin
(
'uint8x8'
,
vpadd
,
cnts2
[
i
]
=
tvm
.
call_llvm_intrin
(
half_dtype
,
vpadd
,
args_1
,
cnts4
[
i
*
2
],
cnts4
[
i
*
2
+
1
])
args_1
,
cnts4
[
i
*
2
],
cnts4
[
i
*
2
+
1
])
cnts
=
tvm
.
call_pure_intrin
(
'uint8x16'
,
'vectorcombine'
,
cnts2
[
0
],
cnts2
[
1
])
cnts
=
tvm
.
call_pure_intrin
(
full_dtype
,
'vectorcombine'
,
cnts2
[
0
],
cnts2
[
1
])
shifted_cnts
=
cnts
<<
tvm
.
const
(
bw
+
bx
,
dtype
)
shifted_cnts
=
cnts
<<
tvm
.
const
(
bw
+
bx
,
pack_
dtype
)
out
=
tvm
.
call_llvm_intrin
(
'uint16x8'
,
vpadalu
,
out
=
tvm
.
call_llvm_intrin
(
return_dtype
,
vpadalu
,
args_2
,
zz
.
vload
(
0
,
'uint16x8'
),
shifted_cnts
)
args_2
,
zz
.
vload
(
0
,
return_dtype
),
shifted_cnts
)
else
:
# ki == 8
else
:
# ki == 8
for
i
in
range
(
m
):
for
i
in
range
(
m
):
ands
=
ww
.
vload
([
bw
,
i
,
0
],
'uint8x8'
)
&
xx
.
vload
([
bx
,
0
],
'uint8x8'
)
w_
=
ww
.
vload
([
bw
,
i
,
0
],
'uint8x8'
)
.
astype
(
half_dtype
)
cnts8
[
i
]
=
tvm
.
popcount
(
ands
)
x_
=
xx
.
vload
([
bx
,
0
],
'uint8x8'
)
.
astype
(
half_dtype
)
if
unipolar
:
cnts8
[
i
]
=
tvm
.
popcount
(
w_
&
x_
)
-
tvm
.
popcount
(
~
w_
&
x_
)
else
:
cnts8
[
i
]
=
tvm
.
popcount
(
w_
&
x_
)
for
i
in
range
(
m
//
2
):
for
i
in
range
(
m
//
2
):
cnts4
[
i
]
=
tvm
.
call_llvm_intrin
(
'uint8x8'
,
vpadd
,
cnts4
[
i
]
=
tvm
.
call_llvm_intrin
(
half_dtype
,
vpadd
,
args_1
,
cnts8
[
i
*
2
],
cnts8
[
i
*
2
+
1
])
args_1
,
cnts8
[
i
*
2
],
cnts8
[
i
*
2
+
1
])
for
i
in
range
(
m
//
4
):
for
i
in
range
(
m
//
4
):
cnts2
[
i
]
=
tvm
.
call_llvm_intrin
(
'uint8x8'
,
vpadd
,
cnts2
[
i
]
=
tvm
.
call_llvm_intrin
(
half_dtype
,
vpadd
,
args_1
,
cnts4
[
i
*
2
],
cnts4
[
i
*
2
+
1
])
args_1
,
cnts4
[
i
*
2
],
cnts4
[
i
*
2
+
1
])
cnts
=
tvm
.
call_pure_intrin
(
'uint8x16'
,
'vectorcombine'
,
cnts2
[
0
],
cnts2
[
1
])
cnts
=
tvm
.
call_pure_intrin
(
full_dtype
,
'vectorcombine'
,
cnts2
[
0
],
cnts2
[
1
])
shifted_cnts
=
cnts
<<
tvm
.
const
(
bw
+
bx
,
dtype
)
shifted_cnts
=
cnts
<<
tvm
.
const
(
bw
+
bx
,
pack_
dtype
)
out
=
tvm
.
call_llvm_intrin
(
'uint16x8'
,
vpadalu
,
out
=
tvm
.
call_llvm_intrin
(
return_dtype
,
vpadalu
,
args_2
,
zz
.
vload
(
0
,
'uint16x8'
),
shifted_cnts
)
args_2
,
zz
.
vload
(
0
,
return_dtype
),
shifted_cnts
)
irb
.
emit
(
zz
.
vstore
(
0
,
out
))
irb
.
emit
(
zz
.
vstore
(
0
,
out
))
return
irb
.
get
()
return
irb
.
get
()
# body, reset, update
# body, reset, update
return
_instr
(
0
),
_instr
(
1
),
_instr
(
2
)
return
_instr
(
0
),
_instr
(
1
),
_instr
(
2
)
with
tvm
.
build_config
(
offset_factor
=
1
,
partition_const_loop
=
True
):
with
tvm
.
build_config
(
offset_factor
=
1
,
partition_const_loop
=
True
):
return
tvm
.
decl_tensor_intrin
(
z
.
op
,
_intrin_func
,
binds
=
{
w
:
Wb
,
x
:
Xb
})
return
tvm
.
decl_tensor_intrin
(
z
.
op
,
_intrin_func
,
binds
=
{
w
:
Wb
,
x
:
Xb
,
z
:
Zb
})
# ARM specific schedule that using custom microkernel
# ARM specific schedule that using custom microkernel
def
_schedule_spatial_conv2d_nhwc
(
s
,
data
,
data_q
,
data_pad
,
data_vec
,
def
_schedule_spatial_conv2d_nhwc
(
cfg
,
s
,
data_pad
,
data_vec
,
kernel_vec
,
kernel
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
last
,
unipolar
):
conv_out
,
output
,
last
):
_
,
_
,
_
,
_
,
_
,
IB
,
CI
=
data_vec
.
shape
# no stride and padding info here
_
,
KH
,
KW
,
KB
,
_
,
_
=
kernel_vec
.
shape
_
,
H
,
W
,
IB
,
CI
=
data_q
.
shape
KH
,
KW
,
KB
,
_
,
CO
=
kernel_q
.
shape
KB
=
get_const_int
(
KB
)
KB
=
get_const_int
(
KB
)
IB
=
get_const_int
(
IB
)
IB
=
get_const_int
(
IB
)
if
data_pad
is
None
:
VC
=
cfg
[
"tile_co"
]
.
size
[
-
1
]
padding
=
(
0
,
0
)
VH
=
cfg
[
"tile_oh"
]
.
size
[
-
1
]
_
,
in_h
,
in_w
,
_
,
_
=
data_q
.
shape
VW
=
cfg
[
"tile_ow"
]
.
size
[
-
1
]
kern_h
,
kern_w
,
_
,
_
=
kernel
.
shape
_
,
out_h
,
out_w
,
_
=
output
.
shape
##### Schedule data padding and packing
hstride
=
(
in_h
-
kern_h
)
//
(
out_h
-
1
)
wstride
=
(
in_w
-
kern_w
)
//
(
out_w
-
1
)
stride
=
get_const_int
(
hstride
),
get_const_int
(
wstride
)
else
:
_
,
in_h
,
in_w
,
_
,
_
=
data_q
.
shape
_
,
pad_h
,
pad_w
,
_
,
_
=
data_pad
.
shape
hpad
=
(
pad_h
-
in_h
)
//
2
wpad
=
(
pad_w
-
in_w
)
//
2
padding
=
get_const_int
(
hpad
),
get_const_int
(
wpad
)
_
,
in_h
,
in_w
,
_
,
_
=
data_pad
.
shape
kern_h
,
kern_w
,
_
,
_
=
kernel
.
shape
_
,
out_h
,
out_w
,
_
=
output
.
shape
hstride
=
(
in_h
-
kern_h
)
//
(
out_h
-
1
)
wstride
=
(
in_w
-
kern_w
)
//
(
out_w
-
1
)
stride
=
get_const_int
(
hstride
),
get_const_int
(
wstride
)
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
output
.
dtype
,
"NHWC"
)
sch
=
_get_schedule
(
wkl
,
"NHWC"
)
VH
=
sch
.
vh
VW
=
sch
.
vw
VC
=
sch
.
vc
ba
=
sch
.
ba
bc
=
sch
.
bc
##### Schedule data packing
if
data_pad
is
not
None
:
if
data_pad
is
not
None
:
s
[
data_pad
]
.
compute_inline
()
s
[
data_pad
]
.
compute_inline
()
_
,
h
,
_
,
_
,
_
,
_
,
_
=
s
[
data_vec
]
.
op
.
axis
_
,
h
,
_
,
_
,
_
,
_
,
_
=
s
[
data_vec
]
.
op
.
axis
if
ba
==
1
:
cfg
.
define_split
(
"tile_ah"
,
cfg
.
axis
(
h
),
policy
=
"all"
,
num_outputs
=
2
,
max_factor
=
32
)
oaxis
=
h
oh
,
ih
=
cfg
[
"tile_ah"
]
.
apply
(
s
,
data_vec
,
h
)
paxis
=
h
s
[
data_vec
]
.
parallel
(
oh
)
else
:
oh
,
ih
=
s
[
data_vec
]
.
split
(
h
,
ba
)
oaxis
=
oh
paxis
=
ih
s
[
data_vec
]
.
parallel
(
paxis
)
s
[
data_vec
]
.
pragma
(
oaxis
,
"parallel_launch_point"
)
s
[
data_vec
]
.
pragma
(
paxis
,
"parallel_stride_pattern"
)
s
[
data_vec
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
####
#
Schedule kernel packing
#### Schedule kernel packing
co
,
_
,
_
,
_
,
_
,
_
=
s
[
kernel_vec
]
.
op
.
axis
co
,
_
,
_
,
_
,
_
,
_
=
s
[
kernel_vec
]
.
op
.
axis
if
bc
==
1
:
cfg
.
define_split
(
"tile_bco"
,
cfg
.
axis
(
co
),
policy
=
"all"
,
num_outputs
=
2
,
max_factor
=
32
)
oaxis
=
co
oco
,
ico
=
cfg
[
"tile_bco"
]
.
apply
(
s
,
kernel_vec
,
co
)
paxis
=
co
s
[
kernel_vec
]
.
parallel
(
oco
)
else
:
oco
,
ico
=
s
[
kernel_vec
]
.
split
(
co
,
bc
)
oaxis
=
oco
paxis
=
ico
s
[
kernel_vec
]
.
parallel
(
paxis
)
s
[
kernel_vec
]
.
pragma
(
oaxis
,
"parallel_launch_point"
)
s
[
kernel_vec
]
.
pragma
(
paxis
,
"parallel_stride_pattern"
)
s
[
kernel_vec
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
##### Schedule Convolution
##### Schedule Convolution
n
,
oh
,
ow
,
co
,
vh
,
vw
,
vc
=
s
[
conv_out
]
.
op
.
axis
n
,
oh
,
ow
,
co
,
vh
,
vw
,
vc
=
s
[
conv_out
]
.
op
.
axis
dh
,
d
w
,
kb
,
ib
,
ci
=
s
[
conv_out
]
.
op
.
reduce_axis
kh
,
k
w
,
kb
,
ib
,
ci
=
s
[
conv_out
]
.
op
.
reduce_axis
kfactor
=
sch
.
kfactor
ci_o
,
ci_i
=
cfg
[
'tile_ci'
]
.
apply
(
s
,
conv_out
,
ci
)
if
sch
.
split_ci
:
re_axes
=
cfg
[
"reorder_0"
]
.
apply
(
s
,
conv_out
,
oci
,
ici
=
s
[
conv_out
]
.
split
(
ci
,
kfactor
)
[
n
,
oh
,
ow
,
co
,
vh
,
vw
,
kh
,
kw
,
ci_o
,
kb
,
ib
,
vc
,
ci_i
])
s
[
conv_out
]
.
reorder
(
n
,
oh
,
ow
,
co
,
vh
,
vw
,
dh
,
dw
,
oci
,
kb
,
ib
,
vc
,
ici
)
else
:
s
[
conv_out
]
.
reorder
(
n
,
oh
,
ow
,
co
,
vh
,
vw
,
dh
,
dw
,
kb
,
ib
,
vc
,
ci
)
pc
=
_intrin_popcount
(
8
,
kfactor
,
KB
,
IB
)
# Use microkernel
kfactor
=
cfg
[
'tile_ci'
]
.
size
[
1
]
if
kfactor
%
8
==
0
:
pc
=
_intrin_popcount
(
VC
,
kfactor
,
KB
,
IB
,
unipolar
)
s
[
conv_out
]
.
tensorize
(
kb
,
pc
)
s
[
conv_out
]
.
tensorize
(
kb
,
pc
)
n
,
h
,
w
,
co
=
s
[
last
]
.
op
.
axis
n
,
h
,
w
,
co
=
s
[
last
]
.
op
.
axis
co
,
vc
=
s
[
last
]
.
split
(
co
,
VC
)
co
,
vc
=
cfg
[
'tile_co'
]
.
apply
(
s
,
last
,
co
)
oh
,
ow
,
vh
,
vw
=
s
[
last
]
.
tile
(
h
,
w
,
VH
,
VW
)
oh
,
vh
=
cfg
[
'tile_oh'
]
.
apply
(
s
,
last
,
h
)
s
[
last
]
.
reorder
(
n
,
oh
,
ow
,
co
,
vc
,
vh
,
vw
)
ow
,
vw
=
cfg
[
'tile_ow'
]
.
apply
(
s
,
last
,
w
)
s
[
last
]
.
vectorize
(
vw
)
s
[
last
]
.
reorder
(
n
,
oh
,
ow
,
co
,
vh
,
vw
,
vc
)
s
[
last
]
.
vectorize
(
vc
)
if
last
!=
output
:
if
last
!=
output
:
s
[
last
]
.
compute_inline
()
s
[
last
]
.
compute_inline
()
s
[
conv_out
]
.
compute_at
(
s
[
last
],
ow
)
s
[
conv_out
]
.
compute_at
(
s
[
last
],
co
)
if
co
==
1
:
s
[
last
]
.
parallel
(
oh
)
oaxis
=
oh
paxis
=
oh
else
:
oho
,
iho
=
s
[
last
]
.
split
(
oh
,
bc
)
oaxis
=
oho
paxis
=
iho
s
[
last
]
.
parallel
(
paxis
)
s
=
s
.
normalize
()
s
=
s
.
normalize
()
return
s
return
s
@
generic.schedule_bitserial_conv2d_nhwc.register
([
"arm_cpu"
]
)
@
autotvm.register_topi_schedule
(
generic
.
nn
.
schedule_bitserial_conv2d_nhwc
,
'arm_cpu'
,
'direct'
)
def
schedule_bitserial_conv2d_nhwc
(
outs
):
def
schedule_bitserial_conv2d_nhwc
(
cfg
,
outs
):
"""
Raspverry pi
schedule for bitserial conv2d"""
"""
Arm cpu
schedule for bitserial conv2d"""
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
scheduled_ops
=
[]
...
@@ -344,10 +318,6 @@ def schedule_bitserial_conv2d_nhwc(outs):
...
@@ -344,10 +318,6 @@ def schedule_bitserial_conv2d_nhwc(outs):
conv_out
=
op
.
input_tensors
[
0
]
conv_out
=
op
.
input_tensors
[
0
]
kernel_vec
=
conv_out
.
op
.
input_tensors
[
0
]
kernel_vec
=
conv_out
.
op
.
input_tensors
[
0
]
kernel_q
=
kernel_vec
.
op
.
input_tensors
[
0
]
kernel_q
=
kernel_vec
.
op
.
input_tensors
[
0
]
kernel
=
kernel_q
.
op
.
input_tensors
[
0
]
if
"QuantizeInput"
in
kernel
.
op
.
name
:
# Need to go up 1 further, from the combine in bitpack
kernel
=
kernel
.
op
.
input_tensors
[
0
]
data_vec
=
conv_out
.
op
.
input_tensors
[
1
]
data_vec
=
conv_out
.
op
.
input_tensors
[
1
]
data_q
=
data_vec
.
op
.
input_tensors
[
0
]
data_q
=
data_vec
.
op
.
input_tensors
[
0
]
data
=
data_q
.
op
.
input_tensors
[
0
]
data
=
data_q
.
op
.
input_tensors
[
0
]
...
@@ -355,13 +325,10 @@ def schedule_bitserial_conv2d_nhwc(outs):
...
@@ -355,13 +325,10 @@ def schedule_bitserial_conv2d_nhwc(outs):
if
isinstance
(
data_q
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
"pad"
in
data_q
.
op
.
tag
:
if
isinstance
(
data_q
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
"pad"
in
data_q
.
op
.
tag
:
data_pad
=
data_q
data_pad
=
data_q
data_q
=
data
data_q
=
data
data
=
data_q
.
op
.
input_tensors
[
0
]
if
"QuantizeInput"
in
data
.
op
.
name
:
# Need to go up 1 further, from the combine in bitpack
data
=
data
.
op
.
input_tensors
[
0
]
data
=
data
.
op
.
input_tensors
[
0
]
unipolar
=
"unipolar"
in
conv_out
.
op
.
tag
_schedule_spatial_conv2d_nhwc
(
s
,
data
,
data_q
,
data_pad
,
data
_vec
,
_schedule_spatial_conv2d_nhwc
(
cfg
,
s
,
data_pad
,
data_vec
,
kernel
_vec
,
kernel
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
outs
[
0
]
)
conv_out
,
output
,
outs
[
0
],
unipolar
)
scheduled_ops
.
append
(
op
)
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
traverse
(
outs
[
0
]
.
op
)
...
...
topi/python/topi/nn/bitserial_conv2d.py
View file @
17351875
# pylint: disable=invalid-name, unused-variable, too-many-locals, too-many-arguments, unused-argument
# pylint: disable=invalid-name, unused-variable, too-many-locals, too-many-arguments, unused-argument
"""Bitserial Conv2D operators"""
"""Bitserial Conv2D operators"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
collections
import
namedtuple
import
numpy
as
np
import
numpy
as
np
import
tvm
import
tvm
from
tvm
import
autotvm
from
topi.transform
import
concatenate
from
topi.transform
import
concatenate
from
.pad
import
pad
from
.pad
import
pad
from
.util
import
get_pad_tuple
from
.util
import
get_pad_tuple
from
..util
import
get_const_tuple
,
get_const_int
from
..util
import
get_const_tuple
,
get_const_int
# workload description of conv2d
Workload
=
namedtuple
(
'Workload'
,
[
'in_dtype'
,
'out_dtype'
,
'height'
,
'width'
,
'in_filter'
,
'out_filter'
,
'hkernel'
,
'wkernel'
,
'hpad'
,
'wpad'
,
'hstride'
,
'wstride'
])
SpatialPackNCHW
=
namedtuple
(
'SpatialPack'
,
[
'vh'
,
'vw'
,
'vc'
,
'ba'
,
'bc'
])
SpatialPackNHWC
=
namedtuple
(
'SpatialPack'
,
[
'vh'
,
'vw'
,
'vc'
,
'ba'
,
'bc'
])
_WORKLOADS
=
[
# workloads of resnet18 on imagenet
# input_size, input_size, ic, oc, kh, kw, pad, pad, stride, stride
Workload
(
'uint32'
,
'int32'
,
56
,
56
,
64
,
64
,
3
,
3
,
1
,
1
,
1
,
1
),
Workload
(
'uint32'
,
'int32'
,
56
,
56
,
64
,
64
,
1
,
1
,
0
,
0
,
1
,
1
),
Workload
(
'uint32'
,
'int32'
,
56
,
56
,
64
,
128
,
3
,
3
,
1
,
1
,
2
,
2
),
Workload
(
'uint32'
,
'int32'
,
56
,
56
,
64
,
128
,
1
,
1
,
0
,
0
,
2
,
2
),
Workload
(
'uint32'
,
'int32'
,
28
,
28
,
128
,
128
,
3
,
3
,
1
,
1
,
1
,
1
),
Workload
(
'uint32'
,
'int32'
,
28
,
28
,
128
,
256
,
3
,
3
,
1
,
1
,
2
,
2
),
Workload
(
'uint32'
,
'int32'
,
28
,
28
,
128
,
256
,
1
,
1
,
0
,
0
,
2
,
2
),
Workload
(
'uint32'
,
'int32'
,
14
,
14
,
256
,
256
,
3
,
3
,
1
,
1
,
1
,
1
),
Workload
(
'uint32'
,
'int32'
,
14
,
14
,
256
,
512
,
3
,
3
,
1
,
1
,
2
,
2
),
Workload
(
'uint32'
,
'int32'
,
14
,
14
,
256
,
512
,
1
,
1
,
0
,
0
,
2
,
2
),
Workload
(
'uint32'
,
'int32'
,
7
,
7
,
512
,
512
,
3
,
3
,
1
,
1
,
1
,
1
),
# workload of alexnet on cifar10
Workload
(
'int32'
,
'int32'
,
27
,
27
,
96
,
192
,
5
,
5
,
2
,
2
,
1
,
1
),
Workload
(
'int32'
,
'int32'
,
13
,
13
,
192
,
384
,
3
,
3
,
1
,
1
,
1
,
1
),
Workload
(
'int32'
,
'int32'
,
13
,
13
,
384
,
384
,
3
,
3
,
1
,
1
,
1
,
1
),
Workload
(
'int32'
,
'int32'
,
13
,
13
,
384
,
256
,
3
,
3
,
1
,
1
,
1
,
1
),
]
@tvm.target.generic_func
@tvm.target.generic_func
def
bitserial_conv2d
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
def
bitserial_conv2d
_nchw
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
layout
=
'NCHW'
,
pack_dtype
=
'uint32'
,
out_dtype
=
'int32'
,
dorefa
=
True
):
pack_dtype
=
'uint32'
,
out_dtype
=
'int16'
,
unipolar
=
True
):
"""Bitserial Conv2D operator.
"""Bitserial Conv2D operator.
Parameters
Parameters
----------
----------
input : tvm.Tensor
input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width] or
4-D with shape [batch, in_channel, in_height, in_width]
[batch, in_height, in_width, in_channel]
filter : tvm.Tensor
filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width] or
4-D with shape [num_filter, in_channel, filter_height, filter_width]
[filter_height, filter_width, in_channel, num_filter]
stride : int or a list/tuple of two ints
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding : int or a list/tuple of two or four ints
padding size, or [pad_height, pad_width]
padding size, [pad_height, pad_width], [pad_top, pad_left, pad_down, pad_right]
layout : str
layout of data
activation_bits: int
activation_bits: int
number of bits used for activations/input elements
number of bits used for activations/input elements
...
@@ -78,63 +40,184 @@ def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits
...
@@ -78,63 +40,184 @@ def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits
pack_dtype: str
pack_dtype: str
bit packing type
bit packing type
dorefa
: bool
unipolar
: bool
preform the bitserial dot-product using 2 popcounts (required for DoReFa-Net)
if binarization style is in unipolar 1/0 format, instead of bipolar -1/+1 format
Returns
Returns
-------
-------
output : tvm.Tensor
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width] or
4-D with shape [batch, out_channel, out_height, out_width]
[batch, out_height, out_width, out_channel]
"""
"""
# search platform specific declaration first
assert
isinstance
(
stride
,
int
)
or
len
(
stride
)
==
2
# default declaration
Input_q
=
bitpack
(
data
,
activation_bits
,
pack_axis
=
1
,
bit_axis
=
2
,
pack_type
=
pack_dtype
)
if
layout
==
'NCHW'
:
Filter_q
=
bitpack
(
filter
,
weight_bits
,
pack_axis
=
1
,
bit_axis
=
4
,
pack_type
=
pack_dtype
)
return
spatial_pack_nchw
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
batch
,
in_channel
,
activation_bits
,
in_height
,
in_width
=
Input_q
.
shape
pack_dtype
=
pack_dtype
,
out_dtype
=
out_dtype
,
dorefa
=
dorefa
)
num_filter
,
channel
,
kernel_h
,
kernel_w
,
weight_bits
=
Filter_q
.
shape
if
layout
==
'NHWC'
:
return
spatial_pack_nhwc
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
if
isinstance
(
padding
,
int
)
or
(
isinstance
(
padding
,
(
tuple
,
list
))
and
len
(
padding
)
==
2
):
pack_dtype
=
pack_dtype
,
out_dtype
=
out_dtype
,
dorefa
=
dorefa
)
TPAD
,
LPAD
,
DPAD
,
RPAD
=
get_pad_tuple
(
padding
,
kernel
)
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
def
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
layout
):
""" Get the workload structure. """
assert
layout
in
(
"NCHW"
,
"NHWC"
),
\
"Only support layouts NCHW and NHWC"
if
layout
==
"NCHW"
:
_
,
CI
,
IH
,
IW
=
[
x
.
value
for
x
in
data
.
shape
]
CO
,
_
,
KH
,
KW
=
[
x
.
value
for
x
in
kernel
.
shape
]
else
:
# NHWC
IH
,
IW
=
data
.
shape
[
1
]
.
value
,
data
.
shape
[
2
]
.
value
KH
,
KW
,
CI
,
CO
=
[
x
for
x
in
get_const_tuple
(
kernel
.
shape
)]
HPAD
,
WPAD
,
_
,
_
=
get_pad_tuple
(
padding
,
kernel
)
if
isinstance
(
stride
,
(
tuple
,
list
)):
HSTR
,
WSTR
=
stride
else
:
else
:
HSTR
,
WSTR
=
stride
,
stride
TPAD
,
LPAD
,
DPAD
,
RPAD
=
padding
pad_before
=
[
0
,
0
,
0
,
TPAD
,
LPAD
]
pad_after
=
[
0
,
0
,
0
,
DPAD
,
RPAD
]
PadInput_q
=
pad
(
Input_q
,
pad_before
,
pad_after
,
name
=
"pad_temp"
)
# compute the output shape
if
isinstance
(
stride
,
int
):
stride_h
=
stride_w
=
stride
else
:
stride_h
,
stride_w
=
stride
out_channel
=
num_filter
out_height
=
(
in_height
-
kernel_h
+
TPAD
+
DPAD
)
//
stride_h
+
1
out_width
=
(
in_width
-
kernel_w
+
LPAD
+
RPAD
)
//
stride_w
+
1
rc
=
tvm
.
reduce_axis
((
0
,
in_channel
),
name
=
'rc'
)
ry
=
tvm
.
reduce_axis
((
0
,
kernel_h
),
name
=
'ry'
)
rx
=
tvm
.
reduce_axis
((
0
,
kernel_w
),
name
=
'rx'
)
b1
=
tvm
.
reduce_axis
((
0
,
activation_bits
),
name
=
'b1'
)
b2
=
tvm
.
reduce_axis
((
0
,
weight_bits
),
name
=
'b2'
)
if
unipolar
:
def
_conv
(
nn
,
ff
,
yy
,
xx
):
b1b2
=
(
b1
+
b2
)
.
astype
(
out_dtype
)
return
tvm
.
sum
(
((
tvm
.
popcount
(
PadInput_q
[
nn
,
rc
,
b1
,
yy
*
stride_h
+
ry
,
xx
*
stride_w
+
rx
]
&
Filter_q
[
ff
,
rc
,
ry
,
rx
,
b2
])
-
tvm
.
popcount
(
PadInput_q
[
nn
,
rc
,
b1
,
yy
*
stride_h
+
ry
,
xx
*
stride_w
+
rx
]
&
~
Filter_q
[
ff
,
rc
,
ry
,
rx
,
b2
]))
<<
(
b1b2
))
.
astype
(
out_dtype
),
axis
=
[
rc
,
ry
,
rx
,
b2
,
b1
])
.
astype
(
out_dtype
)
else
:
def
_conv
(
nn
,
ff
,
yy
,
xx
):
b1b2
=
(
b1
+
b2
)
.
astype
(
out_dtype
)
return
tvm
.
sum
((
tvm
.
popcount
(
PadInput_q
[
nn
,
rc
,
b1
,
yy
*
stride_h
+
ry
,
xx
*
stride_w
+
rx
]
&
Filter_q
[
ff
,
rc
,
ry
,
rx
,
b2
])
<<
(
b1b2
))
.
astype
(
out_dtype
),
axis
=
[
rc
,
ry
,
rx
,
b2
,
b1
])
.
astype
(
out_dtype
)
return
Workload
(
data
.
dtype
,
out_dtype
,
IH
,
IW
,
CI
,
CO
,
KH
,
KW
,
HPAD
,
WPAD
,
HSTR
,
WSTR
)
return
tvm
.
compute
((
batch
,
out_channel
,
out_height
,
out_width
),
_conv
,
name
=
"Conv2dOutput"
,
tag
=
"bitserial_conv2d_nchw"
)
@tvm.target.generic_func
@tvm.target.generic_func
def
_get_schedule
(
wkl
,
layout
):
def
bitserial_conv2d_nhwc
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
# pylint: disable=unreachable
pack_dtype
=
'uint32'
,
out_dtype
=
'int16'
,
unipolar
=
True
):
""" Get the platform specific schedule. """
"""Bitserial Conv2D operator.
target
=
tvm
.
target
.
current_target
()
raise
RuntimeError
(
Parameters
"No schedule for current target:{}"
.
format
(
target
))
----------
# This return has no use, merely to supress pylint warning
input : tvm.Tensor
return
wkl
4-D with shape [batch, in_height, in_width, in_channel]
def
spatial_pack_nchw
(
data
,
kernel
,
stride
,
padding
,
in_bits
,
weight_bits
,
filter : tvm.Tensor
pack_dtype
,
out_dtype
,
dorefa
=
False
):
4-D with shape [filter_height, filter_width, in_channel, num_filter]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two or four ints
padding size, [pad_height, pad_width], [pad_top, pad_left, pad_down, pad_right]
activation_bits: int
number of bits used for activations/input elements
weight_bits: int
number of bits used for weight elements
out_dtype: str
return type of convolution
pack_dtype: str
bit packing type
unipolar: bool
if binarization style is in unipolar 1/0 format, instead of bipolar -1/+1 format
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
assert
isinstance
(
stride
,
int
)
or
len
(
stride
)
==
2
Input_q
=
bitpack
(
data
,
activation_bits
,
pack_axis
=
3
,
bit_axis
=
4
,
pack_type
=
pack_dtype
)
if
len
(
kernel
.
shape
)
==
4
:
Filter_q
=
bitpack
(
kernel
,
weight_bits
,
pack_axis
=
2
,
bit_axis
=
4
,
pack_type
=
pack_dtype
)
kernel_h
,
kernel_w
,
_
,
num_filter
,
_
=
get_const_tuple
(
Filter_q
.
shape
)
else
:
Filter_q
=
kernel
kernel_h
,
kernel_w
,
_
,
_
,
num_filter
=
get_const_tuple
(
Filter_q
.
shape
)
batch
,
in_height
,
in_width
,
in_channel_q
,
_
=
get_const_tuple
(
Input_q
.
shape
)
if
isinstance
(
padding
,
int
)
or
(
isinstance
(
padding
,
(
tuple
,
list
))
and
len
(
padding
)
==
2
):
TPAD
,
LPAD
,
DPAD
,
RPAD
=
get_pad_tuple
(
padding
,
kernel
)
else
:
TPAD
,
LPAD
,
DPAD
,
RPAD
=
padding
pad_before
=
[
0
,
TPAD
,
LPAD
,
0
,
0
]
pad_after
=
[
0
,
DPAD
,
RPAD
,
0
,
0
]
# compute the output shape
if
isinstance
(
stride
,
int
):
stride_h
=
stride_w
=
stride
else
:
stride_h
,
stride_w
=
stride
out_channel
=
num_filter
out_height
=
(
in_height
-
kernel_h
+
TPAD
+
DPAD
)
//
stride_h
+
1
out_width
=
(
in_width
-
kernel_w
+
LPAD
+
RPAD
)
//
stride_w
+
1
PadInput_q
=
pad
(
Input_q
,
pad_before
,
pad_after
,
name
=
"PaddedInput"
)
rc
=
tvm
.
reduce_axis
((
0
,
in_channel_q
),
name
=
'rc'
)
ry
=
tvm
.
reduce_axis
((
0
,
kernel_h
),
name
=
'ry'
)
rx
=
tvm
.
reduce_axis
((
0
,
kernel_w
),
name
=
'rx'
)
b1
=
tvm
.
reduce_axis
((
0
,
activation_bits
),
name
=
'b1'
)
b2
=
tvm
.
reduce_axis
((
0
,
weight_bits
),
name
=
'b2'
)
if
unipolar
:
def
_conv
(
nn
,
yy
,
xx
,
ff
):
b1b2
=
(
b1
+
b2
)
.
astype
(
out_dtype
)
return
tvm
.
sum
(
((
tvm
.
popcount
(
PadInput_q
[
nn
,
yy
*
stride_h
+
ry
,
xx
*
stride_w
+
rx
,
rc
,
b1
]
&
Filter_q
[
ry
,
rx
,
rc
,
ff
,
b2
])
-
tvm
.
popcount
(
PadInput_q
[
nn
,
yy
*
stride_h
+
ry
,
xx
*
stride_w
+
rx
,
rc
,
b1
]
&
~
Filter_q
[
ry
,
rx
,
rc
,
ff
,
b2
]))
<<
b1b2
)
.
astype
(
out_dtype
),
axis
=
[
rc
,
ry
,
rx
,
b2
,
b1
])
else
:
def
_conv
(
nn
,
yy
,
xx
,
ff
):
b1b2
=
(
b1
+
b2
)
.
astype
(
out_dtype
)
return
tvm
.
sum
((
tvm
.
popcount
(
PadInput_q
[
nn
,
yy
*
stride_h
+
ry
,
xx
*
stride_w
+
rx
,
rc
,
b1
]
&
Filter_q
[
ry
,
rx
,
rc
,
ff
,
b2
])
<<
b1b2
)
.
astype
(
out_dtype
),
axis
=
[
rc
,
ry
,
rx
,
b2
,
b1
])
conv
=
tvm
.
compute
((
batch
,
out_height
,
out_width
,
out_channel
),
_conv
,
name
=
"Conv2dOutput"
,
tag
=
"bitserial_conv2d_nhwc"
)
return
conv
@autotvm.register_topi_compute
(
bitserial_conv2d_nchw
,
[
'cpu'
,
'arm_cpu'
],
'direct'
)
def
spatial_pack_nchw
(
cfg
,
data
,
kernel
,
stride
,
padding
,
in_bits
,
weight_bits
,
pack_dtype
=
'uint32'
,
out_dtype
=
'int16'
,
unipolar
=
True
):
""" Compute convolution with pack on spatial axes. """
""" Compute convolution with pack on spatial axes. """
assert
data
.
shape
[
0
]
.
value
==
1
,
"spatial pack convolution only support batch size=1"
assert
data
.
shape
[
0
]
.
value
==
1
,
"spatial pack convolution only support batch size=1"
data_q
=
bitpack
(
data
,
in_bits
,
pack_axis
=
1
,
bit_axis
=
0
,
pack_type
=
pack_dtype
)
data_q
=
bitpack
(
data
,
in_bits
,
pack_axis
=
1
,
bit_axis
=
0
,
pack_type
=
pack_dtype
)
# Check if kernel is already bitpacked
if
len
(
kernel
.
shape
)
==
4
:
kernel_q
=
bitpack
(
kernel
,
weight_bits
,
pack_axis
=
1
,
bit_axis
=
0
,
pack_type
=
pack_dtype
)
kernel_q
=
bitpack
(
kernel
,
weight_bits
,
pack_axis
=
1
,
bit_axis
=
0
,
pack_type
=
pack_dtype
)
IB
,
_
,
CI
,
H
,
W
=
data_q
.
shape
KB
,
CO
,
_
,
KH
,
KW
=
get_const_tuple
(
kernel_q
.
shape
)
KB
,
CO
,
_
,
KH
,
KW
=
kernel_q
.
shape
else
:
HPAD
,
WPAD
,
_
,
_
=
get_pad_tuple
(
padding
,
kernel
)
kernel_vec
=
kernel
OCO
,
_
,
KH
,
KW
,
KB
,
VC
=
get_const_tuple
(
kernel_vec
.
shape
)
CO
=
OCO
*
VC
IB
,
N
,
CI
,
H
,
W
=
get_const_tuple
(
data_q
.
shape
)
KB
,
CO
,
_
,
KH
,
KW
=
get_const_tuple
(
kernel_q
.
shape
)
if
isinstance
(
padding
,
int
)
or
(
isinstance
(
padding
,
(
tuple
,
list
))
and
len
(
padding
)
==
2
):
TPAD
,
LPAD
,
DPAD
,
RPAD
=
get_pad_tuple
(
padding
,
kernel
)
else
:
TPAD
,
LPAD
,
DPAD
,
RPAD
=
padding
pad_before
=
[
0
,
0
,
0
,
TPAD
,
LPAD
]
pad_after
=
[
0
,
0
,
0
,
DPAD
,
RPAD
]
if
isinstance
(
stride
,
(
tuple
,
list
)):
if
isinstance
(
stride
,
(
tuple
,
list
)):
HSTR
,
WSTR
=
stride
HSTR
,
WSTR
=
stride
...
@@ -142,36 +225,48 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
...
@@ -142,36 +225,48 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
HSTR
,
WSTR
=
stride
,
stride
HSTR
,
WSTR
=
stride
,
stride
HCAT
,
WCAT
=
KH
-
1
,
KW
-
1
HCAT
,
WCAT
=
KH
-
1
,
KW
-
1
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
"NCHW"
)
TH
=
H
+
TPAD
+
DPAD
sch
=
_get_schedule
(
wkl
,
"NCHW"
)
TW
=
W
+
LPAD
+
RPAD
VH
=
sch
.
vh
OH
=
(
H
+
TPAD
+
DPAD
-
KH
)
//
HSTR
+
1
VW
=
sch
.
vw
OW
=
(
W
+
LPAD
+
RPAD
-
KW
)
//
WSTR
+
1
VC
=
sch
.
vc
# ==================== define configuration space ====================
TH
=
H
+
2
*
HPAD
n
,
co
,
oh
,
ow
=
cfg
.
axis
(
N
),
cfg
.
axis
(
CO
),
cfg
.
axis
(
OH
),
cfg
.
axis
(
OW
)
TW
=
W
+
2
*
WPAD
ci
,
kh
,
kw
=
cfg
.
reduce_axis
(
CI
),
cfg
.
reduce_axis
(
KH
),
cfg
.
reduce_axis
(
KW
)
OH
=
(
H
+
2
*
HPAD
-
KH
)
//
HSTR
+
1
ib
,
kb
=
cfg
.
reduce_axis
(
in_bits
),
cfg
.
reduce_axis
(
weight_bits
)
OW
=
(
W
+
2
*
WPAD
-
KW
)
//
WSTR
+
1
co
,
vc
=
cfg
.
define_split
(
'tile_co'
,
co
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
max
(
x
.
size
[
1
:])
<=
16
)
oh
,
vh
=
cfg
.
define_split
(
'tile_oh'
,
oh
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
max
(
x
.
size
[
1
:])
<=
16
)
ow
,
vw
=
cfg
.
define_split
(
'tile_ow'
,
ow
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
max
(
x
.
size
[
1
:])
<=
16
)
cfg
.
define_annotate
(
'ann_reduce'
,
[
ib
,
kb
,
kh
,
kw
],
policy
=
'try_unroll'
)
re_axes
=
cfg
.
define_reorder
(
"reorder_0"
,
[
n
,
co
,
oh
,
ow
,
vc
,
vh
,
vw
,
kh
,
kw
,
kb
,
ib
,
ci
],
policy
=
'interval_all'
,
interval
=
(
6
,
11
))
cfg
.
add_flop
(
2
*
N
*
OH
*
OW
*
CO
*
CI
*
8
*
KH
*
KW
)
# these are actually binary ops
# ====================
VC
=
cfg
[
"tile_co"
]
.
size
[
-
1
]
VH
=
cfg
[
"tile_oh"
]
.
size
[
-
1
]
VW
=
cfg
[
"tile_ow"
]
.
size
[
-
1
]
dshape
=
(
IB
,
1
,
CI
,
H
,
W
)
dpshape
=
(
IB
,
1
,
CI
,
TH
,
TW
)
dvshape
=
(
1
,
TH
//
(
VH
*
HSTR
),
TW
//
(
VW
*
WSTR
),
CI
,
VH
*
HSTR
+
HCAT
,
VW
*
WSTR
+
WCAT
,
IB
)
dvshape
=
(
1
,
TH
//
(
VH
*
HSTR
),
TW
//
(
VW
*
WSTR
),
CI
,
VH
*
HSTR
+
HCAT
,
VW
*
WSTR
+
WCAT
,
IB
)
kshape
=
(
KB
,
CO
,
CI
,
KH
,
KW
)
kvshape
=
(
CO
//
VC
,
CI
,
KH
,
KW
,
KB
,
VC
)
kvshape
=
(
CO
//
VC
,
CI
,
KH
,
KW
,
KB
,
VC
)
ovshape
=
(
1
,
CO
//
VC
,
OH
//
VH
,
OW
//
VW
,
VH
,
VW
,
VC
)
ovshape
=
(
1
,
CO
//
VC
,
OH
//
VH
,
OW
//
VW
,
VH
,
VW
,
VC
)
oshape
=
(
1
,
CO
,
OH
,
OW
)
oshape
=
(
1
,
CO
,
OH
,
OW
)
DOPAD
=
(
HPAD
!=
0
and
WPAD
!=
0
)
if
(
TPAD
!=
0
and
RPAD
!=
0
):
if
DOPAD
:
data_pad
=
pad
(
data_q
,
(
0
,
0
,
0
,
TPAD
,
LPAD
),
(
0
,
0
,
0
,
DPAD
,
RPAD
),
name
=
"data_pad"
)
data_pad
=
pad
(
data_q
,
(
0
,
0
,
0
,
HPAD
,
WPAD
),
name
=
"data_pad"
)
else
:
else
:
data_pad
=
data_q
data_pad
=
data_q
data_vec
=
tvm
.
compute
(
dvshape
,
lambda
n
,
h
,
w
,
ci
,
vh
,
vw
,
b
:
\
data_vec
=
tvm
.
compute
(
dvshape
,
lambda
n
,
h
,
w
,
ci
,
vh
,
vw
,
b
:
\
data_pad
[
b
][
n
][
ci
][
h
*
VH
*
HSTR
+
vh
][
w
*
VW
*
WSTR
+
vw
],
name
=
'data_vec'
)
data_pad
[
b
][
n
][
ci
][
h
*
VH
*
HSTR
+
vh
][
w
*
VW
*
WSTR
+
vw
],
name
=
'data_vec'
)
if
len
(
kernel
.
shape
)
==
4
:
kernel_vec
=
tvm
.
compute
(
kvshape
,
lambda
co
,
ci
,
dh
,
dw
,
b
,
vc
:
\
kernel_vec
=
tvm
.
compute
(
kvshape
,
lambda
co
,
ci
,
dh
,
dw
,
b
,
vc
:
\
kernel_q
[
b
][
co
*
VC
+
vc
][
ci
][
dh
][
dw
],
name
=
'kernel_vec'
)
kernel_q
[
b
][
co
*
VC
+
vc
][
ci
][
dh
][
dw
],
name
=
'kernel_vec'
)
...
@@ -183,7 +278,7 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
...
@@ -183,7 +278,7 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
def
_conv
(
n
,
co
,
h
,
w
,
vh
,
vw
,
vc
):
def
_conv
(
n
,
co
,
h
,
w
,
vh
,
vw
,
vc
):
b1b2
=
(
b1
+
b2
)
.
astype
(
out_dtype
)
b1b2
=
(
b1
+
b2
)
.
astype
(
out_dtype
)
if
dorefa
:
if
unipolar
:
return
tvm
.
sum
((
tvm
.
popcount
(
return
tvm
.
sum
((
tvm
.
popcount
(
data_vec
[
n
,
h
,
w
,
ci
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
b1
]
.
astype
(
out_dtype
)
&
data_vec
[
n
,
h
,
w
,
ci
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
b1
]
.
astype
(
out_dtype
)
&
kernel_vec
[
co
,
ci
,
dh
,
dw
,
b2
,
vc
]
.
astype
(
out_dtype
))
-
kernel_vec
[
co
,
ci
,
dh
,
dw
,
b2
,
vc
]
.
astype
(
out_dtype
))
-
...
@@ -203,15 +298,28 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
...
@@ -203,15 +298,28 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits,
conv
[
n
][
co
//
VC
][
h
//
VH
][
w
//
VW
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
conv
[
n
][
co
//
VC
][
h
//
VH
][
w
//
VW
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
name
=
'conv_vec'
,
tag
=
'spatial_bitserial_conv_nchw'
)
name
=
'conv_vec'
,
tag
=
'spatial_bitserial_conv_nchw'
)
def
spatial_pack_nhwc
(
data
,
kernel
,
stride
,
padding
,
in_bits
,
weight_bits
,
@autotvm.register_topi_compute
(
bitserial_conv2d_nhwc
,
'cpu'
,
'direct'
)
pack_dtype
,
out_dtype
,
dorefa
=
False
):
def
spatial_pack_nhwc
(
cfg
,
data
,
kernel
,
stride
,
padding
,
in_bits
,
weight_bits
,
pack_dtype
=
'uint32'
,
out_dtype
=
'int16'
,
unipolar
=
True
):
""" Compute convolution with pack on spatial axes. """
""" Compute convolution with pack on spatial axes. """
assert
data
.
shape
[
0
]
.
value
==
1
,
"spatial pack convolution only support batch size=1"
assert
data
.
shape
[
0
]
.
value
==
1
,
"spatial pack convolution only support batch size=1"
data_q
=
bitpack
(
data
,
in_bits
,
pack_axis
=
3
,
bit_axis
=
4
,
pack_type
=
pack_dtype
)
data_q
=
bitpack
(
data
,
in_bits
,
pack_axis
=
3
,
bit_axis
=
4
,
pack_type
=
pack_dtype
)
pack_kernel
=
len
(
kernel
.
shape
)
==
4
if
pack_kernel
:
kernel_q
=
bitpack
(
kernel
,
weight_bits
,
pack_axis
=
2
,
bit_axis
=
4
,
pack_type
=
pack_dtype
)
kernel_q
=
bitpack
(
kernel
,
weight_bits
,
pack_axis
=
2
,
bit_axis
=
4
,
pack_type
=
pack_dtype
)
_
,
H
,
W
,
CI
,
IB
=
data_q
.
shape
else
:
KH
,
KW
,
_
,
CO
,
KB
=
kernel_q
.
shape
kernel_q
=
kernel
HPAD
,
WPAD
,
_
,
_
=
get_pad_tuple
(
padding
,
kernel
)
KH
,
KW
,
_
,
CO
,
KB
=
get_const_tuple
(
kernel_q
.
shape
)
N
,
H
,
W
,
CI
,
IB
=
get_const_tuple
(
data_q
.
shape
)
if
isinstance
(
padding
,
int
)
or
(
isinstance
(
padding
,
(
tuple
,
list
))
and
len
(
padding
)
==
2
):
TPAD
,
LPAD
,
DPAD
,
RPAD
=
get_pad_tuple
(
padding
,
kernel
)
else
:
TPAD
,
LPAD
,
DPAD
,
RPAD
=
padding
pad_before
=
[
0
,
TPAD
,
LPAD
,
0
,
0
]
pad_after
=
[
0
,
DPAD
,
RPAD
,
0
,
0
]
if
isinstance
(
stride
,
(
tuple
,
list
)):
if
isinstance
(
stride
,
(
tuple
,
list
)):
HSTR
,
WSTR
=
stride
HSTR
,
WSTR
=
stride
...
@@ -219,24 +327,41 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
...
@@ -219,24 +327,41 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
HSTR
,
WSTR
=
stride
,
stride
HSTR
,
WSTR
=
stride
,
stride
HCAT
,
WCAT
=
KH
-
1
,
KW
-
1
HCAT
,
WCAT
=
KH
-
1
,
KW
-
1
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
"NHWC"
)
PAD_H
=
H
+
(
TPAD
+
DPAD
)
sch
=
_get_schedule
(
wkl
,
"NHWC"
)
PAD_W
=
W
+
(
LPAD
+
RPAD
)
VH
=
sch
.
vh
OH
=
(
PAD_H
-
KH
)
//
HSTR
+
1
VW
=
sch
.
vw
OW
=
(
PAD_W
-
KW
)
//
WSTR
+
1
VC
=
sch
.
vc
oshape
=
(
1
,
OH
,
OW
,
CO
)
PAD_H
=
H
+
2
*
HPAD
# ==================== define configuration space ====================
PAD_W
=
W
+
2
*
WPAD
n
,
oh
,
ow
,
co
=
cfg
.
axis
(
N
),
cfg
.
axis
(
OH
),
cfg
.
axis
(
OW
),
cfg
.
axis
(
CO
)
OH
=
(
H
+
2
*
HPAD
-
KH
)
//
HSTR
+
1
ci
,
kh
,
kw
=
cfg
.
reduce_axis
(
CI
),
cfg
.
reduce_axis
(
KH
),
cfg
.
reduce_axis
(
KW
)
OW
=
(
W
+
2
*
WPAD
-
KW
)
//
WSTR
+
1
ib
,
kb
=
cfg
.
reduce_axis
(
in_bits
),
cfg
.
reduce_axis
(
weight_bits
)
co
,
vc
=
cfg
.
define_split
(
'tile_co'
,
co
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
max
(
x
.
size
[
1
:])
<=
16
)
oh
,
vh
=
cfg
.
define_split
(
'tile_oh'
,
oh
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
max
(
x
.
size
[
1
:])
<=
16
)
ow
,
vw
=
cfg
.
define_split
(
'tile_ow'
,
ow
,
policy
=
'all'
,
num_outputs
=
2
,
filter
=
lambda
x
:
max
(
x
.
size
[
1
:])
<=
16
)
cfg
.
define_annotate
(
'ann_reduce'
,
[
ib
,
kb
,
kh
,
kw
],
policy
=
'try_unroll'
)
re_axes
=
cfg
.
define_reorder
(
"reorder_0"
,
[
n
,
oh
,
ow
,
co
,
vh
,
vw
,
kh
,
kw
,
kb
,
ib
,
vc
,
ci
],
policy
=
'interval_all'
,
interval
=
(
3
,
7
))
cfg
.
add_flop
(
2
*
N
*
OH
*
OW
*
CO
*
CI
*
8
*
KH
*
KW
)
# these are actually binary ops
# ====================
VC
=
cfg
[
"tile_co"
]
.
size
[
-
1
]
VH
=
cfg
[
"tile_oh"
]
.
size
[
-
1
]
VW
=
cfg
[
"tile_ow"
]
.
size
[
-
1
]
dvshape
=
(
1
,
PAD_H
//
(
VH
*
HSTR
),
PAD_W
//
(
VW
*
WSTR
),
VH
*
HSTR
+
HCAT
,
VW
*
WSTR
+
WCAT
,
CI
,
IB
)
dvshape
=
(
1
,
PAD_H
//
(
VH
*
HSTR
),
PAD_W
//
(
VW
*
WSTR
),
VH
*
HSTR
+
HCAT
,
VW
*
WSTR
+
WCAT
,
CI
,
IB
)
kvshape
=
(
CO
,
KH
,
KW
,
CI
,
VC
,
KB
)
kvshape
=
(
CO
,
KH
,
KW
,
CI
,
VC
,
KB
)
ovshape
=
(
1
,
OH
,
OW
,
CO
,
VH
,
VW
,
VC
)
ovshape
=
(
1
,
OH
,
OW
,
CO
,
VH
,
VW
,
VC
)
oshape
=
(
1
,
OH
,
OW
,
CO
)
oshape
=
(
1
,
OH
,
OW
,
CO
)
if
(
HPAD
!=
0
and
W
PAD
!=
0
):
if
(
DPAD
!=
0
and
R
PAD
!=
0
):
data_pad
=
pad
(
data_q
,
(
0
,
HPAD
,
W
PAD
,
0
,
0
),
name
=
"data_pad"
)
data_pad
=
pad
(
data_q
,
(
0
,
TPAD
,
LPAD
,
0
,
0
),
(
0
,
DPAD
,
R
PAD
,
0
,
0
),
name
=
"data_pad"
)
else
:
else
:
data_pad
=
data_q
data_pad
=
data_q
...
@@ -254,12 +379,12 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
...
@@ -254,12 +379,12 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
def
_conv
(
n
,
h
,
w
,
co
,
vh
,
vw
,
vc
):
def
_conv
(
n
,
h
,
w
,
co
,
vh
,
vw
,
vc
):
b1b2
=
(
b1
+
b2
)
.
astype
(
out_dtype
)
b1b2
=
(
b1
+
b2
)
.
astype
(
out_dtype
)
if
dorefa
:
if
unipolar
:
return
tvm
.
sum
(
return
tvm
.
sum
(
(
tvm
.
popcount
(
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ci
,
b1
]
.
astype
(
out_dtype
)
&
(
(
tvm
.
popcount
(
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ci
,
b1
]
&
kernel_vec
[
co
,
dh
,
dw
,
ci
,
vc
,
b2
]
.
astype
(
out_dtype
)
)
-
kernel_vec
[
co
,
dh
,
dw
,
ci
,
vc
,
b2
])
.
astype
(
out_dtype
)
-
tvm
.
popcount
(
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ci
,
b1
]
.
astype
(
out_dtype
)
&
tvm
.
popcount
(
data_vec
[
n
,
h
,
w
,
vh
*
HSTR
+
dh
,
vw
*
WSTR
+
dw
,
ci
,
b1
]
&
~
kernel_vec
[
co
,
dh
,
dw
,
ci
,
vc
,
b2
])
.
astype
(
out_dtype
))
<<
b1b2
,
~
kernel_vec
[
co
,
dh
,
dw
,
ci
,
vc
,
b2
])
.
astype
(
out_dtype
))
<<
b1b2
)
,
axis
=
[
dh
,
dw
,
ci
,
b1
,
b2
])
axis
=
[
dh
,
dw
,
ci
,
b1
,
b2
])
return
tvm
.
sum
(
tvm
.
popcount
(
return
tvm
.
sum
(
tvm
.
popcount
(
...
@@ -273,6 +398,7 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
...
@@ -273,6 +398,7 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits,
conv
[
n
][
h
//
VH
][
w
//
VW
][
co
//
VC
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
conv
[
n
][
h
//
VH
][
w
//
VW
][
co
//
VC
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
name
=
'output_unpack'
,
tag
=
'spatial_bitserial_conv_nhwc'
)
name
=
'output_unpack'
,
tag
=
'spatial_bitserial_conv_nhwc'
)
def
bitpack
(
data
,
bits
,
pack_axis
,
bit_axis
,
pack_type
,
name
=
"QuantizeInput"
):
def
bitpack
(
data
,
bits
,
pack_axis
,
bit_axis
,
pack_type
,
name
=
"QuantizeInput"
):
"""Packs data into format necessary for bitserial computation
"""Packs data into format necessary for bitserial computation
pack_axis : int
pack_axis : int
...
@@ -334,8 +460,3 @@ def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
...
@@ -334,8 +460,3 @@ def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
if
bits
>
1
:
if
bits
>
1
:
return
concatenate
(
output_tuple
,
axis
=
bit_axis
)
return
concatenate
(
output_tuple
,
axis
=
bit_axis
)
return
output_tuple
return
output_tuple
_SCH_TO_DECL_FUNC_QUANT
=
{
SpatialPackNCHW
:
spatial_pack_nchw
,
SpatialPackNHWC
:
spatial_pack_nhwc
,
}
topi/python/topi/x86/bitserial_conv2d.py
View file @
17351875
# pylint: disable=invalid-name,unused-variable,invalid-name
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Bitserial conv2d schedule on x86"""
"""Bitserial conv2d schedule on x86"""
import
tvm
import
tvm
from
tvm
import
autotvm
from
topi.util
import
get_const_int
from
topi.util
import
get_const_int
from
..
import
generic
,
tag
from
..
import
generic
,
tag
from
..nn.bitserial_conv2d
import
bitserial_conv2d
,
_get_schedule
,
_get_workload
from
..nn.bitserial_conv2d
import
SpatialPackNCHW
,
SpatialPackNHWC
@autotvm.register_topi_schedule
(
generic
.
nn
.
schedule_bitserial_conv2d_nchw
,
[
'cpu'
],
'direct'
)
from
..nn.bitserial_conv2d
import
_WORKLOADS
,
_SCH_TO_DECL_FUNC_QUANT
@autotvm.register_topi_schedule
(
generic
.
nn
.
schedule_bitserial_conv2d_nhwc
,
[
'cpu'
],
'direct'
)
def
schedule_bitserial_conv2d
(
cfg
,
outs
):
_QUANTIZED_SCHEDULES_NCHW
=
[
# resnet
SpatialPackNCHW
(
2
,
2
,
8
,
1
,
1
),
SpatialPackNCHW
(
1
,
4
,
8
,
4
,
1
),
SpatialPackNCHW
(
1
,
4
,
8
,
1
,
16
),
SpatialPackNCHW
(
1
,
4
,
8
,
4
,
8
),
SpatialPackNCHW
(
1
,
7
,
8
,
3
,
8
),
SpatialPackNCHW
(
1
,
2
,
8
,
1
,
8
),
SpatialPackNCHW
(
2
,
1
,
8
,
1
,
4
),
SpatialPackNCHW
(
1
,
7
,
8
,
1
,
1
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
16
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
8
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
16
),
SpatialPackNCHW
(
3
,
3
,
16
,
3
,
16
),
SpatialPackNCHW
(
1
,
1
,
16
,
2
,
16
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
16
),
SpatialPackNCHW
(
1
,
1
,
8
,
1
,
16
),
]
_QUANTIZED_SCHEDULES_NHWC
=
[
# resnet
SpatialPackNHWC
(
2
,
2
,
8
,
1
,
1
),
SpatialPackNHWC
(
1
,
4
,
8
,
4
,
1
),
SpatialPackNHWC
(
1
,
4
,
8
,
1
,
16
),
SpatialPackNHWC
(
1
,
4
,
8
,
4
,
8
),
SpatialPackNHWC
(
1
,
7
,
8
,
3
,
8
),
SpatialPackNHWC
(
1
,
2
,
8
,
1
,
8
),
SpatialPackNHWC
(
2
,
1
,
8
,
1
,
4
),
SpatialPackNHWC
(
1
,
7
,
8
,
1
,
1
),
SpatialPackNHWC
(
1
,
1
,
8
,
1
,
16
),
SpatialPackNHWC
(
1
,
1
,
8
,
1
,
8
),
SpatialPackNHWC
(
1
,
1
,
8
,
1
,
16
),
]
@_get_schedule.register
(
"cpu"
)
def
_get_schedule_bitserial_conv2d
(
wkl
,
layout
):
if
wkl
not
in
_WORKLOADS
:
raise
ValueError
(
"no schedule for such workload: {}"
.
format
(
wkl
))
idx
=
_WORKLOADS
.
index
(
wkl
)
if
layout
==
"NCHW"
:
sch
=
_QUANTIZED_SCHEDULES_NCHW
[
idx
]
elif
layout
==
"NHWC"
:
sch
=
_QUANTIZED_SCHEDULES_NHWC
[
idx
]
return
sch
@bitserial_conv2d.register
(
"cpu"
)
def
_declaration_bitserial_conv2d
(
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
layout
=
'NCHW'
,
pack_dtype
=
None
,
out_dtype
=
None
,
dorefa
=
False
):
if
out_dtype
is
None
:
out_dtype
=
data
.
dtype
assert
data
.
shape
[
0
]
.
value
==
1
,
"only support batch size=1 convolution on rasp"
assert
layout
in
(
"NCHW"
,
"NHWC"
),
"only support layouts NCHW and NHWC"
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
,
layout
)
sch
=
_get_schedule
(
wkl
,
layout
)
return
_SCH_TO_DECL_FUNC_QUANT
[
type
(
sch
)](
data
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
pack_dtype
,
out_dtype
,
dorefa
)
@generic.schedule_bitserial_conv2d_nchw.register
([
"cpu"
])
@generic.schedule_bitserial_conv2d_nhwc.register
([
"cpu"
])
def
schedule_bitserial_conv2d
(
outs
):
"""CPU schedule for bitserial convolutions NCHW and NHWC"""
"""CPU schedule for bitserial convolutions NCHW and NHWC"""
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
scheduled_ops
=
[]
...
@@ -88,7 +27,6 @@ def schedule_bitserial_conv2d(outs):
...
@@ -88,7 +27,6 @@ def schedule_bitserial_conv2d(outs):
conv_out
=
op
.
input_tensors
[
0
]
conv_out
=
op
.
input_tensors
[
0
]
kernel_vec
=
conv_out
.
op
.
input_tensors
[
1
]
kernel_vec
=
conv_out
.
op
.
input_tensors
[
1
]
kernel_q
=
kernel_vec
.
op
.
input_tensors
[
0
]
kernel_q
=
kernel_vec
.
op
.
input_tensors
[
0
]
kernel
=
kernel_q
.
op
.
input_tensors
[
0
]
data_vec
=
conv_out
.
op
.
input_tensors
[
0
]
data_vec
=
conv_out
.
op
.
input_tensors
[
0
]
data_q
=
data_vec
.
op
.
input_tensors
[
0
]
data_q
=
data_vec
.
op
.
input_tensors
[
0
]
data
=
data_q
.
op
.
input_tensors
[
0
]
data
=
data_q
.
op
.
input_tensors
[
0
]
...
@@ -97,28 +35,26 @@ def schedule_bitserial_conv2d(outs):
...
@@ -97,28 +35,26 @@ def schedule_bitserial_conv2d(outs):
data_pad
=
data_q
data_pad
=
data_q
data_q
=
data
data_q
=
data
data
=
data_q
.
op
.
input_tensors
[
0
]
data
=
data_q
.
op
.
input_tensors
[
0
]
if
"QuantizeInput"
in
kernel
.
op
.
name
:
# Need to go up 1 further, from the combine in bitpack
kernel
=
kernel
.
op
.
input_tensors
[
0
]
if
"QuantizeInput"
in
data
.
op
.
name
:
if
"QuantizeInput"
in
data
.
op
.
name
:
# Need to go up 1 further, from the combine in bitpack
# Need to go up 1 further, from the combine in bitpack
data
=
data
.
op
.
input_tensors
[
0
]
data
=
data
.
op
.
input_tensors
[
0
]
if
'spatial_bitserial_conv_nchw'
in
op
.
tag
:
if
'spatial_bitserial_conv_nchw'
in
op
.
tag
:
_schedule_
spatial_conv2d_nchw
(
s
,
data
,
data_q
,
data_pad
,
data_vec
,
_schedule_
bitserial_conv2d_nchw
(
cfg
,
s
,
data_q
,
data_pad
,
data_vec
,
kernel
,
kernel_q
,
kernel_vec
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
outs
[
0
])
conv_out
,
output
,
outs
[
0
])
elif
'spatial_bitserial_conv_nhwc'
in
op
.
tag
:
elif
'spatial_bitserial_conv_nhwc'
in
op
.
tag
:
_schedule_
spatial_conv2d_nhwc
(
s
,
data
,
data_q
,
data_pad
,
data_vec
,
_schedule_
bitserial_conv2d_nhwc
(
cfg
,
s
,
data_q
,
data_pad
,
data_vec
,
kernel
,
kernel_q
,
kernel_vec
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
outs
[
0
])
conv_out
,
output
,
outs
[
0
])
scheduled_ops
.
append
(
op
)
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
return
s
def
_schedule_
spatial_conv2d_nchw
(
s
,
data
,
data_q
,
data_pad
,
data_vec
,
def
_schedule_
bitserial_conv2d_nchw
(
cfg
,
s
,
data_q
,
data_pad
,
data_vec
,
kernel
,
kernel_q
,
kernel_vec
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
last
):
conv_out
,
output
,
last
):
IB
,
_
,
CI
,
IH
,
IW
=
data_q
.
shape
IB
,
_
,
CI
,
IH
,
IW
=
data_q
.
shape
KB
,
CO
,
_
,
KH
,
KW
=
kernel_q
.
shape
KB
,
CO
,
_
,
KH
,
KW
=
kernel_q
.
shape
...
@@ -138,37 +74,21 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
...
@@ -138,37 +74,21 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
wstride
=
get_const_int
((
TW
-
KW
)
//
(
OW
-
1
))
wstride
=
get_const_int
((
TW
-
KW
)
//
(
OW
-
1
))
stride
=
(
hstride
,
wstride
)
stride
=
(
hstride
,
wstride
)
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
output
.
dtype
,
"NCHW"
)
VC
=
cfg
[
"tile_co"
]
.
size
[
-
1
]
sch
=
_get_schedule
(
wkl
,
"NCHW"
)
VH
=
cfg
[
"tile_oh"
]
.
size
[
-
1
]
VH
=
sch
.
vh
VW
=
cfg
[
"tile_ow"
]
.
size
[
-
1
]
VW
=
sch
.
vw
VC
=
sch
.
vc
ba
=
sch
.
ba
bc
=
sch
.
bc
CC
=
s
.
cache_write
(
conv_out
,
"global"
)
n
,
co
,
oh
,
ow
,
vh
,
vw
,
vc
=
s
[
conv_out
]
.
op
.
axis
s
[
conv_out
]
.
vectorize
(
vc
)
s
[
CC
]
.
compute_at
(
s
[
conv_out
],
ow
)
n
,
co
,
oh
,
ow
,
vh
,
vw
,
vc
=
s
[
CC
]
.
op
.
axis
ci
,
dh
,
dw
,
b1
,
b2
=
s
[
CC
]
.
op
.
reduce_axis
s
[
CC
]
.
reorder
(
ci
,
dh
,
vh
,
dw
,
vw
,
b1
,
b2
,
vc
)
s
[
CC
]
.
unroll
(
b1
)
s
[
CC
]
.
unroll
(
b2
)
s
[
CC
]
.
vectorize
(
vc
)
##### Schedule A
##### Schedule Data padding, and bitpacking
if
data_pad
is
not
None
:
if
data_pad
is
not
None
:
s
[
data_pad
]
.
compute_inline
()
s
[
data_pad
]
.
compute_inline
()
_
,
h
,
_
,
_
,
_
,
_
,
vw
=
s
[
data_vec
]
.
op
.
axis
_
,
_
,
h
,
_
,
_
,
_
,
_
=
s
[
data_vec
]
.
op
.
axis
s
[
data_vec
]
.
vectorize
(
vw
)
cfg
.
define_split
(
"tile_ah"
,
cfg
.
axis
(
h
),
policy
=
"all"
,
num_outputs
=
2
,
max_factor
=
32
)
if
ba
==
1
:
oh
,
ih
=
cfg
[
"tile_ah"
]
.
apply
(
s
,
data_vec
,
h
)
oaxis
=
h
if
cfg
[
"tile_ah"
]
.
size
[
1
]
==
1
:
paxis
=
h
oaxis
=
oh
paxis
=
oh
else
:
else
:
oh
,
ih
=
s
[
data_vec
]
.
split
(
h
,
ba
)
oaxis
=
oh
oaxis
=
oh
paxis
=
ih
paxis
=
ih
...
@@ -178,14 +98,14 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
...
@@ -178,14 +98,14 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s
[
data_vec
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
s
[
data_vec
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
##### Schedule B
##### Schedule Kenerl bitpacking
co
,
_
,
_
,
_
,
_
,
vc
=
s
[
kernel_vec
]
.
op
.
axis
co
,
_
,
_
,
_
,
_
,
_
=
s
[
kernel_vec
]
.
op
.
axis
s
[
kernel_vec
]
.
vectorize
(
vc
)
cfg
.
define_split
(
"tile_bco"
,
cfg
.
axis
(
co
),
policy
=
"all"
,
num_outputs
=
2
,
max_factor
=
32
)
if
bc
==
1
:
oco
,
ico
=
cfg
[
"tile_bco"
]
.
apply
(
s
,
kernel_vec
,
co
)
oaxis
=
co
if
cfg
[
"tile_bco"
]
.
size
[
1
]
==
1
:
paxis
=
co
oaxis
=
oco
paxis
=
oco
else
:
else
:
oco
,
ico
=
s
[
kernel_vec
]
.
split
(
co
,
bc
)
oaxis
=
oco
oaxis
=
oco
paxis
=
ico
paxis
=
ico
...
@@ -195,7 +115,23 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
...
@@ -195,7 +115,23 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s
[
kernel_vec
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
s
[
kernel_vec
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
##### Schedule C
##### Schedule Convolution
n
,
co
,
oh
,
ow
,
vh
,
vw
,
vc
=
s
[
conv_out
]
.
op
.
axis
ci
,
dh
,
dw
,
ib
,
kb
=
s
[
conv_out
]
.
op
.
reduce_axis
# s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
cfg
[
"reorder_0"
]
.
apply
(
s
,
conv_out
,
[
n
,
co
,
oh
,
ow
,
vc
,
vh
,
vw
,
dh
,
dw
,
kb
,
ib
,
ci
])
cfg
[
"ann_reduce"
]
.
apply
(
s
,
conv_out
,
[
kb
,
ib
,
dh
,
dw
],
axis_lens
=
[
get_const_int
(
kb
.
dom
.
extent
),
get_const_int
(
ib
.
dom
.
extent
),
get_const_int
(
dh
.
dom
.
extent
),
get_const_int
(
dw
.
dom
.
extent
)],
max_unroll
=
16
,
cfg
=
cfg
)
s
[
conv_out
]
.
vectorize
(
vc
)
# # Schedule output
n
,
co
,
h
,
w
=
s
[
last
]
.
op
.
axis
n
,
co
,
h
,
w
=
s
[
last
]
.
op
.
axis
co
,
vc
=
s
[
last
]
.
split
(
co
,
VC
)
co
,
vc
=
s
[
last
]
.
split
(
co
,
VC
)
oh
,
ow
,
vh
,
vw
=
s
[
last
]
.
tile
(
h
,
w
,
VH
,
VW
)
oh
,
ow
,
vh
,
vw
=
s
[
last
]
.
tile
(
h
,
w
,
VH
,
VW
)
...
@@ -204,89 +140,58 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
...
@@ -204,89 +140,58 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s
[
output
]
.
compute_inline
()
s
[
output
]
.
compute_inline
()
s
[
conv_out
]
.
compute_at
(
s
[
last
],
ow
)
s
[
conv_out
]
.
compute_at
(
s
[
last
],
ow
)
if
bc
==
1
:
oco
,
ico
=
cfg
[
"tile_oh"
]
.
apply
(
s
,
last
,
co
)
oaxis
=
co
if
cfg
[
"tile_oh"
]
.
size
[
1
]
==
1
:
paxis
=
co
oaxis
=
oco
paxis
=
oco
else
:
else
:
oco
,
ico
=
s
[
last
]
.
split
(
co
,
bc
)
oco
,
ico
=
s
[
last
]
.
split
(
co
,
bc
)
oaxis
=
oco
oaxis
=
oco
paxis
=
ico
paxis
=
ico
s
[
last
]
.
parallel
(
paxis
)
s
[
last
]
.
parallel
(
oco
)
s
[
last
]
.
pragma
(
oaxis
,
"parallel_launch_point"
)
s
[
last
]
.
pragma
(
paxis
,
"parallel_stride_pattern"
)
s
[
last
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
return
s
return
s
def
_schedule_
spatial_conv2d_nhwc
(
s
,
data
,
data_q
,
data_pad
,
data_vec
,
def
_schedule_
bitserial_conv2d_nhwc
(
cfg
,
s
,
data_q
,
data_pad
,
data_vec
,
kernel
,
kernel_q
,
kernel_vec
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
last
):
conv_out
,
output
,
last
):
# no stride and padding info here
# no stride and padding info here
_
,
IH
,
IW
,
CI
,
IB
=
data_q
.
shape
_
,
IH
,
IW
,
CI
,
IB
=
data_q
.
shape
KH
,
KW
,
_
,
CO
,
KB
=
kernel_q
.
shape
KH
,
KW
,
_
,
CO
,
KB
=
kernel_q
.
shape
_
,
OH
,
OW
,
_
=
output
.
shape
_
,
OH
,
OW
,
_
=
output
.
shape
# Infer padding and stride
if
data_pad
is
None
:
padding
=
(
0
,
0
)
TH
,
TW
=
IH
,
IW
else
:
_
,
TH
,
TW
,
_
,
_
=
data_pad
.
shape
hpad
=
get_const_int
((
TH
-
IH
)
//
2
)
wpad
=
get_const_int
((
TW
-
IW
)
//
2
)
padding
=
(
hpad
,
wpad
)
hstride
=
get_const_int
((
TH
-
KH
)
//
(
OH
-
1
))
VC
=
cfg
[
"tile_co"
]
.
size
[
-
1
]
wstride
=
get_const_int
((
TW
-
KW
)
//
(
OW
-
1
))
VH
=
cfg
[
"tile_oh"
]
.
size
[
-
1
]
stride
=
(
hstride
,
wstride
)
VW
=
cfg
[
"tile_ow"
]
.
size
[
-
1
]
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
last
.
dtype
,
"NHWC"
)
##### Schedule data padding and packing
sch
=
_get_schedule
(
wkl
,
"NHWC"
)
VH
=
sch
.
vh
VW
=
sch
.
vw
VC
=
sch
.
vc
ba
=
sch
.
ba
bc
=
sch
.
bc
##### Schedule data packing
if
data_pad
is
not
None
:
if
data_pad
is
not
None
:
s
[
data_pad
]
.
compute_inline
()
s
[
data_pad
]
.
compute_inline
()
_
,
h
,
_
,
_
,
_
,
_
,
_
=
s
[
data_vec
]
.
op
.
axis
_
,
h
,
_
,
_
,
_
,
_
,
_
=
s
[
data_vec
]
.
op
.
axis
if
ba
==
1
:
cfg
.
define_split
(
"tile_ah"
,
cfg
.
axis
(
h
),
policy
=
"all"
,
num_outputs
=
2
,
max_factor
=
32
)
oaxis
=
h
oh
,
ih
=
cfg
[
"tile_ah"
]
.
apply
(
s
,
data_vec
,
h
)
paxis
=
h
s
[
data_vec
]
.
parallel
(
oh
)
else
:
oh
,
ih
=
s
[
data_vec
]
.
split
(
h
,
ba
)
oaxis
=
oh
paxis
=
ih
s
[
data_vec
]
.
parallel
(
paxis
)
s
[
data_vec
]
.
pragma
(
oaxis
,
"parallel_launch_point"
)
s
[
data_vec
]
.
pragma
(
paxis
,
"parallel_stride_pattern"
)
s
[
data_vec
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
##### Schedule kernel packing
##### Schedule kernel packing
co
,
_
,
_
,
_
,
_
,
_
=
s
[
kernel_vec
]
.
op
.
axis
co
,
_
,
_
,
_
,
_
,
_
=
s
[
kernel_vec
]
.
op
.
axis
if
bc
==
1
:
cfg
.
define_split
(
"tile_bco"
,
cfg
.
axis
(
co
),
policy
=
"all"
,
num_outputs
=
2
,
max_factor
=
32
)
oaxis
=
co
oco
,
ico
=
cfg
[
"tile_bco"
]
.
apply
(
s
,
kernel_vec
,
co
)
paxis
=
co
s
[
kernel_vec
]
.
parallel
(
oco
)
else
:
oco
,
ico
=
s
[
kernel_vec
]
.
split
(
co
,
bc
)
oaxis
=
oco
paxis
=
ico
s
[
kernel_vec
]
.
parallel
(
paxis
)
s
[
kernel_vec
]
.
pragma
(
oaxis
,
"parallel_launch_point"
)
s
[
kernel_vec
]
.
pragma
(
paxis
,
"parallel_stride_pattern"
)
s
[
kernel_vec
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
##### Schedule Convolution
##### Schedule Convolution
n
,
oh
,
ow
,
co
,
vh
,
vw
,
vc
=
s
[
conv_out
]
.
op
.
axis
n
,
oh
,
ow
,
co
,
vh
,
vw
,
vc
=
s
[
conv_out
]
.
op
.
axis
dh
,
dw
,
ci
,
b1
,
b2
=
s
[
conv_out
]
.
op
.
reduce_axis
dh
,
dw
,
ci
,
b1
,
b2
=
s
[
conv_out
]
.
op
.
reduce_axis
s
[
conv_out
]
.
reorder
(
n
,
oh
,
ow
,
co
,
vh
,
vw
,
dh
,
dw
,
ci
,
vc
,
b1
,
b2
)
# s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
cfg
[
"reorder_0"
]
.
apply
(
s
,
conv_out
,
[
n
,
oh
,
ow
,
co
,
vh
,
vw
,
dh
,
dw
,
ci
,
vc
,
b1
,
b2
])
cfg
[
"ann_reduce"
]
.
apply
(
s
,
conv_out
,
[
b1
,
b2
,
dh
,
dw
],
axis_lens
=
[
get_const_int
(
b1
.
dom
.
extent
),
get_const_int
(
b2
.
dom
.
extent
),
get_const_int
(
dh
.
dom
.
extent
),
get_const_int
(
dw
.
dom
.
extent
)],
max_unroll
=
16
,
cfg
=
cfg
)
s
[
conv_out
]
.
unroll
(
b1
)
s
[
conv_out
]
.
unroll
(
b1
)
s
[
conv_out
]
.
unroll
(
b2
)
s
[
conv_out
]
.
unroll
(
b2
)
...
@@ -302,17 +207,7 @@ def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
...
@@ -302,17 +207,7 @@ def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
s
[
output
]
.
compute_inline
()
s
[
output
]
.
compute_inline
()
s
[
conv_out
]
.
compute_at
(
s
[
last
],
ow
)
s
[
conv_out
]
.
compute_at
(
s
[
last
],
ow
)
if
bc
==
1
:
oho
,
iho
=
cfg
[
"tile_oh"
]
.
apply
(
s
,
last
,
oh
)
# reuse parameter
oaxis
=
oh
s
[
last
]
.
parallel
(
oho
)
paxis
=
oh
else
:
oho
,
iho
=
s
[
last
]
.
split
(
oh
,
bc
)
oaxis
=
oho
paxis
=
iho
s
[
last
]
.
parallel
(
paxis
)
s
[
last
]
.
pragma
(
oaxis
,
"parallel_launch_point"
)
s
[
last
]
.
pragma
(
paxis
,
"parallel_stride_pattern"
)
s
[
last
]
.
pragma
(
oaxis
,
"parallel_barrier_when_finish"
)
return
s
return
s
topi/tests/python/test_topi_bitserial_conv2d.py
View file @
17351875
...
@@ -11,16 +11,16 @@ def generate_quantized_np(shape, bits, out_dtype):
...
@@ -11,16 +11,16 @@ def generate_quantized_np(shape, bits, out_dtype):
return
np
.
random
.
randint
(
min_val
,
max_val
,
size
=
shape
)
.
astype
(
out_dtype
)
return
np
.
random
.
randint
(
min_val
,
max_val
,
size
=
shape
)
.
astype
(
out_dtype
)
def
verify_bitserial_conv2d_nchw
(
batch
,
in_size
,
in_channel
,
num_filter
,
kernel
,
stride
,
padding
,
def
verify_bitserial_conv2d_nchw
(
batch
,
in_size
,
in_channel
,
num_filter
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
dorefa
):
activation_bits
,
weight_bits
,
unipolar
):
in_height
=
in_width
=
in_size
in_height
=
in_width
=
in_size
input_type
=
'uint32'
input_
d
type
=
'uint32'
out_dtype
=
'int32'
out_dtype
=
'int32'
with
tvm
.
target
.
create
(
'llvm'
):
with
tvm
.
target
.
create
(
'llvm'
):
A
=
tvm
.
placeholder
((
batch
,
in_channel
,
in_height
,
in_width
),
dtype
=
input_type
,
name
=
'A'
)
A
=
tvm
.
placeholder
((
batch
,
in_channel
,
in_height
,
in_width
),
dtype
=
input_
d
type
,
name
=
'A'
)
W
=
tvm
.
placeholder
((
num_filter
,
in_channel
,
kernel
,
kernel
),
dtype
=
input_type
,
name
=
'W'
)
W
=
tvm
.
placeholder
((
num_filter
,
in_channel
,
kernel
,
kernel
),
dtype
=
input_
d
type
,
name
=
'W'
)
B
=
topi
.
nn
.
bitserial_conv2d
(
A
,
W
,
stride
,
padding
,
activation_bits
,
weight_bits
,
B
=
topi
.
nn
.
bitserial_conv2d
_nchw
(
A
,
W
,
stride
,
padding
,
activation_bits
,
weight_bits
,
out_dtype
=
out_dtype
,
layout
=
"NCHW"
,
dorefa
=
dorefa
)
out_dtype
=
out_dtype
,
unipolar
=
unipolar
)
s
=
topi
.
generic
.
schedule_bitserial_conv2d_nchw
([
B
])
s
=
topi
.
generic
.
schedule_bitserial_conv2d_nchw
([
B
])
a_shape
=
get_const_tuple
(
A
.
shape
)
a_shape
=
get_const_tuple
(
A
.
shape
)
...
@@ -28,9 +28,9 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
...
@@ -28,9 +28,9 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
@memoize
(
"topi.tests.test_topi_bitseral_conv2d_nchw"
)
@memoize
(
"topi.tests.test_topi_bitseral_conv2d_nchw"
)
def
get_ref_data
():
def
get_ref_data
():
a_np
=
generate_quantized_np
(
get_const_tuple
(
a_shape
),
activation_bits
,
input_type
)
a_np
=
generate_quantized_np
(
get_const_tuple
(
a_shape
),
activation_bits
,
input_
d
type
)
w_np
=
generate_quantized_np
(
get_const_tuple
(
w_shape
),
weight_bits
,
input_type
)
w_np
=
generate_quantized_np
(
get_const_tuple
(
w_shape
),
weight_bits
,
input_
d
type
)
if
dorefa
:
if
unipolar
:
w_
=
np
.
copy
(
w_np
)
.
astype
(
out_dtype
)
w_
=
np
.
copy
(
w_np
)
.
astype
(
out_dtype
)
for
x
in
np
.
nditer
(
w_
,
op_flags
=
[
'readwrite'
]):
for
x
in
np
.
nditer
(
w_
,
op_flags
=
[
'readwrite'
]):
x
[
...
]
=
1
if
x
==
1
else
-
1
x
[
...
]
=
1
if
x
==
1
else
-
1
...
@@ -49,16 +49,16 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
...
@@ -49,16 +49,16 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
tvm
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
)
tvm
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
)
def
verify_bitserial_conv2d_nhwc
(
batch
,
in_size
,
in_channel
,
num_filter
,
kernel
,
stride
,
padding
,
def
verify_bitserial_conv2d_nhwc
(
batch
,
in_size
,
in_channel
,
num_filter
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
dorefa
):
activation_bits
,
weight_bits
,
unipolar
):
in_height
=
in_width
=
in_size
in_height
=
in_width
=
in_size
input_type
=
'uint32'
input_
d
type
=
'uint32'
out_dtype
=
'int32'
out_dtype
=
'int32'
with
tvm
.
target
.
create
(
'llvm'
):
with
tvm
.
target
.
create
(
'llvm'
):
A
=
tvm
.
placeholder
((
batch
,
in_height
,
in_width
,
in_channel
),
dtype
=
input_type
,
name
=
'A'
)
A
=
tvm
.
placeholder
((
batch
,
in_height
,
in_width
,
in_channel
),
dtype
=
input_
d
type
,
name
=
'A'
)
W
=
tvm
.
placeholder
((
kernel
,
kernel
,
in_channel
,
num_filter
),
dtype
=
input_type
,
name
=
'W'
)
W
=
tvm
.
placeholder
((
kernel
,
kernel
,
in_channel
,
num_filter
),
dtype
=
input_
d
type
,
name
=
'W'
)
B
=
topi
.
nn
.
bitserial_conv2d
(
A
,
W
,
stride
,
padding
,
activation_bits
,
weight_bits
,
out_dtype
=
out_dtype
,
B
=
topi
.
nn
.
bitserial_conv2d
_nhwc
(
A
,
W
,
stride
,
padding
,
activation_bits
,
weight_bits
,
layout
=
"NHWC"
,
dorefa
=
dorefa
)
out_dtype
=
out_dtype
,
unipolar
=
unipolar
)
s
=
topi
.
generic
.
schedule_bitserial_conv2d_nhwc
([
B
])
s
=
topi
.
generic
.
schedule_bitserial_conv2d_nhwc
([
B
])
a_shape
=
get_const_tuple
(
A
.
shape
)
a_shape
=
get_const_tuple
(
A
.
shape
)
...
@@ -66,9 +66,9 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
...
@@ -66,9 +66,9 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
@memoize
(
"topi.tests.test_topi_bitseral_conv2d_nhwc"
)
@memoize
(
"topi.tests.test_topi_bitseral_conv2d_nhwc"
)
def
get_ref_data
():
def
get_ref_data
():
a_np
=
generate_quantized_np
(
get_const_tuple
(
a_shape
),
activation_bits
,
input_type
)
a_np
=
generate_quantized_np
(
get_const_tuple
(
a_shape
),
activation_bits
,
input_
d
type
)
w_np
=
generate_quantized_np
(
get_const_tuple
(
w_shape
),
weight_bits
,
input_type
)
w_np
=
generate_quantized_np
(
get_const_tuple
(
w_shape
),
weight_bits
,
input_
d
type
)
if
dorefa
:
if
unipolar
:
w_
=
np
.
copy
(
w_np
)
.
astype
(
out_dtype
)
w_
=
np
.
copy
(
w_np
)
.
astype
(
out_dtype
)
for
x
in
np
.
nditer
(
w_
,
op_flags
=
[
'readwrite'
]):
for
x
in
np
.
nditer
(
w_
,
op_flags
=
[
'readwrite'
]):
x
[
...
]
=
1
if
x
==
1
else
-
1
x
[
...
]
=
1
if
x
==
1
else
-
1
...
...
topi/tests/python/test_topi_bitserial_conv2d_rasp.py
View file @
17351875
...
@@ -4,6 +4,7 @@ import numpy as np
...
@@ -4,6 +4,7 @@ import numpy as np
import
tvm
import
tvm
import
topi
import
topi
import
topi.testing
import
topi.testing
from
topi.util
import
get_const_tuple
def
generate_quantized_np
(
shape
,
bits
,
out_dtype
):
def
generate_quantized_np
(
shape
,
bits
,
out_dtype
):
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
...
@@ -13,19 +14,20 @@ def generate_quantized_np(shape, bits, out_dtype):
...
@@ -13,19 +14,20 @@ def generate_quantized_np(shape, bits, out_dtype):
# Verify that certain special instructions from the tensorize pass exist
# Verify that certain special instructions from the tensorize pass exist
def
verify_bitserial_conv2d_nhwc
(
batch
,
in_size
,
in_channel
,
num_filter
,
kernel
,
stride
,
padding
,
def
verify_bitserial_conv2d_nhwc
(
batch
,
in_size
,
in_channel
,
num_filter
,
kernel
,
stride
,
padding
,
activation_bits
,
weight_bits
,
dorefa
):
activation_bits
,
weight_bits
,
unipolar
):
in_height
=
in_width
=
in_size
in_height
=
in_width
=
in_size
input_type
=
'uint32'
input_type
=
'uint32'
out_dtype
=
'int
32
'
out_dtype
=
'int
16
'
with
tvm
.
target
.
arm_cpu
(
'rasp3b'
):
device
=
'llvm -device=arm_cpu -model=bcm2837 -target=armv7l-linux-gnueabihf -mattr=+neon'
with
tvm
.
target
.
create
(
device
):
A
=
tvm
.
placeholder
((
batch
,
in_height
,
in_width
,
in_channel
),
dtype
=
input_type
,
name
=
'A'
)
A
=
tvm
.
placeholder
((
batch
,
in_height
,
in_width
,
in_channel
),
dtype
=
input_type
,
name
=
'A'
)
W
=
tvm
.
placeholder
((
kernel
,
kernel
,
in_channel
,
num_filter
),
dtype
=
input_type
,
name
=
'W'
)
W
=
tvm
.
placeholder
((
kernel
,
kernel
,
in_channel
,
num_filter
),
dtype
=
input_type
,
name
=
'W'
)
B
=
topi
.
nn
.
bitserial_conv2d
(
A
,
W
,
stride
,
padding
,
activation_bits
,
weight_bits
,
out_dtype
=
out_dtype
,
B
=
topi
.
nn
.
bitserial_conv2d
_nhwc
(
A
,
W
,
stride
,
padding
,
activation_bits
,
weight_bits
,
layout
=
"NHWC"
,
dorefa
=
dorefa
)
pack_dtype
=
'uint8'
,
out_dtype
=
'int16'
,
unipolar
=
unipolar
)
s
=
topi
.
generic
.
schedule_bitserial_conv2d_nhwc
([
B
])
s
=
topi
.
generic
.
schedule_bitserial_conv2d_nhwc
([
B
])
func
=
tvm
.
build
(
s
,
[
A
,
W
,
B
],
tvm
.
target
.
arm_cpu
(
'rasp3b'
)
)
func
=
tvm
.
build
(
s
,
[
A
,
W
,
B
],
device
)
assembly
=
func
.
get_source
(
'asm'
)
assembly
=
func
.
get_source
(
'asm'
)
matches
=
re
.
findall
(
"vpadal"
,
assembly
)
matches
=
re
.
findall
(
"vpadal"
,
assembly
)
...
@@ -35,6 +37,33 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
...
@@ -35,6 +37,33 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
matches
=
re
.
findall
(
"vpadd"
,
assembly
)
matches
=
re
.
findall
(
"vpadd"
,
assembly
)
assert
(
len
(
matches
)
>
0
)
assert
(
len
(
matches
)
>
0
)
ctx
=
tvm
.
context
(
device
,
0
)
if
'arm'
not
in
os
.
uname
()[
4
]:
print
(
"Skipped running code, not an arm device"
)
return
print
(
"Running on target:
%
s"
%
device
)
def
get_ref_data
():
a_np
=
generate_quantized_np
(
get_const_tuple
(
A
.
shape
),
activation_bits
,
input_type
)
w_np
=
generate_quantized_np
(
get_const_tuple
(
W
.
shape
),
weight_bits
,
input_type
)
if
unipolar
:
w_
=
np
.
copy
(
w_np
)
.
astype
(
out_dtype
)
for
x
in
np
.
nditer
(
w_
,
op_flags
=
[
'readwrite'
]):
x
[
...
]
=
1
if
x
==
1
else
-
1
b_np
=
topi
.
testing
.
conv2d_nhwc_python
(
a_np
,
w_
,
stride
,
padding
)
.
astype
(
out_dtype
)
else
:
b_np
=
topi
.
testing
.
conv2d_nhwc_python
(
a_np
,
w_np
,
stride
,
padding
)
.
astype
(
out_dtype
)
return
a_np
,
w_np
,
b_np
a_np
,
w_np
,
b_np
=
get_ref_data
()
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
w
=
tvm
.
nd
.
array
(
w_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
B
.
shape
),
dtype
=
B
.
dtype
),
ctx
)
func
=
tvm
.
build
(
s
,
[
A
,
W
,
B
],
device
)
func
(
a
,
w
,
b
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
)
def
test_bitserial_conv2d
():
def
test_bitserial_conv2d
():
in_size
=
56
in_size
=
56
ic
,
oc
=
64
,
64
ic
,
oc
=
64
,
64
...
@@ -45,6 +74,9 @@ def test_bitserial_conv2d():
...
@@ -45,6 +74,9 @@ def test_bitserial_conv2d():
verify_bitserial_conv2d_nhwc
(
1
,
in_size
,
ic
,
oc
,
k
,
stride
,
pad
,
1
,
1
,
False
)
verify_bitserial_conv2d_nhwc
(
1
,
in_size
,
ic
,
oc
,
k
,
stride
,
pad
,
1
,
1
,
False
)
verify_bitserial_conv2d_nhwc
(
1
,
in_size
,
ic
,
oc
,
k
,
stride
,
pad
,
2
,
1
,
False
)
verify_bitserial_conv2d_nhwc
(
1
,
in_size
,
ic
,
oc
,
k
,
stride
,
pad
,
2
,
1
,
False
)
verify_bitserial_conv2d_nhwc
(
1
,
in_size
,
ic
,
oc
,
k
,
stride
,
pad
,
1
,
1
,
True
)
verify_bitserial_conv2d_nhwc
(
1
,
in_size
,
ic
,
oc
,
k
,
stride
,
pad
,
2
,
1
,
True
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_bitserial_conv2d
()
test_bitserial_conv2d
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment