Commit 9998f330 by thefiddler Committed by Tianqi Chen

Fix keras frontend elementwise-ops for lists with len>2 (fixes issue #325) (#326)

* Added elementwise-add test

* Fix typo

* Fixed elem-wise ops for lists with len>2
parent 14de1dec
......@@ -77,18 +77,21 @@ def _convert_advanced_activation(insym, keras_layer, _):
def _convert_merge(insym, keras_layer, _):
merge_type = type(keras_layer).__name__
ret = insym[0]
for i in range(1, len(insym)):
if merge_type == 'Add':
return _sym.elemwise_add(insym[0], insym[1])
ret = _sym.elemwise_add(ret, insym[i])
elif merge_type == 'Subtract':
return _sym.elemwise_sub(insym[0], insym[1])
ret = _sym.elemwise_sub(ret, insym[i])
elif merge_type == 'Multiply':
return _sym.elemwise_mul(insym[0], insym[1])
ret = _sym.elemwise_mul(ret, insym[i])
elif merge_type == 'Average':
raise NotImplementedError('Average merge not implemented')
elif merge_type == 'Maximum':
raise NotImplementedError('Maximum merge not implemented')
else:
raise TypeError("Unsupported merge type : {}".format(merge_type))
return ret
def _convert_dense(insym, keras_layer, symtab):
......
......@@ -39,6 +39,27 @@ def verify_keras_frontend(keras_model):
np.testing.assert_allclose(keras_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_elemwise_add():
print("test_forward_elemwise_add")
r = []
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
r.append(x)
x = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
r.append(x)
x = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
y = keras.layers.add([keras.layers.add([x, r[0]]), r[1]])
y = keras.layers.GlobalAveragePooling2D()(y)
keras_model = keras.models.Model(data, y)
verify_keras_frontend(keras_model)
y = keras.layers.add([x, r[0], r[1]])
y = keras.layers.GlobalAveragePooling2D()(y)
keras_model = keras.models.Model(data, y)
verify_keras_frontend(keras_model)
def test_forward_softrelu():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Activation('softplus')(data)
......@@ -114,6 +135,7 @@ def test_forward_resnet50():
if __name__ == '__main__':
test_forward_elemwise_add()
test_forward_softrelu()
test_forward_leaky_relu()
test_forward_dense()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment