Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
2005f852
Commit
2005f852
authored
Nov 01, 2018
by
Wuwei Lin
Committed by
Tianqi Chen
Oct 31, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI] Add dilation argument to conv2d and depthwise_conv2d (#1970)
parent
7a3e389b
Hide whitespace changes
Inline
Side-by-side
Showing
22 changed files
with
312 additions
and
331 deletions
+312
-331
nnvm/python/nnvm/top/nn.py
+8
-16
python/tvm/autotvm/tophub.py
+7
-7
tests/python/unittest/test_lang_tensor_overload_op.py
+2
-1
topi/python/topi/arm_cpu/conv2d.py
+35
-9
topi/python/topi/cuda/conv2d.py
+14
-29
topi/python/topi/cuda/conv2d_int8.py
+27
-92
topi/python/topi/cuda/conv2d_winograd.py
+52
-37
topi/python/topi/generic/nn.py
+0
-18
topi/python/topi/mali/conv2d.py
+21
-6
topi/python/topi/nn/conv2d.py
+61
-49
topi/python/topi/nn/depthwise_conv2d.py
+32
-6
topi/python/topi/rocm/conv2d.py
+6
-23
topi/python/topi/x86/conv2d.py
+25
-11
topi/tests/python/test_topi_conv2d_hwcn.py
+1
-2
topi/tests/python/test_topi_conv2d_int8.py
+2
-3
topi/tests/python/test_topi_conv2d_nchw.py
+3
-4
topi/tests/python/test_topi_conv2d_nhwc.py
+3
-4
topi/tests/python/test_topi_conv2d_winograd.py
+2
-3
topi/tests/python/test_topi_depthwise_conv2d.py
+8
-9
tutorials/autotvm/tune_conv2d_cuda.py
+1
-1
tutorials/topi/intro_topi.py
+1
-1
vta/tests/python/integration/test_benchmark_topi_conv2d.py
+1
-0
No files found.
nnvm/python/nnvm/top/nn.py
View file @
2005f852
...
...
@@ -94,34 +94,26 @@ def compute_conv2d(attrs, inputs, _):
(
dilation_h
,
dilation_w
)
=
dilation
if
dilation_h
<
1
or
dilation_w
<
1
:
raise
ValueError
(
"dilation should be positive value"
)
elif
layout
==
"NCHW4c"
and
(
dilation_h
>
1
or
dilation_w
>
1
):
raise
ValueError
(
"not support dilate now"
)
elif
dilation
==
(
1
,
1
):
kernel
=
inputs
[
1
]
elif
layout
==
"NCHW"
:
kernel
=
topi
.
nn
.
dilate
(
inputs
[
1
],
[
1
,
1
,
dilation_h
,
dilation_w
])
else
:
#layout == NHWC
kernel
=
topi
.
nn
.
dilate
(
inputs
[
1
],
[
1
,
dilation_h
,
dilation_w
,
1
])
if
groups
==
1
and
layout
==
'NCHW4c'
and
inputs
[
0
]
.
dtype
==
'int8'
:
# pylint: disable=assignment-from-no-return
out
=
topi
.
nn
.
conv2d
_NCHWc_int8_prepacked
(
inputs
[
0
],
kernel
,
strides
,
padding
,
layout
,
out_dtype
=
out_dtype
)
out
=
topi
.
nn
.
conv2d
(
inputs
[
0
],
inputs
[
1
]
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
=
out_dtype
)
# pylint: enable=assignment-from-no-return
elif
groups
==
1
:
out
=
topi
.
nn
.
conv2d
(
inputs
[
0
],
kernel
,
strides
,
padding
,
layout
,
out_dtype
=
out_dtype
)
inputs
[
0
],
inputs
[
1
],
strides
,
padding
,
dilation
,
layout
,
out_dtype
=
out_dtype
)
elif
layout
==
"NCHW"
and
\
groups
==
get_const_int
(
inputs
[
0
]
.
shape
[
1
])
and
\
groups
==
channels
:
out
=
topi
.
nn
.
depthwise_conv2d_nchw
(
inputs
[
0
],
kernel
,
strides
,
padding
,
out_dtype
=
out_dtype
)
inputs
[
0
],
inputs
[
1
],
strides
,
padding
,
dilation
,
out_dtype
=
out_dtype
)
elif
layout
==
"NHWC"
and
\
kernel_layout
==
"HWOI"
and
\
groups
==
get_const_int
(
inputs
[
0
]
.
shape
[
3
])
and
\
groups
==
channels
:
out
=
topi
.
nn
.
depthwise_conv2d_nhwc
(
inputs
[
0
],
kernel
,
strides
,
padding
,
out_dtype
=
out_dtype
)
inputs
[
0
],
inputs
[
1
],
strides
,
padding
,
dilation
,
out_dtype
=
out_dtype
)
else
:
raise
ValueError
(
"not support arbitrary group number for now"
)
...
...
@@ -144,7 +136,7 @@ def schedule_conv2d(attrs, outs, target):
if
groups
==
1
and
layout
==
"NCHW"
:
return
topi
.
generic
.
schedule_conv2d_nchw
(
outs
)
elif
groups
==
1
and
layout
==
"NCHW4c"
:
return
topi
.
generic
.
schedule_conv2d_
NCHWc_int8_prepacked
(
outs
)
return
topi
.
generic
.
schedule_conv2d_
nchw
(
outs
)
elif
groups
==
1
and
layout
==
"NHWC"
:
return
topi
.
generic
.
schedule_conv2d_nhwc
(
outs
)
elif
groups
==
channels
and
layout
==
"NCHW"
:
...
...
@@ -175,7 +167,7 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
assert
dilation
==
(
1
,
1
),
"not support dilate now"
if
groups
==
1
:
# pylint: disable=assignment-from-no-return
out
=
topi
.
nn
.
conv2d_NCHWc
(
inputs
[
0
],
inputs
[
1
],
strides
,
padding
,
out
=
topi
.
nn
.
conv2d_NCHWc
(
inputs
[
0
],
inputs
[
1
],
strides
,
padding
,
dilation
,
layout
,
out_layout
,
out_dtype
)
# pylint: enable=assignment-from-no-return
else
:
...
...
@@ -227,7 +219,7 @@ def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, _):
# pylint: disable=assignment-from-no-return
out
=
topi
.
nn
.
conv2d_winograd_without_weight_transform
(
inputs
[
0
],
inputs
[
1
],
strides
,
padding
,
layout
,
out_dtype
,
inputs
[
0
],
inputs
[
1
],
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
tile_size
)
if
attrs
.
get_bool
(
"use_bias"
):
...
...
python/tvm/autotvm/tophub.py
View file @
2005f852
...
...
@@ -20,15 +20,15 @@ AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub
# the version of each package
PACKAGE_VERSION
=
{
'arm_cpu'
:
"v0.0
3
"
,
'llvm'
:
"v0.0
1
"
,
'arm_cpu'
:
"v0.0
4
"
,
'llvm'
:
"v0.0
2
"
,
'cuda'
:
"v0.0
3
"
,
'rocm'
:
"v0.0
1
"
,
'opencl'
:
"v0.0
1
"
,
'mali'
:
"v0.0
3
"
,
'cuda'
:
"v0.0
4
"
,
'rocm'
:
"v0.0
2
"
,
'opencl'
:
"v0.0
2
"
,
'mali'
:
"v0.0
4
"
,
'vta'
:
"v0.0
1
"
,
'vta'
:
"v0.0
4
"
,
}
logger
=
logging
.
getLogger
(
'autotvm'
)
...
...
tests/python/unittest/test_lang_tensor_overload_op.py
View file @
2005f852
...
...
@@ -175,10 +175,11 @@ def verify_conv2d_scalar_bop(batch, in_size, in_channel, num_filter, kernel, str
print
(
"Running on target:
%
s"
%
device
)
k
=
10.0
dilation
=
(
1
,
1
)
with
tvm
.
target
.
create
(
device
):
A
=
tvm
.
placeholder
((
batch
,
in_channel
,
in_size
,
in_size
),
name
=
'A'
)
W
=
tvm
.
placeholder
((
num_filter
,
in_channel
,
kernel
,
kernel
),
name
=
'W'
)
B
=
topi
.
nn
.
conv2d
(
A
,
W
,
stride
,
padding
)
B
=
topi
.
nn
.
conv2d
(
A
,
W
,
stride
,
padding
,
dilation
)
if
typ
==
"add"
:
C
=
B
+
k
elif
typ
==
"sub"
:
...
...
topi/python/topi/arm_cpu/conv2d.py
View file @
2005f852
...
...
@@ -9,11 +9,11 @@ from tvm import autotvm
from
..generic
import
schedule_conv2d_nchw
,
schedule_conv2d_winograd_without_weight_transform
from
..util
import
traverse_inline
,
get_const_tuple
,
const_matrix
from
..nn
import
pad
,
conv2d
,
conv2d_alter_layout
,
conv2d_winograd_without_weight_transform
from
..nn
import
dilate
,
pad
,
conv2d
,
conv2d_alter_layout
,
conv2d_winograd_without_weight_transform
from
..nn.util
import
get_const_int
,
get_pad_tuple
@autotvm.register_topi_compute
(
conv2d
,
'arm_cpu'
,
[
'direct'
])
def
conv2d_arm_cpu
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
):
def
conv2d_arm_cpu
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
):
"""TOPI compute callback for conv2d
Parameters
...
...
@@ -35,6 +35,9 @@ def conv2d_arm_cpu(cfg, data, kernel, strides, padding, layout, out_dtype):
padding : list of two ints
[pad_height, pad_width]
dilation : list of two ints
[dilation_height, dilation_width]
layout : str
layout of data
...
...
@@ -46,7 +49,8 @@ def conv2d_arm_cpu(cfg, data, kernel, strides, padding, layout, out_dtype):
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
return
_decl_spatial_pack
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
num_tile
=
2
)
return
_decl_spatial_pack
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
num_tile
=
2
)
@autotvm.register_topi_schedule
(
schedule_conv2d_nchw
,
'arm_cpu'
,
[
'direct'
,
'winograd'
])
def
schedule_conv2d_nchw_arm_cpu
(
cfg
,
outs
):
...
...
@@ -96,11 +100,22 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
return
s
def
_decl_spatial_pack
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
num_tile
):
def
_decl_spatial_pack
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
num_tile
):
assert
layout
==
"NCHW"
,
"Only support NCHW"
# create workload according to raw arguments
out_dtype
=
out_dtype
or
data
.
dtype
N
,
CI
,
IH
,
IW
=
get_const_tuple
(
data
.
shape
)
if
isinstance
(
dilation
,
int
):
dilation_h
=
dilation_w
=
dilation
else
:
dilation_h
,
dilation_w
=
dilation
if
dilation_h
!=
1
or
dilation_w
!=
1
:
dilation_args
=
(
1
,
1
,
dilation_h
,
dilation_w
)
if
len
(
kernel
.
shape
)
==
4
\
else
(
1
,
1
,
dilation_h
,
dilation_w
,
1
)
kernel
=
dilate
(
kernel
,
dilation_args
)
if
len
(
kernel
.
shape
)
==
4
:
pre_packed
=
False
CO
,
_
,
KH
,
KW
=
get_const_tuple
(
kernel
.
shape
)
...
...
@@ -242,17 +257,27 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
@autotvm.register_topi_compute
(
conv2d
,
'arm_cpu'
,
[
'winograd'
])
def
conv2d_arm_cpu_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
):
def
conv2d_arm_cpu_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
):
""" TOPI compute callback. Use winograd template """
tile_size
=
4
return
_decl_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
tile_size
)
return
_decl_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
tile_size
)
def
_decl_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
tile_size
):
def
_decl_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
tile_size
):
N
,
CI
,
IH
,
IW
=
get_const_tuple
(
data
.
shape
)
if
isinstance
(
dilation
,
int
):
dilation_h
=
dilation_w
=
dilation
else
:
dilation_h
,
dilation_w
=
dilation
if
len
(
kernel
.
shape
)
==
4
:
if
dilation_h
!=
1
or
dilation_w
!=
1
:
kernel
=
dilate
(
kernel
,
(
1
,
1
,
dilation_h
,
dilation_w
))
pre_computed
=
False
CO
,
_
,
KH
,
KW
=
get_const_tuple
(
kernel
.
shape
)
else
:
assert
(
dilation_h
,
dilation_w
)
==
(
1
,
1
),
"Does not support dilation"
pre_computed
=
True
H_CAT
,
W_CAT
,
CO
,
CI
,
VC
=
get_const_tuple
(
kernel
.
shape
)
CO
*=
VC
...
...
@@ -459,9 +484,10 @@ def _schedule_winograd(cfg, s, output, last):
##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
@autotvm.register_topi_compute
(
conv2d_winograd_without_weight_transform
,
'arm_cpu'
,
[
'winograd'
])
def
conv2d_winograd_ww
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
tile_size
):
def
conv2d_winograd_ww
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
tile_size
):
"""TOPI compute callback"""
return
_decl_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
tile_size
)
return
_decl_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
\
tile_size
)
@autotvm.register_topi_schedule
(
schedule_conv2d_winograd_without_weight_transform
,
...
...
topi/python/topi/cuda/conv2d.py
View file @
2005f852
...
...
@@ -5,7 +5,7 @@ from tvm import autotvm
from
tvm.contrib
import
cudnn
from
..
import
nn
,
generic
from
..util
import
get_const_
int
,
get_const_
tuple
,
traverse_inline
from
..util
import
get_const_tuple
,
traverse_inline
from
.conv2d_direct
import
schedule_direct_cuda
from
.conv2d_winograd
import
winograd_cuda
,
schedule_winograd_cuda
...
...
@@ -13,7 +13,7 @@ from .conv2d_int8 import conv2d_NCHWc_int8, schedule_conv2d_NCHWc_int8
@autotvm.register_topi_compute
(
nn
.
conv2d
,
[
'cuda'
,
'gpu'
],
[
'direct'
,
'winograd'
,
'int8'
])
def
conv2d_cuda
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
=
'NCHW'
,
out_dtype
=
'float32'
):
def
conv2d_cuda
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
=
'NCHW'
,
out_dtype
=
'float32'
):
"""Conv2D operator for cuda backend.
Parameters
...
...
@@ -36,6 +36,9 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='f
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
layout : str
layout of data
...
...
@@ -63,32 +66,15 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='f
# handle dilation
stride_h
,
stride_w
=
(
strides
,
strides
)
if
isinstance
(
strides
,
int
)
else
strides
pad_h
,
pad_w
=
(
padding
,
padding
)
if
isinstance
(
padding
,
int
)
else
padding
dilation_h
,
dilation_w
=
(
dilation
,
dilation
)
if
isinstance
(
dilation
,
int
)
else
dilation
OH
=
(
H
+
2
*
pad_h
-
KH
)
//
stride_h
+
1
OW
=
(
W
+
2
*
pad_w
-
KW
)
//
stride_w
+
1
cfg
.
add_flop
(
2
*
N
*
OH
*
OW
*
CO
*
CI
*
KH
*
KW
)
dilation_h
=
dilation_w
=
1
kernel_before_dilation
=
kernel
if
isinstance
(
kernel
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
"dilate"
in
kernel
.
op
.
tag
:
kernel_before_dilation
=
kernel
.
op
.
input_tensors
[
0
]
if
layout
==
'NCHW'
:
dilation_h
=
(
get_const_int
(
kernel
.
shape
[
2
])
+
get_const_int
(
kernel_before_dilation
.
shape
[
2
])
-
1
)
\
//
get_const_int
(
kernel_before_dilation
.
shape
[
2
])
dilation_w
=
(
get_const_int
(
kernel
.
shape
[
3
])
+
get_const_int
(
kernel_before_dilation
.
shape
[
3
])
-
1
)
\
//
get_const_int
(
kernel_before_dilation
.
shape
[
2
])
elif
layout
==
'NHWC'
:
dilation_h
=
(
get_const_int
(
kernel
.
shape
[
1
])
+
get_const_int
(
kernel_before_dilation
.
shape
[
1
])
-
1
)
\
//
get_const_int
(
kernel_before_dilation
.
shape
[
1
])
dilation_w
=
(
get_const_int
(
kernel
.
shape
[
2
])
+
get_const_int
(
kernel_before_dilation
.
shape
[
2
])
-
1
)
\
//
get_const_int
(
kernel_before_dilation
.
shape
[
2
])
cfg
.
add_flop
(
2
*
N
*
OH
*
OW
*
CO
*
CI
*
((
KH
-
1
)
*
dilation_h
+
1
)
*
\
((
KW
-
1
)
*
dilation_w
+
1
))
return
cudnn
.
conv2d_forward
(
data
,
kernel
_before_dilation
,
kernel
,
stride_h
,
stride_w
,
pad_h
,
...
...
@@ -100,16 +86,15 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='f
algo
=-
1
)
# let CUDNN choose the best algo
if
cfg
.
template_key
==
'winograd'
:
return
winograd_cuda
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
return
winograd_cuda
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
pre_computed
=
False
)
if
cfg
.
template_key
==
'int8'
:
return
conv2d_NCHWc_int8
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
pre_computed
=
False
)
return
conv2d_NCHWc_int8
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
)
if
layout
==
'NCHW'
:
return
nn
.
conv2d_nchw
(
data
,
kernel
,
strides
,
padding
,
out_dtype
)
return
nn
.
conv2d_nchw
(
data
,
kernel
,
strides
,
padding
,
dilation
,
out_dtype
)
elif
layout
==
'HWCN'
:
return
nn
.
conv2d_hwcn
(
data
,
kernel
,
strides
,
padding
,
out_dtype
)
return
nn
.
conv2d_hwcn
(
data
,
kernel
,
strides
,
padding
,
dilation
,
out_dtype
)
else
:
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
...
...
@@ -146,7 +131,7 @@ def schedule_conv2d_nchw_cuda(cfg, outs):
if
op
.
tag
==
'conv2d_nchw_winograd'
:
schedule_winograd_cuda
(
cfg
,
s
,
op
.
output
(
0
),
pre_computed
=
False
)
if
op
.
tag
==
"conv2d_NCHWc_int8"
:
schedule_conv2d_NCHWc_int8
(
cfg
,
s
,
op
.
output
(
0
)
,
pre_computed
=
False
)
schedule_conv2d_NCHWc_int8
(
cfg
,
s
,
op
.
output
(
0
))
traverse_inline
(
s
,
outs
[
0
]
.
op
,
_callback
)
return
s
topi/python/topi/cuda/conv2d_int8.py
View file @
2005f852
...
...
@@ -4,37 +4,13 @@ import tvm
from
tvm
import
autotvm
from
.injective
import
_schedule_injective
from
..generic
import
schedule_conv2d_NCHWc_int8_prepacked
from
.tensor_intrin
import
dp4a
from
..nn.conv2d
import
conv2d_NCHWc_int8_prepacked
from
..nn.pad
import
pad
from
..nn.util
import
get_pad_tuple
from
..util
import
get_const_tuple
,
traverse_inline
from
..util
import
get_const_tuple
def
_conv2d_NCHWc_int8_arg_to_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
):
"""convert argument to workload"""
shape
=
get_const_tuple
(
data
.
shape
)
if
len
(
shape
)
==
5
:
N
,
ic_chunk
,
H
,
W
,
ic_block
=
shape
raw_data
=
tvm
.
placeholder
(
(
N
,
ic_chunk
*
ic_block
,
H
,
W
),
dtype
=
data
.
dtype
)
else
:
raw_data
=
data
shape
=
get_const_tuple
(
kernel
.
shape
)
if
len
(
shape
)
==
6
:
oc_chunk
,
ic_chunk
,
KH
,
KW
,
oc_block
,
ic_block
=
shape
raw_kernel
=
tvm
.
placeholder
(
(
oc_chunk
*
oc_block
,
ic_chunk
*
ic_block
,
KH
,
KW
),
dtype
=
kernel
.
dtype
)
else
:
raw_kernel
=
kernel
return
(
'conv2d'
,
)
+
autotvm
.
task
.
task
.
args_to_workload
(
[
raw_data
,
raw_kernel
,
stride
,
padding
,
"NCHW"
,
out_dtype
])
def
conv2d_NCHWc_int8
(
cfg
,
data
,
kernel
,
stride
,
padding
,
layout
,
out_dtype
,
pre_computed
):
def
conv2d_NCHWc_int8
(
cfg
,
data
,
kernel
,
stride
,
padding
,
dilation
,
layout
,
out_dtype
):
"""Convolution operator in NCHW[x]c layout for int8.
Parameters
...
...
@@ -57,25 +33,25 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, layout, out_dtype, pre
padding: int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
layout : str
layout of data
out_dtype : str
The output type. This is used for mixed precision.
pre_computed : str
Whether packed data and kernel are pre-computed
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
assert
layout
in
[
"NCHW"
,
"NCHW4c"
]
ic_block_factor
=
4
oc_block_factor
=
4
pre_computed
=
len
(
kernel
.
shape
)
==
6
if
not
pre_computed
:
batch
,
channels
,
height
,
width
=
get_const_tuple
(
data
.
shape
)
assert
channels
%
ic_block_factor
==
0
,
\
...
...
@@ -109,10 +85,15 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, layout, out_dtype, pre
packed_kernel
.
shape
)
if
isinstance
(
stride
,
int
):
stride_h
,
stride_w
=
stride
stride_h
=
stride_w
=
stride
else
:
stride_h
,
stride_w
=
stride
if
isinstance
(
dilation
,
int
):
dilation_h
=
dilation_w
=
dilation
else
:
dilation_h
,
dilation_w
=
dilation
pad_top
,
pad_left
,
pad_down
,
pad_right
=
get_pad_tuple
(
padding
,
(
kernel_h
,
kernel_w
))
# compute graph
...
...
@@ -121,8 +102,8 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, layout, out_dtype, pre
pad_data
=
pad
(
packed_data
,
pad_before
,
pad_after
,
name
=
"pad_data"
)
# compute the output shape
out_height
=
(
in_height
-
kernel_h
+
pad_top
+
pad_down
)
//
stride_h
+
1
out_width
=
(
in_width
-
kernel_w
+
pad_left
+
pad_right
)
//
stride_w
+
1
out_height
=
(
in_height
-
(
kernel_h
-
1
)
*
dilation_h
-
1
+
pad_top
+
pad_down
)
//
stride_h
+
1
out_width
=
(
in_width
-
(
kernel_w
-
1
)
*
dilation_w
-
1
+
pad_left
+
pad_right
)
//
stride_w
+
1
oshape
=
(
batch
,
oc_chunk
,
out_height
,
out_width
,
oc_block
)
...
...
@@ -132,7 +113,8 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, layout, out_dtype, pre
kw
=
tvm
.
reduce_axis
((
0
,
kernel_w
),
name
=
'kw'
)
conv
=
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
tvm
.
sum
(
pad_data
[
n
,
icc
,
oh
*
stride_h
+
kh
,
ow
*
stride_w
+
kw
,
icb
]
tvm
.
sum
(
pad_data
[
n
,
icc
,
oh
*
stride_h
+
kh
*
dilation_h
,
\
ow
*
stride_w
+
kw
*
dilation_w
,
icb
]
.
astype
(
'int32'
)
*
packed_kernel
[
oc_chunk
,
icc
,
kh
,
kw
,
oc_block
,
icb
]
...
...
@@ -141,9 +123,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, layout, out_dtype, pre
output
=
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
conv
[
n
,
oc_chunk
,
oh
,
ow
,
oc_block
]
.
astype
(
out_dtype
),
tag
=
"conv2d_NCHWc_int8"
,
attrs
=
{
"workload"
:
_conv2d_NCHWc_int8_arg_to_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
)})
tag
=
"conv2d_NCHWc_int8"
)
# num flop
num_flop
=
batch
*
oc_chunk
*
oc_block
*
out_height
*
out_width
*
\
...
...
@@ -156,7 +136,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, layout, out_dtype, pre
_dp4a
=
dp4a
(
'shared'
,
'shared'
,
'local'
)
def
schedule_conv2d_NCHWc_int8
(
cfg
,
s
,
output
,
pre_computed
):
def
schedule_conv2d_NCHWc_int8
(
cfg
,
s
,
output
):
"""Schedule conv2d int8 NCHWc template"""
workload
=
output
.
op
.
attrs
[
"workload"
]
...
...
@@ -171,22 +151,17 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed):
else
:
pad_data
=
packed_data
if
not
pre_computed
:
kernel
,
=
packed_kernel
.
op
.
input_tensors
if
autotvm
.
GLOBAL_SCOPE
.
in_tuning
:
# skip this part during tuning to make recrods accurate
# this part will be pre-computed during NNVM's pre-compute optimization pass
s
[
packed_data
]
.
pragma
(
s
[
packed_data
]
.
op
.
axis
[
0
],
"debug_skip_region"
)
s
[
packed_kernel
]
.
pragma
(
s
[
packed_kernel
]
.
op
.
axis
[
0
],
"debug_skip_region"
)
else
:
if
autotvm
.
GLOBAL_SCOPE
.
in_tuning
:
# skip this part during tuning to make recrods accurate
# this part will be pre-computed during NNVM's pre-compute optimization pass
s
[
packed_data
]
.
pragma
(
s
[
packed_data
]
.
op
.
axis
[
0
],
"debug_skip_region"
)
s
[
packed_kernel
]
.
pragma
(
s
[
packed_kernel
]
.
op
.
axis
[
0
],
"debug_skip_region"
)
else
:
if
isinstance
(
packed_kernel
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
\
packed_kernel
.
name
==
'packed_kernel'
:
# data and kernel are not pre-computed, schedule layout transform here
_schedule_injective
(
packed_data
.
op
,
s
)
_schedule_injective
(
packed_kernel
.
op
,
s
)
else
:
kernel
=
packed_kernel
if
isinstance
(
kernel
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
"dilate"
in
kernel
.
op
.
tag
:
s
[
kernel
]
.
compute_inline
()
if
pad_data
!=
packed_data
:
s
[
pad_data
]
.
compute_inline
()
...
...
@@ -310,43 +285,3 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed):
s
[
output
]
.
pragma
(
kernel_scope
,
'unroll_explicit'
,
False
)
return
s
@conv2d_NCHWc_int8_prepacked.register
([
"cuda"
])
@autotvm.task.dispatcher
def
conv2d_NCHWc_int8_prepacked_dispatcher
(
data
,
kernel
,
stride
,
padding
,
layout
,
out_dtype
):
assert
layout
==
'NCHW4c'
return
_conv2d_NCHWc_int8_arg_to_workload
(
data
,
kernel
,
stride
,
padding
,
out_dtype
)
@conv2d_NCHWc_int8_prepacked_dispatcher.register
(
"int8"
)
def
_decl_conv2d_NCHWc_int8_prepacked
(
cfg
,
data
,
kernel
,
stride
,
padding
,
layout
,
out_dtype
):
return
conv2d_NCHWc_int8
(
cfg
,
data
,
kernel
,
stride
,
padding
,
layout
,
out_dtype
,
pre_computed
=
True
)
@autotvm.register_topi_schedule
(
schedule_conv2d_NCHWc_int8_prepacked
,
[
"cuda"
],
[
"int8"
])
def
schedule_conv2d_NCHWc_int8_prepacked_cuda
(
cfg
,
outs
):
"""TOPI schedule callback of conv2d for cuda
Parameters
----------
cfg: ConfigEntity
The config for this template
outs: Array of Tensor
The computation graph description of conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d.
"""
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
def
_callback
(
op
):
if
'conv2d_NCHWc_int8'
in
op
.
tag
:
schedule_conv2d_NCHWc_int8
(
cfg
,
s
,
op
.
output
(
0
),
pre_computed
=
True
)
traverse_inline
(
s
,
outs
[
0
]
.
op
,
_callback
)
return
s
topi/python/topi/cuda/conv2d_winograd.py
View file @
2005f852
...
...
@@ -7,23 +7,10 @@ import tvm
from
tvm
import
autotvm
from
..
import
nn
from
..nn
import
conv2d_winograd_without_weight_transform
from
..nn
import
conv2d
,
conv2d
_winograd_without_weight_transform
from
..util
import
get_const_int
,
get_const_tuple
,
const_matrix
,
traverse_inline
from
..generic
import
schedule_conv2d_winograd_without_weight_transform
def
_winograd_conv_arg_to_workload
(
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
):
"""convert argument to workload"""
K
=
3
shape
=
get_const_tuple
(
kernel
.
shape
)
if
shape
[
-
2
:]
==
(
K
,
K
):
raw_kernel
=
kernel
else
:
# pre-transformed
_
,
_
,
CI
,
CO
=
shape
raw_kernel
=
tvm
.
placeholder
((
CO
,
CI
,
K
,
K
),
dtype
=
kernel
.
dtype
)
return
(
'conv2d'
,
)
+
autotvm
.
task
.
args_to_workload
(
[
data
,
raw_kernel
,
strides
,
padding
,
layout
,
out_dtype
])
def
_infer_tile_size
(
data
,
kernel
):
N
,
CI
,
H
,
W
=
get_const_tuple
(
data
.
shape
)
...
...
@@ -32,7 +19,7 @@ def _infer_tile_size(data, kernel):
return
4
return
2
def
winograd_cuda
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
pre_computed
):
def
winograd_cuda
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
pre_computed
):
"""Compute declaration for winograd"""
assert
layout
==
'NCHW'
...
...
@@ -41,12 +28,20 @@ def winograd_cuda(cfg, data, kernel, strides, padding, layout, out_dtype, pre_co
N
,
CI
,
H
,
W
=
get_const_tuple
(
data
.
shape
)
if
not
pre_computed
:
# kernel tensor is raw tensor, do strict check
if
isinstance
(
dilation
,
int
):
dilation_h
=
dilation_w
=
dilation
else
:
dilation_h
,
dilation_w
=
dilation
if
dilation_h
!=
1
or
dilation_w
!=
1
:
kernel
=
dilate
(
kernel
,
(
1
,
1
,
dilation_h
,
dilation_w
))
CO
,
CI
,
KH
,
KW
=
get_const_tuple
(
kernel
.
shape
)
HPAD
,
WPAD
,
_
,
_
=
nn
.
get_pad_tuple
(
padding
,
kernel
)
HSTR
,
WSTR
=
(
strides
,
strides
)
if
isinstance
(
strides
,
int
)
else
strides
assert
HSTR
==
1
and
WSTR
==
1
and
HPAD
==
1
and
WPAD
==
1
and
KH
==
3
and
KW
==
3
else
:
# kernel tensor is pre-transfomred. this op is created by
# alter op layout, do not check
# dilation is not supported
HSTR
=
WSTR
=
1
HPAD
=
WPAD
=
1
KH
=
KW
=
3
...
...
@@ -150,9 +145,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, layout, out_dtype, pre_co
# output
output
=
tvm
.
compute
((
N
,
CO
,
H
,
W
),
lambda
n
,
co
,
h
,
w
:
inverse
[
co
][
n
*
nH
*
nW
+
(
h
//
m
)
*
nW
+
w
//
m
][
h
%
m
][
w
%
m
],
name
=
'output'
,
tag
=
'conv2d_nchw_winograd'
,
attrs
=
{
"workload"
:
_winograd_conv_arg_to_workload
(
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
)})
name
=
'output'
,
tag
=
'conv2d_nchw_winograd'
)
cfg
.
add_flop
(
2
*
N
*
CO
*
H
*
W
*
CI
*
KH
*
KW
)
return
output
...
...
@@ -314,16 +307,11 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
return
s
##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
@conv2d_winograd_without_weight_transform.register
([
'cuda'
,
'gpu'
])
@autotvm.task.dispatcher
def
winograd_ww_config_dispatcher_cuda
(
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
tile_size
):
return
_winograd_conv_arg_to_workload
(
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
)
@winograd_ww_config_dispatcher_cuda.register
([
'winograd'
])
def
decl_winograd_ww
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
tile_size
):
return
winograd_cuda
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
pre_computed
=
True
)
@autotvm.register_topi_compute
(
conv2d_winograd_without_weight_transform
,
[
'cuda'
,
'gpu'
],
[
'winograd'
])
def
conv2d_winograd_ww
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
tile_size
):
return
winograd_cuda
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
pre_computed
=
True
)
@autotvm.register_topi_schedule
(
schedule_conv2d_winograd_without_weight_transform
,
...
...
@@ -352,36 +340,54 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
new_attrs
=
{
k
:
attrs
[
k
]
for
k
in
attrs
.
keys
()}
assert
attrs
.
get_int_tuple
(
"dilation"
)
==
(
1
,
1
),
"Does not support dilation "
\
"when alter_op_layout is enabled"
strides
=
attrs
.
get_int_tuple
(
"strides"
)
padding
=
attrs
.
get_int_tuple
(
"padding"
)
dilation
=
attrs
.
get_int_tuple
(
"dilation"
)
groups
=
attrs
.
get_int
(
'groups'
)
layout
=
attrs
[
"layout"
]
out_dtype
=
attrs
[
"out_dtype"
]
out_dtype
=
tinfos
[
0
]
.
dtype
if
out_dtype
==
"same"
else
out_dtype
data
,
kernel
=
tinfos
[
0
:
2
]
N
,
CI
,
H
,
W
=
get_const_tuple
(
data
.
shape
)
CO
,
_
,
KH
,
KW
=
get_const_tuple
(
kernel
.
shape
)
dispatch_ctx
=
autotvm
.
DispatchContext
.
current
if
groups
==
1
:
# query config of this workload
workload
=
(
'conv2d'
,)
+
autotvm
.
task
.
args_to_workload
(
[
tinfos
[
0
],
tinfos
[
1
],
strides
,
padding
,
layout
,
out_dtype
])
cfg
=
autotvm
.
DispatchContext
.
current
.
query
(
t
vm
.
target
.
current_target
()
,
workload
)
[
tinfos
[
0
],
tinfos
[
1
],
strides
,
padding
,
dilation
,
layout
,
out_dtype
])
target
=
tvm
.
target
.
current_target
()
cfg
=
autotvm
.
DispatchContext
.
current
.
query
(
t
arget
,
workload
)
if
cfg
.
is_fallback
:
# if is fallback, clear query cache and return None
autotvm
.
task
.
clear_fallback_cache
(
t
vm
.
target
.
current_target
()
,
workload
)
autotvm
.
task
.
clear_fallback_cache
(
t
arget
,
workload
)
return
None
if
cfg
.
template_key
==
'direct'
:
return
None
if
cfg
.
template_key
==
'int8'
:
assert
'cuda'
in
tvm
.
target
.
current_target
()
.
keys
new_attrs
[
'layout'
]
=
'NCHW4c'
new_attrs
[
'out_layout'
]
=
'NCHW4c'
assert
'cuda'
in
target
.
keys
new_layout
=
'NCHW4c'
new_attrs
[
'layout'
]
=
new_layout
new_attrs
[
'out_layout'
]
=
new_layout
new_attrs
[
'kernel_layout'
]
=
'OIHW4o4i'
ic_block_factor
=
oc_block_factor
=
4
new_data
=
tvm
.
placeholder
((
N
,
CI
//
ic_block_factor
,
H
,
W
,
ic_block_factor
),
dtype
=
data
.
dtype
)
new_kernel
=
tvm
.
placeholder
((
CO
//
oc_block_factor
,
CI
//
ic_block_factor
,
KH
,
KW
,
\
oc_block_factor
,
ic_block_factor
),
dtype
=
kernel
.
dtype
)
new_workload
=
autotvm
.
task
.
args_to_workload
(
[
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
new_layout
,
out_dtype
],
conv2d
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
return
sym
.
conv2d
(
*
copy_inputs
,
**
new_attrs
)
if
attrs
.
get_int_tuple
(
"dilation"
)
!=
(
1
,
1
):
return
None
# pre-compute weight transformation in winograd
tile_size
=
_infer_tile_size
(
tinfos
[
0
],
tinfos
[
1
])
...
...
@@ -390,6 +396,15 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
weight
=
sym
.
transpose
(
weight
,
axes
=
[
0
,
1
,
3
,
2
])
copy_inputs
[
1
]
=
weight
new_attrs
[
'tile_size'
]
=
tile_size
new_data
=
data
new_weight
=
tvm
.
placeholder
((
KH
+
tile_size
-
1
,
KW
+
tile_size
-
1
,
CI
,
CO
),
dtype
=
kernel
.
dtype
)
new_workload
=
autotvm
.
task
.
args_to_workload
(
[
new_data
,
new_weight
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
tile_size
],
conv2d_winograd_without_weight_transform
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
return
sym
.
contrib
.
conv2d_winograd_without_weight_transform
(
*
copy_inputs
,
**
new_attrs
)
# do nothing for depthwise convolution
...
...
topi/python/topi/generic/nn.py
View file @
2005f852
...
...
@@ -122,24 +122,6 @@ def schedule_conv2d_winograd_without_weight_transform(outs):
@tvm.target.generic_func
def
schedule_conv2d_NCHWc_int8_prepacked
(
outs
):
"""Schedule for conv2d NCHWc int8 with prepacked data and kernel
Parameters
----------
outs: Array of Tensor
The computation graph description of this operator
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return
_default_schedule
(
outs
,
False
)
@tvm.target.generic_func
def
schedule_conv2d_transpose_nchw
(
outs
):
"""Schedule for conv2d_transpose_nchw
...
...
topi/python/topi/mali/conv2d.py
View file @
2005f852
...
...
@@ -16,7 +16,7 @@ from ..arm_cpu.conv2d import _decl_spatial_pack, _alter_conv2d_layout_arm
@autotvm.register_topi_compute
(
conv2d
,
'mali'
,
[
'direct'
])
def
conv2d_mali
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
):
def
conv2d_mali
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
):
"""TOPI compute callback for conv2d
Parameters
...
...
@@ -38,6 +38,9 @@ def conv2d_mali(cfg, data, kernel, strides, padding, layout, out_dtype):
padding : list of two ints
[pad_height, pad_width]
dilation : list of two ints
[dilation_height, dilation_width]
layout : str
layout of data
...
...
@@ -49,7 +52,8 @@ def conv2d_mali(cfg, data, kernel, strides, padding, layout, out_dtype):
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
return
_decl_spatial_pack
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
num_tile
=
3
)
return
_decl_spatial_pack
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
num_tile
=
3
)
@autotvm.register_topi_schedule
(
schedule_conv2d_nchw
,
'mali'
,
[
'direct'
,
'winograd'
])
def
schedule_conv2d_nchw_mali
(
cfg
,
outs
):
...
...
@@ -175,16 +179,26 @@ def _pick_tile_size(data, kernel):
return
2
@autotvm.register_topi_compute
(
conv2d
,
'mali'
,
[
'winograd'
])
def
conv2d_mali_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
):
def
conv2d_mali_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
):
tile_size
=
_pick_tile_size
(
data
,
kernel
)
return
_decl_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
tile_size
)
return
_decl_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
tile_size
)
def
_decl_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
tile_size
):
def
_decl_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
tile_size
):
N
,
CI
,
IH
,
IW
=
get_const_tuple
(
data
.
shape
)
if
isinstance
(
dilation
,
int
):
dilation_h
=
dilation_w
=
dilation
else
:
dilation_h
,
dilation_w
=
dilation
if
len
(
kernel
.
shape
)
==
4
:
if
dilation_h
!=
1
or
dilation_w
!=
1
:
kernel
=
dilate
(
kernel
,
(
1
,
1
,
dilation_h
,
dilation_w
))
pre_computed
=
False
CO
,
_
,
KH
,
KW
=
get_const_tuple
(
kernel
.
shape
)
else
:
assert
(
dilation_h
,
dilation_w
)
==
(
1
,
1
),
"Does not support dilation"
pre_computed
=
True
H_CAT
,
W_CAT
,
CO
,
CI
,
VC
=
get_const_tuple
(
kernel
.
shape
)
CO
*=
VC
...
...
@@ -428,7 +442,8 @@ def _schedule_winograd(cfg, s, op):
@autotvm.register_topi_compute
(
conv2d_winograd_without_weight_transform
,
'mali'
,
[
'winograd'
])
def
conv2d_winograd_ww
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
tile_size
):
"""TOPI compute callback"""
return
_decl_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
,
tile_size
)
return
_decl_winograd
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
tile_size
)
@autotvm.register_topi_schedule
(
schedule_conv2d_winograd_without_weight_transform
,
...
...
topi/python/topi/nn/conv2d.py
View file @
2005f852
...
...
@@ -6,6 +6,7 @@ from collections import namedtuple
import
numpy
as
np
import
tvm
from
.dilate
import
dilate
from
.pad
import
pad
from
.util
import
get_pad_tuple
from
..util
import
simplify
,
const_matrix
,
get_const_tuple
...
...
@@ -16,7 +17,7 @@ Workload = namedtuple('Workload',
'hkernel'
,
'wkernel'
,
'hpad'
,
'wpad'
,
'hstride'
,
'wstride'
])
@tvm.target.generic_func
def
conv2d
(
input
,
filter
,
strides
,
padding
,
layout
=
'NCHW'
,
out_dtype
=
None
):
def
conv2d
(
input
,
filter
,
strides
,
padding
,
dilation
,
layout
=
'NCHW'
,
out_dtype
=
None
):
"""Conv2D operator.
Parameters
...
...
@@ -33,6 +34,9 @@ def conv2d(input, filter, strides, padding, layout='NCHW', out_dtype=None):
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
layout : str
layout of data
...
...
@@ -44,11 +48,11 @@ def conv2d(input, filter, strides, padding, layout='NCHW', out_dtype=None):
# search platform specific declaration first
# default declaration
if
layout
==
'NCHW'
:
return
conv2d_nchw
(
input
,
filter
,
strides
,
padding
,
out_dtype
)
return
conv2d_nchw
(
input
,
filter
,
strides
,
padding
,
dilation
,
out_dtype
)
elif
layout
==
'HWCN'
:
return
conv2d_hwcn
(
input
,
filter
,
strides
,
padding
,
out_dtype
)
return
conv2d_hwcn
(
input
,
filter
,
strides
,
padding
,
dilation
,
out_dtype
)
elif
layout
==
'NHWC'
:
return
conv2d_nhwc
(
input
,
filter
,
strides
,
padding
,
out_dtype
)
return
conv2d_nhwc
(
input
,
filter
,
strides
,
padding
,
dilation
,
out_dtype
)
else
:
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
...
...
@@ -85,7 +89,7 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
return
Workload
(
data
.
dtype
,
out_dtype
,
IH
,
IW
,
CI
,
CO
,
KH
,
KW
,
HPAD
,
WPAD
,
HSTR
,
WSTR
)
def
conv2d_nchw
(
Input
,
Filter
,
stride
,
padding
,
out_dtype
=
None
):
def
conv2d_nchw
(
Input
,
Filter
,
stride
,
padding
,
dilation
,
out_dtype
=
None
):
"""Convolution operator in NCHW layout.
Parameters
...
...
@@ -102,6 +106,9 @@ def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
padding : int or str
Padding size, or ['VALID', 'SAME']
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
Returns
-------
Output : tvm.Tensor
...
...
@@ -110,12 +117,22 @@ def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
if
out_dtype
is
None
:
out_dtype
=
Input
.
dtype
assert
isinstance
(
stride
,
int
)
or
len
(
stride
)
==
2
batch
,
in_channel
,
in_height
,
in_width
=
Input
.
shape
num_filter
,
channel
,
kernel_h
,
kernel_w
=
Filter
.
shape
assert
isinstance
(
dilation
,
int
)
or
len
(
dilation
)
==
2
if
isinstance
(
stride
,
int
):
stride_h
=
stride_w
=
stride
else
:
stride_h
,
stride_w
=
stride
if
isinstance
(
dilation
,
int
):
dilation_h
=
dilation_w
=
dilation
else
:
dilation_h
,
dilation_w
=
dilation
if
dilation_h
!=
1
or
dilation_w
!=
1
:
Filter
=
dilate
(
Filter
,
(
1
,
1
,
dilation_h
,
dilation_w
))
batch
,
in_channel
,
in_height
,
in_width
=
Input
.
shape
num_filter
,
channel
,
kernel_h
,
kernel_w
=
Filter
.
shape
pad_top
,
pad_left
,
pad_down
,
pad_right
=
get_pad_tuple
(
padding
,
(
kernel_h
,
kernel_w
))
# compute the output shape
...
...
@@ -138,7 +155,7 @@ def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
axis
=
[
rc
,
ry
,
rx
]),
tag
=
"conv2d_nchw"
)
def
conv2d_hwcn
(
Input
,
Filter
,
stride
,
padding
,
out_dtype
=
None
):
def
conv2d_hwcn
(
Input
,
Filter
,
stride
,
padding
,
dilation
,
out_dtype
=
None
):
"""Convolution operator in HWCN layout.
Parameters
...
...
@@ -155,6 +172,9 @@ def conv2d_hwcn(Input, Filter, stride, padding, out_dtype=None):
padding : int or str
Padding size, or ['VALID', 'SAME']
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
Returns
-------
output : tvm.Tensor
...
...
@@ -163,13 +183,23 @@ def conv2d_hwcn(Input, Filter, stride, padding, out_dtype=None):
if
out_dtype
is
None
:
out_dtype
=
Input
.
dtype
assert
isinstance
(
stride
,
int
)
or
len
(
stride
)
==
2
in_height
,
in_width
,
in_channel
,
batch
=
Input
.
shape
kernel_h
,
kernel_w
,
channel
,
num_filter
=
Filter
.
shape
assert
isinstance
(
dilation
,
int
)
or
len
(
dilation
)
==
2
if
isinstance
(
stride
,
int
):
stride_h
=
stride_w
=
stride
else
:
stride_h
,
stride_w
=
stride
if
isinstance
(
dilation
,
int
):
dilation_h
=
dilation_w
=
dilation
else
:
dilation_h
,
dilation_w
=
dilation
if
dilation_h
!=
1
or
dilation_w
!=
1
:
Filter
=
dilate
(
Filter
,
(
dilation_h
,
dilation_w
,
1
,
1
))
in_height
,
in_width
,
in_channel
,
batch
=
Input
.
shape
kernel_h
,
kernel_w
,
channel
,
num_filter
=
Filter
.
shape
pad_top
,
pad_left
,
pad_down
,
pad_right
=
get_pad_tuple
(
padding
,
(
kernel_h
,
kernel_w
))
# compute the output shape
...
...
@@ -191,7 +221,7 @@ def conv2d_hwcn(Input, Filter, stride, padding, out_dtype=None):
return
Output
def
conv2d_nhwc
(
Input
,
Filter
,
stride
,
padding
,
out_dtype
=
'float32'
):
def
conv2d_nhwc
(
Input
,
Filter
,
stride
,
padding
,
dilation
,
out_dtype
=
'float32'
):
"""Convolution operator in NHWC layout.
Parameters
...
...
@@ -208,19 +238,32 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'):
padding : int or str
Padding size, or ['VALID', 'SAME']
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
assert
isinstance
(
stride
,
int
)
or
len
(
stride
)
==
2
batch
,
in_height
,
in_width
,
in_channel
=
Input
.
shape
kernel_h
,
kernel_w
,
channel
,
num_filter
=
Filter
.
shape
assert
isinstance
(
dilation
,
int
)
or
len
(
dilation
)
==
2
if
isinstance
(
stride
,
int
):
stride_h
=
stride_w
=
stride
else
:
stride_h
,
stride_w
=
stride
if
isinstance
(
dilation
,
int
):
dilation_h
=
dilation_w
=
dilation
else
:
dilation_h
,
dilation_w
=
dilation
if
dilation_h
!=
1
or
dilation_w
!=
1
:
Filter
=
dilate
(
Filter
,
(
dilation_h
,
dilation_w
,
1
,
1
))
batch
,
in_height
,
in_width
,
in_channel
=
Input
.
shape
kernel_h
,
kernel_w
,
channel
,
num_filter
=
Filter
.
shape
pad_top
,
pad_left
,
pad_down
,
pad_right
=
get_pad_tuple
(
padding
,
(
kernel_h
,
kernel_w
))
# compute the output shape
...
...
@@ -243,7 +286,7 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'):
@tvm.target.generic_func
def
conv2d_NCHWc
(
data
,
kernel
,
stride
,
padding
,
layout
,
out_layout
,
out_dtype
=
'float32'
):
def
conv2d_NCHWc
(
data
,
kernel
,
stride
,
padding
,
dilation
,
layout
,
out_layout
,
out_dtype
=
'float32'
):
"""Conv2D operator for nChw[x]c layout.
Parameters
...
...
@@ -262,6 +305,9 @@ def conv2d_NCHWc(data, kernel, stride, padding, layout, out_layout, out_dtype='f
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
layout : str
Input data layout
...
...
@@ -333,7 +379,7 @@ def conv2d_winograd_weight_transform(kernel, tile_size):
@tvm.target.generic_func
def
conv2d_winograd_without_weight_transform
(
input
,
filter
,
strides
,
padding
,
def
conv2d_winograd_without_weight_transform
(
input
,
filter
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
,
tile_size
):
"""Compute convolution in winograd algorithm. The filter is supposed to be transformed
in advance.
...
...
@@ -357,37 +403,3 @@ def conv2d_winograd_without_weight_transform(input, filter, strides, padding,
4-D with shape [batch, out_height, out_width, out_channel]
"""
raise
ValueError
(
"missing register for topi.nn.conv2d_winograd_without_weight_transform"
)
@tvm.target.generic_func
def
conv2d_NCHWc_int8_prepacked
(
data
,
kernel
,
stride
,
padding
,
layout
,
out_dtype
):
"""Convolution operator in NCHW[x]c layout for int8. Data and kernel should be packed in
advance.
Parameters
----------
data : tvm.Tensor
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
kernel : tvm.Tensor
6-D with shape [num_filter_chunk, in_channel_chunk, filter_height,
filter_width, num_filter_block, in_channel_block]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding: int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
layout : str
layout of data
out_dtype: str
The output type. This is used for mixed precision.
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
raise
ValueError
(
"missing register for topi.nn.conv2d_NCHWc_int8_prepacked"
)
topi/python/topi/nn/depthwise_conv2d.py
View file @
2005f852
...
...
@@ -10,7 +10,7 @@ from ..util import simplify
@tvm.target.generic_func
def
depthwise_conv2d_nchw
(
Input
,
Filter
,
stride
,
padding
,
out_dtype
=
None
):
def
depthwise_conv2d_nchw
(
Input
,
Filter
,
stride
,
padding
,
dilation
,
out_dtype
=
None
):
"""Depthwise convolution nchw forward operator.
Parameters
...
...
@@ -27,6 +27,9 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
padding : int or str
Padding size, or ['VALID', 'SAME']
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
out_dtype: str, optional
Output data type
...
...
@@ -37,13 +40,23 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
"""
out_dtype
=
Input
.
dtype
if
out_dtype
is
None
else
out_dtype
batch
,
in_channel
,
in_height
,
in_width
=
Input
.
shape
filter_channel
,
channel_multiplier
,
filter_height
,
filter_width
=
Filter
.
shape
if
isinstance
(
stride
,
int
):
stride_h
=
stride_w
=
stride
else
:
stride_h
,
stride_w
=
stride
if
isinstance
(
dilation
,
int
):
dilation_h
=
dilation_w
=
dilation
else
:
dilation_h
,
dilation_w
=
dilation
if
dilation_h
!=
1
or
dilation_w
!=
1
:
Filter
=
dilate
(
Filter
,
(
1
,
1
,
dilation_h
,
dilation_w
))
batch
,
in_channel
,
in_height
,
in_width
=
Input
.
shape
# shape of dilated kernel
filter_channel
,
channel_multiplier
,
filter_height
,
filter_width
=
Filter
.
shape
pad_top
,
pad_left
,
pad_down
,
pad_right
=
get_pad_tuple
(
padding
,
(
filter_height
,
filter_width
))
out_channel
=
simplify
(
in_channel
*
channel_multiplier
)
...
...
@@ -68,7 +81,7 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
@tvm.target.generic_func
def
depthwise_conv2d_nhwc
(
Input
,
Filter
,
stride
,
padding
,
out_dtype
=
None
):
def
depthwise_conv2d_nhwc
(
Input
,
Filter
,
stride
,
padding
,
dilation
,
out_dtype
=
None
):
"""Depthwise convolution nhwc forward operator.
Parameters
...
...
@@ -85,6 +98,9 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, out_dtype=None):
padding : int or str
Padding size, or ['VALID', 'SAME']
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
out_dtype: str, optional
Output data type
...
...
@@ -95,13 +111,23 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, out_dtype=None):
"""
out_dtype
=
Input
.
dtype
if
out_dtype
is
None
else
out_dtype
batch
,
in_height
,
in_width
,
in_channel
=
Input
.
shape
filter_height
,
filter_width
,
filter_channel
,
channel_multiplier
=
Filter
.
shape
if
isinstance
(
stride
,
int
):
stride_h
=
stride_w
=
stride
else
:
stride_h
,
stride_w
=
stride
if
isinstance
(
dilation
,
int
):
dilation_h
=
dilation_w
=
dilation
else
:
dilation_h
,
dilation_w
=
dilation
if
dilation_h
!=
1
or
dilation_w
!=
1
:
Filter
=
dilate
(
Filter
,
(
dilation_h
,
dilation_w
,
1
,
1
))
batch
,
in_height
,
in_width
,
in_channel
=
Input
.
shape
# shape of dilated kernel
filter_height
,
filter_width
,
filter_channel
,
channel_multiplier
=
Filter
.
shape
pad_top
,
pad_left
,
pad_down
,
pad_right
=
get_pad_tuple
(
padding
,
(
filter_height
,
filter_width
))
out_channel
=
simplify
(
in_channel
*
channel_multiplier
)
...
...
topi/python/topi/rocm/conv2d.py
View file @
2005f852
...
...
@@ -5,11 +5,11 @@ from tvm import autotvm
from
tvm.contrib
import
miopen
from
..
import
nn
,
generic
from
..util
import
get_const_
int
,
get_const_
tuple
from
..util
import
get_const_tuple
from
..cuda.conv2d
import
conv2d_cuda
,
schedule_conv2d_nchw_cuda
@autotvm.register_topi_compute
(
nn
.
conv2d
,
'rocm'
,
[
'direct'
,
'winograd'
])
def
conv2d_rocm
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
=
'NCHW'
,
out_dtype
=
'float32'
):
def
conv2d_rocm
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
=
'NCHW'
,
out_dtype
=
'float32'
):
"""Conv2D operator for rocm backend.
Parameters
...
...
@@ -47,29 +47,12 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='f
# handle dilation
stride_h
,
stride_w
=
(
strides
,
strides
)
if
isinstance
(
strides
,
int
)
else
strides
pad_h
,
pad_w
=
(
padding
,
padding
)
if
isinstance
(
padding
,
int
)
else
padding
dilation_h
,
dilation_w
=
(
dilation
,
dilation
)
if
isinstance
(
dilation
,
int
)
else
dilation
OH
=
(
H
+
2
*
pad_h
-
KH
)
//
stride_h
+
1
OW
=
(
W
+
2
*
pad_w
-
KW
)
//
stride_w
+
1
cfg
.
add_flop
(
2
*
N
*
OH
*
OW
*
CO
*
CI
*
KH
*
KW
)
dilation_h
=
dilation_w
=
1
kernel_before_dilation
=
kernel
if
isinstance
(
kernel
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
"dilate"
in
kernel
.
op
.
tag
:
kernel_before_dilation
=
kernel
.
op
.
input_tensors
[
0
]
if
layout
==
'NCHW'
:
dilation_h
=
(
get_const_int
(
kernel
.
shape
[
2
])
+
get_const_int
(
kernel_before_dilation
.
shape
[
2
])
-
1
)
\
//
get_const_int
(
kernel_before_dilation
.
shape
[
2
])
dilation_w
=
(
get_const_int
(
kernel
.
shape
[
3
])
+
get_const_int
(
kernel_before_dilation
.
shape
[
3
])
-
1
)
\
//
get_const_int
(
kernel_before_dilation
.
shape
[
2
])
elif
layout
==
'NHWC'
:
dilation_h
=
(
get_const_int
(
kernel
.
shape
[
1
])
+
get_const_int
(
kernel_before_dilation
.
shape
[
1
])
-
1
)
\
//
get_const_int
(
kernel_before_dilation
.
shape
[
1
])
dilation_w
=
(
get_const_int
(
kernel
.
shape
[
2
])
+
get_const_int
(
kernel_before_dilation
.
shape
[
2
])
-
1
)
\
//
get_const_int
(
kernel_before_dilation
.
shape
[
2
])
cfg
.
add_flop
(
2
*
N
*
OH
*
OW
*
CO
*
CI
*
((
KH
-
1
)
*
dilation_h
+
1
)
*
\
((
KW
-
1
)
*
dilation_w
+
1
))
return
miopen
.
conv2d_forward
(
data
,
kernel_before_dilation
,
...
...
@@ -81,7 +64,7 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='f
dilation_w
,
conv_mode
=
0
)
return
conv2d_cuda
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
)
return
conv2d_cuda
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
)
@autotvm.register_topi_schedule
(
generic
.
schedule_conv2d_nchw
,
'rocm'
,
[
"direct"
,
'winograd'
])
...
...
topi/python/topi/x86/conv2d.py
View file @
2005f852
...
...
@@ -8,6 +8,7 @@ from .. import generic, tag
from
..
import
nn
from
..util
import
get_const_tuple
from
..nn.conv2d
import
conv2d
,
conv2d_NCHWc
,
conv2d_alter_layout
,
_get_workload
from
..nn.dilate
import
dilate
from
..nn.pad
import
pad
from
.
import
conv2d_avx_1x1
,
conv2d_avx_common
...
...
@@ -38,7 +39,7 @@ def _get_default_config(cfg, workload):
conv2d_avx_common
.
_fallback_schedule
(
cfg
,
workload
,
fp32_vec_len
)
def
_create_tuning_space
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
):
def
_create_tuning_space
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
):
"""Create schedule configuration from input arguments"""
dshape
=
get_const_tuple
(
data
.
shape
)
kshape
=
get_const_tuple
(
kernel
.
shape
)
...
...
@@ -65,28 +66,39 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, layout):
@autotvm.register_topi_compute
(
conv2d
,
'cpu'
,
'direct'
)
def
_declaration_conv
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
):
def
_declaration_conv
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
):
out_dtype
=
data
.
dtype
if
out_dtype
is
None
else
out_dtype
padding
=
padding
if
isinstance
(
padding
,
(
tuple
,
list
))
else
(
padding
,
padding
)
strides
=
strides
if
isinstance
(
strides
,
(
tuple
,
list
))
else
(
strides
,
strides
)
dilation
=
dilation
if
isinstance
(
dilation
,
(
tuple
,
list
))
else
(
dilation
,
dilation
)
if
layout
==
'NCHW'
:
_create_tuning_space
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
)
_create_tuning_space
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
)
if
cfg
.
is_fallback
:
wkl
=
_get_workload
(
data
,
kernel
,
strides
,
padding
,
out_dtype
)
_get_default_config
(
cfg
,
wkl
)
return
_declaration_conv_impl
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
)
return
_declaration_conv_impl
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
)
elif
layout
==
'HWCN'
:
return
nn
.
conv2d_hwcn
(
data
,
kernel
,
strides
,
padding
,
out_dtype
)
return
nn
.
conv2d_hwcn
(
data
,
kernel
,
strides
,
padding
,
dilation
,
out_dtype
)
elif
layout
==
'NHWC'
:
return
nn
.
conv2d_nhwc
(
data
,
kernel
,
strides
,
padding
,
out_dtype
)
return
nn
.
conv2d_nhwc
(
data
,
kernel
,
strides
,
padding
,
dilation
,
out_dtype
)
else
:
raise
ValueError
(
"not support this layout {} yet"
.
format
(
layout
))
def
_declaration_conv_impl
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_dtype
):
def
_declaration_conv_impl
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
):
out_dtype
=
data
.
dtype
if
out_dtype
is
None
else
out_dtype
assert
layout
==
'NCHW'
,
"only support NCHW convolution for AVX"
assert
isinstance
(
dilation
,
int
)
or
len
(
dilation
)
==
2
if
isinstance
(
dilation
,
int
):
dilation_h
,
dilation_w
=
dilation
else
:
dilation_h
,
dilation_w
=
dilation
if
dilation_h
!=
1
or
dilation_w
!=
1
:
kernel
=
dilate
(
kernel
,
(
1
,
1
,
dilation_h
,
dilation_w
))
HPAD
,
WPAD
=
padding
HSTR
,
WSTR
=
strides
...
...
@@ -251,13 +263,13 @@ def schedule_conv2d_nhwc(outs):
@autotvm.task.register
(
"topi_x86_conv2d_NCHWc"
)
def
_topi_nn_conv2d_NCHWc
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
data
,
kernel
,
strides
,
padding
,
origin_layout
,
dtype
=
deserialize_args
(
args
)
data
,
kernel
,
strides
,
padding
,
dilation
,
origin_layout
,
dtype
=
deserialize_args
(
args
)
raw_data_shape
=
get_const_tuple
(
data
.
shape
)
raw_kernel_shape
=
get_const_tuple
(
kernel
.
shape
)
# get config here
cfg
=
get_config
()
_create_tuning_space
(
cfg
,
data
,
kernel
,
strides
,
padding
,
origin_layout
)
_create_tuning_space
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
origin_layout
)
# change shape with the value in config
ic_bn
,
oc_bn
,
ow_bn
=
(
cfg
[
"tile_ic"
]
.
size
[
-
1
],
cfg
[
"tile_oc"
]
.
size
[
-
1
],
...
...
@@ -271,7 +283,7 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
new_data
=
tvm
.
placeholder
(
new_data_shape
,
data
.
dtype
)
new_kernel
=
tvm
.
placeholder
(
new_kernel_shape
,
kernel
.
dtype
)
C
=
_declaration_conv_NCHWc
(
cfg
,
new_data
,
new_kernel
,
strides
,
padding
,
C
=
_declaration_conv_NCHWc
(
cfg
,
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
data_layout
,
out_layout
,
dtype
)
s
=
_schedule_conv2d_NCHWc
(
cfg
,
[
C
])
return
s
,
[
new_data
,
new_kernel
,
C
]
...
...
@@ -326,11 +338,13 @@ def _alter_conv2d_layout(attrs, inputs, tinfo):
@autotvm.register_topi_compute
(
conv2d_NCHWc
,
'cpu'
,
'direct'
)
def
_declaration_conv_NCHWc
(
cfg
,
data
,
kernel
,
strides
,
padding
,
layout
,
out_layout
,
out_dtype
):
padding
,
dilation
,
layout
,
out_layout
,
out_dtype
):
# layout and out_layout are not used here,
# we keep them for debug convenience when dumping autotvm workload
HPAD
,
WPAD
=
padding
if
isinstance
(
padding
,
(
tuple
,
list
))
else
(
padding
,
padding
)
HSTR
,
WSTR
=
strides
if
isinstance
(
strides
,
(
tuple
,
list
))
else
(
strides
,
strides
)
dh
,
dw
=
dilation
if
isinstance
(
dilation
,
(
tuple
,
list
))
else
(
dilation
,
dilation
)
assert
(
dh
,
dw
)
==
(
1
,
1
),
"Does not support dilation"
n
,
ic_chunk
,
ih
,
iw
,
ic_bn
=
get_const_tuple
(
data
.
shape
)
in_channel
=
ic_chunk
*
ic_bn
...
...
topi/tests/python/test_topi_conv2d_hwcn.py
View file @
2005f852
...
...
@@ -13,8 +13,7 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
A
=
tvm
.
placeholder
((
in_height
,
in_width
,
in_channel
,
batch
),
name
=
'A'
)
W
=
tvm
.
placeholder
((
kernel
,
kernel
,
in_channel
,
num_filter
),
name
=
'W'
)
dW
=
topi
.
nn
.
dilate
(
W
,
(
dilation
,
dilation
,
1
,
1
))
B
=
topi
.
nn
.
conv2d_hwcn
(
A
,
dW
,
stride
,
padding
)
B
=
topi
.
nn
.
conv2d_hwcn
(
A
,
W
,
stride
,
padding
,
dilation
)
C
=
topi
.
nn
.
relu
(
B
)
s1
=
topi
.
cuda
.
schedule_conv2d_hwcn
([
B
])
s2
=
topi
.
cuda
.
schedule_conv2d_hwcn
([
C
])
...
...
topi/tests/python/test_topi_conv2d_int8.py
View file @
2005f852
...
...
@@ -15,7 +15,7 @@ oc_block_factor = 4
def
verify_conv2d_NCHWc_int8
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
=
1
,
add_bias
=
False
,
add_relu
=
False
):
print
(
"Workload: (
%
d,
%
d,
%
d,
%
d,
%
d,
%
d,
%
d
)"
%
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
))
print
(
"Workload: (
%
d,
%
d,
%
d,
%
d,
%
d,
%
d,
%
d
,
%
d)"
%
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
))
in_height
=
in_width
=
in_size
...
...
@@ -63,8 +63,7 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
dW
=
topi
.
nn
.
dilate
(
W
,
(
1
,
1
,
dilation
,
dilation
))
C
=
topi
.
nn
.
conv2d
(
A
,
dW
,
(
stride
,
stride
),
(
padding
,
padding
),
C
=
topi
.
nn
.
conv2d
(
A
,
W
,
(
stride
,
stride
),
(
padding
,
padding
),
(
dilation
,
dilation
),
layout
=
'NCHW'
,
out_dtype
=
dtype
)
if
add_bias
:
C
=
topi
.
add
(
C
,
bias
)
...
...
topi/tests/python/test_topi_conv2d_nchw.py
View file @
2005f852
...
...
@@ -11,7 +11,7 @@ from topi.util import get_const_tuple
from
common
import
get_all_backend
def
verify_conv2d_nchw
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
=
1
,
add_bias
=
False
,
add_relu
=
False
):
print
(
"Workload: (
%
d,
%
d,
%
d,
%
d,
%
d,
%
d,
%
d
)"
%
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
))
print
(
"Workload: (
%
d,
%
d,
%
d,
%
d,
%
d,
%
d,
%
d
,
%
d)"
%
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
))
in_height
=
in_width
=
in_size
...
...
@@ -47,9 +47,8 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
dW
=
topi
.
nn
.
dilate
(
W
,
(
1
,
1
,
dilation
,
dilation
))
C
=
topi
.
nn
.
conv2d
(
A
,
dW
,
(
stride
,
stride
),
(
padding
,
padding
),
layout
=
'NCHW'
,
out_dtype
=
dtype
)
C
=
topi
.
nn
.
conv2d
(
A
,
W
,
(
stride
,
stride
),
(
padding
,
padding
),
(
dilation
,
dilation
),
layout
=
'NCHW'
,
out_dtype
=
dtype
)
if
add_bias
:
C
=
topi
.
add
(
C
,
bias
)
if
add_relu
:
...
...
topi/tests/python/test_topi_conv2d_nhwc.py
View file @
2005f852
...
...
@@ -13,18 +13,17 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p
A
=
tvm
.
placeholder
((
batch
,
in_height
,
in_width
,
in_channel
),
name
=
'A'
)
W
=
tvm
.
placeholder
((
kernel
,
kernel
,
in_channel
,
num_filter
),
name
=
'W'
)
dW
=
topi
.
nn
.
dilate
(
W
,
(
1
,
dilation
,
dilation
,
1
))
B
=
topi
.
nn
.
conv2d_nhwc
(
A
,
dW
,
stride
,
padding
)
B
=
topi
.
nn
.
conv2d_nhwc
(
A
,
W
,
stride
,
padding
,
dilation
)
a_shape
=
get_const_tuple
(
A
.
shape
)
w_shape
=
get_const_tuple
(
W
.
shape
)
dtype
=
A
.
dtype
@memoize
(
"topi.tests.test_topi_conv2d_nhwc.verify_nhwc"
)
@memoize
(
"topi.tests.test_topi_conv2d_nhwc.verify_nhwc
.v2
"
)
def
get_ref_data
():
a_np
=
np
.
random
.
uniform
(
size
=
a_shape
)
.
astype
(
dtype
)
w_np
=
np
.
random
.
uniform
(
size
=
w_shape
)
.
astype
(
dtype
)
dw_np
=
topi
.
testing
.
dilate_python
(
w_np
,
(
1
,
dilation
,
dilation
,
1
))
dw_np
=
topi
.
testing
.
dilate_python
(
w_np
,
(
dilation
,
dilation
,
1
,
1
))
b_np
=
topi
.
testing
.
conv2d_nhwc_python
(
a_np
,
dw_np
,
stride
,
padding
)
return
a_np
,
w_np
,
b_np
a_np
,
w_np
,
b_np
=
get_ref_data
()
...
...
topi/tests/python/test_topi_conv2d_winograd.py
View file @
2005f852
...
...
@@ -11,7 +11,7 @@ from topi.util import get_const_tuple
def
verify_conv2d_nchw
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
=
1
,
add_bias
=
False
,
add_relu
=
False
):
print
(
"Workload: (
%
d,
%
d,
%
d,
%
d,
%
d,
%
d,
%
d
)"
%
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
))
print
(
"Workload: (
%
d,
%
d,
%
d,
%
d,
%
d,
%
d,
%
d
,
%
d)"
%
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
,
dilation
))
in_height
=
in_width
=
in_size
...
...
@@ -47,8 +47,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
dW
=
topi
.
nn
.
dilate
(
W
,
(
1
,
1
,
dilation
,
dilation
))
C
=
topi
.
nn
.
conv2d
(
A
,
dW
,
stride
,
padding
,
layout
=
'NCHW'
,
out_dtype
=
dtype
)
C
=
topi
.
nn
.
conv2d
(
A
,
W
,
stride
,
padding
,
dilation
,
layout
=
'NCHW'
,
out_dtype
=
dtype
)
if
add_bias
:
C
=
topi
.
add
(
C
,
bias
)
if
add_relu
:
...
...
topi/tests/python/test_topi_depthwise_conv2d.py
View file @
2005f852
...
...
@@ -26,7 +26,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
# placeholder
Input
=
tvm
.
placeholder
((
batch
,
in_channel
,
in_height
,
in_width
),
name
=
'Input'
)
Filter
=
tvm
.
placeholder
((
filter_channel
,
channel_multiplier
,
filter_height
,
filter_width
),
name
=
'Filter'
)
DilatedFilter
=
topi
.
nn
.
dilate
(
Filter
,
(
1
,
1
,
dilation
,
dilation
),
name
=
'DilatedFilter'
)
Scale
=
tvm
.
placeholder
((
in_channel
*
channel_multiplier
,),
name
=
'Scale'
)
Shift
=
tvm
.
placeholder
((
in_channel
*
channel_multiplier
,),
name
=
'Shift'
)
...
...
@@ -40,8 +39,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
# declare
DepthwiseConv2d
=
topi
.
nn
.
depthwise_conv2d_nchw
(
Input
,
Dilated
Filter
,
(
stride_h
,
stride_w
),
padding_args
,
dtype
)
DepthwiseConv2d
=
topi
.
nn
.
depthwise_conv2d_nchw
(
Input
,
Filter
,
(
stride_h
,
stride_w
),
padding_args
,
d
ilation
,
d
type
)
ScaleShift
=
topi
.
nn
.
scale_shift_nchw
(
DepthwiseConv2d
,
Scale
,
Shift
)
Relu
=
topi
.
nn
.
relu
(
ScaleShift
)
# schedule
...
...
@@ -123,7 +122,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
# placeholder
Input
=
tvm
.
placeholder
((
batch
,
in_height
,
in_width
,
in_channel
),
name
=
'Input'
)
Filter
=
tvm
.
placeholder
((
filter_height
,
filter_width
,
filter_channel
,
channel_multiplier
),
name
=
'Filter'
)
DilatedFilter
=
topi
.
nn
.
dilate
(
Filter
,
(
1
,
1
,
dilation
,
dilation
),
name
=
'DilatedFilter'
)
Scale
=
tvm
.
placeholder
((
in_channel
*
channel_multiplier
,),
name
=
'Scale'
)
Shift
=
tvm
.
placeholder
((
in_channel
*
channel_multiplier
,),
name
=
'Shift'
)
...
...
@@ -138,8 +136,8 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
with
tvm
.
target
.
create
(
device
):
# declare
DepthwiseConv2d
=
topi
.
nn
.
depthwise_conv2d_nhwc
(
Input
,
Dilated
Filter
,
(
stride_h
,
stride_w
),
padding_args
,
dtype
)
DepthwiseConv2d
=
topi
.
nn
.
depthwise_conv2d_nhwc
(
Input
,
Filter
,
(
stride_h
,
stride_w
),
padding_args
,
d
ilation
,
d
type
)
ScaleShift
=
topi
.
nn
.
scale_shift_nhwc
(
DepthwiseConv2d
,
Scale
,
Shift
)
Relu
=
topi
.
nn
.
relu
(
ScaleShift
)
# schedule
...
...
@@ -159,11 +157,11 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
scale_shift_shape
=
get_const_tuple
(
ScaleShift
.
shape
)
# Use memoize, pickle the test data for next time use.
@memoize
(
"topi.tests.test_topi_depthwise_conv2d.nhwc"
)
@memoize
(
"topi.tests.test_topi_depthwise_conv2d.nhwc
.v2
"
)
def
get_ref_data
():
input_np
=
np
.
random
.
uniform
(
size
=
input_shape
)
.
astype
(
dtype
)
filter_np
=
np
.
random
.
uniform
(
size
=
filter_shape
)
.
astype
(
dtype
)
dilated_filter_np
=
topi
.
testing
.
dilate_python
(
filter_np
,
(
1
,
1
,
dilation
,
dilation
))
dilated_filter_np
=
topi
.
testing
.
dilate_python
(
filter_np
,
(
dilation
,
dilation
,
1
,
1
))
scale_np
=
np
.
random
.
uniform
(
size
=
scale_shape
)
.
astype
(
dtype
)
shift_np
=
np
.
random
.
uniform
(
size
=
shift_shape
)
.
astype
(
dtype
)
# correctness with scipy
...
...
@@ -232,7 +230,8 @@ def test_depthwise_conv2d():
depthwise_conv2d_with_workload_nhwc
(
1
,
728
,
32
,
1
,
3
,
1
,
"VALID"
)
depthwise_conv2d_with_workload_nhwc
(
4
,
256
,
64
,
2
,
5
,
2
,
"VALID"
)
# dilation = 2
depthwise_conv2d_with_workload_nhwc
(
1
,
728
,
64
,
1
,
3
,
1
,
"SAME"
,
dilation
=
2
)
# disabled because it uses too large shared memory on cuda
# depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2)
if
__name__
==
"__main__"
:
test_depthwise_conv2d
()
tutorials/autotvm/tune_conv2d_cuda.py
View file @
2005f852
...
...
@@ -68,7 +68,7 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
data
=
tvm
.
placeholder
((
N
,
CI
,
H
,
W
),
name
=
'data'
)
kernel
=
tvm
.
placeholder
((
CO
,
CI
,
KH
,
KW
),
name
=
'kernel'
)
conv
=
topi
.
nn
.
conv2d_nchw
(
data
,
kernel
,
stride
,
padding
,
'float32'
)
conv
=
topi
.
nn
.
conv2d_nchw
(
data
,
kernel
,
stride
,
padding
,
dilation
=
1
,
out_dtype
=
'float32'
)
s
=
tvm
.
create_schedule
([
conv
.
op
])
##### space definition begin #####
...
...
tutorials/topi/intro_topi.py
View file @
2005f852
...
...
@@ -117,7 +117,7 @@ data = tvm.placeholder((1, 3, 224, 224))
kernel
=
tvm
.
placeholder
((
10
,
3
,
5
,
5
))
with
tvm
.
target
.
create
(
"cuda"
):
conv
=
topi
.
nn
.
conv2d
(
data
,
kernel
,
strides
=
1
,
padding
=
2
)
conv
=
topi
.
nn
.
conv2d
(
data
,
kernel
,
strides
=
1
,
padding
=
2
,
dilation
=
1
)
out
=
topi
.
nn
.
relu
(
conv
)
sconv
=
topi
.
generic
.
nn
.
schedule_conv2d_nchw
(
out
)
print
(
tvm
.
lower
(
sconv
,
[
data
,
kernel
],
simple_mode
=
True
))
...
...
vta/tests/python/integration/test_benchmark_topi_conv2d.py
View file @
2005f852
...
...
@@ -33,6 +33,7 @@ def test_cpu_conv2d():
res_conv
=
topi
.
nn
.
conv2d
(
data
,
kernel
,
padding
=
(
wl
.
hpad
,
wl
.
wpad
),
strides
=
(
wl
.
hstride
,
wl
.
wstride
),
dilation
=
(
1
,
1
),
out_dtype
=
"int32"
)
res
=
topi
.
right_shift
(
res_conv
,
8
)
res
=
my_clip
(
res
,
0
,
127
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment