Commit 02c6767a by LiangHao Committed by Yao Wang

[Relay][Frontend][TF] fix _parse_param bug (#4711)

parent 4eecd2a7
......@@ -2391,7 +2391,7 @@ class GraphProto(object):
if np_array.dtype == np.dtype(object):
# Object types are generally tensorflow DT_STRING (DecodeJpeg op).
# Just leave it as placeholder.
if shape:
if shape and name in shape:
var_shape = shape[name]
else:
var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape)
......
......@@ -20,19 +20,22 @@ import numpy as np
from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow
def run_relay(graph, *vars):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
def run_relay(graph, shape_dict=None, *vars):
mod, params = from_tensorflow(
graph.as_graph_def(add_shapes=True),
shape=shape_dict)
ex = relay.create_executor('debug', mod=mod)
return ex.evaluate()(*vars)
def test_assert_true():
g = tf.Graph()
shape = (1, 2)
with g.as_default():
x = tf.placeholder(tf.float32, shape=())
assert_op = tf.Assert(tf.less_equal(x, x), ["it failed"])
x = tf.placeholder(tf.float32, shape=shape, name="input")
assert_op = tf.Assert(tf.reduce_all(tf.less_equal(x, x)), ["it failed"])
with tf.Session() as sess:
x_value = np.random.rand()
x_value = np.random.rand(*shape)
assert sess.run(assert_op, feed_dict={x: x_value}) is None
# In TVM, tf.assert is converted to a no-op which is actually a 0,
......@@ -44,7 +47,7 @@ def test_assert_true():
# do that, it's happening in Relay, and that optimization shouldn't
# affect the arity of the main function. We should have to pass in
# x_value here.
np.testing.assert_allclose(0, run_relay(g).asnumpy())
np.testing.assert_allclose(0, run_relay(g, {'input':shape}).asnumpy())
def test_assert_true_var_capture():
g = tf.Graph()
......@@ -65,7 +68,8 @@ def test_assert_true_var_capture():
# the graph as a boolean, which is not correct - as you can see above,
# TF believes that the value of this graph is None. In addition, the
# arity of the translated function should be 1, not 2.
np.testing.assert_allclose(True, run_relay(g, x_value, x_value).asnumpy())
np.testing.assert_allclose(True,
run_relay(g, None, x_value, x_value).asnumpy())
def test_assert_false():
g = tf.Graph()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment