Commit 62d34ca5 by MORITA Kazutaka Committed by Tianqi Chen

[NNVM][KERAS] Support multiple outputs (#1648)

parent e3365445
...@@ -532,15 +532,15 @@ def from_keras(model): ...@@ -532,15 +532,15 @@ def from_keras(model):
# they are named uniquely to input_1, input_2, input_3 ... by default. # they are named uniquely to input_1, input_2, input_3 ... by default.
for pred_idx, pred in zip(node.node_indices, node.inbound_layers): for pred_idx, pred in zip(node.node_indices, node.inbound_layers):
if isinstance(pred, keras.engine.InputLayer): if isinstance(pred, keras.engine.InputLayer):
_sym = symtab.get_var(pred.name, must_contain=True) sym = symtab.get_var(pred.name, must_contain=True)
else: else:
_sym = symtab.get_var(pred.name + ':' + str(pred_idx), must_contain=True) sym = symtab.get_var(pred.name + ':' + str(pred_idx), must_contain=True)
insym.append(_sym) insym.append(sym)
if len(insym) == 1: if len(insym) == 1:
insym = insym[0] insym = insym[0]
keras_op_to_nnvm(insym, keras_layer, keras_layer.name + ':' + str(my_idx), symtab) keras_op_to_nnvm(insym, keras_layer, keras_layer.name + ':' + str(my_idx), symtab)
outsym = symtab.get_var(model._output_layers[0].name + ':0') outsym = [symtab.get_var(layer.name + ':0') for layer in model._output_layers]
tvmparams = {k:tvm.nd.array(np.array(v, dtype=np.float32)) for k, v in symtab.params.items()} tvmparams = {k:tvm.nd.array(np.array(v, dtype=np.float32)) for k, v in symtab.params.items()}
return outsym, tvmparams return _sym.Group(outsym), tvmparams
...@@ -20,7 +20,9 @@ def verify_keras_frontend(keras_model): ...@@ -20,7 +20,9 @@ def verify_keras_frontend(keras_model):
in_shapes = [] in_shapes = []
for layer in keras_model._input_layers: for layer in keras_model._input_layers:
in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape)) in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape))
out_shape = [dim.value if dim.value is not None else 1 for dim in keras_model._output_layers[0].output.shape] out_shapes = []
for layer in keras_model._output_layers:
out_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.output.shape))
def get_keras_output(xs, dtype='float32'): def get_keras_output(xs, dtype='float32'):
return keras_model.predict(xs) return keras_model.predict(xs)
...@@ -35,8 +37,10 @@ def verify_keras_frontend(keras_model): ...@@ -35,8 +37,10 @@ def verify_keras_frontend(keras_model):
m.set_input(name, tvm.nd.array(x.astype(dtype))) m.set_input(name, tvm.nd.array(x.astype(dtype)))
m.set_input(**params) m.set_input(**params)
m.run() m.run()
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy() out = [m.get_output(i, tvm.nd.empty(shape, dtype)).asnumpy()
for i, shape in enumerate(out_shapes)]
return out if len(out) > 1 else out[0]
xs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes] xs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes]
keras_out = get_keras_output(xs) keras_out = get_keras_output(xs)
...@@ -192,6 +196,16 @@ def test_forward_multi_inputs(): ...@@ -192,6 +196,16 @@ def test_forward_multi_inputs():
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_multi_outputs():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
x = keras.layers.GlobalAveragePooling2D()(x)
y = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
y = keras.layers.GlobalAveragePooling2D()(y)
keras_model = keras.models.Model(data, [x, y])
verify_keras_frontend(keras_model)
def test_forward_reuse_layers(): def test_forward_reuse_layers():
# reuse conv2d # reuse conv2d
data = keras.layers.Input(shape=(32,32,3)) data = keras.layers.Input(shape=(32,32,3))
...@@ -230,4 +244,5 @@ if __name__ == '__main__': ...@@ -230,4 +244,5 @@ if __name__ == '__main__':
test_forward_mobilenet() test_forward_mobilenet()
test_forward_multi_inputs() test_forward_multi_inputs()
test_forward_multi_outputs()
test_forward_reuse_layers() test_forward_reuse_layers()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment