Commit 02c6767a by LiangHao Committed by Yao Wang

[Relay][Frontend][TF] fix _parse_param bug (#4711)

parent 4eecd2a7
...@@ -2391,7 +2391,7 @@ class GraphProto(object): ...@@ -2391,7 +2391,7 @@ class GraphProto(object):
if np_array.dtype == np.dtype(object): if np_array.dtype == np.dtype(object):
# Object types are generally tensorflow DT_STRING (DecodeJpeg op). # Object types are generally tensorflow DT_STRING (DecodeJpeg op).
# Just leave it as placeholder. # Just leave it as placeholder.
if shape: if shape and name in shape:
var_shape = shape[name] var_shape = shape[name]
else: else:
var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape) var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape)
......
...@@ -20,19 +20,22 @@ import numpy as np ...@@ -20,19 +20,22 @@ import numpy as np
from tvm import relay from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow from tvm.relay.frontend.tensorflow import from_tensorflow
def run_relay(graph, *vars): def run_relay(graph, shape_dict=None, *vars):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) mod, params = from_tensorflow(
graph.as_graph_def(add_shapes=True),
shape=shape_dict)
ex = relay.create_executor('debug', mod=mod) ex = relay.create_executor('debug', mod=mod)
return ex.evaluate()(*vars) return ex.evaluate()(*vars)
def test_assert_true(): def test_assert_true():
g = tf.Graph() g = tf.Graph()
shape = (1, 2)
with g.as_default(): with g.as_default():
x = tf.placeholder(tf.float32, shape=()) x = tf.placeholder(tf.float32, shape=shape, name="input")
assert_op = tf.Assert(tf.less_equal(x, x), ["it failed"]) assert_op = tf.Assert(tf.reduce_all(tf.less_equal(x, x)), ["it failed"])
with tf.Session() as sess: with tf.Session() as sess:
x_value = np.random.rand() x_value = np.random.rand(*shape)
assert sess.run(assert_op, feed_dict={x: x_value}) is None assert sess.run(assert_op, feed_dict={x: x_value}) is None
# In TVM, tf.assert is converted to a no-op which is actually a 0, # In TVM, tf.assert is converted to a no-op which is actually a 0,
...@@ -44,7 +47,7 @@ def test_assert_true(): ...@@ -44,7 +47,7 @@ def test_assert_true():
# do that, it's happening in Relay, and that optimization shouldn't # do that, it's happening in Relay, and that optimization shouldn't
# affect the arity of the main function. We should have to pass in # affect the arity of the main function. We should have to pass in
# x_value here. # x_value here.
np.testing.assert_allclose(0, run_relay(g).asnumpy()) np.testing.assert_allclose(0, run_relay(g, {'input':shape}).asnumpy())
def test_assert_true_var_capture(): def test_assert_true_var_capture():
g = tf.Graph() g = tf.Graph()
...@@ -65,7 +68,8 @@ def test_assert_true_var_capture(): ...@@ -65,7 +68,8 @@ def test_assert_true_var_capture():
# the graph as a boolean, which is not correct - as you can see above, # the graph as a boolean, which is not correct - as you can see above,
# TF believes that the value of this graph is None. In addition, the # TF believes that the value of this graph is None. In addition, the
# arity of the translated function should be 1, not 2. # arity of the translated function should be 1, not 2.
np.testing.assert_allclose(True, run_relay(g, x_value, x_value).asnumpy()) np.testing.assert_allclose(True,
run_relay(g, None, x_value, x_value).asnumpy())
def test_assert_false(): def test_assert_false():
g = tf.Graph() g = tf.Graph()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment