Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d430fbb5
Commit
d430fbb5
authored
Dec 18, 2019
by
Alex Gladkov
Committed by
Wuwei Lin
Dec 18, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Implement 1d deconvolution (#4476)
parent
93843536
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
853 additions
and
14 deletions
+853
-14
include/tvm/relay/attrs/nn.h
+58
-0
python/tvm/autotvm/task/relay_integration.py
+1
-0
python/tvm/autotvm/task/topi_integration.py
+12
-0
python/tvm/relay/_parser.py
+1
-0
python/tvm/relay/frontend/mxnet.py
+7
-13
python/tvm/relay/op/nn/_nn.py
+31
-0
python/tvm/relay/op/nn/nn.py
+66
-0
src/relay/op/nn/convolution.cc
+157
-0
src/relay/op/op_common.h
+12
-0
tests/python/relay/test_op_level2.py
+20
-0
topi/python/topi/cuda/__init__.py
+1
-1
topi/python/topi/cuda/conv1d_transpose_ncw.py
+187
-0
topi/python/topi/generic/nn.py
+18
-0
topi/python/topi/nn/__init__.py
+1
-0
topi/python/topi/nn/conv1d_transpose.py
+83
-0
topi/python/topi/nn/util.py
+39
-0
topi/python/topi/testing/__init__.py
+1
-0
topi/python/topi/testing/conv1d_transpose_ncw_python.py
+71
-0
topi/tests/python/test_topi_conv1d_transpose_ncw.py
+87
-0
No files found.
include/tvm/relay/attrs/nn.h
View file @
d430fbb5
...
@@ -315,6 +315,64 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
...
@@ -315,6 +315,64 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
}
}
};
};
/*! \brief Attributes used in 1D transposed convolution operator */
struct
Conv1DTransposeAttrs
:
public
tvm
::
AttrsNode
<
Conv1DTransposeAttrs
>
{
IndexExpr
channels
;
Array
<
IndexExpr
>
kernel_size
;
Array
<
IndexExpr
>
strides
;
Array
<
IndexExpr
>
padding
;
Array
<
IndexExpr
>
output_padding
;
Array
<
IndexExpr
>
dilation
;
int
groups
;
std
::
string
data_layout
;
std
::
string
kernel_layout
;
std
::
string
out_layout
;
DataType
out_dtype
;
TVM_DECLARE_ATTRS
(
Conv1DTransposeAttrs
,
"relay.attrs.Conv1DTransposeAttrs"
)
{
TVM_ATTR_FIELD
(
channels
)
.
set_default
(
NullValue
<
IndexExpr
>
())
.
describe
(
"The dimensionality of the output space"
"i.e. the number of output channels in the convolution."
);
TVM_ATTR_FIELD
(
kernel_size
)
.
describe
(
"The dimensions of the convolution window."
)
.
set_default
(
NullValue
<
Array
<
IndexExpr
>
>
());
TVM_ATTR_FIELD
(
strides
).
set_default
(
Array
<
IndexExpr
>
({
1
}))
.
describe
(
"The strides of the convolution."
);
TVM_ATTR_FIELD
(
output_padding
).
set_default
(
Array
<
IndexExpr
>
({
0
}))
.
describe
(
"Zero-padding added to one side of the output."
);
TVM_ATTR_FIELD
(
padding
).
set_default
(
Array
<
IndexExpr
>
({
0
}))
.
describe
(
"Symmetric or asymmetric padding."
"Single value: the input is implicitly zero-padded on both sides."
"Two values: padding[0] is used for left input padding, "
"padding[1] is used for right input padding,"
);
TVM_ATTR_FIELD
(
dilation
).
set_default
(
Array
<
IndexExpr
>
({
1
}))
.
describe
(
"Specifies the dilation rate to use for dilated convolution."
);
TVM_ATTR_FIELD
(
groups
).
set_default
(
1
)
.
describe
(
"Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated."
);
TVM_ATTR_FIELD
(
data_layout
).
set_default
(
"NCW"
)
.
describe
(
"Dimension ordering of data. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Convolution is applied on the"
"'W' dimension."
);
TVM_ATTR_FIELD
(
kernel_layout
).
set_default
(
"OIW"
)
.
describe
(
"Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc."
"'O', 'I', 'W' stands for num_filter, input_channel, and width"
"dimensions respectively."
);
TVM_ATTR_FIELD
(
out_layout
).
set_default
(
""
)
.
describe
(
"Dimension ordering of output. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Default to be same as input layout."
);
TVM_ATTR_FIELD
(
out_dtype
)
.
set_default
(
NullValue
<
DataType
>
())
.
describe
(
"Output data type, set to explicit type under mixed precision setting"
);
}
};
/*! \brief Attributes for max pool operator */
/*! \brief Attributes for max pool operator */
struct
MaxPool2DAttrs
:
public
tvm
::
AttrsNode
<
MaxPool2DAttrs
>
{
struct
MaxPool2DAttrs
:
public
tvm
::
AttrsNode
<
MaxPool2DAttrs
>
{
Array
<
IndexExpr
>
pool_size
;
Array
<
IndexExpr
>
pool_size
;
...
...
python/tvm/autotvm/task/relay_integration.py
View file @
d430fbb5
...
@@ -128,6 +128,7 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
...
@@ -128,6 +128,7 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
tvm
.
relay
.
op
.
nn
.
dense
:
[
topi
.
nn
.
dense
],
tvm
.
relay
.
op
.
nn
.
dense
:
[
topi
.
nn
.
dense
],
tvm
.
relay
.
op
.
nn
.
batch_matmul
:
[
topi
.
nn
.
batch_matmul
],
tvm
.
relay
.
op
.
nn
.
batch_matmul
:
[
topi
.
nn
.
batch_matmul
],
tvm
.
relay
.
op
.
nn
.
deformable_conv2d
:
[
topi
.
nn
.
deformable_conv2d_nchw
],
tvm
.
relay
.
op
.
nn
.
deformable_conv2d
:
[
topi
.
nn
.
deformable_conv2d_nchw
],
tvm
.
relay
.
op
.
nn
.
conv1d_transpose
:
[
topi
.
nn
.
conv1d_transpose_ncw
],
}
}
topi_funcs
=
[]
topi_funcs
=
[]
...
...
python/tvm/autotvm/task/topi_integration.py
View file @
d430fbb5
...
@@ -92,6 +92,7 @@ class TaskExtractEnv:
...
@@ -92,6 +92,7 @@ class TaskExtractEnv:
topi
.
nn
.
bitserial_conv2d_nhwc
:
"topi_nn_bitserial_conv2d_nhwc"
,
topi
.
nn
.
bitserial_conv2d_nhwc
:
"topi_nn_bitserial_conv2d_nhwc"
,
topi
.
nn
.
bitserial_dense
:
"topi_nn_bitserial_dense"
,
topi
.
nn
.
bitserial_dense
:
"topi_nn_bitserial_dense"
,
topi
.
nn
.
deformable_conv2d_nchw
:
"topi_nn_deformable_conv2d_nchw"
,
topi
.
nn
.
deformable_conv2d_nchw
:
"topi_nn_deformable_conv2d_nchw"
,
topi
.
nn
.
conv1d_transpose_ncw
:
"topi_nn_conv1d_transpose_ncw"
,
}
}
self
.
topi_to_schedule
=
{
self
.
topi_to_schedule
=
{
...
@@ -109,6 +110,7 @@ class TaskExtractEnv:
...
@@ -109,6 +110,7 @@ class TaskExtractEnv:
topi
.
nn
.
bitserial_conv2d_nhwc
:
[
topi
.
generic
.
schedule_bitserial_conv2d_nhwc
],
topi
.
nn
.
bitserial_conv2d_nhwc
:
[
topi
.
generic
.
schedule_bitserial_conv2d_nhwc
],
topi
.
nn
.
bitserial_dense
:
[
topi
.
generic
.
schedule_bitserial_dense
],
topi
.
nn
.
bitserial_dense
:
[
topi
.
generic
.
schedule_bitserial_dense
],
topi
.
nn
.
deformable_conv2d_nchw
:
[
topi
.
generic
.
schedule_deformable_conv2d_nchw
],
topi
.
nn
.
deformable_conv2d_nchw
:
[
topi
.
generic
.
schedule_deformable_conv2d_nchw
],
topi
.
nn
.
conv1d_transpose_ncw
:
[
topi
.
generic
.
schedule_conv1d_transpose_ncw
],
}
}
# function reflection for tracing
# function reflection for tracing
...
@@ -125,6 +127,7 @@ class TaskExtractEnv:
...
@@ -125,6 +127,7 @@ class TaskExtractEnv:
topi
.
nn
.
bitserial_conv2d_nhwc
:
lambda
x
:
setattr
(
topi
.
nn
,
'bitserial_conv2d_nhwc'
,
x
),
topi
.
nn
.
bitserial_conv2d_nhwc
:
lambda
x
:
setattr
(
topi
.
nn
,
'bitserial_conv2d_nhwc'
,
x
),
topi
.
nn
.
bitserial_dense
:
lambda
x
:
setattr
(
topi
.
nn
,
'bitserial_dense'
,
x
),
topi
.
nn
.
bitserial_dense
:
lambda
x
:
setattr
(
topi
.
nn
,
'bitserial_dense'
,
x
),
topi
.
nn
.
deformable_conv2d_nchw
:
lambda
x
:
setattr
(
topi
.
nn
,
'deformable_conv2d_nchw'
,
x
),
topi
.
nn
.
deformable_conv2d_nchw
:
lambda
x
:
setattr
(
topi
.
nn
,
'deformable_conv2d_nchw'
,
x
),
topi
.
nn
.
conv1d_transpose_ncw
:
lambda
x
:
setattr
(
topi
.
nn
,
'conv1d_transpose_ncw'
,
x
),
}
}
self
.
allow_duplicate
=
allow_duplicate
self
.
allow_duplicate
=
allow_duplicate
...
@@ -214,6 +217,15 @@ class TaskExtractEnv:
...
@@ -214,6 +217,15 @@ class TaskExtractEnv:
s
=
topi
.
generic
.
schedule_conv2d_transpose_nchw
([
C
])
s
=
topi
.
generic
.
schedule_conv2d_transpose_nchw
([
C
])
return
s
,
[
A
,
W
,
C
]
return
s
,
[
A
,
W
,
C
]
@register
(
"topi_nn_conv1d_transpose_ncw"
)
def
_topi_nn_conv1d_transpose_ncw
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
args
=
deserialize_args
(
args
)
A
,
W
=
args
[:
2
]
C
=
topi
.
nn
.
conv1d_transpose_ncw
(
*
args
,
**
kwargs
)
s
=
topi
.
generic
.
schedule_conv1d_transpose_ncw
([
C
])
return
s
,
[
A
,
W
,
C
]
@register
(
"topi_nn_dense"
)
@register
(
"topi_nn_dense"
)
def
_topi_nn_dense
(
*
args
,
**
kwargs
):
def
_topi_nn_dense
(
*
args
,
**
kwargs
):
assert
not
kwargs
,
"Do not support kwargs in template function call"
assert
not
kwargs
,
"Do not support kwargs in template function call"
...
...
python/tvm/relay/_parser.py
View file @
d430fbb5
...
@@ -141,6 +141,7 @@ FUNC_OPS = {
...
@@ -141,6 +141,7 @@ FUNC_OPS = {
"nn.softmax"
:
op
.
nn
.
softmax
,
"nn.softmax"
:
op
.
nn
.
softmax
,
"reshape"
:
op
.
reshape
,
"reshape"
:
op
.
reshape
,
"nn.conv2d_transpose"
:
op
.
nn
.
conv2d_transpose
,
"nn.conv2d_transpose"
:
op
.
nn
.
conv2d_transpose
,
"nn.conv1d_transpose"
:
op
.
nn
.
conv1d_transpose
,
"concatenate"
:
op
.
concatenate
,
"concatenate"
:
op
.
concatenate
,
"nn.dropout"
:
op
.
nn
.
dropout_raw
,
"nn.dropout"
:
op
.
nn
.
dropout_raw
,
"zeros"
:
op
.
zeros
,
"zeros"
:
op
.
zeros
,
...
...
python/tvm/relay/frontend/mxnet.py
View file @
d430fbb5
...
@@ -207,29 +207,23 @@ def _mx_conv1d_transpose(inputs, attrs):
...
@@ -207,29 +207,23 @@ def _mx_conv1d_transpose(inputs, attrs):
if
data_layout
!=
"NCW"
:
if
data_layout
!=
"NCW"
:
raise
tvm
.
error
.
OpAttributeInvalid
(
raise
tvm
.
error
.
OpAttributeInvalid
(
'Only "NCW" data layout is supported for 1D Convolution'
)
'Only "NCW" data layout is supported for 1D Convolution'
)
data_layout
=
"NCHW"
channel_axis
=
1
channel_axis
=
1
kernel_layout
=
"OIHW"
kernel_layout
=
"OIW"
new_attrs
=
{}
new_attrs
=
{}
new_attrs
[
"channels"
]
=
attrs
.
get_int
(
"num_filter"
)
new_attrs
[
"channels"
]
=
attrs
.
get_int
(
"num_filter"
)
new_attrs
[
"kernel_size"
]
=
(
1
,)
+
attrs
.
get_int_tuple
(
"kernel"
)
new_attrs
[
"kernel_size"
]
=
attrs
.
get_int_tuple
(
"kernel"
)
new_attrs
[
"strides"
]
=
(
1
,)
+
attrs
.
get_int_tuple
(
"stride"
,
(
1
,))
new_attrs
[
"strides"
]
=
attrs
.
get_int_tuple
(
"stride"
,
(
1
,))
new_attrs
[
"output_padding"
]
=
(
0
,)
+
attrs
.
get_int_tuple
(
"adj"
,
(
0
,))
new_attrs
[
"output_padding"
]
=
attrs
.
get_int_tuple
(
"adj"
,
(
0
,))
new_attrs
[
"padding"
]
=
(
0
,)
+
attrs
.
get_int_tuple
(
"pad"
,
(
0
,))
new_attrs
[
"padding"
]
=
attrs
.
get_int_tuple
(
"pad"
,
(
0
,))
new_attrs
[
"dilation"
]
=
(
1
,)
+
attrs
.
get_int_tuple
(
"dilate"
,
(
1
,))
new_attrs
[
"dilation"
]
=
attrs
.
get_int_tuple
(
"dilate"
,
(
1
,))
new_attrs
[
"groups"
]
=
attrs
.
get_int
(
"num_group"
,
1
)
new_attrs
[
"groups"
]
=
attrs
.
get_int
(
"num_group"
,
1
)
new_attrs
[
"data_layout"
]
=
data_layout
new_attrs
[
"data_layout"
]
=
data_layout
new_attrs
[
"kernel_layout"
]
=
kernel_layout
new_attrs
[
"kernel_layout"
]
=
kernel_layout
use_bias
=
not
attrs
.
get_bool
(
"no_bias"
,
True
)
use_bias
=
not
attrs
.
get_bool
(
"no_bias"
,
True
)
data
=
_op
.
expand_dims
(
inputs
[
0
],
axis
=
2
)
res
=
_op
.
nn
.
conv1d_transpose
(
inputs
[
0
],
inputs
[
1
],
**
new_attrs
)
kernel
=
_op
.
expand_dims
(
inputs
[
1
],
axis
=
2
)
res
=
_op
.
nn
.
conv2d_transpose
(
data
,
kernel
,
**
new_attrs
)
if
use_bias
:
if
use_bias
:
assert
len
(
inputs
)
==
3
assert
len
(
inputs
)
==
3
res
=
_op
.
nn
.
bias_add
(
res
,
inputs
[
2
],
axis
=
channel_axis
)
res
=
_op
.
nn
.
bias_add
(
res
,
inputs
[
2
],
axis
=
channel_axis
)
res
=
_op
.
squeeze
(
res
,
axis
=
[
2
])
return
res
return
res
...
...
python/tvm/relay/op/nn/_nn.py
View file @
d430fbb5
...
@@ -348,6 +348,37 @@ def legalize_conv2d_transpose(attrs, inputs, types):
...
@@ -348,6 +348,37 @@ def legalize_conv2d_transpose(attrs, inputs, types):
reg
.
register_pattern
(
"nn.conv2d_transpose"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
reg
.
register_pattern
(
"nn.conv2d_transpose"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
# conv1d_transpose
@reg.register_compute
(
"nn.conv1d_transpose"
)
def
compute_conv1d_transpose
(
attrs
,
inputs
,
out_dtype
,
target
):
"""Compute definition of conv1d_transpose"""
padding
=
get_const_tuple
(
attrs
.
padding
)
strides
=
get_const_tuple
(
attrs
.
strides
)
dilation
=
get_const_tuple
(
attrs
.
dilation
)
groups
=
attrs
.
groups
layout
=
attrs
.
data_layout
out_dtype
=
attrs
.
out_dtype
out_dtype
=
(
inputs
[
0
]
.
dtype
if
out_dtype
in
(
"same"
,
""
)
else
out_dtype
)
assert
layout
==
"NCW"
,
"conv1d_transpose ncw only supported"
assert
dilation
==
(
1
,),
"conv1d_transpose dilation is not supported"
assert
groups
==
1
,
"conv1d_transpose groups == 1 only supported"
out
=
topi
.
nn
.
conv1d_transpose_ncw
(
inputs
[
0
],
inputs
[
1
],
strides
,
padding
,
out_dtype
)
output_padding
=
get_const_tuple
(
attrs
.
output_padding
)
out
=
topi
.
nn
.
pad
(
out
,
[
0
,
0
,
0
],
[
0
,
0
,
output_padding
[
0
]])
return
[
out
]
@reg.register_schedule
(
"nn.conv1d_transpose"
)
def
schedule_conv1d_transpose
(
attrs
,
outs
,
target
):
"""Schedule definition of conv1d_transpose"""
with
target
:
return
topi
.
generic
.
schedule_conv1d_transpose_ncw
(
outs
)
reg
.
register_pattern
(
"nn.conv1d_transpose"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
# bias_add
# bias_add
reg
.
register_schedule
(
"nn.bias_add"
,
schedule_injective
)
reg
.
register_schedule
(
"nn.bias_add"
,
schedule_injective
)
reg
.
register_pattern
(
"nn.bias_add"
,
OpPattern
.
BROADCAST
)
reg
.
register_pattern
(
"nn.bias_add"
,
OpPattern
.
BROADCAST
)
...
...
python/tvm/relay/op/nn/nn.py
View file @
d430fbb5
...
@@ -257,6 +257,72 @@ def conv2d_transpose(data,
...
@@ -257,6 +257,72 @@ def conv2d_transpose(data,
kernel_layout
,
out_layout
,
output_padding
,
out_dtype
)
kernel_layout
,
out_layout
,
output_padding
,
out_dtype
)
def
conv1d_transpose
(
data
,
weight
,
strides
=
(
1
,),
padding
=
(
0
,),
dilation
=
(
1
,),
groups
=
1
,
channels
=
None
,
kernel_size
=
None
,
data_layout
=
"NCW"
,
kernel_layout
=
"OIW"
,
out_layout
=
""
,
output_padding
=
(
0
,),
out_dtype
=
""
):
"""One dimensional transposed convolution operator.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
strides : Tuple[int], optional
The strides of convolution.
padding : Tuple[int], optional
The padding of convolution on both sides of inputs.
dilation : Tuple[int], optional
Specifies the dilation rate to be used for dilated convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
groups : int, optional
Number of groups for grouped convolution.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : Optional[str]
Layout of the output, by default, out_layout is the same as data_layout
output_padding : Tuple[int], optional
Additional zero-padding to be added to one side of the output.
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return
_make
.
conv1d_transpose
(
data
,
weight
,
strides
,
padding
,
dilation
,
groups
,
channels
,
kernel_size
,
data_layout
,
kernel_layout
,
out_layout
,
output_padding
,
out_dtype
)
def
softmax
(
data
,
axis
=-
1
):
def
softmax
(
data
,
axis
=-
1
):
r"""Computes softmax.
r"""Computes softmax.
...
...
src/relay/op/nn/convolution.cc
View file @
d430fbb5
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <vector>
#include <vector>
#include "../../pass/alter_op_layout.h"
#include "../../pass/alter_op_layout.h"
#include "../op_common.h"
#include "convolution.h"
#include "convolution.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -328,6 +329,162 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
...
@@ -328,6 +329,162 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
.
add_type_rel
(
"Conv2DTranspose"
,
Conv2DTransposeRel
);
.
add_type_rel
(
"Conv2DTranspose"
,
Conv2DTransposeRel
);
// relay.nn.conv1d_transpose
TVM_REGISTER_NODE_TYPE
(
Conv1DTransposeAttrs
);
bool
Conv1DTransposeRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
3
);
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
const
auto
*
weight
=
types
[
1
].
as
<
TensorTypeNode
>
();
if
(
data
==
nullptr
)
return
false
;
static
const
Layout
kNCW
(
"NCW"
);
static
const
Layout
kOIW
(
"OIW"
);
const
Conv1DTransposeAttrs
*
param
=
attrs
.
as
<
Conv1DTransposeAttrs
>
();
CHECK
(
param
!=
nullptr
);
const
Layout
in_layout
(
param
->
data_layout
);
const
Layout
kernel_layout
(
param
->
kernel_layout
);
const
auto
trans_in_layout
=
BijectiveLayoutNode
::
make
(
in_layout
,
kNCW
);
CHECK
(
trans_in_layout
.
defined
())
<<
"Conv only support input layouts that are convertible from NCW."
<<
" But got "
<<
in_layout
;
const
auto
trans_kernel_layout
=
BijectiveLayoutNode
::
make
(
kernel_layout
,
kOIW
);
CHECK
(
trans_kernel_layout
.
defined
())
<<
"Conv only support kernel layouts that are convertible from OIW."
<<
" But got "
<<
kernel_layout
;
Layout
out_layout
(
param
->
out_layout
==
""
?
param
->
data_layout
:
param
->
out_layout
);
const
auto
trans_out_layout
=
BijectiveLayoutNode
::
make
(
out_layout
,
kNCW
);
CHECK
(
trans_out_layout
.
defined
())
<<
"Conv only support output layouts that are convertible from NCW."
<<
" But got "
<<
out_layout
;
IndexExpr
channels
,
dilated_ksize_y
,
dilated_ksize_x
;
auto
dshape_ncw
=
trans_in_layout
.
ForwardShape
(
data
->
shape
);
// infer weight if the kernel_size and channels are defined
if
(
param
->
kernel_size
.
defined
()
&&
param
->
channels
.
defined
())
{
CHECK_EQ
(
param
->
kernel_size
.
size
(),
1
);
CHECK_EQ
(
param
->
dilation
.
size
(),
1
);
Array
<
IndexExpr
>
wshape
({
dshape_ncw
[
1
],
indexdiv
(
param
->
channels
,
param
->
groups
),
param
->
kernel_size
[
0
]});
wshape
=
trans_kernel_layout
.
BackwardShape
(
wshape
);
dilated_ksize_x
=
1
+
(
param
->
kernel_size
[
0
]
-
1
)
*
param
->
dilation
[
0
];
channels
=
param
->
channels
;
// assign result to reporter
reporter
->
Assign
(
types
[
1
],
TensorTypeNode
::
make
(
wshape
,
data
->
dtype
));
}
else
{
// use weight to infer the conv shape.
if
(
weight
==
nullptr
)
return
false
;
auto
wshape
=
trans_kernel_layout
.
ForwardShape
(
weight
->
shape
);
if
(
param
->
kernel_size
.
defined
())
{
CHECK_EQ
(
param
->
kernel_size
.
size
(),
1
);
// check the size
CHECK
(
reporter
->
AssertEQ
(
param
->
kernel_size
[
0
],
wshape
[
2
]))
<<
"Conv1D: shape of weight is inconsistent with kernel_size, "
<<
" kernel_size="
<<
param
->
kernel_size
<<
" wshape="
<<
Array
<
IndexExpr
>
(
wshape
);
}
if
(
param
->
channels
.
defined
())
{
CHECK
(
reporter
->
AssertEQ
(
param
->
channels
,
wshape
[
1
]))
<<
"Conv1D: shape of weight is inconsistent with channels, "
<<
" channels="
<<
param
->
channels
<<
" wshape="
<<
Array
<
IndexExpr
>
(
wshape
);
}
CHECK
(
reporter
->
AssertEQ
(
indexdiv
(
dshape_ncw
[
1
],
param
->
groups
),
wshape
[
0
]));
channels
=
wshape
[
1
];
dilated_ksize_x
=
1
+
(
wshape
[
2
]
-
1
)
*
param
->
dilation
[
0
];
}
// dilation
IndexExpr
pad_w
;
GetPaddingWidth
(
param
->
padding
,
&
pad_w
);
Array
<
IndexExpr
>
oshape
({
dshape_ncw
[
0
],
channels
,
0
});
oshape
.
Set
(
2
,
(
param
->
strides
[
0
]
*
(
dshape_ncw
[
2
]
-
1
)
+
dilated_ksize_x
-
pad_w
+
param
->
output_padding
[
0
]));
DataType
out_dtype
=
param
->
out_dtype
;
if
(
out_dtype
.
bits
()
==
0
)
{
out_dtype
=
data
->
dtype
;
}
oshape
=
trans_out_layout
.
BackwardShape
(
oshape
);
reporter
->
Assign
(
types
[
2
],
TensorTypeNode
::
make
(
oshape
,
out_dtype
));
return
true
;
}
Expr
MakeConv1DTranspose
(
Expr
data
,
Expr
weight
,
Array
<
IndexExpr
>
strides
,
Array
<
IndexExpr
>
padding
,
Array
<
IndexExpr
>
dilation
,
int
groups
,
IndexExpr
channels
,
Array
<
IndexExpr
>
kernel_size
,
std
::
string
data_layout
,
std
::
string
kernel_layout
,
std
::
string
out_layout
,
Array
<
IndexExpr
>
output_padding
,
DataType
out_dtype
)
{
auto
attrs
=
make_node
<
Conv1DTransposeAttrs
>
();
attrs
->
channels
=
std
::
move
(
channels
);
attrs
->
kernel_size
=
std
::
move
(
kernel_size
);
attrs
->
strides
=
std
::
move
(
strides
);
attrs
->
padding
=
std
::
move
(
padding
);
attrs
->
output_padding
=
std
::
move
(
output_padding
);
attrs
->
dilation
=
std
::
move
(
dilation
);
attrs
->
groups
=
groups
;
attrs
->
data_layout
=
std
::
move
(
data_layout
);
attrs
->
kernel_layout
=
std
::
move
(
kernel_layout
);
attrs
->
out_layout
=
std
::
move
(
out_layout
);
attrs
->
out_dtype
=
std
::
move
(
out_dtype
);
static
const
Op
&
op
=
Op
::
Get
(
"nn.conv1d_transpose"
);
return
CallNode
::
make
(
op
,
{
data
,
weight
},
Attrs
(
attrs
),
{});
}
TVM_REGISTER_API
(
"relay.op.nn._make.conv1d_transpose"
)
.
set_body_typed
(
MakeConv1DTranspose
);
RELAY_REGISTER_OP
(
"nn.conv1d_transpose"
)
.
describe
(
R"code(Transposed 1D convolution layer (sometimes called Deconvolution).
The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction
of a normal convolution, i.e., from something that has the shape of the
output of some convolution to something that has the shape of its input
while maintaining a connectivity pattern that is compatible with
said convolution.
- **data**: This depends on the `layout` parameter. Input is 3D array of shape
(batch_size, in_channels, width) if `layout` is `NCW`.
- **weight**: (in_channels, channels, kernel_size[0])
- **bias**: (channels,)
- **out**: This depends on the `layout` parameter. Output is 3D array of shape
(batch_size, channels, out_width) if `layout` is `NCW`.
out_width is calculated as::
out_width = (width-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0]
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type
<
Conv1DTransposeAttrs
>
()
.
set_num_inputs
(
2
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"weight"
,
"Tensor"
,
"The weight tensor."
)
.
set_support_level
(
2
)
.
add_type_rel
(
"Conv1DTranspose"
,
Conv1DTransposeRel
);
// relay.nn.contrib_conv2d_winograd_without_weight_transform
// relay.nn.contrib_conv2d_winograd_without_weight_transform
TVM_REGISTER_NODE_TYPE
(
Conv2DWinogradAttrs
);
TVM_REGISTER_NODE_TYPE
(
Conv2DWinogradAttrs
);
...
...
src/relay/op/op_common.h
View file @
d430fbb5
...
@@ -150,6 +150,18 @@ class OpMatch {
...
@@ -150,6 +150,18 @@ class OpMatch {
MatchFunc
default_
;
MatchFunc
default_
;
};
};
/*! \brief A utility function to get padding width from a 1 or 2 ints tuple. */
inline
void
GetPaddingWidth
(
const
Array
<
IndexExpr
>&
padding
,
IndexExpr
*
pad_w
)
{
if
(
padding
.
size
()
==
1
)
{
*
pad_w
=
padding
[
0
]
*
2
;
}
else
if
(
padding
.
size
()
==
2
)
{
*
pad_w
=
padding
[
0
]
+
padding
[
1
];
}
else
{
CHECK_EQ
(
padding
.
size
(),
4
)
<<
" Expected padding size of 1 or 2, found "
<<
padding
.
size
();
}
}
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
...
...
tests/python/relay/test_op_level2.py
View file @
d430fbb5
...
@@ -413,6 +413,25 @@ def test_conv2d_transpose_nhwc_run():
...
@@ -413,6 +413,25 @@ def test_conv2d_transpose_nhwc_run():
c_np
=
topi
.
testing
.
conv2d_transpose_nhwc_python
(
data
,
kernel
,
'HWOI'
,
2
,
1
)
c_np
=
topi
.
testing
.
conv2d_transpose_nhwc_python
(
data
,
kernel
,
'HWOI'
,
2
,
1
)
d_np
=
np
.
zeros
(
shape
=
oshape_nhwc
)
d_np
=
np
.
zeros
(
shape
=
oshape_nhwc
)
d_np
[:,
0
:
c_np
.
shape
[
1
],
0
:
c_np
.
shape
[
2
],:]
=
c_np
d_np
[:,
0
:
c_np
.
shape
[
1
],
0
:
c_np
.
shape
[
2
],:]
=
c_np
def
test_conv1d_transpose_ncw_run
():
dshape
=
(
1
,
3
,
18
)
kshape
=
(
3
,
10
,
3
)
oshape
=
(
1
,
10
,
37
)
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
w
=
relay
.
var
(
"w"
)
y
=
relay
.
nn
.
conv1d_transpose
(
x
,
w
,
channels
=
10
,
kernel_size
=
(
3
,),
strides
=
(
2
,),
padding
=
(
1
,),
output_padding
=
(
2
,))
func
=
relay
.
Function
([
x
,
w
],
y
)
dtype
=
"float32"
data
=
np
.
random
.
uniform
(
size
=
dshape
)
.
astype
(
dtype
)
kernel
=
np
.
random
.
uniform
(
size
=
kshape
)
.
astype
(
dtype
)
c_np
=
topi
.
testing
.
conv1d_transpose_ncw_python
(
data
,
kernel
,
2
,
1
)
d_np
=
np
.
zeros
(
shape
=
oshape
)
d_np
[:,:,
0
:
c_np
.
shape
[
2
]]
=
c_np
ref_res
=
d_np
ref_res
=
d_np
for
target
,
ctx
in
ctx_list
():
for
target
,
ctx
in
ctx_list
():
...
@@ -893,6 +912,7 @@ if __name__ == "__main__":
...
@@ -893,6 +912,7 @@ if __name__ == "__main__":
test_conv2d_transpose_infer_type
()
test_conv2d_transpose_infer_type
()
test_conv2d_transpose_nchw_run
()
test_conv2d_transpose_nchw_run
()
test_conv2d_transpose_nhwc_run
()
test_conv2d_transpose_nhwc_run
()
test_conv1d_transpose_ncw_run
()
test_conv2d_run
()
test_conv2d_run
()
test_conv2d_winograd
()
test_conv2d_winograd
()
test_conv3d_run
()
test_conv3d_run
()
...
...
topi/python/topi/cuda/__init__.py
View file @
d430fbb5
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
.
import
conv2d
,
depthwise_conv2d
,
conv2d_transpose_nchw
,
deformable_conv2d
,
\
from
.
import
conv2d
,
depthwise_conv2d
,
conv2d_transpose_nchw
,
deformable_conv2d
,
\
group_conv2d_nchw
,
dense
group_conv2d_nchw
,
dense
,
conv1d_transpose_ncw
from
.
import
conv3d
from
.
import
conv3d
from
.conv2d_hwcn
import
schedule_conv2d_hwcn
from
.conv2d_hwcn
import
schedule_conv2d_hwcn
from
.depthwise_conv2d
import
schedule_depthwise_conv2d_backward_input_nhwc
from
.depthwise_conv2d
import
schedule_depthwise_conv2d_backward_input_nhwc
...
...
topi/python/topi/cuda/conv1d_transpose_ncw.py
0 → 100644
View file @
d430fbb5
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Conv1d transpose template for cuda backend"""
import
tvm
from
tvm
import
autotvm
from
..
import
nn
,
generic
from
..util
import
get_const_tuple
,
traverse_inline
@autotvm.task.register_topi_compute
(
nn
.
conv1d_transpose_ncw
,
[
'cuda'
,
'gpu'
],
"direct"
)
def
conv1d_transpose_ncw_cuda
(
cfg
,
data
,
kernel
,
stride
,
padding
,
out_dtype
):
"""Transposed 1D convolution ncw forward operator.
Parameters
----------
cfg: ConfigEntity
The config for this template
Input : tvm.Tensor
3-D with shape [batch, in_channel, inp_width]
Filter : tvm.Tensor
3-D with shape [in_channel, num_filter, kernel_size]
stride : tuple of one int
The spatial stride along width
padding : int, tuple, or string
int: padding size
tuple of 2 ints: (pad_left, pad_right) for left and right padding
string: ['VALID', 'SAME']
out_dtype: str
The output type. This is used in mixed precision
Returns
-------
Output : tvm.Tensor
u 3-D with shape [batch, out_channel, out_width]
"""
if
isinstance
(
stride
,
(
tuple
,
list
)):
stride
=
stride
[
0
]
cfg
.
stride
=
stride
batch
,
inp_channels
,
inp_width
=
get_const_tuple
(
data
.
shape
)
_
,
out_channels
,
kernel_size
=
get_const_tuple
(
kernel
.
shape
)
pad_left
,
pad_right
=
nn
.
get_pad_tuple1d
(
padding
,
kernel_size
)
out_width
=
(
inp_width
-
1
)
*
stride
+
kernel_size
-
pad_left
-
pad_right
pad_left
=
kernel_size
-
1
-
pad_left
pad_right
=
kernel_size
-
1
-
pad_right
dilated_width
=
stride
*
(
inp_width
-
1
)
+
1
data
=
tvm
.
compute
(
(
batch
,
inp_channels
,
pad_left
+
dilated_width
+
pad_right
),
lambda
n
,
c
,
x
:
tvm
.
if_then_else
(
tvm
.
all
(
x
>=
pad_left
,
x
<
pad_left
+
dilated_width
,
tvm
.
indexmod
(
x
-
pad_left
,
stride
)
.
equal
(
0
)),
data
[
n
,
c
,
tvm
.
indexdiv
(
x
-
pad_left
,
stride
)],
tvm
.
const
(
0.
,
"float32"
)),
name
=
'data_pad'
)
dc
=
tvm
.
reduce_axis
((
0
,
inp_channels
),
name
=
'dc'
)
dw
=
tvm
.
reduce_axis
((
0
,
kernel_size
),
name
=
'dw'
)
data_out
=
tvm
.
compute
(
(
batch
,
out_channels
,
out_width
),
lambda
b
,
c
,
w
:
tvm
.
sum
(
data
[
b
,
dc
,
w
+
dw
]
.
astype
(
out_dtype
)
*
kernel
[
dc
,
c
,
kernel_size
-
1
-
dw
]
.
astype
(
out_dtype
),
axis
=
[
dc
,
dw
]),
tag
=
"conv1d_transpose_ncw"
)
return
data_out
@autotvm.task.register_topi_schedule
(
generic
.
schedule_conv1d_transpose_ncw
,
[
'cuda'
,
'gpu'
],
'direct'
)
def
schedule_conv1d_transpose_ncw_cuda
(
cfg
,
outs
):
"""TOPI Schedule callback for conv1d_transpose operator.
Parameters
----------
cfg: ConfigEntity
The parameters for this template
outs: Array of Tensor
The computation graph description of conv1d transpose
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv1d transpose.
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
def
_callback
(
op
):
if
op
.
tag
==
'conv1d_transpose_ncw'
:
pad_data
=
op
.
input_tensors
[
0
]
kernel
=
op
.
input_tensors
[
1
]
conv
=
op
.
output
(
0
)
##### space definition begin #####
n
,
f
,
x
=
s
[
conv
]
.
op
.
axis
rc
=
s
[
conv
]
.
op
.
reduce_axis
[
0
]
cfg
.
define_split
(
"tile_n"
,
cfg
.
axis
(
n
),
num_outputs
=
4
)
cfg
.
define_split
(
"tile_f"
,
cfg
.
axis
(
f
),
num_outputs
=
4
)
cfg
.
define_split
(
"tile_x"
,
cfg
.
axis
(
x
),
num_outputs
=
4
)
cfg
.
define_split
(
"tile_rc"
,
cfg
.
axis
(
rc
),
num_outputs
=
3
)
cfg
.
define_knob
(
"auto_unroll_max_step"
,
[
64
,
512
,
1500
])
target
=
tvm
.
target
.
current_target
()
if
target
.
target_name
in
[
'nvptx'
,
'rocm'
]:
cfg
.
define_knob
(
"unroll_explicit"
,
[
1
])
else
:
cfg
.
define_knob
(
"unroll_explicit"
,
[
0
,
1
])
##### space definition end #####
if
isinstance
(
kernel
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
'dilate'
in
kernel
.
op
.
tag
:
s
[
kernel
]
.
compute_inline
()
if
conv
.
op
in
s
.
outputs
:
output
=
conv
OL
=
s
.
cache_write
(
conv
,
'local'
)
else
:
output
=
s
.
outputs
[
0
]
.
output
(
0
)
s
[
conv
]
.
set_scope
(
'local'
)
OL
=
conv
# create cache stage
s
[
pad_data
]
.
set_scope
(
'shared'
)
AA
=
pad_data
WW
=
s
.
cache_read
(
kernel
,
'shared'
,
[
OL
])
# tile and bind spatial axes
n
,
f
,
x
=
s
[
output
]
.
op
.
axis
kernel_scope
,
n
=
s
[
output
]
.
split
(
n
,
nparts
=
1
)
bn
,
vn
,
tn
,
ni
=
cfg
[
"tile_n"
]
.
apply
(
s
,
output
,
n
)
bf
,
vf
,
tf
,
fi
=
cfg
[
"tile_f"
]
.
apply
(
s
,
output
,
f
)
bx
,
vx
,
tx
,
xi
=
cfg
[
"tile_x"
]
.
apply
(
s
,
output
,
x
)
s
[
output
]
.
reorder
(
bn
,
bf
,
bx
,
vn
,
vf
,
vx
,
tn
,
tf
,
tx
,
ni
,
fi
,
xi
)
s
[
output
]
.
bind
(
bn
,
tvm
.
thread_axis
(
"blockIdx.z"
))
s
[
output
]
.
bind
(
bf
,
tvm
.
thread_axis
(
"blockIdx.y"
))
s
[
output
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
output
]
.
bind
(
vn
,
tvm
.
thread_axis
(
"vthread"
))
s
[
output
]
.
bind
(
vf
,
tvm
.
thread_axis
(
"vthread"
))
s
[
output
]
.
bind
(
vx
,
tvm
.
thread_axis
(
"vthread"
))
s
[
output
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
OL
]
.
compute_at
(
s
[
output
],
tx
)
# number of threads
n_tz
=
cfg
[
"tile_n"
]
.
size
[
2
]
*
cfg
[
"tile_f"
]
.
size
[
2
]
n_tx
=
cfg
[
"tile_x"
]
.
size
[
2
]
# tile reduction axes
n
,
f
,
x
=
s
[
OL
]
.
op
.
axis
rc
,
rx
=
s
[
OL
]
.
op
.
reduce_axis
rco
,
rcm
,
rci
=
cfg
[
'tile_rc'
]
.
apply
(
s
,
OL
,
rc
)
s
[
OL
]
.
reorder
(
rco
,
rcm
,
rx
,
rci
,
n
,
f
,
x
)
s
[
AA
]
.
compute_at
(
s
[
OL
],
rx
)
s
[
WW
]
.
compute_at
(
s
[
OL
],
rx
)
# cooperative fetching
for
load
in
[
AA
,
WW
]:
n
,
f
,
x
=
s
[
load
]
.
op
.
axis
fused
=
s
[
load
]
.
fuse
(
f
,
x
)
tz
,
fused
=
s
[
load
]
.
split
(
fused
,
nparts
=
n_tz
)
tx
,
fused
=
s
[
load
]
.
split
(
fused
,
nparts
=
n_tx
)
s
[
load
]
.
bind
(
tz
,
tvm
.
thread_axis
(
"threadIdx.y"
))
s
[
load
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
output
]
.
pragma
(
kernel_scope
,
'auto_unroll_max_step'
,
cfg
[
'auto_unroll_max_step'
]
.
val
)
s
[
output
]
.
pragma
(
kernel_scope
,
'unroll_explicit'
,
cfg
[
'unroll_explicit'
]
.
val
)
traverse_inline
(
s
,
outs
[
0
]
.
op
,
_callback
)
return
s
topi/python/topi/generic/nn.py
View file @
d430fbb5
...
@@ -262,6 +262,24 @@ def schedule_conv2d_transpose_nchw(outs):
...
@@ -262,6 +262,24 @@ def schedule_conv2d_transpose_nchw(outs):
@tvm.target.generic_func
@tvm.target.generic_func
def
schedule_conv1d_transpose_ncw
(
outs
):
"""Schedule for conv1d_transpose_ncw
Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_transpose_ncw
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return
_default_schedule
(
outs
,
False
)
@tvm.target.generic_func
def
schedule_depthwise_conv2d_nchw
(
outs
):
def
schedule_depthwise_conv2d_nchw
(
outs
):
"""Schedule for depthwise_conv2d_nchw
"""Schedule for depthwise_conv2d_nchw
...
...
topi/python/topi/nn/__init__.py
View file @
d430fbb5
...
@@ -31,6 +31,7 @@ from .mapping import *
...
@@ -31,6 +31,7 @@ from .mapping import *
from
.pooling
import
*
from
.pooling
import
*
from
.softmax
import
*
from
.softmax
import
*
from
.conv2d_transpose
import
*
from
.conv2d_transpose
import
*
from
.conv1d_transpose
import
*
from
.bnn
import
*
from
.bnn
import
*
from
.upsampling
import
*
from
.upsampling
import
*
from
.local_response_norm
import
*
from
.local_response_norm
import
*
...
...
topi/python/topi/nn/conv1d_transpose.py
0 → 100644
View file @
d430fbb5
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, unused-argument
"""Transposed 1D convolution operators (sometimes called Deconvolution)."""
from
__future__
import
absolute_import
as
_abs
import
tvm
from
.dilate
import
dilate
from
.pad
import
pad
from
..util
import
simplify
from
.util
import
get_pad_tuple1d
@tvm.target.generic_func
def
conv1d_transpose_ncw
(
data
,
kernel
,
stride
,
padding
,
out_dtype
):
"""Transposed 1D convolution ncw forward operator.
Parameters
----------
data : tvm.Tensor
3-D with shape [batch, in_channel, in_width]
kernel : tvm.Tensor
3-D with shape [in_channel, num_filter, filter_width]
stride : ints
The spatial stride along width
padding : int or str
Padding size, or ['VALID', 'SAME']
out_dtype : str
The output data type. This is used for mixed precision.
Returns
-------
output : tvm.Tensor
3-D with shape [batch, out_channel, out_width]
"""
# dilate and pad
if
isinstance
(
stride
,
(
tuple
,
list
)):
stride
=
stride
[
0
]
batch
,
channels_in
,
data_width
=
data
.
shape
_
,
channels_out
,
kernel_width
=
kernel
.
shape
channels_out
=
simplify
(
channels_out
)
data
=
dilate
(
data
,
[
1
,
1
,
stride
],
name
=
'data_dilate'
)
pad_left
,
pad_right
=
get_pad_tuple1d
(
padding
,
(
kernel_width
,))
pad_left
=
kernel_width
-
1
-
pad_left
pad_right
=
kernel_width
-
1
-
pad_right
data
=
pad
(
data
,
[
0
,
0
,
pad_left
],
[
0
,
0
,
pad_right
],
name
=
'data_pad'
)
# transpose kernel, switch kernel layout to IOW
kernel
=
tvm
.
compute
((
channels_out
,
channels_in
,
kernel_width
),
\
lambda
o
,
i
,
w
:
kernel
[
i
][
o
][
kernel_width
-
1
-
w
],
\
name
=
'kernel'
)
# convolution
_
,
_
,
data_width
=
data
.
shape
out_w
=
simplify
(
data_width
-
kernel_width
+
1
)
dc
=
tvm
.
reduce_axis
((
0
,
channels_in
),
name
=
'dc'
)
dw
=
tvm
.
reduce_axis
((
0
,
kernel_width
),
name
=
'dw'
)
output
=
tvm
.
compute
(
(
batch
,
channels_out
,
out_w
),
lambda
b
,
c
,
w
:
tvm
.
sum
(
data
[
b
,
dc
,
w
+
dw
]
.
astype
(
out_dtype
)
*
kernel
[
c
,
dc
,
dw
]
.
astype
(
out_dtype
),
axis
=
[
dc
,
dw
]),
tag
=
"conv1d_transpose_ncw"
)
return
output
topi/python/topi/nn/util.py
View file @
d430fbb5
...
@@ -172,3 +172,42 @@ def get_pad_tuple3d(padding, kernel):
...
@@ -172,3 +172,42 @@ def get_pad_tuple3d(padding, kernel):
pad_left
=
(
pad_w
+
1
)
//
2
pad_left
=
(
pad_w
+
1
)
//
2
pad_front
=
(
pad_d
+
1
)
//
2
pad_front
=
(
pad_d
+
1
)
//
2
return
pad_front
,
pad_top
,
pad_left
,
pad_d
-
pad_front
,
pad_h
-
pad_top
,
pad_w
-
pad_left
return
pad_front
,
pad_top
,
pad_left
,
pad_d
-
pad_front
,
pad_h
-
pad_top
,
pad_w
-
pad_left
def
get_pad_tuple1d
(
padding
,
kernel
):
"""Common code to get the pad option
Parameters
----------
padding : int or str
Padding size, or ['VALID', 'SAME']
kernel : tuple of int
Conv kernel size
Returns
-------
pad_left : int
Padding size on left
pad_right : int
Padding size on right.
"""
# compute the padding size
if
isinstance
(
padding
,
(
tuple
,
list
)):
if
len
(
padding
)
==
1
:
pad_w
=
padding
[
0
]
*
2
elif
len
(
padding
)
==
2
:
return
padding
[
0
],
padding
[
1
]
else
:
raise
ValueError
(
"Size of padding can only be 2 or 4"
)
elif
isinstance
(
padding
,
int
):
pad_w
=
padding
*
2
elif
padding
==
"VALID"
:
pad_w
=
0
elif
padding
==
"SAME"
:
pad_w
=
kernel
[
0
]
-
1
else
:
raise
ValueError
(
"Unknown padding option
%
s"
%
padding
)
pad_left
=
(
pad_w
+
1
)
//
2
return
pad_left
,
pad_w
-
pad_left
topi/python/topi/testing/__init__.py
View file @
d430fbb5
...
@@ -26,6 +26,7 @@ from .conv2d_nchw_python import conv2d_nchw_python
...
@@ -26,6 +26,7 @@ from .conv2d_nchw_python import conv2d_nchw_python
from
.conv2d_nhwc_python
import
conv2d_nhwc_python
from
.conv2d_nhwc_python
import
conv2d_nhwc_python
from
.conv3d_ncdhw_python
import
conv3d_ncdhw_python
from
.conv3d_ncdhw_python
import
conv3d_ncdhw_python
from
.conv2d_transpose_python
import
conv2d_transpose_nchw_python
,
conv2d_transpose_nhwc_python
from
.conv2d_transpose_python
import
conv2d_transpose_nchw_python
,
conv2d_transpose_nhwc_python
from
.conv1d_transpose_ncw_python
import
conv1d_transpose_ncw_python
from
.deformable_conv2d_nchw_python
import
deformable_conv2d_nchw_python
from
.deformable_conv2d_nchw_python
import
deformable_conv2d_nchw_python
from
.depthwise_conv2d_python
import
depthwise_conv2d_python_nchw
,
depthwise_conv2d_python_nhwc
from
.depthwise_conv2d_python
import
depthwise_conv2d_python_nchw
,
depthwise_conv2d_python_nhwc
from
.dilate_python
import
dilate_python
from
.dilate_python
import
dilate_python
...
...
topi/python/topi/testing/conv1d_transpose_ncw_python.py
0 → 100644
View file @
d430fbb5
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-variable
"""Transposed 1D convolution in python"""
import
numpy
as
np
import
scipy
import
topi
from
topi.nn.util
import
get_pad_tuple1d
def
conv1d_transpose_ncw_python
(
a_np
,
w_np
,
stride
,
padding
):
"""Transposed 1D convolution operator in NCW layout.
Parameters
----------
a_np : numpy.ndarray
3-D with shape [batch, in_channel, in_width]
w_np : numpy.ndarray
3-D with shape [in_channel, num_filter, filter_width]
stride : int or a list/tuple of one int
Stride size, or [stride_width]
padding : int, tuple, or str
Single int for padding size, or
tuple of 2 ints for left and right padding, or
['VALID', 'SAME']
Returns
-------
b_np : np.ndarray
3-D with shape [batch, out_channel, out_width]
"""
batch
,
in_c
,
in_w
=
a_np
.
shape
_
,
out_c
,
filter_w
=
w_np
.
shape
if
isinstance
(
stride
,
int
):
stride_w
=
stride
else
:
stride_w
=
stride
[
0
]
fpad_left
,
fpad_right
=
get_pad_tuple1d
(
padding
,
filter_w
)
# dilate stage
dilated_a_np
=
topi
.
testing
.
dilate_python
(
a_np
,
[
1
,
1
,
stride_w
])
# padding stage
bpad_left
=
filter_w
-
1
-
fpad_left
bpad_right
=
filter_w
-
1
-
fpad_right
padded_a_np
=
np
.
zeros
((
batch
,
in_c
,
dilated_a_np
.
shape
[
2
]
+
bpad_left
+
bpad_right
))
padded_a_np
[:,
:,
bpad_left
:
dilated_a_np
.
shape
[
2
]
+
bpad_left
]
=
dilated_a_np
# convolution stage
out_w
=
(
in_w
-
1
)
*
stride_w
-
fpad_left
-
fpad_right
+
filter_w
b_np
=
np
.
zeros
((
batch
,
out_c
,
out_w
))
for
n
in
range
(
batch
):
for
f
in
range
(
out_c
):
for
c
in
range
(
in_c
):
out
=
scipy
.
signal
.
convolve
(
padded_a_np
[
n
,
c
],
w_np
[
c
,
f
],
mode
=
'valid'
)
b_np
[
n
,
f
]
+=
out
return
b_np
topi/tests/python/test_topi_conv1d_transpose_ncw.py
0 → 100644
View file @
d430fbb5
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for transposed convolution."""
import
numpy
as
np
import
itertools
import
tvm
import
topi
import
topi.testing
from
tvm.contrib.pickle_memoize
import
memoize
from
topi.util
import
get_const_tuple
from
common
import
get_all_backend
def
verify_conv1d_transpose_ncw
(
batch
,
in_channel
,
in_size
,
num_filter
,
kernel
,
stride
,
padding
):
in_width
=
in_size
A
=
tvm
.
placeholder
((
batch
,
in_channel
,
in_width
),
name
=
'A'
)
W
=
tvm
.
placeholder
((
in_channel
,
num_filter
,
kernel
),
name
=
'W'
)
a_shape
=
get_const_tuple
(
A
.
shape
)
w_shape
=
get_const_tuple
(
W
.
shape
)
dtype
=
A
.
dtype
@memoize
(
"topi.tests.test_topi_conv1d_transpose.verify_conv1d_transpose_ncw"
)
def
get_ref_data
():
a_np
=
np
.
random
.
uniform
(
size
=
a_shape
)
.
astype
(
dtype
)
w_np
=
np
.
random
.
uniform
(
size
=
w_shape
)
.
astype
(
dtype
)
b_np
=
topi
.
testing
.
conv1d_transpose_ncw_python
(
a_np
,
w_np
,
stride
,
padding
)
c_np
=
np
.
maximum
(
b_np
,
0
)
return
a_np
,
w_np
,
b_np
,
c_np
a_np
,
w_np
,
b_np
,
c_np
=
get_ref_data
()
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
with
tvm
.
target
.
create
(
device
):
B
=
topi
.
nn
.
conv1d_transpose_ncw
(
A
,
W
,
stride
,
padding
,
A
.
dtype
)
C
=
topi
.
nn
.
relu
(
B
)
s1
=
topi
.
generic
.
schedule_conv1d_transpose_ncw
([
B
])
s2
=
topi
.
generic
.
schedule_conv1d_transpose_ncw
([
C
])
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
w
=
tvm
.
nd
.
array
(
w_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
B
.
shape
),
dtype
=
B
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
C
.
shape
),
dtype
=
C
.
dtype
),
ctx
)
func1
=
tvm
.
build
(
s1
,
[
A
,
W
,
B
],
device
)
func2
=
tvm
.
build
(
s2
,
[
A
,
W
,
C
],
device
)
func1
(
a
,
w
,
b
)
func2
(
a
,
w
,
c
)
tvm
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
)
tvm
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
c_np
,
rtol
=
1e-5
)
for
device
in
get_all_backend
():
check_device
(
device
)
def
test_conv1d_transpose_ncw
():
verify_conv1d_transpose_ncw
(
1
,
3
,
224
,
32
,
5
,
1
,
0
)
verify_conv1d_transpose_ncw
(
1
,
3
,
224
,
32
,
7
,
1
,
2
)
verify_conv1d_transpose_ncw
(
1
,
3
,
224
,
32
,
5
,
2
,
1
)
verify_conv1d_transpose_ncw
(
1
,
3
,
224
,
32
,
5
,
2
,
0
)
verify_conv1d_transpose_ncw
(
1
,
32
,
32
,
128
,
5
,
1
,
0
)
verify_conv1d_transpose_ncw
(
1
,
32
,
32
,
128
,
5
,
2
,
1
)
verify_conv1d_transpose_ncw
(
1
,
1
,
1024
,
1
,
512
,
1
,
256
)
verify_conv1d_transpose_ncw
(
1
,
1
,
1024
,
1
,
512
,
2
,
256
)
verify_conv1d_transpose_ncw
(
1
,
1
,
1024
,
1
,
512
,
5
,
256
)
verify_conv1d_transpose_ncw
(
1
,
1
,
10
,
1
,
5
,
1
,
(
0
,
3
))
verify_conv1d_transpose_ncw
(
1
,
1
,
10
,
1
,
5
,
1
,
(
1
,
3
))
verify_conv1d_transpose_ncw
(
1
,
1
,
10
,
1
,
5
,
1
,
(
2
,
3
))
if
__name__
==
"__main__"
:
test_conv1d_transpose_ncw
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment