Commit d9a9a8b6 by Zhi Committed by Haichen Shen

[relay][op] multibox_transform_loc (#2315)

parent 10b6e7e0
......@@ -40,6 +40,24 @@ struct MultiBoxPriorAttrs : public tvm::AttrsNode<MultiBoxPriorAttrs> {
}
};
struct MultiBoxTransformLocAttrs
: public tvm::AttrsNode<MultiBoxTransformLocAttrs> {
bool clip;
double threshold;
Array<IndexExpr> variances;
TVM_DECLARE_ATTRS(MultiBoxTransformLocAttrs,
"relay.attrs.MultiBoxTransformLocAttrs") {
TVM_ATTR_FIELD(clip).set_default(true)
.describe("Clip out-of-boundary boxes.");
TVM_ATTR_FIELD(threshold).set_default(0.01)
.describe("Threshold to be a positive prediction.");
TVM_ATTR_FIELD(variances)
.set_default(Array<IndexExpr>({0.1f, 0.1f , 0.2f, 0.2f}))
.describe("Variances to be decoded from box regression output.");
}
};
/*! \brief Attributes used in non_maximum_suppression operators */
struct NMSAttrs : public tvm::AttrsNode<NMSAttrs>{
double overlap_threshold;
......
......@@ -36,3 +36,39 @@ def multibox_prior(data,
3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4]
"""
return _make.multibox_prior(data, sizes, ratios, steps, offsets, clip)
def multibox_transform_loc(cls_prob,
loc_pred,
anchor,
clip=True,
threshold=0.01,
variance=(0.1, 0.1, 0.2, 0.2)):
"""Location transformation for multibox detection
Parameters
----------
cls_prob : tvm.relay.Expr
Class probabilities.
loc_pred : tvm.relay.Expr
Location regression predictions.
anchor : tvm.relay.Expr
Prior anchor boxes.
clip : boolean, optional
Whether to clip out-of-boundary boxes.
threshold : double, optional
Threshold to be a positive prediction.
variance : Tuple of float, optional
Variances to be decoded from box regression output.
Returns
-------
ret : tuple of tvm.relay.Expr
"""
return _make.multibox_transform_loc(cls_prob, loc_pred, anchor, clip,
threshold, variance)
......@@ -68,5 +68,78 @@ RELAY_REGISTER_OP("vision.multibox_prior")
.set_support_level(5)
.add_type_rel("MultiBoxPrior", MultiboxPriorRel);
TVM_REGISTER_NODE_TYPE(MultiBoxTransformLocAttrs);
bool MultiBoxTransformLocRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs, const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* cls_prob = types[0].as<TensorTypeNode>();
const auto* loc_pred = types[1].as<TensorTypeNode>();
const auto* anchor = types[2].as<TensorTypeNode>();
CHECK(cls_prob != nullptr && loc_pred != nullptr && anchor != nullptr);
const auto& cls_shape = cls_prob->shape;
const auto& loc_shape = loc_pred->shape;
const auto& anchor_shape = anchor->shape;
CHECK_EQ(cls_shape.size(), 3U)
<< "The dimension of class probability should be 3, but received "
<< cls_shape.size();
CHECK_EQ(loc_shape.size(), 2U)
<< "The dimension of location prediction should be 2, but received "
<< loc_shape.size();
CHECK_EQ(anchor_shape.size(), 3U)
<< "The dimension of anchor should be 3, but received "
<< anchor_shape.size();
CHECK(reporter->AssertEQ(cls_shape[2], anchor_shape[1]))
<< "Number of anchors mismatch found";
CHECK(reporter->AssertEQ(cls_shape[2] * 4, loc_shape[1]))
<< "# anchors mismatch with # loc.";
CHECK(reporter->Assert(anchor_shape[1] > 0)) << "Number of anchors must > 0.";
CHECK(reporter->AssertEQ(anchor_shape[2], 4));
std::vector<IndexExpr> oshape0({cls_shape[0], anchor_shape[1], 6});
std::vector<IndexExpr> oshape1({cls_shape[0]});
std::vector<Type> fields;
fields.push_back(TensorTypeNode::make(oshape0, cls_prob->dtype));
fields.push_back(TensorTypeNode::make(oshape1, Int(32)));
// assign output type
reporter->Assign(types[3], TupleTypeNode::make(Array<Type>(fields)));
return true;
}
Expr MakeMultiBoxTransformLoc(Expr cls_prob,
Expr loc_pred,
Expr anchor,
bool clip,
double threshold,
Array<IndexExpr> variances) {
auto attrs = make_node<MultiBoxTransformLocAttrs>();
attrs->clip = std::move(clip);
attrs->threshold = std::move(threshold);
attrs->variances = std::move(variances);
static const Op& op = Op::Get("vision.multibox_transform_loc");
return CallNode::make(op, {cls_prob, loc_pred, anchor}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.vision._make.multibox_transform_loc")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 6>(MakeMultiBoxTransformLoc, args, rv);
});
RELAY_REGISTER_OP("vision.multibox_transform_loc")
.describe(R"doc("Location transformation for multibox detection."
)doc" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.MultiBoxTransformLocAttrs")
.set_num_inputs(3)
.add_argument("cls_prob", "Tensor", "Class probabilities.")
.add_argument("loc_pred", "Tensor", "Location regression predictions.")
.add_argument("anchor", "Tensor", "Multibox prior anchor boxes")
.add_type_rel("MultiBoxTransformLoc", MultiBoxTransformLocRel)
.set_support_level(5);
} // namespace relay
} // namespace tvm
......@@ -102,8 +102,64 @@ def test_nms():
(n, num_anchors, 6), "float32")
def test_multibox_transform_loc():
def test_default_value():
num_anchors = 5
num_classes = 5
cls_prob = relay.var(
"cls_prob",
relay.ty.TensorType((1, num_anchors, num_classes), "float32"))
loc_pred = relay.var(
"loc_pred", relay.ty.TensorType((1, num_anchors * 4), "float32"))
anchors = relay.var(
"anchors", relay.ty.TensorType((1, num_anchors, 4), "float32"))
ret = relay.vision.multibox_transform_loc(
cls_prob=cls_prob, loc_pred=loc_pred, anchor=anchors)
ret = relay.ir_pass.infer_type(ret)
ref_type = relay.ty.TupleType(
tvm.convert([
relay.ty.TensorType((1, num_anchors, 6), "float32"),
relay.ty.TensorType((1, ), "int")
]))
assert ret.checked_type == ref_type
def test_threshold():
num_anchors = 5
num_classes = 5
n = tvm.var("n")
cls_prob = relay.var(
"cls_prob",
relay.ty.TensorType((n, num_anchors, num_classes), "float32"))
loc_pred = relay.var(
"loc_pred", relay.ty.TensorType((n, num_anchors * 4), "float32"))
anchors = relay.var(
"anchors", relay.ty.TensorType((1, num_anchors, 4), "float32"))
threshold = 0.02
variance = (0.2, 0.2, 0.3, 0.3)
ret = relay.vision.multibox_transform_loc(
cls_prob=cls_prob,
loc_pred=loc_pred,
anchor=anchors,
threshold=threshold,
variance=variance)
ret = relay.ir_pass.infer_type(ret)
ref_type = relay.ty.TupleType(
tvm.convert([
relay.ty.TensorType((n, num_anchors, 6), "float32"),
relay.ty.TensorType((n, ), "int")
]))
assert ret.checked_type == ref_type
test_default_value()
test_threshold()
if __name__ == "__main__":
test_resize_infer_type()
test_resize()
test_multibox_prior()
test_multibox_transform_loc()
test_nms()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment