Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
62d34ca5
Commit
62d34ca5
authored
Aug 24, 2018
by
MORITA Kazutaka
Committed by
Tianqi Chen
Aug 23, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NNVM][KERAS] Support multiple outputs (#1648)
parent
e3365445
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
8 deletions
+23
-8
nnvm/python/nnvm/frontend/keras.py
+5
-5
nnvm/tests/python/frontend/keras/test_forward.py
+18
-3
No files found.
nnvm/python/nnvm/frontend/keras.py
View file @
62d34ca5
...
...
@@ -532,15 +532,15 @@ def from_keras(model):
# they are named uniquely to input_1, input_2, input_3 ... by default.
for
pred_idx
,
pred
in
zip
(
node
.
node_indices
,
node
.
inbound_layers
):
if
isinstance
(
pred
,
keras
.
engine
.
InputLayer
):
_
sym
=
symtab
.
get_var
(
pred
.
name
,
must_contain
=
True
)
sym
=
symtab
.
get_var
(
pred
.
name
,
must_contain
=
True
)
else
:
_
sym
=
symtab
.
get_var
(
pred
.
name
+
':'
+
str
(
pred_idx
),
must_contain
=
True
)
insym
.
append
(
_
sym
)
sym
=
symtab
.
get_var
(
pred
.
name
+
':'
+
str
(
pred_idx
),
must_contain
=
True
)
insym
.
append
(
sym
)
if
len
(
insym
)
==
1
:
insym
=
insym
[
0
]
keras_op_to_nnvm
(
insym
,
keras_layer
,
keras_layer
.
name
+
':'
+
str
(
my_idx
),
symtab
)
outsym
=
symtab
.
get_var
(
model
.
_output_layers
[
0
]
.
name
+
':0'
)
outsym
=
[
symtab
.
get_var
(
layer
.
name
+
':0'
)
for
layer
in
model
.
_output_layers
]
tvmparams
=
{
k
:
tvm
.
nd
.
array
(
np
.
array
(
v
,
dtype
=
np
.
float32
))
for
k
,
v
in
symtab
.
params
.
items
()}
return
outsym
,
tvmparams
return
_sym
.
Group
(
outsym
)
,
tvmparams
nnvm/tests/python/frontend/keras/test_forward.py
View file @
62d34ca5
...
...
@@ -20,7 +20,9 @@ def verify_keras_frontend(keras_model):
in_shapes
=
[]
for
layer
in
keras_model
.
_input_layers
:
in_shapes
.
append
(
tuple
(
dim
.
value
if
dim
.
value
is
not
None
else
1
for
dim
in
layer
.
input
.
shape
))
out_shape
=
[
dim
.
value
if
dim
.
value
is
not
None
else
1
for
dim
in
keras_model
.
_output_layers
[
0
]
.
output
.
shape
]
out_shapes
=
[]
for
layer
in
keras_model
.
_output_layers
:
out_shapes
.
append
(
tuple
(
dim
.
value
if
dim
.
value
is
not
None
else
1
for
dim
in
layer
.
output
.
shape
))
def
get_keras_output
(
xs
,
dtype
=
'float32'
):
return
keras_model
.
predict
(
xs
)
...
...
@@ -35,8 +37,10 @@ def verify_keras_frontend(keras_model):
m
.
set_input
(
name
,
tvm
.
nd
.
array
(
x
.
astype
(
dtype
)))
m
.
set_input
(
**
params
)
m
.
run
()
out
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
(
out_shape
,
dtype
))
return
out
.
asnumpy
()
out
=
[
m
.
get_output
(
i
,
tvm
.
nd
.
empty
(
shape
,
dtype
))
.
asnumpy
()
for
i
,
shape
in
enumerate
(
out_shapes
)]
return
out
if
len
(
out
)
>
1
else
out
[
0
]
xs
=
[
np
.
random
.
uniform
(
size
=
shape
,
low
=-
1.0
,
high
=
1.0
)
for
shape
in
in_shapes
]
keras_out
=
get_keras_output
(
xs
)
...
...
@@ -192,6 +196,16 @@ def test_forward_multi_inputs():
verify_keras_frontend
(
keras_model
)
def
test_forward_multi_outputs
():
data
=
keras
.
layers
.
Input
(
shape
=
(
32
,
32
,
3
))
x
=
keras
.
layers
.
Conv2D
(
8
,
(
3
,
3
),
padding
=
"same"
)(
data
)
x
=
keras
.
layers
.
GlobalAveragePooling2D
()(
x
)
y
=
keras
.
layers
.
Conv2D
(
8
,
(
3
,
3
),
padding
=
"same"
)(
data
)
y
=
keras
.
layers
.
GlobalAveragePooling2D
()(
y
)
keras_model
=
keras
.
models
.
Model
(
data
,
[
x
,
y
])
verify_keras_frontend
(
keras_model
)
def
test_forward_reuse_layers
():
# reuse conv2d
data
=
keras
.
layers
.
Input
(
shape
=
(
32
,
32
,
3
))
...
...
@@ -230,4 +244,5 @@ if __name__ == '__main__':
test_forward_mobilenet
()
test_forward_multi_inputs
()
test_forward_multi_outputs
()
test_forward_reuse_layers
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment