Commit 25c91d34 by Oldpan Committed by Jared Roesch

Fix a bug of flatten in ONNX to Relay converter (#3180)

* fix onnx frontend flatten bug

* Update onnx.py

* Update onnx.py

* Update onnx.py
parent a364af8f
......@@ -335,6 +335,23 @@ class Reciprocal(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
return _expr.const(1.0) / inputs[0]
class Flatten(OnnxOpConverter):
""" Operator converter for Flatten.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get('axis', 1)
if axis == 1:
out = _op.nn.batch_flatten(inputs[0])
else:
newshape = [0] * (axis + 1)
newshape[axis] = -1
out = _op.reshape(inputs[0], list(newshape))
return out
class Reshape(OnnxOpConverter):
""" Operator converter for Reshape.
"""
......@@ -850,7 +867,7 @@ def _get_convert_map(opset):
# 'InstanceNormalization'
# 'LpNormalization'
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
'Flatten': Renamer('batch_flatten'),
'Flatten': Flatten.get_converter(opset),
'LRN': LRN.get_converter(opset),
# defs/reduction
......
......@@ -211,6 +211,29 @@ def test_squeeze():
tvm.testing.assert_allclose(out_shape, tvm_out.shape)
def test_flatten():
in_shape = (1, 3, 4, 4)
axis = 1
ref_shape = (1, 48)
flatten_node = helper.make_node("Flatten", ["in"], ["out"], axis = axis)
graph = helper.make_graph([flatten_node],
"flatten_test",
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(ref_shape))])
model = helper.make_model(graph, producer_name='flatten_test')
for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape).astype('int32')
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
def test_unsqueeze():
in_shape = (3, 3)
axis = (0, 3, 4)
......@@ -1046,6 +1069,7 @@ def test_LogSoftmax():
{'axis': 1})
if __name__ == '__main__':
test_flatten()
test_reshape()
test_shape()
test_power()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment