Unverified Commit 1014fefa by Samuel Committed by GitHub

[KERAS]Embedding layer (#5444)

parent 3f47b327
...@@ -207,6 +207,14 @@ def _convert_permute(inexpr, keras_layer, _): ...@@ -207,6 +207,14 @@ def _convert_permute(inexpr, keras_layer, _):
return _op.transpose(inexpr, axes=(0,) + keras_layer.dims) return _op.transpose(inexpr, axes=(0,) + keras_layer.dims)
def _convert_embedding(inexpr, keras_layer, etab):
indices = inexpr
weightList = keras_layer.get_weights()
weight = etab.new_const(weightList[0])
out = _op.take(weight, indices.astype('int32'), axis=0)
return out
def _convert_dense(inexpr, keras_layer, etab): def _convert_dense(inexpr, keras_layer, etab):
weightList = keras_layer.get_weights() weightList = keras_layer.get_weights()
weight = etab.new_const(weightList[0].transpose([1, 0])) weight = etab.new_const(weightList[0].transpose([1, 0]))
...@@ -893,7 +901,7 @@ _convert_map = { ...@@ -893,7 +901,7 @@ _convert_map = {
'Maximum' : _convert_merge, 'Maximum' : _convert_merge,
'Dot' : _convert_merge, 'Dot' : _convert_merge,
'Permute' : _convert_permute, 'Permute' : _convert_permute,
# 'Embedding' : _convert_embedding, 'Embedding' : _convert_embedding,
# 'RepeatVector' : _convert_repeat_vector, # 'RepeatVector' : _convert_repeat_vector,
'InputLayer' : _default_skip, 'InputLayer' : _default_skip,
......
...@@ -466,6 +466,24 @@ class TestKeras: ...@@ -466,6 +466,24 @@ class TestKeras:
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, layout='NDHWC') verify_keras_frontend(keras_model, layout='NDHWC')
def test_forward_embedding(self, keras):
data = keras.layers.Input(shape=(2, 4), dtype="int32")
x = keras.layers.Embedding(10, 3)(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)
data = keras.layers.Input(shape=(2, 3, 4), dtype="int32")
x = keras.layers.Embedding(4, 5)(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)
data = keras.layers.Input(shape=(6, 2, 3, 4), dtype="int32")
x = keras.layers.Embedding(4, 5)(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)
if __name__ == '__main__': if __name__ == '__main__':
for k in [keras, tf_keras]: for k in [keras, tf_keras]:
sut = TestKeras() sut = TestKeras()
...@@ -497,4 +515,4 @@ if __name__ == '__main__': ...@@ -497,4 +515,4 @@ if __name__ == '__main__':
sut.test_forward_pool3d(keras=k) sut.test_forward_pool3d(keras=k)
sut.test_forward_upsample3d(keras=k) sut.test_forward_upsample3d(keras=k)
sut.test_forward_zero_padding3d(keras=k) sut.test_forward_zero_padding3d(keras=k)
sut.test_forward_embedding(keras=k)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment