Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5c410c4c
Commit
5c410c4c
authored
Nov 11, 2018
by
Wuwei Lin
Committed by
Tianqi Chen
Nov 10, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI][CUDA] int8 group conv2d (#2075)
parent
3ee13fc5
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
690 additions
and
16 deletions
+690
-16
nnvm/python/nnvm/top/nn.py
+5
-0
python/tvm/autotvm/task/nnvm_integration.py
+13
-1
topi/python/topi/cuda/__init__.py
+2
-1
topi/python/topi/cuda/group_conv2d_nchw.py
+308
-0
topi/python/topi/generic/nn.py
+19
-0
topi/python/topi/nn/conv2d.py
+77
-0
topi/python/topi/testing/conv2d_nchw_python.py
+35
-2
topi/tests/python/common.py
+15
-0
topi/tests/python/test_topi_conv2d_int8.py
+1
-12
topi/tests/python/test_topi_group_conv2d.py
+215
-0
No files found.
nnvm/python/nnvm/top/nn.py
View file @
5c410c4c
...
...
@@ -108,6 +108,9 @@ def compute_conv2d(attrs, inputs, _):
groups
==
channels
:
out
=
topi
.
nn
.
depthwise_conv2d_nchw
(
inputs
[
0
],
inputs
[
1
],
strides
,
padding
,
dilation
,
out_dtype
=
out_dtype
)
elif
layout
==
"NCHW"
:
out
=
topi
.
nn
.
group_conv2d_nchw
(
inputs
[
0
],
inputs
[
1
],
strides
,
padding
,
dilation
,
groups
,
out_dtype
=
out_dtype
)
elif
layout
==
"NHWC"
and
\
kernel_layout
==
"HWOI"
and
\
groups
==
get_const_int
(
inputs
[
0
]
.
shape
[
3
])
and
\
...
...
@@ -143,6 +146,8 @@ def schedule_conv2d(attrs, outs, target):
return
topi
.
generic
.
schedule_depthwise_conv2d_nchw
(
outs
)
elif
groups
==
channels
and
layout
==
"NHWC"
and
kernel_layout
==
"HWOI"
:
return
topi
.
generic
.
schedule_depthwise_conv2d_nhwc
(
outs
)
elif
layout
==
"NCHW"
:
return
topi
.
generic
.
schedule_group_conv2d_nchw
(
outs
)
else
:
raise
ValueError
(
"No compatible schedule"
)
...
...
python/tvm/autotvm/task/nnvm_integration.py
View file @
5c410c4c
...
...
@@ -58,7 +58,8 @@ class TaskExtractEnv:
# NOTE: To add more symbols, you only need to change the following lists
# nnvm symbol -> topi compute
self
.
symbol2topi
=
{
nnvm
.
sym
.
conv2d
:
[
topi
.
nn
.
conv2d
,
topi
.
nn
.
depthwise_conv2d_nchw
],
nnvm
.
sym
.
conv2d
:
[
topi
.
nn
.
conv2d
,
topi
.
nn
.
depthwise_conv2d_nchw
,
topi
.
nn
.
group_conv2d_nchw
],
nnvm
.
sym
.
conv2d_transpose
:
[
topi
.
nn
.
conv2d_transpose_nchw
],
nnvm
.
sym
.
dense
:
[
topi
.
nn
.
dense
],
}
...
...
@@ -67,6 +68,7 @@ class TaskExtractEnv:
self
.
topi_to_task
=
{
topi
.
nn
.
conv2d
:
"topi_nn_conv2d"
,
topi
.
nn
.
depthwise_conv2d_nchw
:
"topi_nn_depthwise_conv2d_nchw"
,
topi
.
nn
.
group_conv2d_nchw
:
"topi_nn_group_conv2d_nchw"
,
topi
.
nn
.
conv2d_transpose_nchw
:
"topi_nn_conv2d_transpose_nchw"
,
topi
.
nn
.
dense
:
"topi_nn_dense"
,
}
...
...
@@ -76,6 +78,7 @@ class TaskExtractEnv:
topi
.
generic
.
schedule_conv2d_nhwc
],
topi
.
nn
.
depthwise_conv2d_nchw
:
[
topi
.
generic
.
schedule_depthwise_conv2d_nchw
,
topi
.
generic
.
schedule_depthwise_conv2d_nhwc
],
topi
.
nn
.
group_conv2d_nchw
:
[
topi
.
generic
.
schedule_group_conv2d_nchw
],
topi
.
nn
.
conv2d_transpose_nchw
:
[
topi
.
generic
.
schedule_conv2d_transpose_nchw
],
topi
.
nn
.
dense
:
[
topi
.
generic
.
schedule_dense
],
}
...
...
@@ -143,6 +146,15 @@ class TaskExtractEnv:
s
=
topi
.
generic
.
schedule_depthwise_conv2d_nchw
([
C
])
return
s
,
[
A
,
W
,
C
]
@register
(
"topi_nn_group_conv2d_nchw"
)
def
_topi_nn_group_conv2d_nchw
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
args
=
deserialize_args
(
args
)
A
,
W
=
args
[:
2
]
C
=
topi
.
nn
.
group_conv2d_nchw
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
schedule_group_conv2d_nchw
([
C
])
return
s
,
[
A
,
W
,
C
]
@register
(
"topi_nn_conv2d_transpose_nchw"
)
def
_topi_nn_conv2d_transpose_nchw
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
...
...
topi/python/topi/cuda/__init__.py
View file @
5c410c4c
...
...
@@ -2,10 +2,11 @@
"""CUDA specific declaration and schedules."""
from
__future__
import
absolute_import
as
_abs
from
.
import
conv2d
,
depthwise_conv2d
,
conv2d_transpose_nchw
from
.
import
conv2d
,
depthwise_conv2d
,
conv2d_transpose_nchw
,
group_conv2d_nchw
from
.conv2d_hwcn
import
schedule_conv2d_hwcn
from
.depthwise_conv2d
import
schedule_depthwise_conv2d_backward_input_nhwc
from
.depthwise_conv2d
import
schedule_depthwise_conv2d_backward_weight_nhwc
from
.group_conv2d_nchw
import
schedule_conv2d_nchw_cuda
from
.reduction
import
schedule_reduce
from
.softmax
import
schedule_softmax
from
.injective
import
schedule_injective
,
schedule_elemwise
,
schedule_broadcast
...
...
topi/python/topi/cuda/group_conv2d_nchw.py
0 → 100644
View file @
5c410c4c
# pylint: disable=invalid-name
"""The template for cuda group_conv2d_nchw"""
import
tvm
from
tvm
import
autotvm
from
.injective
import
_schedule_injective
from
.tensor_intrin
import
dp4a
from
..nn.pad
import
pad
from
..nn.util
import
get_pad_tuple
from
..util
import
traverse_inline
,
get_const_tuple
,
get_const_int
from
..
import
nn
,
generic
@autotvm.register_topi_compute
(
nn
.
group_conv2d_nchw
,
[
'cuda'
,
'gpu'
],
[
'direct'
,
'int8'
])
def
group_conv2d_nchw_cuda
(
cfg
,
data
,
kernel
,
stride
,
padding
,
dilation
,
groups
,
out_dtype
=
'float32'
):
"""Group convolution operator in NCHW layout.
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width] or
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
kernel : tvm.Tensor
4-D with shape [num_filter, in_channel // groups, filter_height, filter_width] or
6-D with shape [num_filter_chunk, in_channel_chunk // groups, filter_height,
filter_width, num_filter_block, in_channel_block]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
dilation : int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
groups : int
number of groups
out_dtype : str
The output type. This is used for mixed precision.
Returns
-------
Output : tvm.Tensor
5-D with shape [batch, out_channel, out_height, out_width, out_channel_block]
"""
ic_block_factor
=
4
oc_block_factor
=
4
pre_computed
=
len
(
kernel
.
shape
)
==
6
if
not
pre_computed
:
batch
,
channels
,
height
,
width
=
get_const_tuple
(
data
.
shape
)
out_channels
,
in_channels
,
kernel_h
,
kernel_w
=
get_const_tuple
(
kernel
.
shape
)
assert
channels
%
groups
==
0
,
"input channels must divide group size"
assert
out_channels
%
groups
==
0
,
"output channels must divide group size"
assert
channels
%
ic_block_factor
==
0
,
\
"Number of input channels per group must divide {}"
.
format
(
ic_block_factor
)
assert
out_channels
%
4
==
0
,
\
"Number of output channels per group must divide {}"
.
format
(
oc_block_factor
)
packed_data
=
tvm
.
compute
((
batch
,
channels
//
ic_block_factor
,
height
,
width
,
ic_block_factor
),
lambda
n
,
c
,
h
,
w
,
vc
:
data
[
n
,
c
*
ic_block_factor
+
vc
,
h
,
w
],
name
=
"packed_data"
)
packed_kernel
=
tvm
.
compute
(
(
out_channels
//
oc_block_factor
,
in_channels
//
ic_block_factor
,
kernel_h
,
kernel_w
,
oc_block_factor
,
ic_block_factor
),
lambda
oc_chunk
,
ic_chunk
,
kh
,
kw
,
oc_block
,
ic_block
:
kernel
[
oc_chunk
*
oc_block_factor
+
oc_block
,
ic_chunk
*
ic_block_factor
+
ic_block
,
kh
,
kw
],
name
=
"packed_kernel"
)
else
:
packed_data
=
data
packed_kernel
=
kernel
batch
,
ic_chunk
,
in_height
,
in_width
,
_
=
get_const_tuple
(
packed_data
.
shape
)
oc_chunk
,
_
,
kernel_h
,
kernel_w
,
oc_block
,
ic_block
=
get_const_tuple
(
packed_kernel
.
shape
)
if
isinstance
(
stride
,
int
):
stride_h
=
stride_w
=
stride
else
:
stride_h
,
stride_w
=
stride
if
isinstance
(
dilation
,
int
):
dilation_h
=
dilation_w
=
dilation
else
:
dilation_h
,
dilation_w
=
dilation
pad_top
,
pad_left
,
pad_down
,
pad_right
=
get_pad_tuple
(
padding
,
(
kernel_h
,
kernel_w
))
# compute graph
pad_before
=
[
0
,
0
,
pad_top
,
pad_left
,
0
]
pad_after
=
[
0
,
0
,
pad_down
,
pad_right
,
0
]
pad_data
=
pad
(
packed_data
,
pad_before
,
pad_after
,
name
=
"pad_data"
)
# compute the output shape
out_height
=
(
in_height
-
(
kernel_h
-
1
)
*
dilation_h
-
1
+
pad_top
+
pad_down
)
//
stride_h
+
1
out_width
=
(
in_width
-
(
kernel_w
-
1
)
*
dilation_w
-
1
+
pad_left
+
pad_right
)
//
stride_w
+
1
oshape
=
(
batch
,
oc_chunk
,
out_height
,
out_width
,
oc_block
)
icc
=
tvm
.
reduce_axis
((
0
,
ic_chunk
//
groups
),
name
=
'ic_chunk'
)
icb
=
tvm
.
reduce_axis
((
0
,
ic_block_factor
),
name
=
'ic_block'
)
kh
=
tvm
.
reduce_axis
((
0
,
kernel_h
),
name
=
'kh'
)
kw
=
tvm
.
reduce_axis
((
0
,
kernel_w
),
name
=
'kw'
)
conv
=
tvm
.
compute
(
oshape
,
lambda
n
,
occ
,
oh
,
ow
,
ocb
:
tvm
.
sum
(
pad_data
[
n
,
occ
//
(
oc_chunk
//
groups
)
*
(
ic_chunk
//
groups
)
+
icc
,
oh
*
stride_h
+
kh
*
dilation_h
,
ow
*
stride_w
+
kw
*
dilation_w
,
icb
]
.
astype
(
'int32'
)
*
packed_kernel
[
occ
,
icc
,
kh
,
kw
,
ocb
,
icb
]
.
astype
(
'int32'
),
axis
=
[
icc
,
kh
,
kw
,
icb
]))
output
=
tvm
.
compute
(
oshape
,
lambda
*
index
:
conv
(
*
index
)
.
astype
(
out_dtype
),
tag
=
'group_conv2d_NCHWc_int8'
)
num_flop
=
batch
*
oc_chunk
*
oc_block
*
out_height
*
out_width
*
\
ic_chunk
*
ic_block
*
kernel_h
*
kernel_w
*
2
//
groups
cfg
.
add_flop
(
num_flop
)
return
output
_dp4a
=
dp4a
(
'shared'
,
'shared'
,
'local'
)
def
schedule_group_conv2d_NCHWc_int8
(
cfg
,
s
,
output
):
"""Schedule group conv2d int8 NCHWc template"""
workload
=
output
.
op
.
attrs
[
"workload"
]
groups
=
get_const_int
(
workload
[
6
])
conv
=
output
.
op
.
input_tensors
[
0
]
packed_data
,
packed_kernel
=
conv
.
op
.
input_tensors
if
isinstance
(
packed_data
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
"pad"
in
packed_data
.
op
.
tag
:
pad_data
=
packed_data
packed_data
=
pad_data
.
op
.
input_tensors
[
0
]
else
:
pad_data
=
packed_data
if
autotvm
.
GLOBAL_SCOPE
.
in_tuning
:
# skip this part during tuning to make records accurate
# this part will be pre-computed during NNVM's pre-compute optimization pass
s
[
packed_data
]
.
pragma
(
s
[
packed_data
]
.
op
.
axis
[
0
],
"debug_skip_region"
)
s
[
packed_kernel
]
.
pragma
(
s
[
packed_kernel
]
.
op
.
axis
[
0
],
"debug_skip_region"
)
else
:
if
isinstance
(
packed_kernel
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
\
packed_kernel
.
name
==
'packed_kernel'
:
# data and kernel are not pre-computed, schedule layout transform here
_schedule_injective
(
packed_data
.
op
,
s
)
_schedule_injective
(
packed_kernel
.
op
,
s
)
if
pad_data
!=
packed_data
:
s
[
pad_data
]
.
compute_inline
()
# create cache stage
AA
=
s
.
cache_read
(
pad_data
,
'shared'
,
[
conv
])
WW
=
s
.
cache_read
(
packed_kernel
,
'shared'
,
[
conv
])
s
[
conv
]
.
set_scope
(
'local'
)
# handle bias
if
output
.
op
not
in
s
.
outputs
:
s
[
output
]
.
compute_inline
()
output
=
s
.
outputs
[
0
]
.
output
(
0
)
oc_chunk
=
get_const_int
(
output
.
shape
[
1
])
# tile and bind spatial axes
n
,
f
,
y
,
x
,
c
=
s
[
output
]
.
op
.
axis
cfg
.
define_split
(
"tile_n"
,
n
,
num_outputs
=
4
)
cfg
.
define_split
(
"tile_g"
,
cfg
.
axis
(
groups
),
num_outputs
=
2
)
cfg
.
define_split
(
"tile_f"
,
cfg
.
axis
(
oc_chunk
//
groups
),
num_outputs
=
4
)
cfg
.
define_split
(
"tile_y"
,
y
,
num_outputs
=
4
)
cfg
.
define_split
(
"tile_x"
,
x
,
num_outputs
=
4
)
# this is the scope to attach global config inside this kernel
kernel_scope
,
n
=
s
[
output
]
.
split
(
n
,
nparts
=
1
)
g
,
f
=
s
[
output
]
.
split
(
f
,
nparts
=
groups
)
s
[
output
]
.
bind
(
n
,
tvm
.
thread_axis
(
'blockIdx.z'
))
bn
,
vn
,
tn
,
ni
=
cfg
[
"tile_n"
]
.
apply
(
s
,
output
,
n
)
bg
,
vg
=
cfg
[
"tile_g"
]
.
apply
(
s
,
output
,
g
)
bf
,
vf
,
tf
,
fi
=
cfg
[
"tile_f"
]
.
apply
(
s
,
output
,
f
)
by
,
vy
,
ty
,
yi
=
cfg
[
"tile_y"
]
.
apply
(
s
,
output
,
y
)
bx
,
vx
,
tx
,
xi
=
cfg
[
"tile_x"
]
.
apply
(
s
,
output
,
x
)
s
[
output
]
.
reorder
(
bn
,
bg
,
bf
,
by
,
bx
,
vn
,
vg
,
vf
,
vy
,
vx
,
tn
,
tf
,
ty
,
tx
,
ni
,
fi
,
yi
,
xi
)
s
[
output
]
.
bind
(
bn
,
tvm
.
thread_axis
(
"blockIdx.z"
))
s
[
output
]
.
bind
(
s
[
output
]
.
fuse
(
bg
,
bf
),
tvm
.
thread_axis
(
"blockIdx.y"
))
s
[
output
]
.
bind
(
s
[
output
]
.
fuse
(
by
,
bx
),
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
output
]
.
bind
(
vn
,
tvm
.
thread_axis
(
"vthread"
))
s
[
output
]
.
bind
(
vg
,
tvm
.
thread_axis
(
"vthread"
))
s
[
output
]
.
bind
(
vf
,
tvm
.
thread_axis
(
"vthread"
))
s
[
output
]
.
bind
(
vy
,
tvm
.
thread_axis
(
"vthread"
))
s
[
output
]
.
bind
(
vx
,
tvm
.
thread_axis
(
"vthread"
))
cfg
.
define_knob
(
"fuse_yx"
,
[
0
,
1
])
# fuse ty,tx or tn,tf
if
cfg
[
"fuse_yx"
]
.
val
:
s
[
output
]
.
bind
(
tn
,
tvm
.
thread_axis
(
"threadIdx.z"
))
s
[
output
]
.
bind
(
tf
,
tvm
.
thread_axis
(
"threadIdx.y"
))
tyx
=
s
[
output
]
.
fuse
(
ty
,
tx
)
s
[
output
]
.
bind
(
tyx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
conv
]
.
compute_at
(
s
[
output
],
tyx
)
# number of threads
n_tz
=
cfg
[
"tile_n"
]
.
size
[
2
]
n_ty
=
cfg
[
"tile_f"
]
.
size
[
2
]
n_tx
=
cfg
[
"tile_y"
]
.
size
[
2
]
*
cfg
[
"tile_x"
]
.
size
[
2
]
else
:
s
[
output
]
.
bind
(
tn
,
tvm
.
thread_axis
(
"threadIdx.z"
))
s
[
output
]
.
bind
(
s
[
output
]
.
fuse
(
tn
,
tf
),
tvm
.
thread_axis
(
"threadIdx.z"
))
s
[
output
]
.
bind
(
ty
,
tvm
.
thread_axis
(
"threadIdx.y"
))
s
[
output
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
conv
]
.
compute_at
(
s
[
output
],
tx
)
# number of threads
n_tz
=
cfg
[
"tile_n"
]
.
size
[
2
]
*
cfg
[
"tile_f"
]
.
size
[
2
]
n_ty
=
cfg
[
"tile_y"
]
.
size
[
2
]
n_tx
=
cfg
[
"tile_x"
]
.
size
[
2
]
# tile and bind reduction axes
n
,
f
,
y
,
x
,
c
=
s
[
conv
]
.
op
.
axis
rc
,
ry
,
rx
,
rc_block
=
s
[
conv
]
.
op
.
reduce_axis
cfg
.
define_split
(
"tile_rc"
,
cfg
.
axis
(
rc
),
num_outputs
=
2
)
cfg
.
define_split
(
"tile_ry"
,
cfg
.
axis
(
ry
),
num_outputs
=
2
)
cfg
.
define_split
(
"tile_rx"
,
cfg
.
axis
(
rx
),
num_outputs
=
2
)
rco
,
rci
=
cfg
[
'tile_rc'
]
.
apply
(
s
,
conv
,
rc
)
ryo
,
ryi
=
cfg
[
'tile_ry'
]
.
apply
(
s
,
conv
,
ry
)
rxo
,
rxi
=
cfg
[
'tile_rx'
]
.
apply
(
s
,
conv
,
rx
)
s
[
conv
]
.
reorder
(
rco
,
ryo
,
rxo
,
rci
,
ryi
,
rxi
,
n
,
f
,
y
,
x
,
c
,
rc_block
)
_
,
rc_block
=
s
[
conv
]
.
split
(
rc_block
,
factor
=
4
)
s
[
conv
]
.
tensorize
(
rc_block
,
_dp4a
)
s
[
AA
]
.
compute_at
(
s
[
conv
],
rxo
)
s
[
WW
]
.
compute_at
(
s
[
conv
],
rxo
)
# cooperative fetching
for
load
in
[
AA
,
WW
]:
c
=
s
[
load
]
.
op
.
axis
[
-
1
]
c_outer
,
c
=
s
[
load
]
.
split
(
c
,
factor
=
4
)
s
[
load
]
.
vectorize
(
c
)
fused
=
s
[
load
]
.
op
.
axis
[:
-
1
]
+
[
c_outer
]
fused
=
s
[
load
]
.
fuse
(
*
fused
)
fused
,
tx
=
s
[
load
]
.
split
(
fused
,
factor
=
n_tx
)
fused
,
ty
=
s
[
load
]
.
split
(
fused
,
factor
=
n_ty
)
fused
,
tz
=
s
[
load
]
.
split
(
fused
,
factor
=
n_tz
)
s
[
load
]
.
bind
(
tz
,
tvm
.
thread_axis
(
"threadIdx.z"
))
s
[
load
]
.
bind
(
ty
,
tvm
.
thread_axis
(
"threadIdx.y"
))
s
[
load
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
# double buffer
cfg
.
define_knob
(
'AA_double_buffer'
,
[
0
,
1
])
cfg
.
define_knob
(
'WW_double_buffer'
,
[
0
,
1
])
if
cfg
[
'AA_double_buffer'
]
.
val
:
s
[
AA
]
.
double_buffer
()
if
cfg
[
'WW_double_buffer'
]
.
val
:
s
[
WW
]
.
double_buffer
()
# unroll
cfg
.
define_knob
(
"auto_unroll_max_step"
,
[
0
,
512
,
1500
])
s
[
output
]
.
pragma
(
kernel_scope
,
'auto_unroll_max_step'
,
cfg
[
'auto_unroll_max_step'
]
.
val
)
s
[
output
]
.
pragma
(
kernel_scope
,
'unroll_explicit'
,
False
)
return
s
@autotvm.register_topi_schedule
(
generic
.
schedule_group_conv2d_nchw
,
[
"cuda"
,
"gpu"
],
[
"direct"
,
"int8"
])
def
schedule_conv2d_nchw_cuda
(
cfg
,
outs
):
"""TOPI schedule callback of group conv2d for cuda gpu
Parameters
----------
cfg: ConfigEntity
The config for this template
outs: Array of Tensor
The computation graph description of conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for group conv2d.
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
def
_callback
(
op
):
if
op
.
tag
==
"group_conv2d_NCHWc_int8"
:
schedule_group_conv2d_NCHWc_int8
(
cfg
,
s
,
op
.
output
(
0
))
traverse_inline
(
s
,
outs
[
0
]
.
op
,
_callback
)
return
s
topi/python/topi/generic/nn.py
View file @
5c410c4c
...
...
@@ -173,6 +173,25 @@ def schedule_depthwise_conv2d_nhwc(outs):
"""
return
_default_schedule
(
outs
,
False
)
@tvm.target.generic_func
def
schedule_group_conv2d_nchw
(
outs
):
"""Schedule for conv2d_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of group_conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return
_default_schedule
(
outs
,
False
)
@tvm.target.generic_func
def
schedule_bitserial_conv2d_nchw
(
outs
):
"""Schedule for bitserial_conv2d_nchw
...
...
topi/python/topi/nn/conv2d.py
View file @
5c410c4c
...
...
@@ -403,3 +403,80 @@ def conv2d_winograd_without_weight_transform(input, filter, strides, padding, di
4-D with shape [batch, out_height, out_width, out_channel]
"""
raise
ValueError
(
"missing register for topi.nn.conv2d_winograd_without_weight_transform"
)
@tvm.target.generic_func
def
group_conv2d_nchw
(
Input
,
Filter
,
stride
,
padding
,
dilation
,
groups
,
out_dtype
=
None
):
"""Group convolution operator in NCHW layout.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
Filter : tvm.Tensor
4-D with shape [num_filter, in_channel // groups, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
dilation : int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
groups : int
number of groups
out_dtype : str
The output type. This is used for mixed precision.
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
if
out_dtype
is
None
:
out_dtype
=
Input
.
dtype
assert
isinstance
(
stride
,
int
)
or
len
(
stride
)
==
2
assert
isinstance
(
dilation
,
int
)
or
len
(
dilation
)
==
2
if
isinstance
(
stride
,
int
):
stride_h
=
stride_w
=
stride
else
:
stride_h
,
stride_w
=
stride
if
isinstance
(
dilation
,
int
):
dilation_h
=
dilation_w
=
dilation
else
:
dilation_h
,
dilation_w
=
dilation
batch
,
in_channel
,
in_height
,
in_width
=
get_const_tuple
(
Input
.
shape
)
num_filter
,
_
,
kernel_h
,
kernel_w
=
get_const_tuple
(
Filter
.
shape
)
assert
in_channel
%
groups
==
0
,
"input channels must divide group size"
assert
num_filter
%
groups
==
0
,
"output channels must divide group size"
pad_top
,
pad_left
,
pad_down
,
pad_right
=
get_pad_tuple
(
padding
,
(
kernel_h
,
kernel_w
))
# compute the output shape
out_channel
=
num_filter
out_height
=
simplify
(
(
in_height
-
(
kernel_h
-
1
)
*
dilation_h
-
1
+
pad_top
+
pad_down
)
//
stride_h
+
1
)
out_width
=
simplify
(
(
in_width
-
(
kernel_w
-
1
)
*
dilation_w
-
1
+
pad_left
+
pad_right
)
//
stride_w
+
1
)
# compute graph
pad_before
=
[
0
,
0
,
pad_top
,
pad_left
]
pad_after
=
[
0
,
0
,
pad_down
,
pad_right
]
temp
=
pad
(
Input
,
pad_before
,
pad_after
,
name
=
"pad_temp"
)
rc
=
tvm
.
reduce_axis
((
0
,
in_channel
//
groups
),
name
=
'rc'
)
ry
=
tvm
.
reduce_axis
((
0
,
kernel_h
),
name
=
'ry'
)
rx
=
tvm
.
reduce_axis
((
0
,
kernel_w
),
name
=
'rx'
)
return
tvm
.
compute
(
(
batch
,
out_channel
,
out_height
,
out_width
),
lambda
nn
,
ff
,
yy
,
xx
:
tvm
.
sum
(
temp
[
nn
,
ff
//
(
num_filter
//
groups
)
*
(
in_channel
//
groups
)
+
rc
,
yy
*
stride_h
+
ry
*
dilation_h
,
xx
*
stride_w
+
rx
*
dilation_w
]
.
astype
(
out_dtype
)
*
Filter
[
ff
,
rc
,
ry
,
rx
]
.
astype
(
out_dtype
),
axis
=
[
rc
,
ry
,
rx
]),
tag
=
"conv2d_nchw"
)
topi/python/topi/testing/conv2d_nchw_python.py
View file @
5c410c4c
...
...
@@ -4,8 +4,8 @@ import numpy as np
import
scipy.signal
def
conv2d_nchw_python
(
a_np
,
w_np
,
stride
,
padding
):
"""Convolution operator in
HWCN
layout.
def
_
conv2d_nchw_python
(
a_np
,
w_np
,
stride
,
padding
):
"""Convolution operator in
NCHW
layout.
Parameters
----------
...
...
@@ -66,3 +66,36 @@ def conv2d_nchw_python(a_np, w_np, stride, padding):
apad
,
np
.
rot90
(
np
.
rot90
(
w_np
[
f
,
c
])),
mode
=
'valid'
)
b_np
[
n
,
f
]
+=
out
[::
stride_h
,
::
stride_w
]
return
b_np
def
conv2d_nchw_python
(
a_np
,
w_np
,
stride
,
padding
,
groups
=
1
):
"""Convolution operator in NCHW layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
w_np : numpy.ndarray
4-D with shape [num_filter, in_channel // groups, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str or a list/tuple of two ints
Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width]
groups : int
Number of groups
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
a_slices
=
np
.
array_split
(
a_np
,
groups
,
axis
=
1
)
w_slices
=
np
.
array_split
(
w_np
,
groups
,
axis
=
0
)
b_slices
=
[
_conv2d_nchw_python
(
a_slice
,
w_slice
,
stride
,
padding
)
for
a_slice
,
w_slice
in
zip
(
a_slices
,
w_slices
)]
b_np
=
np
.
concatenate
(
b_slices
,
axis
=
1
)
return
b_np
topi/tests/python/common.py
View file @
5c410c4c
"""Common utility for topi test"""
from
tvm
import
autotvm
from
tvm.autotvm.task.space
import
FallbackConfigEntity
def
get_all_backend
():
"""return all supported target
...
...
@@ -10,3 +14,14 @@ def get_all_backend():
"""
return
[
'llvm'
,
'cuda'
,
'opencl'
,
'metal'
,
'rocm'
,
'vulkan'
,
'nvptx'
,
'llvm -device=arm_cpu'
,
'opencl -device=mali'
,
'aocl_sw_emu'
]
class
NCHWcInt8Fallback
(
autotvm
.
FallbackContext
):
def
_query_inside
(
self
,
target
,
workload
):
key
=
(
target
,
workload
)
if
key
in
self
.
memory
:
return
self
.
memory
[
key
]
cfg
=
FallbackConfigEntity
()
cfg
.
template_key
=
'int8'
self
.
memory
[
key
]
=
cfg
return
cfg
topi/tests/python/test_topi_conv2d_int8.py
View file @
5c410c4c
...
...
@@ -9,7 +9,7 @@ import topi.testing
from
tvm.contrib.pickle_memoize
import
memoize
from
topi.util
import
get_const_tuple
from
common
import
get_all_backend
from
common
import
get_all_backend
,
NCHWcInt8Fallback
oc_block_factor
=
4
...
...
@@ -88,17 +88,6 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str
check_device
(
device
)
class
NCHWcInt8Fallback
(
autotvm
.
FallbackContext
):
def
_query_inside
(
self
,
target
,
workload
):
key
=
(
target
,
workload
)
if
key
in
self
.
memory
:
return
self
.
memory
[
key
]
cfg
=
FallbackConfigEntity
()
cfg
.
template_key
=
'int8'
self
.
memory
[
key
]
=
cfg
return
cfg
def
test_conv2d_nchw
():
with
NCHWcInt8Fallback
():
# ResNet18 workloads where channels in / out are multiple of oc_block_factor
...
...
topi/tests/python/test_topi_group_conv2d.py
0 → 100644
View file @
5c410c4c
"""Example code to do group convolution."""
import
numpy
as
np
import
tvm
from
tvm
import
autotvm
from
tvm.autotvm.task.space
import
FallbackConfigEntity
import
topi
import
topi.testing
from
tvm.contrib.pickle_memoize
import
memoize
from
topi.util
import
get_const_tuple
from
common
import
get_all_backend
,
NCHWcInt8Fallback
def
verify_group_conv2d_nchw
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
,
groups
,
add_bias
=
False
,
add_relu
=
False
):
print
(
"Workload: (
%
d,
%
d,
%
d,
%
d,
%
d,
%
d,
%
d,
%
d,
%
d)"
%
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
,
groups
))
in_height
=
in_width
=
in_size
A
=
tvm
.
placeholder
((
batch
,
in_channel
,
in_height
,
in_width
),
name
=
'A'
)
W
=
tvm
.
placeholder
((
num_filter
,
in_channel
//
groups
,
kernel
,
kernel
),
name
=
'W'
)
bias
=
tvm
.
placeholder
((
num_filter
,
1
,
1
),
name
=
'bias'
)
a_shape
=
get_const_tuple
(
A
.
shape
)
w_shape
=
get_const_tuple
(
W
.
shape
)
bias_shape
=
get_const_tuple
(
bias
.
shape
)
dtype
=
A
.
dtype
@memoize
(
"topi.tests.test_topi_group_conv2d.verify_group_conv2d_nchw"
)
def
get_ref_data
():
a_np
=
np
.
random
.
uniform
(
size
=
a_shape
)
.
astype
(
dtype
)
w_np
=
np
.
random
.
uniform
(
size
=
w_shape
)
.
astype
(
dtype
)
b_np
=
np
.
random
.
uniform
(
size
=
bias_shape
)
.
astype
(
dtype
)
dw_np
=
topi
.
testing
.
dilate_python
(
w_np
,
(
1
,
1
,
dilation
,
dilation
))
c_np
=
topi
.
testing
.
conv2d_nchw_python
(
a_np
,
dw_np
,
stride
,
padding
,
groups
)
.
astype
(
dtype
)
if
add_bias
:
b_np
=
np
.
random
.
uniform
(
size
=
bias_shape
)
.
astype
(
dtype
)
c_np
+=
b_np
if
add_relu
:
c_np
=
np
.
maximum
(
c_np
,
0
)
return
a_np
,
w_np
,
b_np
,
c_np
a_np
,
w_np
,
b_np
,
c_np
=
get_ref_data
()
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
if
device
==
"cuda"
and
not
tvm
.
contrib
.
nvcc
.
have_int8
(
ctx
.
compute_version
):
print
(
"Skip because int8 intrinsics are not available"
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
C
=
topi
.
nn
.
group_conv2d_nchw
(
A
,
W
,
stride
,
padding
,
dilation
,
groups
,
out_dtype
=
dtype
)
if
add_bias
:
C
=
topi
.
add
(
C
,
bias
)
if
add_relu
:
C
=
topi
.
nn
.
relu
(
C
)
s
=
topi
.
generic
.
schedule_group_conv2d_nchw
([
C
])
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
w
=
tvm
.
nd
.
array
(
w_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
b_np
,
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
C
.
shape
),
dtype
=
C
.
dtype
),
ctx
)
if
add_bias
:
func
=
tvm
.
build
(
s
,
[
A
,
W
,
bias
,
C
],
device
,
name
=
"relu_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d"
%
\
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
,
groups
))
func
(
a
,
w
,
b
,
c
)
else
:
func
=
tvm
.
build
(
s
,
[
A
,
W
,
C
],
device
,
name
=
"relu_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d"
%
\
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
,
groups
))
func
(
a
,
w
,
c
)
tvm
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
c_np
,
rtol
=
1e-5
)
for
device
in
[
"llvm"
]:
check_device
(
device
)
oc_block_factor
=
4
def
verify_group_conv2d_NCHWc_int8
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
,
groups
,
add_bias
=
False
,
add_relu
=
False
):
print
(
"Workload: (
%
d,
%
d,
%
d,
%
d,
%
d,
%
d,
%
d,
%
d,
%
d)"
%
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
,
groups
))
in_height
=
in_width
=
in_size
A
=
tvm
.
placeholder
((
batch
,
in_channel
,
in_height
,
in_width
),
name
=
'A'
,
dtype
=
'int8'
)
W
=
tvm
.
placeholder
((
num_filter
,
in_channel
//
groups
,
kernel
,
kernel
),
name
=
'W'
,
dtype
=
'int8'
)
bias
=
tvm
.
placeholder
((
num_filter
//
oc_block_factor
,
1
,
1
,
oc_block_factor
),
name
=
'bias'
,
dtype
=
'int8'
)
a_shape
=
get_const_tuple
(
A
.
shape
)
w_shape
=
get_const_tuple
(
W
.
shape
)
bias_shape
=
get_const_tuple
(
bias
.
shape
)
dtype
=
A
.
dtype
@memoize
(
"topi.tests.test_topi_group_conv2d.verify_group_conv2d_NCHWc_int8"
)
def
get_ref_data
():
a_np
=
np
.
random
.
randint
(
low
=-
128
,
high
=
127
,
size
=
a_shape
)
.
astype
(
dtype
)
w_np
=
np
.
random
.
randint
(
low
=-
128
,
high
=
128
,
size
=
w_shape
)
.
astype
(
dtype
)
b_np
=
np
.
random
.
uniform
(
size
=
bias_shape
)
.
astype
(
dtype
)
dw_np
=
topi
.
testing
.
dilate_python
(
w_np
,
(
1
,
1
,
dilation
,
dilation
))
c_np
=
topi
.
testing
.
conv2d_nchw_python
(
a_np
,
dw_np
,
stride
,
padding
,
groups
)
.
astype
(
dtype
)
# convert to NCHWc
_
,
_
,
out_height
,
out_width
=
c_np
.
shape
c_np
=
c_np
.
reshape
((
batch
,
num_filter
//
oc_block_factor
,
oc_block_factor
,
\
out_height
,
out_width
))
.
transpose
(
0
,
1
,
3
,
4
,
2
)
if
add_bias
:
b_np
=
np
.
random
.
uniform
(
size
=
bias_shape
)
.
astype
(
dtype
)
c_np
+=
b_np
if
add_relu
:
c_np
=
np
.
maximum
(
c_np
,
0
)
return
a_np
,
w_np
,
b_np
,
c_np
a_np
,
w_np
,
b_np
,
c_np
=
get_ref_data
()
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
if
device
==
"cuda"
and
not
tvm
.
contrib
.
nvcc
.
have_int8
(
ctx
.
compute_version
):
print
(
"Skip because int8 intrinsics are not available"
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
C
=
topi
.
nn
.
group_conv2d_nchw
(
A
,
W
,
stride
,
padding
,
dilation
,
groups
,
out_dtype
=
dtype
)
if
add_bias
:
C
=
topi
.
add
(
C
,
bias
)
if
add_relu
:
C
=
topi
.
nn
.
relu
(
C
)
s
=
topi
.
generic
.
schedule_group_conv2d_nchw
([
C
])
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
w
=
tvm
.
nd
.
array
(
w_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
b_np
,
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
C
.
shape
),
dtype
=
C
.
dtype
),
ctx
)
if
add_bias
:
func
=
tvm
.
build
(
s
,
[
A
,
W
,
bias
,
C
],
device
,
name
=
"relu_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d"
%
\
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
,
groups
))
func
(
a
,
w
,
b
,
c
)
else
:
func
=
tvm
.
build
(
s
,
[
A
,
W
,
C
],
device
,
name
=
"relu_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d_
%
d"
%
\
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
,
groups
))
func
(
a
,
w
,
c
)
tvm
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
c_np
,
rtol
=
1e-5
)
for
device
in
[
"cuda"
]:
check_device
(
device
)
def
test_group_conv2d_nchw
():
# ResNeXt-50 workload
verify_group_conv2d_nchw
(
1
,
128
,
56
,
128
,
3
,
1
,
1
,
1
,
32
)
verify_group_conv2d_nchw
(
1
,
256
,
56
,
256
,
3
,
2
,
1
,
1
,
32
)
verify_group_conv2d_nchw
(
1
,
256
,
28
,
256
,
3
,
1
,
1
,
1
,
32
)
verify_group_conv2d_nchw
(
1
,
512
,
28
,
512
,
3
,
2
,
1
,
1
,
32
)
verify_group_conv2d_nchw
(
1
,
512
,
14
,
512
,
3
,
1
,
1
,
1
,
32
)
verify_group_conv2d_nchw
(
1
,
1024
,
14
,
1024
,
3
,
2
,
1
,
1
,
32
)
verify_group_conv2d_nchw
(
1
,
1024
,
7
,
1024
,
3
,
1
,
1
,
1
,
32
)
# bias, relu
verify_group_conv2d_nchw
(
1
,
128
,
56
,
128
,
3
,
1
,
1
,
1
,
32
,
add_relu
=
True
)
verify_group_conv2d_nchw
(
1
,
128
,
56
,
128
,
3
,
1
,
1
,
1
,
32
,
add_bias
=
True
)
verify_group_conv2d_nchw
(
1
,
128
,
56
,
128
,
3
,
1
,
1
,
1
,
32
,
add_relu
=
True
,
add_bias
=
True
)
# dilation
verify_group_conv2d_NCHWc_int8
(
1
,
128
,
56
,
128
,
3
,
1
,
1
,
2
,
32
)
# batch size
verify_group_conv2d_nchw
(
2
,
128
,
56
,
128
,
3
,
1
,
1
,
1
,
32
)
verify_group_conv2d_nchw
(
9
,
128
,
56
,
128
,
3
,
1
,
1
,
1
,
32
)
def
test_group_conv2d_NCHWc_int8
():
with
NCHWcInt8Fallback
():
# ResNeXt-50 workload
verify_group_conv2d_NCHWc_int8
(
1
,
128
,
56
,
128
,
3
,
1
,
1
,
1
,
32
)
verify_group_conv2d_NCHWc_int8
(
1
,
256
,
56
,
256
,
3
,
2
,
1
,
1
,
32
)
verify_group_conv2d_NCHWc_int8
(
1
,
256
,
28
,
256
,
3
,
1
,
1
,
1
,
32
)
verify_group_conv2d_NCHWc_int8
(
1
,
512
,
28
,
512
,
3
,
2
,
1
,
1
,
32
)
verify_group_conv2d_NCHWc_int8
(
1
,
512
,
14
,
512
,
3
,
1
,
1
,
1
,
32
)
verify_group_conv2d_NCHWc_int8
(
1
,
1024
,
14
,
1024
,
3
,
2
,
1
,
1
,
32
)
verify_group_conv2d_NCHWc_int8
(
1
,
1024
,
7
,
1024
,
3
,
1
,
1
,
1
,
32
)
# bias, relu
verify_group_conv2d_NCHWc_int8
(
1
,
128
,
56
,
128
,
3
,
1
,
1
,
1
,
32
,
add_relu
=
True
)
verify_group_conv2d_NCHWc_int8
(
1
,
128
,
56
,
128
,
3
,
1
,
1
,
1
,
32
,
add_bias
=
True
)
verify_group_conv2d_NCHWc_int8
(
1
,
128
,
56
,
128
,
3
,
1
,
1
,
1
,
32
,
add_relu
=
True
,
add_bias
=
True
)
# dilation
verify_group_conv2d_NCHWc_int8
(
1
,
128
,
56
,
128
,
3
,
1
,
1
,
2
,
32
)
# batch size
verify_group_conv2d_NCHWc_int8
(
2
,
128
,
56
,
128
,
3
,
1
,
1
,
1
,
32
)
verify_group_conv2d_NCHWc_int8
(
9
,
128
,
56
,
128
,
3
,
1
,
1
,
1
,
32
)
if
__name__
==
"__main__"
:
test_group_conv2d_nchw
()
test_group_conv2d_NCHWc_int8
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment