Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9049d669
Commit
9049d669
authored
Nov 22, 2019
by
Alexander Pivovarov
Committed by
Yao Wang
Nov 22, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Legalize] Legalize conv2d_transpose for NHWC (#4399)
parent
87bd799e
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
165 additions
and
5 deletions
+165
-5
python/tvm/relay/op/nn/_nn.py
+20
-0
python/tvm/relay/op/op_attrs.py
+5
-0
tests/python/relay/test_op_level2.py
+32
-4
topi/python/topi/nn/conv2d_transpose.py
+60
-0
topi/python/topi/testing/__init__.py
+1
-1
topi/python/topi/testing/conv2d_transpose_python.py
+47
-0
No files found.
python/tvm/relay/op/nn/_nn.py
View file @
9049d669
...
@@ -278,6 +278,26 @@ def schedule_conv2d_transpose(attrs, outs, target):
...
@@ -278,6 +278,26 @@ def schedule_conv2d_transpose(attrs, outs, target):
return
topi
.
generic
.
schedule_conv2d_transpose_nchw
(
outs
)
return
topi
.
generic
.
schedule_conv2d_transpose_nchw
(
outs
)
@reg.register_legalize
(
"nn.conv2d_transpose"
)
def
legalize_conv2d_transpose
(
attrs
,
inputs
,
types
):
"""Legalize conv2d_transpose op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current Transposed convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return
topi
.
nn
.
conv2d_transpose_legalize
(
attrs
,
inputs
,
types
)
reg
.
register_pattern
(
"nn.conv2d_transpose"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
reg
.
register_pattern
(
"nn.conv2d_transpose"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
# bias_add
# bias_add
...
...
python/tvm/relay/op/op_attrs.py
View file @
9049d669
...
@@ -284,3 +284,8 @@ class BinaryConv2DAttrs(Attrs):
...
@@ -284,3 +284,8 @@ class BinaryConv2DAttrs(Attrs):
@register_relay_attr_node
@register_relay_attr_node
class
BinaryDenseAttrs
(
Attrs
):
class
BinaryDenseAttrs
(
Attrs
):
"""Attributes used in bitserial dense operators"""
"""Attributes used in bitserial dense operators"""
@register_relay_attr_node
class
Conv2DTransposeAttrs
(
Attrs
):
"""Attributes used in Transposed Conv2D operators"""
tests/python/relay/test_op_level2.py
View file @
9049d669
...
@@ -311,8 +311,8 @@ def test_conv2d_transpose_infer_type():
...
@@ -311,8 +311,8 @@ def test_conv2d_transpose_infer_type():
(
10
,
15
,
3
,
3
),
"float32"
)
(
10
,
15
,
3
,
3
),
"float32"
)
# infer by shape of w, mixed precision
# infer by shape of w, mixed precision
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
10
,
10
,
12
n
,
h
,
w
,
c
=
tvm
.
var
(
"n"
),
10
,
10
,
12
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
h
,
w
,
c
),
"float32"
))
w
=
relay
.
var
(
"w"
,
relay
.
TensorType
((
12
,
11
,
5
,
5
),
"float32"
))
w
=
relay
.
var
(
"w"
,
relay
.
TensorType
((
12
,
11
,
5
,
5
),
"float32"
))
y
=
relay
.
nn
.
conv2d_transpose
(
x
,
w
,
y
=
relay
.
nn
.
conv2d_transpose
(
x
,
w
,
output_padding
=
(
1
,
1
),
output_padding
=
(
1
,
1
),
...
@@ -323,7 +323,7 @@ def test_conv2d_transpose_infer_type():
...
@@ -323,7 +323,7 @@ def test_conv2d_transpose_infer_type():
(
n
,
15
,
15
,
11
),
"float32"
)
(
n
,
15
,
15
,
11
),
"float32"
)
def
test_conv2d_transpose_run
():
def
test_conv2d_transpose_
nchw_
run
():
dshape
=
(
1
,
3
,
18
,
18
)
dshape
=
(
1
,
3
,
18
,
18
)
kshape
=
(
3
,
10
,
3
,
3
)
kshape
=
(
3
,
10
,
3
,
3
)
oshape
=
(
1
,
10
,
37
,
37
)
oshape
=
(
1
,
10
,
37
,
37
)
...
@@ -348,6 +348,33 @@ def test_conv2d_transpose_run():
...
@@ -348,6 +348,33 @@ def test_conv2d_transpose_run():
tvm
.
testing
.
assert_allclose
(
op_res1
.
asnumpy
(),
ref_res
,
rtol
=
1e-5
,
atol
=
1e-5
)
tvm
.
testing
.
assert_allclose
(
op_res1
.
asnumpy
(),
ref_res
,
rtol
=
1e-5
,
atol
=
1e-5
)
def
test_conv2d_transpose_nhwc_run
():
dshape_nhwc
=
(
1
,
18
,
18
,
3
)
kshape_hwoi
=
(
3
,
3
,
10
,
3
)
oshape_nhwc
=
(
1
,
37
,
37
,
10
)
x
=
relay
.
var
(
"x"
,
shape
=
dshape_nhwc
)
w
=
relay
.
var
(
"w"
)
# kshape and kernel_layout should have swapped IO.
# kshape is HWOI and kernel_layout is HWIO
y
=
relay
.
nn
.
conv2d_transpose
(
x
,
w
,
channels
=
10
,
kernel_size
=
(
3
,
3
),
strides
=
(
2
,
2
),
padding
=
(
1
,
1
),
output_padding
=
(
2
,
2
),
data_layout
=
"NHWC"
,
kernel_layout
=
"HWIO"
)
func
=
relay
.
Function
([
x
,
w
],
y
)
dtype
=
"float32"
data
=
np
.
random
.
uniform
(
size
=
dshape_nhwc
)
.
astype
(
dtype
)
kernel
=
np
.
random
.
uniform
(
size
=
kshape_hwoi
)
.
astype
(
dtype
)
# use true kshape layout here - HWOI
c_np
=
topi
.
testing
.
conv2d_transpose_nhwc_python
(
data
,
kernel
,
'HWOI'
,
2
,
1
)
d_np
=
np
.
zeros
(
shape
=
oshape_nhwc
)
d_np
[:,
0
:
c_np
.
shape
[
1
],
0
:
c_np
.
shape
[
2
],:]
=
c_np
ref_res
=
d_np
for
target
,
ctx
in
ctx_list
():
intrp1
=
relay
.
create_executor
(
"graph"
,
ctx
=
ctx
,
target
=
target
)
op_res1
=
intrp1
.
evaluate
(
func
)(
data
,
kernel
)
tvm
.
testing
.
assert_allclose
(
op_res1
.
asnumpy
(),
ref_res
,
rtol
=
1e-5
,
atol
=
1e-5
)
def
test_upsampling_infer_type
():
def
test_upsampling_infer_type
():
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
...
@@ -819,7 +846,8 @@ if __name__ == "__main__":
...
@@ -819,7 +846,8 @@ if __name__ == "__main__":
test_pad_infer_type
()
test_pad_infer_type
()
test_pad_run
()
test_pad_run
()
test_conv2d_transpose_infer_type
()
test_conv2d_transpose_infer_type
()
test_conv2d_transpose_run
()
test_conv2d_transpose_nchw_run
()
test_conv2d_transpose_nhwc_run
()
test_conv2d_run
()
test_conv2d_run
()
test_conv2d_winograd
()
test_conv2d_winograd
()
test_bitserial_conv2d_infer_type
()
test_bitserial_conv2d_infer_type
()
...
...
topi/python/topi/nn/conv2d_transpose.py
View file @
9049d669
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
"""Transposed 2D convolution operators (sometimes called Deconvolution)."""
"""Transposed 2D convolution operators (sometimes called Deconvolution)."""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
tvm
import
tvm
from
tvm
import
relay
from
.dilate
import
dilate
from
.dilate
import
dilate
from
.pad
import
pad
from
.pad
import
pad
from
.util
import
get_pad_tuple
from
.util
import
get_pad_tuple
...
@@ -102,3 +103,62 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype)
...
@@ -102,3 +103,62 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype)
axis
=
[
dc
,
dh
,
dw
]),
tag
=
"conv2d_transpose_nchw"
)
axis
=
[
dc
,
dh
,
dw
]),
tag
=
"conv2d_transpose_nchw"
)
return
Output
return
Output
@tvm.target.generic_func
def
conv2d_transpose_legalize
(
attrs
,
inputs
,
types
):
"""Legalizes Transposed 2D convolution op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current Transposed 2D convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
if
attrs
[
'data_layout'
]
==
'NHWC'
:
data
,
kernel
=
inputs
kernel_layout
=
attrs
[
'kernel_layout'
]
# Convert Kernel layout to IOHW
# kernel_layout is different from input kernel layout - IO is swapped
if
kernel_layout
==
'HWIO'
:
# input kernel layout is swapped to HWOI
# output kernel layout will be IOHW
kernel
=
relay
.
transpose
(
kernel
,
axes
=
(
3
,
2
,
0
,
1
))
elif
kernel_layout
==
'HWOI'
:
# input kernel layout is swapped to HWIO
# output kernel layout will be IOHW
kernel
=
relay
.
transpose
(
kernel
,
axes
=
(
2
,
3
,
0
,
1
))
elif
kernel_layout
==
'IOHW'
:
# input kernel layout is swapped to OIHW
# output kernel layout will be IOHW
kernel
=
relay
.
transpose
(
kernel
,
axes
=
(
1
,
0
,
2
,
3
))
elif
kernel_layout
==
'OIHW'
:
# input kernel layout is swapped to IOHW
# output kernel layout will be IOHW
pass
else
:
# Skip legalize. Let relay.nn.conv2d_transpose to handle the case
return
None
# Set new attrs for conv2d_transpose.
new_attrs
=
{
k
:
attrs
[
k
]
for
k
in
attrs
.
keys
()}
new_attrs
[
'data_layout'
]
=
'NCHW'
# layout of kernel should be IOHW, but kernel_layout should be swapped - OIHW
new_attrs
[
'kernel_layout'
]
=
'OIHW'
# Convert data to NCHW.
data
=
relay
.
transpose
(
data
,
axes
=
(
0
,
3
,
1
,
2
))
deconv
=
relay
.
nn
.
conv2d_transpose
(
data
,
kernel
,
**
new_attrs
)
# Convert back to original NHWC layout.
out
=
relay
.
transpose
(
deconv
,
axes
=
(
0
,
2
,
3
,
1
))
return
out
return
None
topi/python/topi/testing/__init__.py
View file @
9049d669
...
@@ -24,7 +24,7 @@ from __future__ import absolute_import as _abs
...
@@ -24,7 +24,7 @@ from __future__ import absolute_import as _abs
from
.conv2d_hwcn_python
import
conv2d_hwcn_python
from
.conv2d_hwcn_python
import
conv2d_hwcn_python
from
.conv2d_nchw_python
import
conv2d_nchw_python
from
.conv2d_nchw_python
import
conv2d_nchw_python
from
.conv2d_nhwc_python
import
conv2d_nhwc_python
from
.conv2d_nhwc_python
import
conv2d_nhwc_python
from
.conv2d_transpose_
nchw_python
import
conv2d_transpose_nchw
_python
from
.conv2d_transpose_
python
import
conv2d_transpose_nchw_python
,
conv2d_transpose_nhwc
_python
from
.deformable_conv2d_nchw_python
import
deformable_conv2d_nchw_python
from
.deformable_conv2d_nchw_python
import
deformable_conv2d_nchw_python
from
.depthwise_conv2d_python
import
depthwise_conv2d_python_nchw
,
depthwise_conv2d_python_nhwc
from
.depthwise_conv2d_python
import
depthwise_conv2d_python_nchw
,
depthwise_conv2d_python_nhwc
from
.dilate_python
import
dilate_python
from
.dilate_python
import
dilate_python
...
...
topi/python/topi/testing/conv2d_transpose_
nchw_
python.py
→
topi/python/topi/testing/conv2d_transpose_python.py
View file @
9049d669
...
@@ -73,3 +73,50 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
...
@@ -73,3 +73,50 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
padded_a_np
[
n
,
c
],
w_np
[
c
,
f
],
mode
=
'valid'
)
padded_a_np
[
n
,
c
],
w_np
[
c
,
f
],
mode
=
'valid'
)
b_np
[
n
,
f
]
+=
out
b_np
[
n
,
f
]
+=
out
return
b_np
return
b_np
def
conv2d_transpose_nhwc_python
(
a_nhwc
,
weight
,
weight_format
,
stride
,
padding
):
"""Transposed convolution operator in NHWC layout.
Parameters
----------
a_nhwc : numpy.ndarray
4-D with shape [batch, in_height, in_width, in_channel]
weight : numpy.ndarray
4-D in formats HWIO, HWOI, OIHW or IOHW
weight_format : str
['HWIO', 'HWOI', 'OIHW', 'IOHW']
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
assert
a_nhwc
.
ndim
==
4
,
"a_nhwc number of dimensions should be 4"
assert
weight
.
ndim
==
4
,
"weight number of dimensions should be 4"
a_nchw
=
np
.
transpose
(
a_nhwc
,
(
0
,
3
,
1
,
2
))
# conv2d_transpose_nchw_python needs kernel layout to be IOHW
if
weight_format
==
'HWIO'
:
w_iohw
=
np
.
transpose
(
weight
,
(
2
,
3
,
0
,
1
))
elif
weight_format
==
'HWOI'
:
w_iohw
=
np
.
transpose
(
weight
,
(
3
,
2
,
0
,
1
))
elif
weight_format
==
'OIHW'
:
w_iohw
=
np
.
transpose
(
weight
,
(
1
,
0
,
2
,
3
))
elif
weight_format
==
'IOHW'
:
w_iohw
=
weight
else
:
raise
ValueError
(
'Valid weight_formats are HWIO, HWOI, OIHW or IOHW'
)
res_nchw
=
conv2d_transpose_nchw_python
(
a_nchw
,
w_iohw
,
stride
,
padding
)
res_nhwc
=
np
.
transpose
(
res_nchw
,
(
0
,
2
,
3
,
1
))
return
res_nhwc
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment