Commit 9998f330 by thefiddler Committed by Tianqi Chen

Fix keras frontend elementwise-ops for lists with len>2 (fixes issue #325) (#326)

* Added elementwise-add test

* Fix typo

* Fixed elem-wise ops for lists with len>2
parent 14de1dec
...@@ -77,18 +77,21 @@ def _convert_advanced_activation(insym, keras_layer, _): ...@@ -77,18 +77,21 @@ def _convert_advanced_activation(insym, keras_layer, _):
def _convert_merge(insym, keras_layer, _): def _convert_merge(insym, keras_layer, _):
merge_type = type(keras_layer).__name__ merge_type = type(keras_layer).__name__
if merge_type == 'Add': ret = insym[0]
return _sym.elemwise_add(insym[0], insym[1]) for i in range(1, len(insym)):
elif merge_type == 'Subtract': if merge_type == 'Add':
return _sym.elemwise_sub(insym[0], insym[1]) ret = _sym.elemwise_add(ret, insym[i])
elif merge_type == 'Multiply': elif merge_type == 'Subtract':
return _sym.elemwise_mul(insym[0], insym[1]) ret = _sym.elemwise_sub(ret, insym[i])
elif merge_type == 'Average': elif merge_type == 'Multiply':
raise NotImplementedError('Average merge not implemented') ret = _sym.elemwise_mul(ret, insym[i])
elif merge_type == 'Maximum': elif merge_type == 'Average':
raise NotImplementedError('Maximum merge not implemented') raise NotImplementedError('Average merge not implemented')
else: elif merge_type == 'Maximum':
raise TypeError("Unsupported merge type : {}".format(merge_type)) raise NotImplementedError('Maximum merge not implemented')
else:
raise TypeError("Unsupported merge type : {}".format(merge_type))
return ret
def _convert_dense(insym, keras_layer, symtab): def _convert_dense(insym, keras_layer, symtab):
......
...@@ -39,6 +39,27 @@ def verify_keras_frontend(keras_model): ...@@ -39,6 +39,27 @@ def verify_keras_frontend(keras_model):
np.testing.assert_allclose(keras_out, tvm_out, rtol=1e-5, atol=1e-5) np.testing.assert_allclose(keras_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_elemwise_add():
print("test_forward_elemwise_add")
r = []
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
r.append(x)
x = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
r.append(x)
x = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
y = keras.layers.add([keras.layers.add([x, r[0]]), r[1]])
y = keras.layers.GlobalAveragePooling2D()(y)
keras_model = keras.models.Model(data, y)
verify_keras_frontend(keras_model)
y = keras.layers.add([x, r[0], r[1]])
y = keras.layers.GlobalAveragePooling2D()(y)
keras_model = keras.models.Model(data, y)
verify_keras_frontend(keras_model)
def test_forward_softrelu(): def test_forward_softrelu():
data = keras.layers.Input(shape=(32,32,3)) data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Activation('softplus')(data) x = keras.layers.Activation('softplus')(data)
...@@ -114,6 +135,7 @@ def test_forward_resnet50(): ...@@ -114,6 +135,7 @@ def test_forward_resnet50():
if __name__ == '__main__': if __name__ == '__main__':
test_forward_elemwise_add()
test_forward_softrelu() test_forward_softrelu()
test_forward_leaky_relu() test_forward_leaky_relu()
test_forward_dense() test_forward_dense()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment