Commit 62d34ca5 by MORITA Kazutaka Committed by Tianqi Chen

[NNVM][KERAS] Support multiple outputs (#1648)

parent e3365445
......@@ -532,15 +532,15 @@ def from_keras(model):
# they are named uniquely to input_1, input_2, input_3 ... by default.
for pred_idx, pred in zip(node.node_indices, node.inbound_layers):
if isinstance(pred, keras.engine.InputLayer):
_sym = symtab.get_var(pred.name, must_contain=True)
sym = symtab.get_var(pred.name, must_contain=True)
else:
_sym = symtab.get_var(pred.name + ':' + str(pred_idx), must_contain=True)
insym.append(_sym)
sym = symtab.get_var(pred.name + ':' + str(pred_idx), must_contain=True)
insym.append(sym)
if len(insym) == 1:
insym = insym[0]
keras_op_to_nnvm(insym, keras_layer, keras_layer.name + ':' + str(my_idx), symtab)
outsym = symtab.get_var(model._output_layers[0].name + ':0')
outsym = [symtab.get_var(layer.name + ':0') for layer in model._output_layers]
tvmparams = {k:tvm.nd.array(np.array(v, dtype=np.float32)) for k, v in symtab.params.items()}
return outsym, tvmparams
return _sym.Group(outsym), tvmparams
......@@ -20,7 +20,9 @@ def verify_keras_frontend(keras_model):
in_shapes = []
for layer in keras_model._input_layers:
in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape))
out_shape = [dim.value if dim.value is not None else 1 for dim in keras_model._output_layers[0].output.shape]
out_shapes = []
for layer in keras_model._output_layers:
out_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.output.shape))
def get_keras_output(xs, dtype='float32'):
return keras_model.predict(xs)
......@@ -35,8 +37,10 @@ def verify_keras_frontend(keras_model):
m.set_input(name, tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
m.run()
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy()
out = [m.get_output(i, tvm.nd.empty(shape, dtype)).asnumpy()
for i, shape in enumerate(out_shapes)]
return out if len(out) > 1 else out[0]
xs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes]
keras_out = get_keras_output(xs)
......@@ -192,6 +196,16 @@ def test_forward_multi_inputs():
verify_keras_frontend(keras_model)
def test_forward_multi_outputs():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
x = keras.layers.GlobalAveragePooling2D()(x)
y = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
y = keras.layers.GlobalAveragePooling2D()(y)
keras_model = keras.models.Model(data, [x, y])
verify_keras_frontend(keras_model)
def test_forward_reuse_layers():
# reuse conv2d
data = keras.layers.Input(shape=(32,32,3))
......@@ -230,4 +244,5 @@ if __name__ == '__main__':
test_forward_mobilenet()
test_forward_multi_inputs()
test_forward_multi_outputs()
test_forward_reuse_layers()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment