Commit b64f3f1c by Yao Wang Committed by Yizhi Liu

[Relay][OP] MultiboxPrior (#1882)

* Relay MultiboxPrior Operator

* Fix lint

* Fix build

* Add test for default args
parent 201cfdc5
...@@ -12,6 +12,34 @@ ...@@ -12,6 +12,34 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
/*! \brief Attributes used in multibox_prior operators */
struct MultiBoxPriorAttrs : public tvm::AttrsNode<MultiBoxPriorAttrs> {
Array<IndexExpr> sizes;
Array<IndexExpr> ratios;
Array<IndexExpr> steps;
Array<IndexExpr> offsets;
bool clip;
TVM_DECLARE_ATTRS(MultiBoxPriorAttrs, "relay.attrs.MultiBoxPriorAttrs") {
TVM_ATTR_FIELD(sizes)
.set_default(Array<IndexExpr>({static_cast<float>(1.0)}))
.describe("List of sizes of generated MultiBoxPriores.");
TVM_ATTR_FIELD(ratios)
.set_default(Array<IndexExpr>({static_cast<float>(1.0)}))
.describe("List of aspect ratios of generated MultiBoxPriores.");
TVM_ATTR_FIELD(steps)
.set_default(Array<IndexExpr>({static_cast<float>(-1.0),
static_cast<float>(-1.0)}))
.describe("Priorbox step across y and x, -1 for auto calculation.");
TVM_ATTR_FIELD(offsets)
.set_default(Array<IndexExpr>({static_cast<float>(0.5),
static_cast<float>(0.5)}))
.describe("Priorbox center offsets, y and x respectively.");
TVM_ATTR_FIELD(clip).set_default(false)
.describe("Whether to clip out-of-boundary boxes.");
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_VISION_H_ #endif // TVM_RELAY_ATTRS_VISION_H_
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
"""Vision network related operators.""" """Vision network related operators."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .multibox import *
"""Multibox operations."""
from __future__ import absolute_import as _abs
from . import _make
def multibox_prior(data,
sizes=(1.0,),
ratios=(1.0,),
steps=(-1.0, -1.0),
offsets=(0.5, 0.5),
clip=False):
"""Generate prior(anchor) boxes from data, sizes and ratios.
Parameters
----------
data : relay.Expr
The input data tensor.
sizes : tuple of float, optional
Tuple of sizes for anchor boxes.
ratios : tuple of float, optional
Tuple of ratios for anchor boxes.
steps : Tuple of float, optional
Priorbox step across y and x, -1 for auto calculation.
offsets : tuple of int, optional
Priorbox center offsets, y and x respectively.
clip : boolean, optional
Whether to clip out-of-boundary boxes.
Returns
-------
out : relay.Expr
3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4]
"""
return _make.multibox_prior(data, sizes, ratios, steps, offsets, clip)
/*!
* Copyright (c) 2018 by Contributors
* \file multibox_op.cc
* \brief Multibox related operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/vision.h>
#include <vector>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(MultiBoxPriorAttrs);
bool MultiboxPriorRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
const MultiBoxPriorAttrs* param = attrs.as<MultiBoxPriorAttrs>();
const auto& dshape = data->shape;
CHECK_EQ(dshape.size(), 4) << "Input data should be 4D: "
"[batch, channel, height, width]";
IndexExpr in_height = dshape[2];
IndexExpr in_width = dshape[3];
int num_sizes = static_cast<int>(param->sizes.size());
int num_ratios = static_cast<int>(param->ratios.size());
// since input sizes are same in each batch, we could share MultiBoxPrior
std::vector<IndexExpr> oshape(
{1, in_height * in_width * (num_sizes + num_ratios - 1), 4});
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}
Expr MakeMultiBoxPrior(Expr data,
Array<IndexExpr> sizes,
Array<IndexExpr> ratios,
Array<IndexExpr> steps,
Array<IndexExpr> offsets,
bool clip) {
auto attrs = make_node<MultiBoxPriorAttrs>();
attrs->sizes = std::move(sizes);
attrs->ratios = std::move(ratios);
attrs->steps = std::move(steps);
attrs->offsets = std::move(offsets);
attrs->clip = clip;
static const Op& op = Op::Get("vision.multibox_prior");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.vision._make.multibox_prior")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 6>(MakeMultiBoxPrior, args, rv);
});
RELAY_REGISTER_OP("vision.multibox_prior")
.describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios."
)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(4)
.add_type_rel("MultiBoxPrior", MultiboxPriorRel);
} // namespace relay
} // namespace tvm
...@@ -124,6 +124,37 @@ def test_binary_broadcast(): ...@@ -124,6 +124,37 @@ def test_binary_broadcast():
ftype = func.checked_type ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32") assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32")
def test_multibox_prior():
sizes = (0.3, 1.5, 0.7)
ratios = (1.3, 2.4)
steps = (2.0, 1.5)
offsets = (0.2, 0.3)
clip = True
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 3, 56, 56
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.vision.multibox_prior(x.var, sizes, ratios,
steps, offsets, clip))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(1, h * w * (len(sizes) + len(ratios) - 1), 4), "float32")
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 24, 32, 32
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.vision.multibox_prior(x.var))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(1, h * w, 4), "float32")
def test_where(): def test_where():
ib = relay.ir_builder.IRBuilder() ib = relay.ir_builder.IRBuilder()
...@@ -144,3 +175,4 @@ if __name__ == "__main__": ...@@ -144,3 +175,4 @@ if __name__ == "__main__":
test_binary_op() test_binary_op()
test_binary_broadcast_op() test_binary_broadcast_op()
test_where() test_where()
test_multibox_prior()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment