Commit 83bef9ff by Jon Soifer Committed by Tianqi Chen

[Relay][Frontend][TensorFlow] Support BatchMatMul with input dimensions larger than 3 (#3732)

* Support BatchMatMul with shapes greater than length 3

* Fixes

* Add tests

* Remove dependency on Python3

* Clean up

* Merge with master

* Resolve comments
parent 4c01e8ee
......@@ -448,11 +448,31 @@ def _matmul():
def _batch_matmul():
def _impl(inputs, attr, params):
input_x = inputs[0]
input_y = inputs[1]
orig_shape_x = attr['_input_shapes'][input_x]
orig_shape_y = attr['_input_shapes'][input_y]
# reshape n-dimensional batch matmul into 3d
if len(orig_shape_x) > 3:
outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)]
num_outer_elts = np.prod(outer_dims)
new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1])
new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1])
input_x = _op.reshape(input_x, newshape=new_shape_x)
input_y = _op.reshape(input_y, newshape=new_shape_y)
adj_x = attr['adj_x']
adj_y = attr['adj_y']
input_x = _op.transpose(inputs[0], axes=[0, 2, 1]) if adj_x else inputs[0]
input_y = _op.transpose(inputs[1], axes=[0, 2, 1]) if not adj_y else inputs[1]
input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x
input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y
ret = get_relay_op('batch_matmul')(input_x, input_y)
# reshape result back to n-dimensional
if len(orig_shape_x) > 3:
final_shape = attr['_output_shapes'][0]
ret = _op.reshape(ret, newshape=final_shape)
return ret
return _impl
......
......@@ -685,6 +685,10 @@ def test_forward_batch_matmul():
_test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32', True, True)
_test_batch_matmul((3, 5, 4), (3, 5, 4), 'int32', True, False)
_test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True)
_test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'int32')
_test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), 'float32', True, True)
_test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), 'int32', True, False)
_test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), 'float32', False, True)
#######################################################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment