Commit 786c49f3 by Yong Wu Committed by Yao Wang

[Relay][TF] add BatchMatMul (#3634)

parent 18d0ad31
......@@ -51,18 +51,16 @@ def _infer_value(input_val, params):
return m.get_output(0)
def _get_relay_op(op_name):
try:
op = getattr(_op, op_name)
except AttributeError:
ops = [_op, _op.nn, _op.image, _op.vision]
for operator in ops:
try:
op = getattr(_op.nn, op_name)
op = getattr(operator, op_name)
return op
except AttributeError:
op = getattr(_op.image, op_name)
continue
if not op:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend TensorFlow.'.format(op_name))
return op
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend TensorFlow.'.format(op_name))
class AttrCvt(object):
"""Common attribute converter. An AttrConverter instance is a callable:
......@@ -611,6 +609,16 @@ def _matmul():
return _impl
def _batch_matmul():
def _impl(inputs, attr, params):
adj_x = attr['adj_x']
adj_y = attr['adj_y']
input_x = _op.transpose(inputs[0], axes=[0, 2, 1]) if adj_x else inputs[0]
input_y = _op.transpose(inputs[1], axes=[0, 2, 1]) if not adj_y else inputs[1]
ret = _get_relay_op('batch_matmul')(input_x, input_y)
return ret
return _impl
def _undef():
def _impl(inputs, attr, params):
return _sym.__undef__()
......@@ -1309,6 +1317,8 @@ _convert_map = {
'ArgMax' : _argx(_op.argmax, 'argmax'),
'ArgMin' : _argx(_op.argmin, 'argmin'),
'AvgPool' : _pooling('avg_pool'),
'BatchMatMul' : _batch_matmul(),
'BatchMatMulV2' : _batch_matmul(),
'BatchNormWithGlobalNormalization' : _batch_norm(),
'BatchToSpaceND' : _batch_to_space_nd(),
'BiasAdd' : _bias_add(),
......
......@@ -622,8 +622,8 @@ def test_forward_variable():
#######################################################################
# MatMul
# ------
# MatMul, BatchMatMul, BatchMatMulV2
# ----------------------------------
def _test_matmul(i, j, k, dtype, outer=None):
""" One iteration of matmul """
......@@ -647,10 +647,28 @@ def _test_matmul(i, j, k, dtype, outer=None):
compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name)
def test_forward_matmul():
""" Matmul op test"""
""" MatMul op test"""
_test_matmul(1, 3, 6, 'int32')
_test_matmul(5, 3, 1, 'float64')
# TODO non-empty outer requires BatchMatMul (BatchMatMulV2 for some cases?) support
def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False):
with tf.Graph().as_default():
A = tf.placeholder(shape=A_shape, dtype=dtype, name='A')
B = tf.placeholder(shape=B_shape, dtype=dtype, name='B')
result = tf.matmul(A, B, adjoint_a=adjoint_a,
adjoint_b=adjoint_b, name='batchmatmul')
A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name)
def test_forward_batch_matmul():
""" TF op BatchMatMul, BatchMatMulV2 test"""
_test_batch_matmul((3, 5, 4), (3, 4, 5), 'int32')
_test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32', True, True)
_test_batch_matmul((3, 5, 4), (3, 5, 4), 'int32', True, False)
_test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True)
#######################################################################
......@@ -2197,6 +2215,7 @@ if __name__ == '__main__':
test_forward_rel_ops()
test_forward_logical()
test_forward_where()
test_forward_matmul()
# TODO missing tests: rank, range
test_forward_batch_matmul()
# TODO missing tests: rank
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment