Commit 5bcd3313 by Jon Soifer Committed by Tianqi Chen

[Relay][Frontend][ONNX] Add support for broadcasting to Where and MatMul (#4267)

parent 14a5a358
...@@ -298,6 +298,12 @@ class MatMul(OnnxOpConverter): ...@@ -298,6 +298,12 @@ class MatMul(OnnxOpConverter):
# Convert a and b into 3 dimensional tensors. # Convert a and b into 3 dimensional tensors.
a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]]) a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]])
b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]]) b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]])
# Broadcast b to match batch size of a
new_b_shape = list(infer_shape(b))
new_a_shape = infer_shape(a)
if new_a_shape[0] > new_b_shape[0]:
new_b_shape[0] = new_a_shape[0]
b = _op.broadcast_to(b, new_b_shape)
# Transpose matrix dimensions of b. # Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1]) b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul. # Perform a batch matmul.
...@@ -987,6 +993,14 @@ class Where(OnnxOpConverter): ...@@ -987,6 +993,14 @@ class Where(OnnxOpConverter):
""" """
@classmethod @classmethod
def _impl_v9(cls, inputs, attr, params): def _impl_v9(cls, inputs, attr, params):
# x and y can be broadcasted
condition_shape = infer_shape(inputs[0])
x_shape = infer_shape(inputs[1])
y_shape = infer_shape(inputs[2])
if len(condition_shape) > len(x_shape):
inputs[1] = _op.broadcast_to(inputs[1], condition_shape)
if len(condition_shape) > len(y_shape):
inputs[2] = _op.broadcast_to(inputs[2], condition_shape)
return _op.where(inputs[0], inputs[1], inputs[2]) return _op.where(inputs[0], inputs[1], inputs[2])
class Or(Elemwise): class Or(Elemwise):
...@@ -996,6 +1010,7 @@ class Or(Elemwise): ...@@ -996,6 +1010,7 @@ class Or(Elemwise):
def _impl_v7(cls, inputs, attr, params): def _impl_v7(cls, inputs, attr, params):
return _op.logical_or(inputs[0], inputs[1]) return _op.logical_or(inputs[0], inputs[1])
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
......
...@@ -498,11 +498,7 @@ def test_matmul(): ...@@ -498,11 +498,7 @@ def test_matmul():
model, [a_array, b_array], target, ctx, out_np.shape) model, [a_array, b_array], target, ctx, out_np.shape)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
def verify_batch_matmul(a_shape, b_shape):
def test_batch_matmul():
a_shape = (2, 3, 4, 3)
b_shape = (2, 3, 3, 4)
a_array = np.random.uniform(size=a_shape).astype('float32') a_array = np.random.uniform(size=a_shape).astype('float32')
b_array = np.random.uniform(size=b_shape).astype('float32') b_array = np.random.uniform(size=b_shape).astype('float32')
out_np = np.matmul(a_array, b_array) out_np = np.matmul(a_array, b_array)
...@@ -525,6 +521,10 @@ def test_batch_matmul(): ...@@ -525,6 +521,10 @@ def test_batch_matmul():
model, [a_array, b_array], target, ctx, out_np.shape) model, [a_array, b_array], target, ctx, out_np.shape)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
def test_batch_matmul():
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4))
verify_batch_matmul((2, 4, 3), (3, 4))
verify_batch_matmul((2, 3, 4, 3), (3, 4))
def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
in_array = np.random.uniform(size=shape).astype(dtype) in_array = np.random.uniform(size=shape).astype(dtype)
...@@ -1600,6 +1600,11 @@ def test_where(): ...@@ -1600,6 +1600,11 @@ def test_where():
outdata = np.where(condition, x, y) outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata) verify_where(condition, x, y, TensorProto.FLOAT, outdata)
x = np.array(1, dtype=np.float32)
y = np.array([2], dtype=np.float32)
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata)
def verify_or(indata, dtype): def verify_or(indata, dtype):
x = indata[0].astype(dtype) x = indata[0].astype(dtype)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment