Unverified Commit 702db6f9 by Matthew Brookhart Committed by GitHub

Add RoiAlign to Onnx frontend (#5454)

parent 17cd27da
......@@ -26,6 +26,7 @@ from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import vision as _vision
from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
from .common import infer_type, infer_value, infer_value_simulated, get_name
......@@ -1495,6 +1496,34 @@ class TopK(OnnxOpConverter):
return _op.topk(inputs[0], k=K, axis=axis)
class RoiAlign(OnnxOpConverter):
"""Operator converter for TopK
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if len(inputs) != 3:
raise ValueError("Expect 3 inputs only")
x = inputs[0]
rois = inputs[1]
batch_indices = inputs[2]
mode = attr.get("mode", "avg")
if mode != b'avg':
raise ValueError("RoiAlign in Relay only uses avg mode")
output_height = attr.get("output_height", 1)
output_width = attr.get("output_width", 1)
sampling_ratio = attr.get("sampling_ratio", 0)
spatial_scale = attr.get("spatial_scale", 1.0)
batch_indices = _op.expand_dims(batch_indices, axis=1, num_newaxis=1)
batch_indices = _op.cast(
batch_indices, infer_type(rois).type_annotation.dtype)
rois = _op.concatenate([batch_indices, rois], 1)
return _vision.roi_align(x, rois, [output_height, output_width],
spatial_scale, sampling_ratio)
# compatible operators that do NOT require any conversion.
_identity_list = []
......@@ -1592,6 +1621,9 @@ def _get_convert_map(opset):
# Recurrent Layers
'LSTM': LSTM.get_converter(opset),
# defs/vision
'RoiAlign': RoiAlign.get_converter(opset),
# defs/reduction
'ReduceMax': ReduceMax.get_converter(opset),
'ReduceMin': ReduceMin.get_converter(opset),
......
......@@ -36,6 +36,8 @@ bool ROIAlignRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* rois = types[1].as<TensorTypeNode>();
CHECK(data);
CHECK(rois);
const auto& dshape = data->shape;
const auto& rshape = rois->shape;
CHECK(roi_align_attrs);
......
......@@ -2432,6 +2432,68 @@ def test_topk():
verify_topk([n, n, n], 5, 2)
def test_roi_align():
def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0):
output_dims = [num_roi, input_dims[1], output_height, output_width]
node = helper.make_node('RoiAlign',
inputs=['X', 'rois', 'batch_indicies'],
outputs=['Y'],
mode="avg",
output_height=output_height,
output_width=output_width,
sampling_ratio=sampling_ratio,
spatial_scale=spatial_scale,
)
graph = helper.make_graph([node],
"roialign_test",
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)),
helper.make_tensor_value_info(
"rois", TensorProto.FLOAT, [num_roi, 4]),
helper.make_tensor_value_info(
"batch_indicies", TensorProto.INT64, [num_roi, ]),
],
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, output_dims)])
model = helper.make_model(graph, producer_name='roialign_test')
np_data = np.random.uniform(size=input_dims).astype("float32")
np_rois = np.random.uniform(size=[num_roi, 4]).astype(
'float32') * input_dims[2]
np_batch_indicies = np.random.randint(
low=0, high=input_dims[0], size=num_roi)
onnx_out = get_onnxruntime_output(
model, [np_data, np_rois, np_batch_indicies])
for target, ctx in [('llvm', tvm.cpu())]:
tvm_out = get_tvm_output(model, [np_data, np_rois, np_batch_indicies], target, ctx, output_dims,
output_dtype='float32')
tvm.testing.assert_allclose(
onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05)
verify_roi_align((1, 4, 16, 16), 32, 7, 7,
sampling_ratio=0, spatial_scale=1.0)
verify_roi_align((4, 4, 16, 32), 32, 7, 7,
sampling_ratio=0, spatial_scale=1.0)
verify_roi_align((1, 8, 16, 16), 32, 7, 7,
sampling_ratio=0, spatial_scale=1.0)
verify_roi_align((1, 4, 8, 8), 32, 7, 7,
sampling_ratio=0, spatial_scale=1.0)
verify_roi_align((1, 4, 16, 16), 16, 5, 7,
sampling_ratio=0, spatial_scale=1.0)
verify_roi_align((1, 4, 16, 12), 8, 7, 3,
sampling_ratio=0, spatial_scale=1.0)
verify_roi_align((1, 4, 16, 16), 32, 7, 7,
sampling_ratio=0, spatial_scale=0.5)
verify_roi_align((3, 4, 12, 16), 32, 7, 7,
sampling_ratio=0, spatial_scale=1.5)
verify_roi_align((5, 4, 16, 14), 32, 7, 7,
sampling_ratio=1, spatial_scale=1.0)
verify_roi_align((1, 4, 16, 16), 32, 7, 7,
sampling_ratio=2, spatial_scale=1.0)
if __name__ == '__main__':
test_flatten()
test_reshape()
......@@ -2498,3 +2560,4 @@ if __name__ == '__main__':
test_resize()
test_nonzero()
test_topk()
test_roialign()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment