Commit e274e66e by Takato Yamada Committed by Zhi

[relay][op] add expand op (from ONNX) to relay frontend (#4483)

* Add Expand to onnx.py

* add test function for expand

* Fix a onnx frontend test

* Add tests for the value itself instead of shape only on test_expand

* Cleaned up some unnecessary modifications.
parent d430fbb5
......@@ -1080,6 +1080,52 @@ class Or(Elemwise):
def _impl_v7(cls, inputs, attr, params):
return _op.logical_or(inputs[0], inputs[1])
class Expand(OnnxOpConverter):
""" Operator converter for Expand.
"""
@classmethod
def _impl_v8(cls, inputs, attr, params):
in_shape = np.array(infer_shape(inputs[0])).astype('int32')
if get_name(inputs[1]) in params:
shape = params[inputs[1].name_hint].asnumpy().astype('int32')
else:
shape = infer_value_simulated(inputs[1], params).asnumpy().astype('int32')
# Currently 'op.broadcast_to' expect the rank of the given 'shape'
# (the 2nd input) is always higher than that of the given 'input' (the 1st input)
# However, ONNX Expand supports multi-directional broadcasting, which allows
# above pattern and also some extent of 'shape' can be smaller than the corresponding
# extent of 'input'. In this case, the extent of 'shape' must be 1.
# https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
# In above cases, we cannot directorly apply 'op.broadcast_to' instead of 'expand'
# so, here we solved this problem by expanding the given 'shape' itself.
def expand_shape(in_shape, shape):
""" A function expands the shape when the rank is lower than that of the given
intput. Also it replaces the extent of the shape with the corresponding extent
of the intput when it is 1.
"""
# here we flip the shapes because this can be more simply written
# when the innermost dimension is located at the index 0.
in_shape = np.flip(in_shape, axis=0)
shape = np.flip(shape, axis=0)
if in_shape.size < shape.size:
for i in range(shape.size):
if i < in_shape.size and in_shape[i] > shape[i]:
shape[i] = in_shape[i]
else:
for i in range(in_shape.size):
if i >= shape.size:
np.append(shape, in_shape[i])
elif shape[i] == 1:
shape[i] = in_shape[i]
new_shape = np.flip(shape, axis=0)
return new_shape
shape = expand_shape(in_shape, shape)
return _op.broadcast_to(inputs[0], shape=tuple(shape))
# compatible operators that do NOT require any conversion.
_identity_list = []
......@@ -1187,6 +1233,7 @@ def _get_convert_map(opset):
# defs/tensor
'Cast': Cast.get_converter(opset),
'Reshape': Reshape.get_converter(opset),
'Expand': Expand.get_converter(opset),
'Concat': Concat.get_converter(opset),
'Split': Split.get_converter(opset),
'Slice': Slice.get_converter(opset),
......
......@@ -142,6 +142,46 @@ def test_reshape():
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
def test_expand():
def _test_expand(name, data, shape, ref_data):
shape_array = np.array(shape)
shape_node = onnx.helper.make_node('Constant',
inputs=[],
outputs=['shape'],
value=onnx.helper.make_tensor(name = 'const_tensor',
data_type = onnx.TensorProto.INT32,
dims = shape_array.shape,
vals = shape_array.flatten().astype('int32')))
expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])
graph = helper.make_graph([shape_node, expand_node],
"expand_test",
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(data.shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(ref_data.shape))])
model = helper.make_model(graph, producer_name=name)
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, data, target, ctx, ref_data.shape, 'float32')
tvm.testing.assert_allclose(ref_data, tvm_out)
in_shape = (3, 1)
shape = (3, 4)
data = np.random.uniform(size=in_shape).astype(np.float32)
ref_data = np.tile(data, 4)
_test_expand('expand_with_dim_unchanged_test', data, shape, ref_data)
in_shape = (3, 1)
shape = (2, 1, 6)
data = np.random.uniform(size=in_shape).astype(np.float32)
ref_data = data * np.ones(shape, dtype=np.float32)
_test_expand('expand_with_dim_changed_test', data, shape, ref_data)
def verify_depth_to_space(inshape, outshape, mode, blockSize):
node = onnx.helper.make_node('DepthToSpace',
inputs=['x'],
......@@ -1710,6 +1750,7 @@ if __name__ == '__main__':
test_flatten()
test_reshape()
test_shape()
test_expand()
test_power()
test_squeeze()
test_unsqueeze()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment