Commit 5c410037 by Wuwei Lin Committed by Tianqi Chen

[TOPI][Relay] max_pool2d & avg_pool2d gradient (#3601)

parent 440df0aa
......@@ -22,6 +22,7 @@ from .op import register_gradient
from .transform import collapse_sum_like, broadcast_to_like, where
from .tensor import exp, negative, power, less
from .tensor import zeros_like, ones_like
from . import nn as _nn
@register_gradient("log")
......@@ -146,3 +147,20 @@ def clip_grad(orig, grad):
zeros = zeros_like(x)
ones = ones_like(x)
return [where(less(x, a_mins), zeros, where(less(a_maxs, x), zeros, ones * grad))]
@register_gradient("nn.max_pool2d")
def max_pool2d_grad(orig, grad):
attrs = orig.attrs
pool_grad = _nn.max_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
strides=attrs.strides, padding=attrs.padding,
layout=attrs.layout, ceil_mode=attrs.ceil_mode)
return [pool_grad]
@register_gradient("nn.avg_pool2d")
def avg_pool2d_grad(orig, grad):
attrs = orig.attrs
pool_grad = _nn.avg_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
strides=attrs.strides, padding=attrs.padding,
layout=attrs.layout, ceil_mode=attrs.ceil_mode,
count_include_pad=attrs.count_include_pad)
return [pool_grad]
......@@ -255,6 +255,28 @@ def schedule_avg_pool2d(attrs, outs, target):
reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# max_pool2d_grad
@reg.register_schedule("nn.max_pool2d_grad")
def schedule_max_pool2d_grad(attrs, outs, target):
"""Schedule definition of max_pool2d_grad"""
with target:
return topi.generic.schedule_pool_grad(outs)
reg.register_pattern("nn.max_pool2d_grad", OpPattern.OUT_ELEMWISE_FUSABLE)
# avg_pool2d_grad
@reg.register_schedule("nn.avg_pool2d_grad")
def schedule_avg_pool2d_grad(attrs, outs, target):
"""Schedule definition of avg_pool2d_grad"""
with target:
return topi.generic.schedule_pool_grad(outs)
reg.register_pattern("nn.avg_pool2d_grad", OpPattern.OUT_ELEMWISE_FUSABLE)
# global_max_pool2d
@reg.register_schedule("nn.global_max_pool2d")
def schedule_global_max_pool2d(_, outs, target):
......
......@@ -327,6 +327,88 @@ def avg_pool2d(data,
return _make.avg_pool2d(data, pool_size, strides, padding,
layout, ceil_mode, count_include_pad)
def max_pool2d_grad(out_grad,
data,
pool_size=(1, 1),
strides=(1, 1),
padding=(0, 0),
layout="NCHW",
ceil_mode=False):
r"""Gradient of 2D maximum pooling operator.
This operator takes out_grad and data as input and calculates gradient of max_pool2d.
Parameters
----------
out_grad : tvm.relay.Expr
The output gradient
data : tvm.relay.Expr
The input data to the operator.
strides : tuple of int, optional
The strides of pooling.
padding : tuple of int, optional
The padding for pooling.
layout : str, optional
Layout of the input.
ceil_mode : bool, optional
To enable or disable ceil while pooling.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.max_pool2d_grad(out_grad, data, pool_size, strides, padding,
layout, ceil_mode)
def avg_pool2d_grad(out_grad,
data,
pool_size=(1, 1),
strides=(1, 1),
padding=(0, 0),
layout="NCHW",
ceil_mode=False,
count_include_pad=False):
r"""Gradient of 2D average pooling operator.
This operator takes out_grad and data as input and calculates gradient of avg_pool2d.
Parameters
----------
out_grad : tvm.relay.Expr
The output gradient
data : tvm.relay.Expr
The input data to the operator.
strides : tuple of int, optional
The strides of pooling.
padding : tuple of int, optional
The padding for pooling.
layout : str, optional
Layout of the input.
ceil_mode : bool, optional
To enable or disable ceil while pooling.
count_include_pad : bool, optional
To include padding to compute the average.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.avg_pool2d_grad(out_grad, data, pool_size, strides, padding,
layout, ceil_mode, count_include_pad)
def global_max_pool2d(data,
layout="NCHW"):
r"""2D global maximum pooling operator.
......
......@@ -251,3 +251,13 @@ class YoloReorgAttrs(Attrs):
@register_relay_attr_node
class ProposalAttrs(Attrs):
"""Attributes used in proposal operators"""
@register_relay_attr_node
class MaxPool2DAttrs(Attrs):
"""Attributes used in max_pool2d operators"""
@register_relay_attr_node
class AvgPool2DAttrs(Attrs):
"""Attributes used in avg_pool2d operators"""
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -557,5 +557,161 @@ RELAY_REGISTER_OP("contrib.adaptive_max_pool2d")
Pool2DInferCorrectLayout<AdaptivePool2DAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", AdaptivePool2DCompute<topi::nn::kMaxPool>);
bool Pool2DGradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
// assign output type
reporter->Assign(types[2], types[1]);
return true;
}
template <typename AttrType, topi::nn::PoolType mode>
Array<Tensor> Pool2DGradCompute(const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_type, const Target& target) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr);
CHECK_EQ(inputs.size(), 2);
auto pool_size = param->pool_size;
auto strides = param->strides;
auto padding = param->padding;
auto ceil_mode = param->ceil_mode;
Layout layout(param->layout);
CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
<< "pool2d_grad currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
<< "pool2d_grad does not support input split on height";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "pool2d_grad does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
<< "Pool2DGrad only support 4-D output gradient (e.g., NCHW)"
<< " or 5-D output gradient (last dimension is a split of channel)";
CHECK(inputs[1].ndim() == 4U || inputs[1].ndim() == 5U)
<< "Pool2DGrad only support 4-D input (e.g., NCHW)"
<< " or 5-D input (last dimension is a split of channel)";
if (param->padding.size() == 1) {
padding.push_back(padding[0]);
padding.push_back(padding[0]);
padding.push_back(padding[0]);
} else if (param->padding.size() == 2) {
padding.push_back(padding[0]);
padding.push_back(padding[1]);
}
if (mode == topi::nn::kAvgPool) {
bool count_include_pad = reinterpret_cast<const AvgPool2DAttrs*>(param)->count_include_pad;
return Array<Tensor>{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding,
mode, ceil_mode, layout.name(), count_include_pad)};
} else {
return Array<Tensor>{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding,
mode, ceil_mode, layout.name())};
}
}
// MaxPool2DGrad
Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array<IndexExpr> pool_size,
Array<IndexExpr> strides, Array<IndexExpr> padding, std::string layout, bool ceil_mode) {
auto attrs = make_node<MaxPool2DAttrs>();
attrs->pool_size = std::move(pool_size);
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->layout = std::move(layout);
attrs->ceil_mode = ceil_mode;
static const Op& op = Op::Get("nn.max_pool2d_grad");
return CallNode::make(op, {out_grad, data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.max_pool2d_grad").set_body_typed(MakeMaxPool2DGrad);
RELAY_REGISTER_OP("nn.max_pool2d_grad")
.describe(R"code(Gradient of max pooling operation for two dimensional data.
- **out_grad**: This depends on the `layout` parameter. Output gradient is 4D array of
shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
out_height and out_width are are the output size of the pooling operation,
which are calculated as::
out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1
out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1
where padding will be an expanded array based on number of values passed as::
one int : all sides same padding used.
two int : bottom, right use same as top and left.
four int: padding width in the order of (top, left, bottom, right).
When `ceil_mode` is `True`, ceil will be used instead of floor in this
equation.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
- **grad**: This depends on the `layout` parameter. Grad is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.MaxPool2DAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("MaxPool2DGrad", Pool2DGradRel)
.set_attr<FTVMCompute>("FTVMCompute", Pool2DGradCompute<MaxPool2DAttrs, topi::nn::kMaxPool>);
// AvgPool2DGrad
Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array<IndexExpr> pool_size,
Array<IndexExpr> strides, Array<IndexExpr> padding, std::string layout, bool ceil_mode,
bool count_include_pad) {
auto attrs = make_node<AvgPool2DAttrs>();
attrs->pool_size = std::move(pool_size);
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->layout = std::move(layout);
attrs->ceil_mode = ceil_mode;
attrs->count_include_pad = count_include_pad;
static const Op& op = Op::Get("nn.avg_pool2d_grad");
return CallNode::make(op, {out_grad, data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.avg_pool2d_grad").set_body_typed(MakeAvgPool2DGrad);
RELAY_REGISTER_OP("nn.avg_pool2d_grad")
.describe(R"code(Gradient of average pooling operation for two dimensional data.
- **out_grad**: This depends on the `layout` parameter. Output gradient is 4D array of
shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
out_height and out_width are are the output size of the pooling operation,
which are calculated as::
out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1
out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1
where padding will be an expanded array based on number of values passed as::
one int : all sides same padding used.
two int : bottom, right use same as top and left.
four int: padding width in the order of (top, left, bottom, right).
When `ceil_mode` is `True`, ceil will be used instead of floor in this
equation.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
- **grad**: This depends on the `layout` parameter. Grad is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.MaxPool2DAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("MaxPool2DGrad", Pool2DGradRel)
.set_attr<FTVMCompute>("FTVMCompute", Pool2DGradCompute<AvgPool2DAttrs, topi::nn::kAvgPool>);
} // namespace relay
} // namespace tvm
......@@ -42,7 +42,7 @@ using namespace tvm;
*
* \return The index after flattening
*/
inline Expr RavelIndex(Array<Var> indices, Array<Expr> shape) {
inline Expr RavelIndex(Array<Expr> indices, Array<Expr> shape) {
CHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size";
CHECK_GT(indices.size(), 0) << "indices must not be empty";
Expr idx;
......
......@@ -224,7 +224,7 @@ inline Tensor CommReduceIdx(const Tensor& data,
auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, &data]
(const Array<Var>& indices) {
Array<Expr> eval_range;
Array<Var> eval_indices;
Array<Expr> eval_indices;
int arg_counter = 0;
int red_counter = 0;
......@@ -466,6 +466,22 @@ inline Tensor argmin(const Tensor& data,
return CommReduceIdx(data, axis, func, keepdims, atleast1d);
}
inline FCommReduce MakeArgmaxReducer() {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<Expr> result;
result.push_back(tvm::ir::Select::make(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::ir::Select::make(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<Type> types) {
Array<Expr> result;
result.push_back(tvm::make_const(types[0], -1)); // idx
result.push_back(types[1].min()); // val
return result;
};
return MakeCommReducer(fcombine, fidentity, "argmax");
}
/*!
* \brief Creates an operation that finds the indices of the maximum
* values over a given axis.
......@@ -484,20 +500,8 @@ inline Tensor argmax(const Tensor& data,
const Array<Integer>& axis,
bool keepdims = false,
bool atleast1d = false) {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<Expr> result;
result.push_back(tvm::ir::Select::make(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::ir::Select::make(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<Type> types) {
Array<Expr> result;
result.push_back(tvm::make_const(types[0], -1)); // idx
result.push_back(types[1].min()); // val
return result;
};
auto func = MakeCommReducer(fcombine, fidentity, "argmax");
return CommReduceIdx(data, axis, func, keepdims, atleast1d);
auto reducer = MakeArgmaxReducer();
return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
}
/*!
......
......@@ -210,7 +210,8 @@ inline Tensor reshape(const Tensor& x,
auto x_shape = x->shape;
return compute(
newshape, [&](const Array<Var>& indices) {
return x(UnravelIndex(RavelIndex(indices, newshape), x_shape));
return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape),
x_shape));
}, name, tag);
}
......
......@@ -421,6 +421,19 @@ def schedule_pool(outs, layout):
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_pool_grad(outs):
"""Schedule for pool_grad
Parameters
----------
outs: Array of Tensor
The computation graph description of pool
in the format of an array of tensors.
"""
return _default_schedule(outs, False)
@tvm.target.override_native_generic_func("schedule_adaptive_pool")
def schedule_adaptive_pool(outs):
"""Schedule for adaptive pool
......
......@@ -114,6 +114,68 @@ def pool(data,
return cpp.nn.pool(data, kernel, stride, padding,
POOL_TYPE_CODE[pool_type], ceil_mode, layout, count_include_pad)
def pool_grad(grads,
data,
kernel,
stride,
padding,
pool_type,
ceil_mode=False,
layout="NCHW",
count_include_pad=True):
"""Gradient of pooling on height and width dimension of data.
It decides the height and width dimension according to the layout string,
in which 'W' and 'H' means width and height respectively.
Width and height dimension cannot be split.
For example, NCHW, NCHW16c, etc. are valid for pool,
while NCHW16w, NCHW16h are not.
See parameter `layout` for more information of the layout string convention.
Parameters
----------
grads : tvm.Tensor
n-D with shape of layout
data : tvm.Tensor
n-D with shape of layout
kernel : list/tuple of two ints
Kernel size, [kernel_height, kernel_width]
stride : list/tuple of two ints
Stride size, [stride_height, stride_width]
padding : list/tuple of four ints
Pad size, [pad_top, pad_left, pad_bottom, pad_right]]
pool_type : str
Pool type, 'max' or 'avg'
ceil_mode : bool
Whether to use ceil when calculating output size.
layout: string
Layout of the input data.
The layout is supposed to be composed of upper cases, lower cases and numbers,
where upper case indicates a dimension and
the corresponding lower case with factor size indicates the split dimension.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block],
in which channel_block=16 is a split of dimension channel.
count_include_pad: bool
Whether include padding in the calculation when pool_type is 'avg'
Returns
-------
output : tvm.Tensor
n-D in the same layout
"""
return cpp.nn.pool_grad(grads, data, kernel,
stride, padding, POOL_TYPE_CODE[pool_type],
ceil_mode, layout, count_include_pad)
def adaptive_pool(data,
output_size,
pool_type,
......
......@@ -24,3 +24,4 @@ from .strided_slice_python import strided_slice_python
from .batch_matmul import batch_matmul
from .slice_axis_python import slice_axis_python
from .sequence_mask_python import sequence_mask
from .pool_grad_python import pool_grad_nchw
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Gradient of pooling in python"""
import numpy as np
def pool_grad_nchw(a_np, out_grad_np, pool_size, strides, padding, pool_type, ceil_mode,
count_include_pad=True):
"""pool_grad for NCHW layout in python"""
dtype = a_np.dtype
n, ic, ih, iw = a_np.shape
kh, kw = pool_size
sh, sw = strides
pt, pl, pb, pr = padding
pad_np = np.zeros(shape=(n, ic, ih+pt+pb, iw+pl+pr)).astype(dtype)
no_zero = (range(n), range(ic), (range(pt, ih+pt)), (range(pl, iw+pl)))
pad_np[np.ix_(*no_zero)] = a_np
_, oc, oh, ow = out_grad_np.shape
pool_grad_np = np.zeros(shape=a_np.shape)
pad_pool_grad_np = np.zeros(shape=pad_np.shape)
if pool_type == 'avg':
for i in range(oh):
for j in range(ow):
if count_include_pad:
shape = pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw].shape
# this can be different from kh*kw if input size cannot divide stride
pad_count = shape[2] * shape[3]
else:
pad_count = np.sum(
pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2, 3))
# take the first element, as they are the same across batch and channel
pad_count = pad_count.ravel()[0]
pad_pool_grad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] += \
out_grad_np[:, :, i, j].reshape(n,ic,1,1) / np.maximum(pad_count, 1)
elif pool_type =='max':
for i in range(oh):
for j in range(ow):
a_patch = pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw]
a_patch = np.reshape(a_patch, (n, ic, -1))
max_indices = np.argmax(a_patch, axis=2)
c_idx, n_idx = np.meshgrid(range(ic), range(n), sparse=True)
h_idx, w_idx = np.unravel_index(max_indices, (kh, kw))
pad_pool_grad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw][n_idx, c_idx, h_idx, w_idx] += \
out_grad_np[n_idx, c_idx, i, j]
for i in range(pool_grad_np.shape[2]):
for j in range(pool_grad_np.shape[3]):
pool_grad_np[:, :, i, j] = pad_pool_grad_np[:, :, i + pt, j + pl]
return pool_grad_np
......@@ -473,6 +473,13 @@ TVM_REGISTER_GLOBAL("topi.nn.pool")
args[5], args[6], args[7]);
});
TVM_REGISTER_GLOBAL("topi.nn.pool_grad")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::pool_grad(args[0], args[1], args[2], args[3], args[4],
static_cast<nn::PoolType>(static_cast<int>(args[5])),
args[6], args[7], args[8]);
});
TVM_REGISTER_GLOBAL("topi.nn.global_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::global_pool(args[0],
......
......@@ -18,6 +18,7 @@
import numpy as np
import tvm
import topi
import topi.testing
import math
from topi.util import get_const_tuple
......@@ -85,6 +86,57 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_
for device in get_all_backend():
check_device(device)
def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True):
iw = ih
kw = kh
sw = sh
pt, pl, pb, pr = padding
layout = "NCHW"
A = tvm.placeholder((n, ic, ih, iw), name='A')
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode,
layout="NCHW", count_include_pad=count_include_pad)
dtype = A.dtype
bshape = get_const_tuple(B.shape)
ashape = get_const_tuple(A.shape)
if ceil_mode:
assert bshape[2] == int(math.ceil(float(ashape[2] - kh + pt + pb) / sh) + 1)
assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pl + pr) / sw) + 1)
else:
assert bshape[2] == int(math.floor(float(ashape[2] - kh + pt + pb) / sh) + 1)
assert bshape[3] == int(math.floor(float(ashape[3] - kw + pl + pr) / sw) + 1)
OutGrad = tvm.placeholder(bshape, name='OutGrad')
PoolGrad = topi.nn.pool_grad(OutGrad, A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode,
layout="NCHW", count_include_pad=count_include_pad)
a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype)
out_grad_np = np.random.uniform(low=0.001, size=bshape).astype(dtype)
pool_grad_np = topi.testing.pool_grad_nchw(a_np, out_grad_np, pool_size=(kh, kw),
strides=(sh, sw), padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode,
count_include_pad=count_include_pad)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_pool_grad(PoolGrad)
a = tvm.nd.array(a_np, ctx)
out_grad = tvm.nd.array(out_grad_np, ctx)
pool_grad = tvm.nd.array(np.zeros(get_const_tuple(PoolGrad.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, OutGrad, PoolGrad], device)
f(a, out_grad, pool_grad)
tvm.testing.assert_allclose(pool_grad.asnumpy(), pool_grad_np, rtol=1e-5)
for device in ['llvm']: # only support llvm
check_device(device)
def test_pool():
verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True)
verify_pool(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True)
......@@ -100,6 +152,23 @@ def test_pool():
verify_pool(1, 256, 31, 3, 3, [1, 0, 3, 2], 'max', False)
verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True)
verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'avg', False, False)
verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True)
verify_pool_grad(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True)
verify_pool_grad(1, 256, 32, 2, 2, [1, 2, 1, 2], 'avg', False, False)
verify_pool_grad(1, 256, 31, 4, 4, [2, 2, 2, 2], 'avg', False, False)
verify_pool_grad(1, 256, 31, 4, 4, [0, 0, 0, 0], 'avg', False, False)
verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False)
verify_pool_grad(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', False)
verify_pool_grad(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', True)
verify_pool_grad(1, 256, 31, 3, 3, [2, 1, 0, 3], 'avg', False, True)
verify_pool_grad(1, 256, 32, 2, 2, [0, 3, 2, 1], 'avg', False, False)
verify_pool_grad(1, 256, 31, 3, 3, [1, 0, 3, 2], 'max', False)
verify_pool_grad(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True)
verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'max', False)
verify_pool_grad(1, 256, 32, 1, 2, [1, 1, 1, 1], 'avg', False, False)
def verify_global_pool(n, c, h, w, pool_type):
A = tvm.placeholder((n, c, h, w), name='A')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment