Commit e6ca91e1 by lixiaoquan Committed by Jared Roesch

[Relay][Tensorflow] Allow an op as loop var. (#3056)

parent f88f4580
......@@ -30,6 +30,7 @@ from topi.util import get_const_tuple
from .. import ir_pass
from .. import expr as _expr
from .. import op as _op
from ..expr_functor import ExprMutator
__all__ = ['from_tensorflow']
......@@ -1414,6 +1415,27 @@ class RecurrentNetworks(object):
# 1.x.
_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']
class RewriteSubgraph(ExprMutator):
"""
A helper class to rewrite expr in while loop function to variable
Parameters
----------
rewrite_map : Dict[expr, expr]
A dictionay contains a set of expr to var mapping.
"""
def __init__(self, rewrite_map):
ExprMutator.__init__(self)
self.rewrite_map = rewrite_map
def visit(self, expr):
if expr in self.rewrite_map:
return self.rewrite_map[expr]
return super().visit(expr)
def rewrite_subgraph(expr, rewrites):
return RewriteSubgraph(rewrites).visit(expr)
def _in_while_loop(control_flow_node_map, op_name):
"""
Check if a given control flow operator is part of a while loop execution
......@@ -1594,14 +1616,17 @@ class Loop:
loop_vars = []
bind_map = {}
for i, var in enumerate(self.loop_vars):
assert isinstance(var, _expr.Var), repr(var)
v = tvm.relay.var("loop_var" + str(i),
type_annotation=var.type_annotation)
if not isinstance(var, _expr.Var):
var_type = ir_pass.infer_type(var).checked_type
else:
var_type = var.type_annotation
v = tvm.relay.var("loop_var" + str(i), type_annotation=var_type)
loop_vars.append(v)
bind_map[var] = v
self.cond = tvm.relay.bind(self.cond, bind_map)
self.body = [tvm.relay.bind(b, bind_map) for b in self.body]
self.cond = rewrite_subgraph(self.cond, bind_map)
self.body = [rewrite_subgraph(b, bind_map) for b in self.body]
cond = tvm.relay.op.min(self.cond)
......
......@@ -51,6 +51,23 @@ def test_vanilla_loop():
check_equal(graph, tf_out)
def test_callnode_loop_vars():
graph = tf.Graph()
with graph.as_default():
i = tf.add(tf.constant(0), 1)
def c(i): return tf.less(i, 10)
def b(i): return tf.add(i, 1)
r = tf.while_loop(c, b, [i])
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
def test_loop_2_vars():
graph = tf.Graph()
with graph.as_default():
......@@ -288,6 +305,7 @@ if __name__ == "__main__":
test_loop_3_vars()
test_loop_conditions()
test_loop_bodies()
test_callnode_loop_vars()
# tf.cond
test_vanilla_cond()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment