Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5c410037
Commit
5c410037
authored
Jul 25, 2019
by
Wuwei Lin
Committed by
Tianqi Chen
Jul 24, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI][Relay] max_pool2d & avg_pool2d gradient (#3601)
parent
440df0aa
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
529 additions
and
19 deletions
+529
-19
python/tvm/relay/op/_tensor_grad.py
+18
-0
python/tvm/relay/op/nn/_nn.py
+22
-0
python/tvm/relay/op/nn/nn.py
+82
-0
python/tvm/relay/op/op_attrs.py
+10
-0
src/relay/op/nn/pooling.cc
+158
-2
topi/include/topi/detail/ravel_unravel.h
+1
-1
topi/include/topi/nn/pooling.h
+0
-0
topi/include/topi/reduction.h
+19
-15
topi/include/topi/transform.h
+2
-1
topi/python/topi/generic/nn.py
+13
-0
topi/python/topi/nn/pooling.py
+62
-0
topi/python/topi/testing/__init__.py
+1
-0
topi/python/topi/testing/pool_grad_python.py
+65
-0
topi/src/topi.cc
+7
-0
topi/tests/python/test_topi_pooling.py
+69
-0
No files found.
python/tvm/relay/op/_tensor_grad.py
View file @
5c410037
...
...
@@ -22,6 +22,7 @@ from .op import register_gradient
from
.transform
import
collapse_sum_like
,
broadcast_to_like
,
where
from
.tensor
import
exp
,
negative
,
power
,
less
from
.tensor
import
zeros_like
,
ones_like
from
.
import
nn
as
_nn
@register_gradient
(
"log"
)
...
...
@@ -146,3 +147,20 @@ def clip_grad(orig, grad):
zeros
=
zeros_like
(
x
)
ones
=
ones_like
(
x
)
return
[
where
(
less
(
x
,
a_mins
),
zeros
,
where
(
less
(
a_maxs
,
x
),
zeros
,
ones
*
grad
))]
@register_gradient
(
"nn.max_pool2d"
)
def
max_pool2d_grad
(
orig
,
grad
):
attrs
=
orig
.
attrs
pool_grad
=
_nn
.
max_pool2d_grad
(
grad
,
orig
.
args
[
0
],
pool_size
=
attrs
.
pool_size
,
strides
=
attrs
.
strides
,
padding
=
attrs
.
padding
,
layout
=
attrs
.
layout
,
ceil_mode
=
attrs
.
ceil_mode
)
return
[
pool_grad
]
@register_gradient
(
"nn.avg_pool2d"
)
def
avg_pool2d_grad
(
orig
,
grad
):
attrs
=
orig
.
attrs
pool_grad
=
_nn
.
avg_pool2d_grad
(
grad
,
orig
.
args
[
0
],
pool_size
=
attrs
.
pool_size
,
strides
=
attrs
.
strides
,
padding
=
attrs
.
padding
,
layout
=
attrs
.
layout
,
ceil_mode
=
attrs
.
ceil_mode
,
count_include_pad
=
attrs
.
count_include_pad
)
return
[
pool_grad
]
python/tvm/relay/op/nn/_nn.py
View file @
5c410037
...
...
@@ -255,6 +255,28 @@ def schedule_avg_pool2d(attrs, outs, target):
reg
.
register_pattern
(
"nn.avg_pool2d"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
# max_pool2d_grad
@reg.register_schedule
(
"nn.max_pool2d_grad"
)
def
schedule_max_pool2d_grad
(
attrs
,
outs
,
target
):
"""Schedule definition of max_pool2d_grad"""
with
target
:
return
topi
.
generic
.
schedule_pool_grad
(
outs
)
reg
.
register_pattern
(
"nn.max_pool2d_grad"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
# avg_pool2d_grad
@reg.register_schedule
(
"nn.avg_pool2d_grad"
)
def
schedule_avg_pool2d_grad
(
attrs
,
outs
,
target
):
"""Schedule definition of avg_pool2d_grad"""
with
target
:
return
topi
.
generic
.
schedule_pool_grad
(
outs
)
reg
.
register_pattern
(
"nn.avg_pool2d_grad"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
# global_max_pool2d
@reg.register_schedule
(
"nn.global_max_pool2d"
)
def
schedule_global_max_pool2d
(
_
,
outs
,
target
):
...
...
python/tvm/relay/op/nn/nn.py
View file @
5c410037
...
...
@@ -327,6 +327,88 @@ def avg_pool2d(data,
return
_make
.
avg_pool2d
(
data
,
pool_size
,
strides
,
padding
,
layout
,
ceil_mode
,
count_include_pad
)
def
max_pool2d_grad
(
out_grad
,
data
,
pool_size
=
(
1
,
1
),
strides
=
(
1
,
1
),
padding
=
(
0
,
0
),
layout
=
"NCHW"
,
ceil_mode
=
False
):
r"""Gradient of 2D maximum pooling operator.
This operator takes out_grad and data as input and calculates gradient of max_pool2d.
Parameters
----------
out_grad : tvm.relay.Expr
The output gradient
data : tvm.relay.Expr
The input data to the operator.
strides : tuple of int, optional
The strides of pooling.
padding : tuple of int, optional
The padding for pooling.
layout : str, optional
Layout of the input.
ceil_mode : bool, optional
To enable or disable ceil while pooling.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return
_make
.
max_pool2d_grad
(
out_grad
,
data
,
pool_size
,
strides
,
padding
,
layout
,
ceil_mode
)
def
avg_pool2d_grad
(
out_grad
,
data
,
pool_size
=
(
1
,
1
),
strides
=
(
1
,
1
),
padding
=
(
0
,
0
),
layout
=
"NCHW"
,
ceil_mode
=
False
,
count_include_pad
=
False
):
r"""Gradient of 2D average pooling operator.
This operator takes out_grad and data as input and calculates gradient of avg_pool2d.
Parameters
----------
out_grad : tvm.relay.Expr
The output gradient
data : tvm.relay.Expr
The input data to the operator.
strides : tuple of int, optional
The strides of pooling.
padding : tuple of int, optional
The padding for pooling.
layout : str, optional
Layout of the input.
ceil_mode : bool, optional
To enable or disable ceil while pooling.
count_include_pad : bool, optional
To include padding to compute the average.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return
_make
.
avg_pool2d_grad
(
out_grad
,
data
,
pool_size
,
strides
,
padding
,
layout
,
ceil_mode
,
count_include_pad
)
def
global_max_pool2d
(
data
,
layout
=
"NCHW"
):
r"""2D global maximum pooling operator.
...
...
python/tvm/relay/op/op_attrs.py
View file @
5c410037
...
...
@@ -251,3 +251,13 @@ class YoloReorgAttrs(Attrs):
@register_relay_attr_node
class
ProposalAttrs
(
Attrs
):
"""Attributes used in proposal operators"""
@register_relay_attr_node
class
MaxPool2DAttrs
(
Attrs
):
"""Attributes used in max_pool2d operators"""
@register_relay_attr_node
class
AvgPool2DAttrs
(
Attrs
):
"""Attributes used in avg_pool2d operators"""
src/relay/op/nn/pooling.cc
View file @
5c410037
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -557,5 +557,161 @@ RELAY_REGISTER_OP("contrib.adaptive_max_pool2d")
Pool2DInferCorrectLayout
<
AdaptivePool2DAttrs
>
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
AdaptivePool2DCompute
<
topi
::
nn
::
kMaxPool
>
);
bool
Pool2DGradRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
3
);
const
auto
*
data
=
types
[
1
].
as
<
TensorTypeNode
>
();
if
(
data
==
nullptr
)
return
false
;
// assign output type
reporter
->
Assign
(
types
[
2
],
types
[
1
]);
return
true
;
}
template
<
typename
AttrType
,
topi
::
nn
::
PoolType
mode
>
Array
<
Tensor
>
Pool2DGradCompute
(
const
Attrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Type
&
out_type
,
const
Target
&
target
)
{
static
const
Layout
kNCHW
(
"NCHW"
);
const
auto
*
param
=
attrs
.
as
<
AttrType
>
();
CHECK
(
param
!=
nullptr
);
CHECK_EQ
(
inputs
.
size
(),
2
);
auto
pool_size
=
param
->
pool_size
;
auto
strides
=
param
->
strides
;
auto
padding
=
param
->
padding
;
auto
ceil_mode
=
param
->
ceil_mode
;
Layout
layout
(
param
->
layout
);
CHECK
(
BijectiveLayoutNode
::
make
(
layout
,
kNCHW
).
defined
())
<<
"pool2d_grad currently only supports layouts that are convertible from NCHW"
;
CHECK_EQ
(
layout
.
IndexOf
(
LayoutAxis
::
Get
(
'h'
)),
-
1
)
<<
"pool2d_grad does not support input split on height"
;
CHECK_EQ
(
layout
.
IndexOf
(
LayoutAxis
::
Get
(
'w'
)),
-
1
)
<<
"pool2d_grad does not support input split on width"
;
CHECK
(
inputs
[
0
].
ndim
()
==
4U
||
inputs
[
0
].
ndim
()
==
5U
)
<<
"Pool2DGrad only support 4-D output gradient (e.g., NCHW)"
<<
" or 5-D output gradient (last dimension is a split of channel)"
;
CHECK
(
inputs
[
1
].
ndim
()
==
4U
||
inputs
[
1
].
ndim
()
==
5U
)
<<
"Pool2DGrad only support 4-D input (e.g., NCHW)"
<<
" or 5-D input (last dimension is a split of channel)"
;
if
(
param
->
padding
.
size
()
==
1
)
{
padding
.
push_back
(
padding
[
0
]);
padding
.
push_back
(
padding
[
0
]);
padding
.
push_back
(
padding
[
0
]);
}
else
if
(
param
->
padding
.
size
()
==
2
)
{
padding
.
push_back
(
padding
[
0
]);
padding
.
push_back
(
padding
[
1
]);
}
if
(
mode
==
topi
::
nn
::
kAvgPool
)
{
bool
count_include_pad
=
reinterpret_cast
<
const
AvgPool2DAttrs
*>
(
param
)
->
count_include_pad
;
return
Array
<
Tensor
>
{
topi
::
nn
::
pool_grad
(
inputs
[
0
],
inputs
[
1
],
pool_size
,
strides
,
padding
,
mode
,
ceil_mode
,
layout
.
name
(),
count_include_pad
)};
}
else
{
return
Array
<
Tensor
>
{
topi
::
nn
::
pool_grad
(
inputs
[
0
],
inputs
[
1
],
pool_size
,
strides
,
padding
,
mode
,
ceil_mode
,
layout
.
name
())};
}
}
// MaxPool2DGrad
Expr
MakeMaxPool2DGrad
(
Expr
out_grad
,
Expr
data
,
Array
<
IndexExpr
>
pool_size
,
Array
<
IndexExpr
>
strides
,
Array
<
IndexExpr
>
padding
,
std
::
string
layout
,
bool
ceil_mode
)
{
auto
attrs
=
make_node
<
MaxPool2DAttrs
>
();
attrs
->
pool_size
=
std
::
move
(
pool_size
);
attrs
->
strides
=
std
::
move
(
strides
);
attrs
->
padding
=
std
::
move
(
padding
);
attrs
->
layout
=
std
::
move
(
layout
);
attrs
->
ceil_mode
=
ceil_mode
;
static
const
Op
&
op
=
Op
::
Get
(
"nn.max_pool2d_grad"
);
return
CallNode
::
make
(
op
,
{
out_grad
,
data
},
Attrs
(
attrs
),
{});
}
TVM_REGISTER_API
(
"relay.op.nn._make.max_pool2d_grad"
).
set_body_typed
(
MakeMaxPool2DGrad
);
RELAY_REGISTER_OP
(
"nn.max_pool2d_grad"
)
.
describe
(
R"code(Gradient of max pooling operation for two dimensional data.
- **out_grad**: This depends on the `layout` parameter. Output gradient is 4D array of
shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
out_height and out_width are are the output size of the pooling operation,
which are calculated as::
out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1
out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1
where padding will be an expanded array based on number of values passed as::
one int : all sides same padding used.
two int : bottom, right use same as top and left.
four int: padding width in the order of (top, left, bottom, right).
When `ceil_mode` is `True`, ceil will be used instead of floor in this
equation.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
- **grad**: This depends on the `layout` parameter. Grad is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.MaxPool2DAttrs"
)
.
set_num_inputs
(
2
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
2
)
.
add_type_rel
(
"MaxPool2DGrad"
,
Pool2DGradRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
Pool2DGradCompute
<
MaxPool2DAttrs
,
topi
::
nn
::
kMaxPool
>
);
// AvgPool2DGrad
Expr
MakeAvgPool2DGrad
(
Expr
out_grad
,
Expr
data
,
Array
<
IndexExpr
>
pool_size
,
Array
<
IndexExpr
>
strides
,
Array
<
IndexExpr
>
padding
,
std
::
string
layout
,
bool
ceil_mode
,
bool
count_include_pad
)
{
auto
attrs
=
make_node
<
AvgPool2DAttrs
>
();
attrs
->
pool_size
=
std
::
move
(
pool_size
);
attrs
->
strides
=
std
::
move
(
strides
);
attrs
->
padding
=
std
::
move
(
padding
);
attrs
->
layout
=
std
::
move
(
layout
);
attrs
->
ceil_mode
=
ceil_mode
;
attrs
->
count_include_pad
=
count_include_pad
;
static
const
Op
&
op
=
Op
::
Get
(
"nn.avg_pool2d_grad"
);
return
CallNode
::
make
(
op
,
{
out_grad
,
data
},
Attrs
(
attrs
),
{});
}
TVM_REGISTER_API
(
"relay.op.nn._make.avg_pool2d_grad"
).
set_body_typed
(
MakeAvgPool2DGrad
);
RELAY_REGISTER_OP
(
"nn.avg_pool2d_grad"
)
.
describe
(
R"code(Gradient of average pooling operation for two dimensional data.
- **out_grad**: This depends on the `layout` parameter. Output gradient is 4D array of
shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
out_height and out_width are are the output size of the pooling operation,
which are calculated as::
out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1
out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1
where padding will be an expanded array based on number of values passed as::
one int : all sides same padding used.
two int : bottom, right use same as top and left.
four int: padding width in the order of (top, left, bottom, right).
When `ceil_mode` is `True`, ceil will be used instead of floor in this
equation.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
- **grad**: This depends on the `layout` parameter. Grad is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.MaxPool2DAttrs"
)
.
set_num_inputs
(
2
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
2
)
.
add_type_rel
(
"MaxPool2DGrad"
,
Pool2DGradRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
Pool2DGradCompute
<
AvgPool2DAttrs
,
topi
::
nn
::
kAvgPool
>
);
}
// namespace relay
}
// namespace tvm
topi/include/topi/detail/ravel_unravel.h
View file @
5c410037
...
...
@@ -42,7 +42,7 @@ using namespace tvm;
*
* \return The index after flattening
*/
inline
Expr
RavelIndex
(
Array
<
Va
r
>
indices
,
Array
<
Expr
>
shape
)
{
inline
Expr
RavelIndex
(
Array
<
Exp
r
>
indices
,
Array
<
Expr
>
shape
)
{
CHECK_EQ
(
indices
.
size
(),
shape
.
size
())
<<
"indices and shape must have equal size"
;
CHECK_GT
(
indices
.
size
(),
0
)
<<
"indices must not be empty"
;
Expr
idx
;
...
...
topi/include/topi/nn/pooling.h
View file @
5c410037
This diff is collapsed.
Click to expand it.
topi/include/topi/reduction.h
View file @
5c410037
...
...
@@ -224,7 +224,7 @@ inline Tensor CommReduceIdx(const Tensor& data,
auto
compute
=
[
ndim
,
keepdims
,
&
real_axis
,
&
reduce_axes
,
&
func
,
&
data
]
(
const
Array
<
Var
>&
indices
)
{
Array
<
Expr
>
eval_range
;
Array
<
Va
r
>
eval_indices
;
Array
<
Exp
r
>
eval_indices
;
int
arg_counter
=
0
;
int
red_counter
=
0
;
...
...
@@ -466,6 +466,22 @@ inline Tensor argmin(const Tensor& data,
return
CommReduceIdx
(
data
,
axis
,
func
,
keepdims
,
atleast1d
);
}
inline
FCommReduce
MakeArgmaxReducer
()
{
auto
fcombine
=
[](
Array
<
Var
>
lhs
,
Array
<
Var
>
rhs
)
{
Array
<
Expr
>
result
;
result
.
push_back
(
tvm
::
ir
::
Select
::
make
(
lhs
[
1
]
>=
rhs
[
1
],
lhs
[
0
],
rhs
[
0
]));
// idx
result
.
push_back
(
tvm
::
ir
::
Select
::
make
(
lhs
[
1
]
>=
rhs
[
1
],
lhs
[
1
],
rhs
[
1
]));
// val
return
result
;
};
auto
fidentity
=
[](
std
::
vector
<
Type
>
types
)
{
Array
<
Expr
>
result
;
result
.
push_back
(
tvm
::
make_const
(
types
[
0
],
-
1
));
// idx
result
.
push_back
(
types
[
1
].
min
());
// val
return
result
;
};
return
MakeCommReducer
(
fcombine
,
fidentity
,
"argmax"
);
}
/*!
* \brief Creates an operation that finds the indices of the maximum
* values over a given axis.
...
...
@@ -484,20 +500,8 @@ inline Tensor argmax(const Tensor& data,
const
Array
<
Integer
>&
axis
,
bool
keepdims
=
false
,
bool
atleast1d
=
false
)
{
auto
fcombine
=
[](
Array
<
Var
>
lhs
,
Array
<
Var
>
rhs
)
{
Array
<
Expr
>
result
;
result
.
push_back
(
tvm
::
ir
::
Select
::
make
(
lhs
[
1
]
>=
rhs
[
1
],
lhs
[
0
],
rhs
[
0
]));
// idx
result
.
push_back
(
tvm
::
ir
::
Select
::
make
(
lhs
[
1
]
>=
rhs
[
1
],
lhs
[
1
],
rhs
[
1
]));
// val
return
result
;
};
auto
fidentity
=
[](
std
::
vector
<
Type
>
types
)
{
Array
<
Expr
>
result
;
result
.
push_back
(
tvm
::
make_const
(
types
[
0
],
-
1
));
// idx
result
.
push_back
(
types
[
1
].
min
());
// val
return
result
;
};
auto
func
=
MakeCommReducer
(
fcombine
,
fidentity
,
"argmax"
);
return
CommReduceIdx
(
data
,
axis
,
func
,
keepdims
,
atleast1d
);
auto
reducer
=
MakeArgmaxReducer
();
return
CommReduceIdx
(
data
,
axis
,
reducer
,
keepdims
,
atleast1d
);
}
/*!
...
...
topi/include/topi/transform.h
View file @
5c410037
...
...
@@ -210,7 +210,8 @@ inline Tensor reshape(const Tensor& x,
auto
x_shape
=
x
->
shape
;
return
compute
(
newshape
,
[
&
](
const
Array
<
Var
>&
indices
)
{
return
x
(
UnravelIndex
(
RavelIndex
(
indices
,
newshape
),
x_shape
));
return
x
(
UnravelIndex
(
RavelIndex
(
Array
<
Expr
>
{
indices
.
begin
(),
indices
.
end
()},
newshape
),
x_shape
));
},
name
,
tag
);
}
...
...
topi/python/topi/generic/nn.py
View file @
5c410037
...
...
@@ -421,6 +421,19 @@ def schedule_pool(outs, layout):
return
_default_schedule
(
outs
,
False
)
@tvm.target.generic_func
def
schedule_pool_grad
(
outs
):
"""Schedule for pool_grad
Parameters
----------
outs: Array of Tensor
The computation graph description of pool
in the format of an array of tensors.
"""
return
_default_schedule
(
outs
,
False
)
@tvm.target.override_native_generic_func
(
"schedule_adaptive_pool"
)
def
schedule_adaptive_pool
(
outs
):
"""Schedule for adaptive pool
...
...
topi/python/topi/nn/pooling.py
View file @
5c410037
...
...
@@ -114,6 +114,68 @@ def pool(data,
return
cpp
.
nn
.
pool
(
data
,
kernel
,
stride
,
padding
,
POOL_TYPE_CODE
[
pool_type
],
ceil_mode
,
layout
,
count_include_pad
)
def
pool_grad
(
grads
,
data
,
kernel
,
stride
,
padding
,
pool_type
,
ceil_mode
=
False
,
layout
=
"NCHW"
,
count_include_pad
=
True
):
"""Gradient of pooling on height and width dimension of data.
It decides the height and width dimension according to the layout string,
in which 'W' and 'H' means width and height respectively.
Width and height dimension cannot be split.
For example, NCHW, NCHW16c, etc. are valid for pool,
while NCHW16w, NCHW16h are not.
See parameter `layout` for more information of the layout string convention.
Parameters
----------
grads : tvm.Tensor
n-D with shape of layout
data : tvm.Tensor
n-D with shape of layout
kernel : list/tuple of two ints
Kernel size, [kernel_height, kernel_width]
stride : list/tuple of two ints
Stride size, [stride_height, stride_width]
padding : list/tuple of four ints
Pad size, [pad_top, pad_left, pad_bottom, pad_right]]
pool_type : str
Pool type, 'max' or 'avg'
ceil_mode : bool
Whether to use ceil when calculating output size.
layout: string
Layout of the input data.
The layout is supposed to be composed of upper cases, lower cases and numbers,
where upper case indicates a dimension and
the corresponding lower case with factor size indicates the split dimension.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block],
in which channel_block=16 is a split of dimension channel.
count_include_pad: bool
Whether include padding in the calculation when pool_type is 'avg'
Returns
-------
output : tvm.Tensor
n-D in the same layout
"""
return
cpp
.
nn
.
pool_grad
(
grads
,
data
,
kernel
,
stride
,
padding
,
POOL_TYPE_CODE
[
pool_type
],
ceil_mode
,
layout
,
count_include_pad
)
def
adaptive_pool
(
data
,
output_size
,
pool_type
,
...
...
topi/python/topi/testing/__init__.py
View file @
5c410037
...
...
@@ -24,3 +24,4 @@ from .strided_slice_python import strided_slice_python
from
.batch_matmul
import
batch_matmul
from
.slice_axis_python
import
slice_axis_python
from
.sequence_mask_python
import
sequence_mask
from
.pool_grad_python
import
pool_grad_nchw
topi/python/topi/testing/pool_grad_python.py
0 → 100644
View file @
5c410037
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Gradient of pooling in python"""
import
numpy
as
np
def
pool_grad_nchw
(
a_np
,
out_grad_np
,
pool_size
,
strides
,
padding
,
pool_type
,
ceil_mode
,
count_include_pad
=
True
):
"""pool_grad for NCHW layout in python"""
dtype
=
a_np
.
dtype
n
,
ic
,
ih
,
iw
=
a_np
.
shape
kh
,
kw
=
pool_size
sh
,
sw
=
strides
pt
,
pl
,
pb
,
pr
=
padding
pad_np
=
np
.
zeros
(
shape
=
(
n
,
ic
,
ih
+
pt
+
pb
,
iw
+
pl
+
pr
))
.
astype
(
dtype
)
no_zero
=
(
range
(
n
),
range
(
ic
),
(
range
(
pt
,
ih
+
pt
)),
(
range
(
pl
,
iw
+
pl
)))
pad_np
[
np
.
ix_
(
*
no_zero
)]
=
a_np
_
,
oc
,
oh
,
ow
=
out_grad_np
.
shape
pool_grad_np
=
np
.
zeros
(
shape
=
a_np
.
shape
)
pad_pool_grad_np
=
np
.
zeros
(
shape
=
pad_np
.
shape
)
if
pool_type
==
'avg'
:
for
i
in
range
(
oh
):
for
j
in
range
(
ow
):
if
count_include_pad
:
shape
=
pad_np
[:,
:,
i
*
sh
:
i
*
sh
+
kh
,
j
*
sw
:
j
*
sw
+
kw
]
.
shape
# this can be different from kh*kw if input size cannot divide stride
pad_count
=
shape
[
2
]
*
shape
[
3
]
else
:
pad_count
=
np
.
sum
(
pad_np
[:,
:,
i
*
sh
:
i
*
sh
+
kh
,
j
*
sw
:
j
*
sw
+
kw
]
>
0
,
axis
=
(
2
,
3
))
# take the first element, as they are the same across batch and channel
pad_count
=
pad_count
.
ravel
()[
0
]
pad_pool_grad_np
[:,
:,
i
*
sh
:
i
*
sh
+
kh
,
j
*
sw
:
j
*
sw
+
kw
]
+=
\
out_grad_np
[:,
:,
i
,
j
]
.
reshape
(
n
,
ic
,
1
,
1
)
/
np
.
maximum
(
pad_count
,
1
)
elif
pool_type
==
'max'
:
for
i
in
range
(
oh
):
for
j
in
range
(
ow
):
a_patch
=
pad_np
[:,
:,
i
*
sh
:
i
*
sh
+
kh
,
j
*
sw
:
j
*
sw
+
kw
]
a_patch
=
np
.
reshape
(
a_patch
,
(
n
,
ic
,
-
1
))
max_indices
=
np
.
argmax
(
a_patch
,
axis
=
2
)
c_idx
,
n_idx
=
np
.
meshgrid
(
range
(
ic
),
range
(
n
),
sparse
=
True
)
h_idx
,
w_idx
=
np
.
unravel_index
(
max_indices
,
(
kh
,
kw
))
pad_pool_grad_np
[:,
:,
i
*
sh
:
i
*
sh
+
kh
,
j
*
sw
:
j
*
sw
+
kw
][
n_idx
,
c_idx
,
h_idx
,
w_idx
]
+=
\
out_grad_np
[
n_idx
,
c_idx
,
i
,
j
]
for
i
in
range
(
pool_grad_np
.
shape
[
2
]):
for
j
in
range
(
pool_grad_np
.
shape
[
3
]):
pool_grad_np
[:,
:,
i
,
j
]
=
pad_pool_grad_np
[:,
:,
i
+
pt
,
j
+
pl
]
return
pool_grad_np
topi/src/topi.cc
View file @
5c410037
...
...
@@ -473,6 +473,13 @@ TVM_REGISTER_GLOBAL("topi.nn.pool")
args
[
5
],
args
[
6
],
args
[
7
]);
});
TVM_REGISTER_GLOBAL
(
"topi.nn.pool_grad"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
nn
::
pool_grad
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
],
static_cast
<
nn
::
PoolType
>
(
static_cast
<
int
>
(
args
[
5
])),
args
[
6
],
args
[
7
],
args
[
8
]);
});
TVM_REGISTER_GLOBAL
(
"topi.nn.global_pool"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
nn
::
global_pool
(
args
[
0
],
...
...
topi/tests/python/test_topi_pooling.py
View file @
5c410037
...
...
@@ -18,6 +18,7 @@
import
numpy
as
np
import
tvm
import
topi
import
topi.testing
import
math
from
topi.util
import
get_const_tuple
...
...
@@ -85,6 +86,57 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_
for
device
in
get_all_backend
():
check_device
(
device
)
def
verify_pool_grad
(
n
,
ic
,
ih
,
kh
,
sh
,
padding
,
pool_type
,
ceil_mode
,
count_include_pad
=
True
):
iw
=
ih
kw
=
kh
sw
=
sh
pt
,
pl
,
pb
,
pr
=
padding
layout
=
"NCHW"
A
=
tvm
.
placeholder
((
n
,
ic
,
ih
,
iw
),
name
=
'A'
)
B
=
topi
.
nn
.
pool
(
A
,
kernel
=
[
kh
,
kw
],
stride
=
[
sh
,
sw
],
padding
=
padding
,
pool_type
=
pool_type
,
ceil_mode
=
ceil_mode
,
layout
=
"NCHW"
,
count_include_pad
=
count_include_pad
)
dtype
=
A
.
dtype
bshape
=
get_const_tuple
(
B
.
shape
)
ashape
=
get_const_tuple
(
A
.
shape
)
if
ceil_mode
:
assert
bshape
[
2
]
==
int
(
math
.
ceil
(
float
(
ashape
[
2
]
-
kh
+
pt
+
pb
)
/
sh
)
+
1
)
assert
bshape
[
3
]
==
int
(
math
.
ceil
(
float
(
ashape
[
3
]
-
kw
+
pl
+
pr
)
/
sw
)
+
1
)
else
:
assert
bshape
[
2
]
==
int
(
math
.
floor
(
float
(
ashape
[
2
]
-
kh
+
pt
+
pb
)
/
sh
)
+
1
)
assert
bshape
[
3
]
==
int
(
math
.
floor
(
float
(
ashape
[
3
]
-
kw
+
pl
+
pr
)
/
sw
)
+
1
)
OutGrad
=
tvm
.
placeholder
(
bshape
,
name
=
'OutGrad'
)
PoolGrad
=
topi
.
nn
.
pool_grad
(
OutGrad
,
A
,
kernel
=
[
kh
,
kw
],
stride
=
[
sh
,
sw
],
padding
=
padding
,
pool_type
=
pool_type
,
ceil_mode
=
ceil_mode
,
layout
=
"NCHW"
,
count_include_pad
=
count_include_pad
)
a_np
=
np
.
random
.
uniform
(
low
=
0.001
,
size
=
(
n
,
ic
,
ih
,
iw
))
.
astype
(
dtype
)
out_grad_np
=
np
.
random
.
uniform
(
low
=
0.001
,
size
=
bshape
)
.
astype
(
dtype
)
pool_grad_np
=
topi
.
testing
.
pool_grad_nchw
(
a_np
,
out_grad_np
,
pool_size
=
(
kh
,
kw
),
strides
=
(
sh
,
sw
),
padding
=
padding
,
pool_type
=
pool_type
,
ceil_mode
=
ceil_mode
,
count_include_pad
=
count_include_pad
)
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_pool_grad
(
PoolGrad
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
out_grad
=
tvm
.
nd
.
array
(
out_grad_np
,
ctx
)
pool_grad
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
PoolGrad
.
shape
),
dtype
=
dtype
),
ctx
)
f
=
tvm
.
build
(
s
,
[
A
,
OutGrad
,
PoolGrad
],
device
)
f
(
a
,
out_grad
,
pool_grad
)
tvm
.
testing
.
assert_allclose
(
pool_grad
.
asnumpy
(),
pool_grad_np
,
rtol
=
1e-5
)
for
device
in
[
'llvm'
]:
# only support llvm
check_device
(
device
)
def
test_pool
():
verify_pool
(
1
,
256
,
32
,
2
,
2
,
[
0
,
0
,
0
,
0
],
'avg'
,
False
,
True
)
verify_pool
(
1
,
256
,
31
,
3
,
3
,
[
1
,
2
,
1
,
2
],
'avg'
,
False
,
True
)
...
...
@@ -100,6 +152,23 @@ def test_pool():
verify_pool
(
1
,
256
,
31
,
3
,
3
,
[
1
,
0
,
3
,
2
],
'max'
,
False
)
verify_pool
(
1
,
256
,
31
,
3
,
3
,
[
3
,
2
,
1
,
0
],
'max'
,
True
)
verify_pool_grad
(
1
,
256
,
32
,
3
,
2
,
[
1
,
1
,
1
,
1
],
'avg'
,
False
,
False
)
verify_pool_grad
(
1
,
256
,
32
,
2
,
2
,
[
0
,
0
,
0
,
0
],
'avg'
,
False
,
True
)
verify_pool_grad
(
1
,
256
,
31
,
3
,
3
,
[
1
,
2
,
1
,
2
],
'avg'
,
False
,
True
)
verify_pool_grad
(
1
,
256
,
32
,
2
,
2
,
[
1
,
2
,
1
,
2
],
'avg'
,
False
,
False
)
verify_pool_grad
(
1
,
256
,
31
,
4
,
4
,
[
2
,
2
,
2
,
2
],
'avg'
,
False
,
False
)
verify_pool_grad
(
1
,
256
,
31
,
4
,
4
,
[
0
,
0
,
0
,
0
],
'avg'
,
False
,
False
)
verify_pool_grad
(
1
,
256
,
32
,
2
,
2
,
[
0
,
0
,
0
,
0
],
'max'
,
False
)
verify_pool_grad
(
1
,
256
,
31
,
3
,
3
,
[
2
,
1
,
2
,
1
],
'max'
,
False
)
verify_pool_grad
(
1
,
256
,
31
,
3
,
3
,
[
2
,
1
,
2
,
1
],
'max'
,
True
)
verify_pool_grad
(
1
,
256
,
31
,
3
,
3
,
[
2
,
1
,
0
,
3
],
'avg'
,
False
,
True
)
verify_pool_grad
(
1
,
256
,
32
,
2
,
2
,
[
0
,
3
,
2
,
1
],
'avg'
,
False
,
False
)
verify_pool_grad
(
1
,
256
,
31
,
3
,
3
,
[
1
,
0
,
3
,
2
],
'max'
,
False
)
verify_pool_grad
(
1
,
256
,
31
,
3
,
3
,
[
3
,
2
,
1
,
0
],
'max'
,
True
)
verify_pool_grad
(
1
,
256
,
32
,
3
,
2
,
[
1
,
1
,
1
,
1
],
'max'
,
False
)
verify_pool_grad
(
1
,
256
,
32
,
1
,
2
,
[
1
,
1
,
1
,
1
],
'avg'
,
False
,
False
)
def
verify_global_pool
(
n
,
c
,
h
,
w
,
pool_type
):
A
=
tvm
.
placeholder
((
n
,
c
,
h
,
w
),
name
=
'A'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment