Unverified Commit 7bc0b27e by Yao Wang Committed by GitHub

[Frontend][TensorFlow]TensorFlow Parser Control Flow Enhancement (#5020)

* Improve TF control flow major logic

* Pass mod into operator convert function

* Fix LoopBound

* Add more control flow tests

* Add two test cases for stridedslice

* Fix docstring

* Fix lint

* Fix import

* Fix test assert

* Minor fix conv3d

* Add more comments

* Fix for dilation2d

* Change newly added atan

* Change newly added unravel
parent a422589c
......@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=broad-except
"""Common utilities"""
from __future__ import absolute_import as _abs
import logging
......@@ -482,24 +483,37 @@ def infer_channels(inputs, transpose=False):
return channels
def infer_value(input_val, params):
def infer_value(input_val, params, mod=None):
"""A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that
whose output shape depends on the value of a tensor.
"""
# pylint: disable=import-outside-toplevel
from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _function.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
return m.get_output(0)
try:
# TODO(kevinthesun): Use VM for all cases.
# pylint: disable=import-outside-toplevel
from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _function.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
return m.get_output(0)
except Exception:
if isinstance(mod, IRModule):
mod["main"] = _expr.Function(analysis.free_vars(input_val), input_val)
else:
mod = IRModule.from_expr(input_val)
exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
inputs = []
for param in mod['main'].params:
inputs.append(tvm.nd.array(params[param.name_hint]))
result = exc.evaluate()(*inputs)
return result
def infer_value_simulated(input_val, params):
......
......@@ -27,14 +27,16 @@ from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow
def check_equal(graph, tf_out):
def check_equal(graph, tf_out, input_map=None):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
if input_map is not None:
params.update(input_map)
ex = relay.create_executor('vm', mod=mod)
relay_out = ex.evaluate()(**params)
if isinstance(relay_out, nd.NDArray):
np.testing.assert_allclose(tf_out, relay_out.asnumpy())
else:
if not isinstance(tf_out, list):
if not isinstance(tf_out, (list, tuple)):
tf_out = [tf_out]
for x, y in zip(tf_out, [r.asnumpy() for r in relay_out]):
np.testing.assert_allclose(x, y)
......@@ -303,9 +305,70 @@ def test_cond_in_loop():
check_equal(graph, tf_out)
def test_vanilla_loop_bound():
graph = tf.Graph()
with graph.as_default():
dshape = (2, 10)
dtype = "float32"
dname = "data"
np_data = np.random.uniform(size=dshape).astype(dtype)
data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
x = tf.slice(data, [1, 4], [1, 4])
outer = x + 5.0
def body(x, y):
res = tf.cond(tf.less(y, 10), lambda: tf.add(
10.0, 20.0), lambda: tf.square(10.0))
z = tf.constant(7)
res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10)
return tf.multiply(res, x * outer), y + 1
y = tf.constant(0)
def condition(x, y):
return tf.less(y, 20)
r = tf.while_loop(condition, body, loop_vars=[x, y])
with tf.Session() as sess:
tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data})
if __name__ == "__main__":
check_equal(graph, tf_out, {dname: np_data})
def test_nested_loop_bound():
graph = tf.Graph()
with graph.as_default():
dshape = (2, 10)
dtype = "float32"
dname = "data"
np_data = np.random.uniform(size=dshape).astype(dtype)
data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
x = tf.slice(data, [1, 4], [1, 4])
outer = x + 5.0
def body(x, y):
res = tf.cond(tf.less(y, 10), lambda: tf.add(
10.0, 20.0), lambda: tf.square(10.0))
def nested_body(nx, ny):
return nx + 1, res + 2.0
def nested_cond(nx, ny):
return tf.less(nx, 15)
nx = tf.constant(0)
ny = tf.constant(0.0)
nested_res = tf.while_loop(nested_cond, nested_body, loop_vars=[nx, ny])
res = res + nested_res[1]
z = tf.constant(7)
res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10)
return tf.multiply(res, x * outer), y + 1
y = tf.constant(0)
def condition(x, y):
return tf.less(y, 20)
r = tf.while_loop(condition, body, loop_vars=[x, y])
with tf.Session() as sess:
tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data})
check_equal(graph, tf_out, {dname: np_data})
if __name__ == "__main__":
# tf.while_loop
test_vanilla_loop()
test_loop_2_vars()
......@@ -325,3 +388,5 @@ if __name__ == "__main__":
test_nested_cond()
test_loop_in_cond()
test_cond_in_loop()
test_vanilla_loop_bound()
test_nested_loop_bound()
......@@ -67,13 +67,11 @@ def test_assert_true_var_capture():
x_value = np.random.rand()
assert sess.run(assert_op, feed_dict={x: x_value}) is None
# ToDo: The frontend converter gets confused here as well, thinking
# that it needs to be told what x is twice. It also notes the output of
# TODO: The frontend converter notes the output of
# the graph as a boolean, which is not correct - as you can see above,
# TF believes that the value of this graph is None. In addition, the
# arity of the translated function should be 1, not 2.
# TF believes that the value of this graph is None.
np.testing.assert_allclose(True,
run_relay(g, None, x_value, x_value).asnumpy())
run_relay(g, None, x_value).asnumpy())
def test_assert_false():
g = tf.Graph()
......
......@@ -1207,6 +1207,8 @@ def test_forward_stridedslice():
'''test StridedSlice'''
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
_test_stridedslice((2, 1), [0], [1], [1], 'float32', shrink_axis_mask=1)
_test_stridedslice((2, 3, 4), [0], [1], [1], 'float32', shrink_axis_mask=8)
_test_stridedslice((3, 4, 3), [1, -1, 0],
[4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment