Commit e6ca91e1 by lixiaoquan Committed by Jared Roesch

[Relay][Tensorflow] Allow an op as loop var. (#3056)

parent f88f4580
...@@ -30,6 +30,7 @@ from topi.util import get_const_tuple ...@@ -30,6 +30,7 @@ from topi.util import get_const_tuple
from .. import ir_pass from .. import ir_pass
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from ..expr_functor import ExprMutator
__all__ = ['from_tensorflow'] __all__ = ['from_tensorflow']
...@@ -1414,6 +1415,27 @@ class RecurrentNetworks(object): ...@@ -1414,6 +1415,27 @@ class RecurrentNetworks(object):
# 1.x. # 1.x.
_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond'] _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']
class RewriteSubgraph(ExprMutator):
"""
A helper class to rewrite expr in while loop function to variable
Parameters
----------
rewrite_map : Dict[expr, expr]
A dictionay contains a set of expr to var mapping.
"""
def __init__(self, rewrite_map):
ExprMutator.__init__(self)
self.rewrite_map = rewrite_map
def visit(self, expr):
if expr in self.rewrite_map:
return self.rewrite_map[expr]
return super().visit(expr)
def rewrite_subgraph(expr, rewrites):
return RewriteSubgraph(rewrites).visit(expr)
def _in_while_loop(control_flow_node_map, op_name): def _in_while_loop(control_flow_node_map, op_name):
""" """
Check if a given control flow operator is part of a while loop execution Check if a given control flow operator is part of a while loop execution
...@@ -1594,14 +1616,17 @@ class Loop: ...@@ -1594,14 +1616,17 @@ class Loop:
loop_vars = [] loop_vars = []
bind_map = {} bind_map = {}
for i, var in enumerate(self.loop_vars): for i, var in enumerate(self.loop_vars):
assert isinstance(var, _expr.Var), repr(var) if not isinstance(var, _expr.Var):
v = tvm.relay.var("loop_var" + str(i), var_type = ir_pass.infer_type(var).checked_type
type_annotation=var.type_annotation) else:
var_type = var.type_annotation
v = tvm.relay.var("loop_var" + str(i), type_annotation=var_type)
loop_vars.append(v) loop_vars.append(v)
bind_map[var] = v bind_map[var] = v
self.cond = tvm.relay.bind(self.cond, bind_map) self.cond = rewrite_subgraph(self.cond, bind_map)
self.body = [tvm.relay.bind(b, bind_map) for b in self.body] self.body = [rewrite_subgraph(b, bind_map) for b in self.body]
cond = tvm.relay.op.min(self.cond) cond = tvm.relay.op.min(self.cond)
......
...@@ -51,6 +51,23 @@ def test_vanilla_loop(): ...@@ -51,6 +51,23 @@ def test_vanilla_loop():
check_equal(graph, tf_out) check_equal(graph, tf_out)
def test_callnode_loop_vars():
graph = tf.Graph()
with graph.as_default():
i = tf.add(tf.constant(0), 1)
def c(i): return tf.less(i, 10)
def b(i): return tf.add(i, 1)
r = tf.while_loop(c, b, [i])
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
def test_loop_2_vars(): def test_loop_2_vars():
graph = tf.Graph() graph = tf.Graph()
with graph.as_default(): with graph.as_default():
...@@ -288,6 +305,7 @@ if __name__ == "__main__": ...@@ -288,6 +305,7 @@ if __name__ == "__main__":
test_loop_3_vars() test_loop_3_vars()
test_loop_conditions() test_loop_conditions()
test_loop_bodies() test_loop_bodies()
test_callnode_loop_vars()
# tf.cond # tf.cond
test_vanilla_cond() test_vanilla_cond()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment