Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
c147a31d
Commit
c147a31d
authored
Oct 23, 2019
by
Bjarke Hammersholt Roune
Committed by
Zhi
Oct 23, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add support and testing for tf.assert (as no-op) and tf.no_op to TF Relay frontend. (#4172)
parent
5408d3a3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
161 additions
and
2 deletions
+161
-2
python/tvm/relay/frontend/tensorflow.py
+25
-2
tests/python/frontend/tensorflow/test_debugging.py
+93
-0
tests/python/frontend/tensorflow/test_no_op.py
+43
-0
No files found.
python/tvm/relay/frontend/tensorflow.py
View file @
c147a31d
...
...
@@ -436,6 +436,24 @@ def _check_numerics():
return
AttrCvt
(
op_name
=
"copy"
,
ignores
=
[
'message'
])(
inputs
,
attr
)
return
_impl
def
_assert
():
# ToDo: In general people want asserts to be gone from TensorFlow graphs
# when they are optimizing them, so converting it to a no-op is
# reasonable. However, it would be nice to have the option to keep them
# once Relay gets a Halt or Assert op.
return
_no_op
()
def
_no_op
():
def
_impl
(
inputs
,
attr
,
params
):
# ToDo: This should really be an op that returns nothing, which could
# be represented as an empty tuple. It turns out that TVM
# infrastructure doesn't like running functions that return None and
# also don't like running functions that return an empty tuple. So it
# doesn't work, but it should be made to work and then this could be
# improved. In the mean time, it is hard to imagine a case where it
# matters in any real way that a no-op is converted to a constant 0.
return
tvm
.
relay
.
const
(
0
)
return
_impl
def
_matmul
():
def
_impl
(
inputs
,
attr
,
params
):
...
...
@@ -1326,6 +1344,7 @@ _convert_map = {
'All'
:
_reduce
(
'all'
),
'ArgMax'
:
_argx
(
_op
.
argmax
,
'argmax'
),
'ArgMin'
:
_argx
(
_op
.
argmin
,
'argmin'
),
'Assert'
:
_assert
(),
'AvgPool'
:
_pooling
(
'avg_pool'
),
'BatchMatMul'
:
_batch_matmul
(),
'BatchMatMulV2'
:
_batch_matmul
(),
...
...
@@ -1384,6 +1403,7 @@ _convert_map = {
'Mod'
:
_elemwise
(
'mod'
),
'Mul'
:
_elemwise
(
'multiply'
),
'Neg'
:
AttrCvt
(
'negative'
),
'NoOp'
:
_no_op
(),
'NotEqual'
:
_broadcast
(
'not_equal'
),
'OneHot'
:
_one_hot
(),
'Pack'
:
_pack
(),
...
...
@@ -2196,8 +2216,11 @@ class GraphProto(object):
if
np_array
.
dtype
==
np
.
dtype
(
object
):
# Object types are generally tensorflow DT_STRING (DecodeJpeg op).
# Just leave it as placeholder.
self
.
_nodes
[
name
]
=
[
_expr
.
var
(
name
,
shape
=
shape
[
name
],
dtype
=
'uint8'
)]
if
shape
:
var_shape
=
shape
[
name
]
else
:
var_shape
=
tensor_util
.
TensorShapeProtoToList
(
value
.
tensor
.
tensor_shape
)
self
.
_nodes
[
name
]
=
[
_expr
.
var
(
name
,
shape
=
var_shape
,
dtype
=
'uint8'
)]
return
array_ndim
=
len
(
np_array
.
shape
)
...
...
tests/python/frontend/tensorflow/test_debugging.py
0 → 100644
View file @
c147a31d
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for converting TensorFlow debugging ops to Relay."""
import
tensorflow
as
tf
import
numpy
as
np
from
tvm
import
relay
from
tvm.relay.frontend.tensorflow
import
from_tensorflow
def
run_relay
(
graph
,
*
vars
):
mod
,
params
=
from_tensorflow
(
graph
.
as_graph_def
(
add_shapes
=
True
))
ex
=
relay
.
create_executor
(
'debug'
,
mod
=
mod
)
return
ex
.
evaluate
()(
*
vars
)
def
test_assert_true
():
g
=
tf
.
Graph
()
with
g
.
as_default
():
x
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
())
assert_op
=
tf
.
Assert
(
tf
.
less_equal
(
x
,
x
),
[
"it failed"
])
with
tf
.
Session
()
as
sess
:
x_value
=
np
.
random
.
rand
()
assert
sess
.
run
(
assert_op
,
feed_dict
=
{
x
:
x_value
})
is
None
# In TVM, tf.assert is converted to a no-op which is actually a 0,
# though it should probably be none or an empty tuple.
#
# ToDo: It appears that the frontend converter gets confused here and
# entirely eliminates all operands from main(). Likely because x <= x
# is always true, so the placeholder can be eliminated. But TF doesn't
# do that, it's happening in Relay, and that optimization shouldn't
# affect the arity of the main function. We should have to pass in
# x_value here.
np
.
testing
.
assert_allclose
(
0
,
run_relay
(
g
)
.
asnumpy
())
def
test_assert_true_var_capture
():
g
=
tf
.
Graph
()
with
g
.
as_default
():
x
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
())
# It turns out that tf.assert() creates a large and complex subgraph if
# you capture a variable as part of the error message. So we need to
# test that, too.
assert_op
=
tf
.
Assert
(
tf
.
less_equal
(
x
,
x
),
[
"it failed"
,
x
])
with
tf
.
Session
()
as
sess
:
x_value
=
np
.
random
.
rand
()
assert
sess
.
run
(
assert_op
,
feed_dict
=
{
x
:
x_value
})
is
None
# ToDo: The frontend converter gets confused here as well, thinking
# that it needs to be told what x is twice. It also notes the output of
# the graph as a boolean, which is not correct - as you can see above,
# TF believes that the value of this graph is None. In addition, the
# arity of the translated function should be 1, not 2.
np
.
testing
.
assert_allclose
(
True
,
run_relay
(
g
,
x_value
,
x_value
)
.
asnumpy
())
def
test_assert_false
():
g
=
tf
.
Graph
()
with
g
.
as_default
():
assert_op
=
tf
.
Assert
(
tf
.
constant
(
False
),
[
"it failed"
])
with
tf
.
Session
()
as
sess
:
try
:
print
(
sess
.
run
(
assert_op
))
assert
False
# TF should have thrown an exception
except
tf
.
errors
.
InvalidArgumentError
as
e
:
assert
"it failed"
in
e
.
message
# In TVM, tf.assert is converted to a no-op which is actually a 0,
# though it should probably be none or an empty tuple. For the same
# reason, there should not be an error here, even though the assertion
# argument is false.
np
.
testing
.
assert_allclose
(
0
,
run_relay
(
g
)
.
asnumpy
())
if
__name__
==
"__main__"
:
test_assert_true
()
test_assert_true_var_capture
()
test_assert_false
()
tests/python/frontend/tensorflow/test_no_op.py
0 → 100644
View file @
c147a31d
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for converting TensorFlow debugging ops to Relay."""
import
tensorflow
as
tf
import
numpy
as
np
from
tvm
import
relay
from
tvm.relay.frontend.tensorflow
import
from_tensorflow
def
run_relay
(
graph
):
mod
,
params
=
from_tensorflow
(
graph
.
as_graph_def
(
add_shapes
=
True
))
ex
=
relay
.
create_executor
(
'debug'
,
mod
=
mod
)
return
ex
.
evaluate
()(
**
params
)
def
test_no_op
():
g
=
tf
.
Graph
()
with
g
.
as_default
():
no_op
=
tf
.
no_op
()
with
tf
.
Session
()
as
sess
:
# In TF, the type of a no-op is None.
assert
sess
.
run
(
no_op
)
is
None
# In TVM, no-op is currently translated to 0, though it should
# probably be none or an empty tuple.
np
.
testing
.
assert_allclose
(
0
,
run_relay
(
g
)
.
asnumpy
())
if
__name__
==
"__main__"
:
test_no_op
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment