Unverified Commit 7bc0b27e by Yao Wang Committed by GitHub

[Frontend][TensorFlow]TensorFlow Parser Control Flow Enhancement (#5020)

* Improve TF control flow major logic

* Pass mod into operator convert function

* Fix LoopBound

* Add more control flow tests

* Add two test cases for stridedslice

* Fix docstring

* Fix lint

* Fix import

* Fix test assert

* Minor fix conv3d

* Add more comments

* Fix for dilation2d

* Change newly added atan

* Change newly added unravel
parent a422589c
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=broad-except
"""Common utilities""" """Common utilities"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import logging import logging
...@@ -482,24 +483,37 @@ def infer_channels(inputs, transpose=False): ...@@ -482,24 +483,37 @@ def infer_channels(inputs, transpose=False):
return channels return channels
def infer_value(input_val, params): def infer_value(input_val, params, mod=None):
"""A hack for getting the value of an expression by evaluating a """A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that portion of the relay graph. This is often needed for functions that
whose output shape depends on the value of a tensor. whose output shape depends on the value of a tensor.
""" """
# pylint: disable=import-outside-toplevel try:
from tvm.contrib import graph_runtime # TODO(kevinthesun): Use VM for all cases.
# Check that all free variables have associated parameters. # pylint: disable=import-outside-toplevel
assert all(var.name_hint in params.keys() for var in analysis.free_vars( from tvm.contrib import graph_runtime
input_val)), "All inputs to infer must be available in params." # Check that all free variables have associated parameters.
func = _function.Function(analysis.free_vars(input_val), input_val) assert all(var.name_hint in params.keys() for var in analysis.free_vars(
with tvm.relay.build_config(opt_level=0): input_val)), "All inputs to infer must be available in params."
graph, lib, params = tvm.relay.build(func, target="llvm", params=params) func = _function.Function(analysis.free_vars(input_val), input_val)
ctx = tvm.cpu(0) with tvm.relay.build_config(opt_level=0):
m = graph_runtime.create(graph, lib, ctx) graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
m.set_input(**params) ctx = tvm.cpu(0)
m.run() m = graph_runtime.create(graph, lib, ctx)
return m.get_output(0) m.set_input(**params)
m.run()
return m.get_output(0)
except Exception:
if isinstance(mod, IRModule):
mod["main"] = _expr.Function(analysis.free_vars(input_val), input_val)
else:
mod = IRModule.from_expr(input_val)
exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
inputs = []
for param in mod['main'].params:
inputs.append(tvm.nd.array(params[param.name_hint]))
result = exc.evaluate()(*inputs)
return result
def infer_value_simulated(input_val, params): def infer_value_simulated(input_val, params):
......
...@@ -27,14 +27,16 @@ from tvm import relay ...@@ -27,14 +27,16 @@ from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow from tvm.relay.frontend.tensorflow import from_tensorflow
def check_equal(graph, tf_out): def check_equal(graph, tf_out, input_map=None):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
if input_map is not None:
params.update(input_map)
ex = relay.create_executor('vm', mod=mod) ex = relay.create_executor('vm', mod=mod)
relay_out = ex.evaluate()(**params) relay_out = ex.evaluate()(**params)
if isinstance(relay_out, nd.NDArray): if isinstance(relay_out, nd.NDArray):
np.testing.assert_allclose(tf_out, relay_out.asnumpy()) np.testing.assert_allclose(tf_out, relay_out.asnumpy())
else: else:
if not isinstance(tf_out, list): if not isinstance(tf_out, (list, tuple)):
tf_out = [tf_out] tf_out = [tf_out]
for x, y in zip(tf_out, [r.asnumpy() for r in relay_out]): for x, y in zip(tf_out, [r.asnumpy() for r in relay_out]):
np.testing.assert_allclose(x, y) np.testing.assert_allclose(x, y)
...@@ -303,9 +305,70 @@ def test_cond_in_loop(): ...@@ -303,9 +305,70 @@ def test_cond_in_loop():
check_equal(graph, tf_out) check_equal(graph, tf_out)
def test_vanilla_loop_bound():
graph = tf.Graph()
with graph.as_default():
dshape = (2, 10)
dtype = "float32"
dname = "data"
np_data = np.random.uniform(size=dshape).astype(dtype)
data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
x = tf.slice(data, [1, 4], [1, 4])
outer = x + 5.0
def body(x, y):
res = tf.cond(tf.less(y, 10), lambda: tf.add(
10.0, 20.0), lambda: tf.square(10.0))
z = tf.constant(7)
res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10)
return tf.multiply(res, x * outer), y + 1
y = tf.constant(0)
def condition(x, y):
return tf.less(y, 20)
r = tf.while_loop(condition, body, loop_vars=[x, y])
with tf.Session() as sess:
tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data})
if __name__ == "__main__": check_equal(graph, tf_out, {dname: np_data})
def test_nested_loop_bound():
graph = tf.Graph()
with graph.as_default():
dshape = (2, 10)
dtype = "float32"
dname = "data"
np_data = np.random.uniform(size=dshape).astype(dtype)
data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
x = tf.slice(data, [1, 4], [1, 4])
outer = x + 5.0
def body(x, y):
res = tf.cond(tf.less(y, 10), lambda: tf.add(
10.0, 20.0), lambda: tf.square(10.0))
def nested_body(nx, ny):
return nx + 1, res + 2.0
def nested_cond(nx, ny):
return tf.less(nx, 15)
nx = tf.constant(0)
ny = tf.constant(0.0)
nested_res = tf.while_loop(nested_cond, nested_body, loop_vars=[nx, ny])
res = res + nested_res[1]
z = tf.constant(7)
res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10)
return tf.multiply(res, x * outer), y + 1
y = tf.constant(0)
def condition(x, y):
return tf.less(y, 20)
r = tf.while_loop(condition, body, loop_vars=[x, y])
with tf.Session() as sess:
tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data})
check_equal(graph, tf_out, {dname: np_data})
if __name__ == "__main__":
# tf.while_loop # tf.while_loop
test_vanilla_loop() test_vanilla_loop()
test_loop_2_vars() test_loop_2_vars()
...@@ -325,3 +388,5 @@ if __name__ == "__main__": ...@@ -325,3 +388,5 @@ if __name__ == "__main__":
test_nested_cond() test_nested_cond()
test_loop_in_cond() test_loop_in_cond()
test_cond_in_loop() test_cond_in_loop()
test_vanilla_loop_bound()
test_nested_loop_bound()
...@@ -67,13 +67,11 @@ def test_assert_true_var_capture(): ...@@ -67,13 +67,11 @@ def test_assert_true_var_capture():
x_value = np.random.rand() x_value = np.random.rand()
assert sess.run(assert_op, feed_dict={x: x_value}) is None assert sess.run(assert_op, feed_dict={x: x_value}) is None
# ToDo: The frontend converter gets confused here as well, thinking # TODO: The frontend converter notes the output of
# that it needs to be told what x is twice. It also notes the output of
# the graph as a boolean, which is not correct - as you can see above, # the graph as a boolean, which is not correct - as you can see above,
# TF believes that the value of this graph is None. In addition, the # TF believes that the value of this graph is None.
# arity of the translated function should be 1, not 2.
np.testing.assert_allclose(True, np.testing.assert_allclose(True,
run_relay(g, None, x_value, x_value).asnumpy()) run_relay(g, None, x_value).asnumpy())
def test_assert_false(): def test_assert_false():
g = tf.Graph() g = tf.Graph()
......
...@@ -1207,6 +1207,8 @@ def test_forward_stridedslice(): ...@@ -1207,6 +1207,8 @@ def test_forward_stridedslice():
'''test StridedSlice''' '''test StridedSlice'''
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1) _test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
_test_stridedslice((2, 1), [0], [1], [1], 'float32', shrink_axis_mask=1)
_test_stridedslice((2, 3, 4), [0], [1], [1], 'float32', shrink_axis_mask=8)
_test_stridedslice((3, 4, 3), [1, -1, 0], _test_stridedslice((3, 4, 3), [1, -1, 0],
[4, -5, 3], [2, -1, 1], 'float32') [4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [ _test_stridedslice((3, 4, 3), [1, 0], [4, 3], [
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment