Commit 19b8b3a4 by Yong Wu Committed by MORITA Kazutaka

[Relay][Keras] Dot (#3668)

* [Relay][Keras] Dot

* fix reshape

* fix comments
parent 07a83a66
...@@ -156,7 +156,26 @@ def _convert_advanced_activation(inexpr, keras_layer, etab): ...@@ -156,7 +156,26 @@ def _convert_advanced_activation(inexpr, keras_layer, etab):
def _convert_merge(inexpr, keras_layer, _): def _convert_merge(inexpr, keras_layer, _):
merge_type = type(keras_layer).__name__ merge_type = type(keras_layer).__name__
ret = inexpr[0] ret = inexpr[0]
if merge_type == 'Subtract': if merge_type == 'Dot':
axes = keras_layer.axes
if isinstance(keras_layer.axes, int):
axes = [keras_layer.axes, keras_layer.axes]
if isinstance(axes, list):
if len(axes) != 2:
raise tvm.error.OpAttributeUnimplemented(
'Dot with axes {} is not supported.'.format(keras_layer.axes))
for i, axis in enumerate(axes):
if axis not in [1, 2]:
raise tvm.error.OpAttributeUnimplemented(
'Dot with axes {} is not supported.'.format(keras_layer.axes))
if axes[i] == 2:
inexpr[i] = _op.transpose(inexpr[i], axes=[0, 2, 1])
else:
raise tvm.error.OpAttributeUnImplemented(
'Dot with axes {} is not supported.'.format(keras_layer.axes))
ret_dot = _op.nn.batch_matmul(inexpr[0], inexpr[1])
ret = _op.transpose(ret_dot, axes=[0, 2, 1])
elif merge_type == 'Subtract':
assert len(inexpr) == 2, "Subtract merge takes 2 inputs." assert len(inexpr) == 2, "Subtract merge takes 2 inputs."
ret = _op.subtract(ret, inexpr[1]) ret = _op.subtract(ret, inexpr[1])
elif merge_type in ['Add', 'Multiply', 'Maximum']: elif merge_type in ['Add', 'Multiply', 'Maximum']:
...@@ -635,7 +654,7 @@ _convert_map = { ...@@ -635,7 +654,7 @@ _convert_map = {
'Average' : _convert_merge, 'Average' : _convert_merge,
'Maximum' : _convert_merge, 'Maximum' : _convert_merge,
# 'Dot' : _convert_merge, 'Dot' : _convert_merge,
'Permute' : _convert_permute, 'Permute' : _convert_permute,
# 'Embedding' : _convert_embedding, # 'Embedding' : _convert_embedding,
# 'RepeatVector' : _convert_repeat_vector, # 'RepeatVector' : _convert_repeat_vector,
......
...@@ -84,13 +84,26 @@ def test_forward_merge(): ...@@ -84,13 +84,26 @@ def test_forward_merge():
keras.layers.Average(), keras.layers.Average(),
keras.layers.Concatenate()] keras.layers.Concatenate()]
for merge_func in merge_funcs: for merge_func in merge_funcs:
if isinstance(merge_func, keras.layers.merge.Subtract): if isinstance(merge_func, (keras.layers.merge.Subtract, keras.layers.merge.Dot)):
out = merge_func([x, y]) out = merge_func([x, y])
else: else:
out = merge_func([x, y, z]) out = merge_func([x, y, z])
keras_model = keras.models.Model(data, out) keras_model = keras.models.Model(data, out)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_merge_dot():
data1 = keras.layers.Input(shape=(2, 2))
data2 = keras.layers.Input(shape=(2, 2))
merge_funcs = [keras.layers.Dot(axes=[1, 2]),
keras.layers.Dot(axes=[2, 1]),
keras.layers.Dot(axes=[1, 1]),
keras.layers.Dot(axes=[2, 2]),
keras.layers.Dot(axes=1),
keras.layers.Dot(axes=2)]
for merge_func in merge_funcs:
out = merge_func([data1, data2])
keras_model = keras.models.Model([data1, data2], out)
verify_keras_frontend(keras_model)
def test_forward_activations(): def test_forward_activations():
data = keras.layers.Input(shape=(32, 32, 3)) data = keras.layers.Input(shape=(32, 32, 3))
...@@ -281,6 +294,7 @@ def test_forward_mobilenet(): ...@@ -281,6 +294,7 @@ def test_forward_mobilenet():
if __name__ == '__main__': if __name__ == '__main__':
test_forward_merge() test_forward_merge()
test_forward_merge_dot()
test_forward_activations() test_forward_activations()
test_forward_dense() test_forward_dense()
test_forward_permute() test_forward_permute()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment