Commit de919cbd by Jon Soifer Committed by Jared Roesch

[Relay][Frontend][ONNX] Broadcast condition, x, and y for Where op (#4774)

* ONNX frontend broadcast condition

* fix

* fix style

Co-authored-by: Jon Soifer <jonso@microsoft.com>
parent f71a10c5
...@@ -1105,14 +1105,33 @@ class Where(OnnxOpConverter): ...@@ -1105,14 +1105,33 @@ class Where(OnnxOpConverter):
""" """
@classmethod @classmethod
def _impl_v9(cls, inputs, attr, params): def _impl_v9(cls, inputs, attr, params):
# x and y can be broadcasted
condition_shape = infer_shape(inputs[0]) condition_shape = infer_shape(inputs[0])
x_shape = infer_shape(inputs[1]) x_shape = infer_shape(inputs[1])
y_shape = infer_shape(inputs[2]) y_shape = infer_shape(inputs[2])
if len(condition_shape) > len(x_shape):
inputs[1] = _op.broadcast_to(inputs[1], condition_shape) # condition, x, and y can all be broadcasted.
if len(condition_shape) > len(y_shape): # broadcast each of them to the longest shape.
inputs[2] = _op.broadcast_to(inputs[2], condition_shape) # if two shapes have the same number of dimensions,
# try to choose the one that doesn't have "1" as
# a dimension.
shapes = [condition_shape, x_shape, y_shape]
shape_lens = [len(shape) for shape in shapes]
max_size = max(shape_lens)
max_size_idxs = [i for i, x in enumerate(shape_lens) if x == max_size]
broadcast_idx = max_size_idxs[0]
if len(max_size_idxs) > 1:
for idx in max_size_idxs:
if 1 not in shapes[idx]:
broadcast_idx = idx
broadcast_shape = shapes[broadcast_idx]
if condition_shape != broadcast_shape:
inputs[0] = _op.broadcast_to(inputs[0], broadcast_shape)
if x_shape != broadcast_shape:
inputs[1] = _op.broadcast_to(inputs[1], broadcast_shape)
if y_shape != broadcast_shape:
inputs[2] = _op.broadcast_to(inputs[2], broadcast_shape)
return _op.where(inputs[0], inputs[1], inputs[2]) return _op.where(inputs[0], inputs[1], inputs[2])
class Or(Elemwise): class Or(Elemwise):
......
...@@ -1684,6 +1684,22 @@ def test_where(): ...@@ -1684,6 +1684,22 @@ def test_where():
outdata = np.where(condition, x, y) outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata) verify_where(condition, x, y, TensorProto.FLOAT, outdata)
x = np.array([2], dtype=np.float32)
y = np.array(1, dtype=np.float32)
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata)
condition = np.array(1, dtype=np.bool)
x = np.array([[1, 2], [3, 4]], dtype=np.float32)
y = np.array([[5, 6], [7, 8]], dtype=np.float32)
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata)
x = np.array([[1, 2], [3, 4]], dtype=np.float32)
y = np.array([[1], [7]], dtype=np.float32)
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata)
def verify_or(indata, dtype): def verify_or(indata, dtype):
x = indata[0].astype(dtype) x = indata[0].astype(dtype)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment