Commit 8b1d07ff by Wuwei Lin Committed by masahi

[RELAY][OP] ROI Align (#2618)

parent 80f8e982
......@@ -74,6 +74,30 @@ struct NMSAttrs : public tvm::AttrsNode<NMSAttrs>{
}
};
/*! \brief Attributes used in roi_align operators */
struct ROIAlignAttrs : public tvm::AttrsNode<ROIAlignAttrs> {
Array<IndexExpr> pooled_size;
double spatial_scale;
int sample_ratio;
std::string layout;
TVM_DECLARE_ATTRS(ROIAlignAttrs, "relay.attrs.ROIAlignAttrs") {
TVM_ATTR_FIELD(pooled_size).describe("Output size of roi align.");
TVM_ATTR_FIELD(spatial_scale)
.describe(
"Ratio of input feature map height (or w) to raw image height (or w). "
"Equals the reciprocal of total stride in convolutional layers, which should be "
"in range (0.0, 1.0]");
TVM_ATTR_FIELD(sample_ratio)
.set_default(-1)
.describe("Optional sampling ratio of ROI align, using adaptive size by default.");
TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
"Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_VISION_H_
......@@ -268,6 +268,15 @@ def _mx_multibox_detection(inputs, attrs):
return _op.vision.nms(ret[0], ret[1], **new_attrs1)
def _mx_roi_align(inputs, attrs):
new_attrs = {}
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
new_attrs["spatial_scale"] = attrs.get_float("spatial_scale")
new_attrs["sample_ratio"] = attrs.get_int("sample_ratio", -1)
new_attrs["layout"] = "NCHW"
return _op.vision.roi_align(inputs[0], inputs[1], **new_attrs)
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
......@@ -357,6 +366,7 @@ _convert_map = {
# vision
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
"_contrib_MultiBoxDetection" : _mx_multibox_detection,
"_contrib_ROIAlign" : _mx_roi_align,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
......
......@@ -4,4 +4,6 @@ from __future__ import absolute_import as _abs
from .multibox import *
from .nms import *
from .rcnn import *
from . import _multibox
from . import _rcnn
# pylint: disable=invalid-name, unused-argument
"""Faster R-CNN and Mask R-CNN operations."""
import topi
from topi.util import get_const_tuple
from .. import op as reg
from ..op import OpPattern
@reg.register_compute("vision.roi_align")
def compute_roi_align(attrs, inputs, _, target):
"""Compute definition of roi_align"""
assert attrs.layout == "NCHW"
return [topi.vision.rcnn.roi_align_nchw(
inputs[0], inputs[1], pooled_size=get_const_tuple(attrs.pooled_size),
spatial_scale=attrs.spatial_scale, sample_ratio=attrs.sample_ratio)]
@reg.register_schedule("vision.roi_align")
def schedule_roi_align(_, outs, target):
"""Schedule definition of roi_align"""
with target:
return topi.generic.vision.schedule_roi_align(outs)
reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE)
"""Faster R-CNN and Mask R-CNN operations."""
from . import _make
def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='NCHW'):
"""ROI align operator.
Parameters
----------
data : relay.Expr
4-D tensor with shape [batch, channel, height, width]
rois : relay.Expr
2-D tensor with shape [num_roi, 5]. The last dimension should be in format of
[batch_index, w_start, h_start, w_end, h_end]
pooled_size : list/tuple of two ints
output size
spatial_scale : float
Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal
of total stride in convolutional layers, which should be in range (0.0, 1.0]
sample_ratio : int
Optional sampling ratio of ROI align, using adaptive size by default.
Returns
-------
output : relay.Expr
4-D tensor with shape [num_roi, channel, pooled_size, pooled_size]
"""
return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout)
/*!
* Copyright (c) 2019 by Contributors
* \file rcnn_op.cc
* \brief Faster RCNN and Mask RCNN operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/vision.h>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(ROIAlignAttrs);
bool ROIAlignRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
auto roi_align_attrs = attrs.as<ROIAlignAttrs>();
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* rois = types[1].as<TensorTypeNode>();
const auto& dshape = data->shape;
const auto& rshape = rois->shape;
CHECK(roi_align_attrs);
CHECK_EQ(dshape.size(), 4) << "Input data should be 4-D.";
CHECK_EQ(rshape.size(), 2) << "Input rois should be 2-D.";
CHECK_EQ(roi_align_attrs->layout, "NCHW") << "ROI Align only supports NCHW layout";
// assign output type
std::vector<IndexExpr> oshape(
{rshape[0], dshape[1], roi_align_attrs->pooled_size[0], roi_align_attrs->pooled_size[1]});
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
}
Expr MakeROIAlign(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spatial_scale,
int sample_ratio, std::string layout) {
auto attrs = make_node<ROIAlignAttrs>();
attrs->pooled_size = pooled_size;
attrs->spatial_scale = spatial_scale;
attrs->sample_ratio = sample_ratio;
attrs->layout = layout;
static const Op& op = Op::Get("vision.roi_align");
return CallNode::make(op, {data, rois}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.vision._make.roi_align")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 6>(MakeROIAlign, args, rv);
});
RELAY_REGISTER_OP("vision.roi_align")
.describe(R"doc(ROI Align operator.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
- **rois**: 2D array of shape (num_roi, 5). The last dimension should be in format of
[batch_index, w_start, h_start, w_end, h_end].
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(num_roi, channels, pooled_height, pooled_width) if `layout` is `NCHW`.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("rois", "Tensor", "The input rois")
.set_support_level(5)
.add_type_rel("ROIAlign", ROIAlignRel);
} // namespace relay
} // namespace tvm
......@@ -273,9 +273,44 @@ def test_multibox_transform_loc():
test_threshold()
def test_roi_align():
def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_ratio):
data = relay.var("data", relay.ty.TensorType(data_shape, "float32"))
rois = relay.var("rois", relay.ty.TensorType(rois_shape, "float32"))
z = relay.vision.roi_align(data, rois, pooled_size=(pooled_size, pooled_size),
spatial_scale=spatial_scale, sample_ratio=sample_ratio,
layout="NCHW")
zz = relay.ir_pass.infer_type(z)
batch, channel, in_size, _ = data_shape
num_roi = rois_shape[0]
assert zz.checked_type == relay.ty.TensorType(
(num_roi, channel, pooled_size, pooled_size), "float32")
func = relay.Function([data, rois], z)
func = relay.ir_pass.infer_type(func)
np_data = np.random.uniform(size=data_shape).astype("float32")
np_rois = np.random.uniform(size=rois_shape).astype('float32') * in_size
np_rois[:, 0] = np.random.randint(low = 0, high = batch, size = num_roi)
ref_res = topi.testing.roi_align_nchw_python(np_data, np_rois, pooled_size=pooled_size,
spatial_scale=spatial_scale,
sample_ratio=sample_ratio)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(np_data, np_rois)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-4)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res2 = intrp2.evaluate(func)(np_data, np_rois)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-4)
verify_roi_align((1, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=1.0, sample_ratio=-1)
verify_roi_align((4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5, sample_ratio=2)
if __name__ == "__main__":
test_resize_infer_type()
test_resize()
test_multibox_prior()
test_multibox_transform_loc()
test_nms()
test_roi_align()
......@@ -68,8 +68,8 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
if sample_ratio > 0:
roi_bin_grid_h = roi_bin_grid_w = tvm.const(sample_ratio, 'int32')
else:
roi_bin_grid_h = tvm.ceil(roi_h / pooled_size).astype('int32')
roi_bin_grid_w = tvm.ceil(roi_w / pooled_size).astype('int32')
roi_bin_grid_h = tvm.ceil(roi_h / pooled_size_h).astype('int32')
roi_bin_grid_w = tvm.ceil(roi_w / pooled_size_w).astype('int32')
count = roi_bin_grid_h * roi_bin_grid_w
rh = tvm.reduce_axis((0, roi_bin_grid_h))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment