Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
02c6767a
Commit
02c6767a
authored
Jan 15, 2020
by
LiangHao
Committed by
Yao Wang
Jan 15, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Frontend][TF] fix _parse_param bug (#4711)
parent
4eecd2a7
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
8 deletions
+12
-8
python/tvm/relay/frontend/tensorflow.py
+1
-1
tests/python/frontend/tensorflow/test_debugging.py
+11
-7
No files found.
python/tvm/relay/frontend/tensorflow.py
View file @
02c6767a
...
...
@@ -2391,7 +2391,7 @@ class GraphProto(object):
if
np_array
.
dtype
==
np
.
dtype
(
object
):
# Object types are generally tensorflow DT_STRING (DecodeJpeg op).
# Just leave it as placeholder.
if
shape
:
if
shape
and
name
in
shape
:
var_shape
=
shape
[
name
]
else
:
var_shape
=
tensor_util
.
TensorShapeProtoToList
(
value
.
tensor
.
tensor_shape
)
...
...
tests/python/frontend/tensorflow/test_debugging.py
View file @
02c6767a
...
...
@@ -20,19 +20,22 @@ import numpy as np
from
tvm
import
relay
from
tvm.relay.frontend.tensorflow
import
from_tensorflow
def
run_relay
(
graph
,
*
vars
):
mod
,
params
=
from_tensorflow
(
graph
.
as_graph_def
(
add_shapes
=
True
))
def
run_relay
(
graph
,
shape_dict
=
None
,
*
vars
):
mod
,
params
=
from_tensorflow
(
graph
.
as_graph_def
(
add_shapes
=
True
),
shape
=
shape_dict
)
ex
=
relay
.
create_executor
(
'debug'
,
mod
=
mod
)
return
ex
.
evaluate
()(
*
vars
)
def
test_assert_true
():
g
=
tf
.
Graph
()
shape
=
(
1
,
2
)
with
g
.
as_default
():
x
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
()
)
assert_op
=
tf
.
Assert
(
tf
.
less_equal
(
x
,
x
),
[
"it failed"
])
x
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
shape
,
name
=
"input"
)
assert_op
=
tf
.
Assert
(
tf
.
reduce_all
(
tf
.
less_equal
(
x
,
x
)
),
[
"it failed"
])
with
tf
.
Session
()
as
sess
:
x_value
=
np
.
random
.
rand
()
x_value
=
np
.
random
.
rand
(
*
shape
)
assert
sess
.
run
(
assert_op
,
feed_dict
=
{
x
:
x_value
})
is
None
# In TVM, tf.assert is converted to a no-op which is actually a 0,
...
...
@@ -44,7 +47,7 @@ def test_assert_true():
# do that, it's happening in Relay, and that optimization shouldn't
# affect the arity of the main function. We should have to pass in
# x_value here.
np
.
testing
.
assert_allclose
(
0
,
run_relay
(
g
)
.
asnumpy
())
np
.
testing
.
assert_allclose
(
0
,
run_relay
(
g
,
{
'input'
:
shape
}
)
.
asnumpy
())
def
test_assert_true_var_capture
():
g
=
tf
.
Graph
()
...
...
@@ -65,7 +68,8 @@ def test_assert_true_var_capture():
# the graph as a boolean, which is not correct - as you can see above,
# TF believes that the value of this graph is None. In addition, the
# arity of the translated function should be 1, not 2.
np
.
testing
.
assert_allclose
(
True
,
run_relay
(
g
,
x_value
,
x_value
)
.
asnumpy
())
np
.
testing
.
assert_allclose
(
True
,
run_relay
(
g
,
None
,
x_value
,
x_value
)
.
asnumpy
())
def
test_assert_false
():
g
=
tf
.
Graph
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment