Commit 2df3364b by Zhi Committed by Yizhi Liu

[RELAY][Frontend][TF] decompile tf control flow (#2830)

* decompile tf control flow

* Add docs

* remove import relay

* move tests under tensorflow frontend

* minor fix
parent 8ef35dc2
......@@ -5,6 +5,7 @@ from __future__ import print_function
import logging
import warnings
from collections import defaultdict
# Numpy support
import numpy as np
......@@ -1270,6 +1271,220 @@ class RecurrentNetworks(object):
params, num_layers)
return sym
# An internal list to contain all the control flow primitives used in Tensorflow
# 1.x.
_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']
def _in_while_loop(control_flow_node_map, op_name):
"""
Check if a given control flow operator is part of a while loop execution
frame. This is based on the fact that there is only one occurrence of
`LoopCond` for a loop execution frame and it is only presented in the loop
construct.
Parameters
----------
control_flow_node_map : Dict[str, Set[str]]
A dictionay contains the unqiue control flow execution frame name to
a set of primitive operators mapping.
op_name : str
The name of a control flow primitive.
Returns
-------
ret : bool
Return true if the operator is in a while loop execution frame,
otherwise, return false.
"""
return op_name in control_flow_node_map and \
"LoopCond" in control_flow_node_map[op_name]
class Branch:
"""A class contains the components that are used to build up a Relay if
node.
Parameters
----------
cond : tvm.relay.Expr
The condition of a if node.
true_branch : tvm.relay.Expr
The body of the true branch of a if expression.
false_branch: tvm.relay.Expr
The body of the false branch of a if expression.
_if : tvm.relay.Expr
An internal variable indicates where an if expression is already created
for a matched TF condition construct.
Examples
--------
The following is a cond statement written in TensorFlow:
.. code-block:: python
def vanilla_cond():
i = tf.constant(1)
j = tf.constant(4)
def f1():
return tf.multiply(1, 17)
def f2():
return tf.add(4, 23)
r = tf.cond(tf.less(i, j), f1, f2)
This condition statement should be coverted into Relay in the following
form:
.. code-block:: python
fn (%Const: Tensor[(1,), int32],
%Const_1: Tensor[(1,), int32],
%cond/Mul/x: Tensor[(1,), int32],
%cond/Mul/y: Tensor[(1,), int32],
%cond/Add/x: Tensor[(1,), int32],
%cond/Add/y: Tensor[(1,), int32]) {
%0 = less(%Const, %Const_1) # ty=Tensor[(1,), bool]
%1 = min(%0)
if (%1) {
%2 = multiply(%cond/Mul/x, %cond/Mul/y)
%2
} else {
%3 = add(%cond/Add/x, %cond/Add/y)
%3
}
}
"""
def __init__(self):
self._if = None
self.cond = None
self.true_branch = None
self.false_branch = None
def _if_node(self):
"""An internal API to create a relay if node from the matched TF
condition construct.
"""
# `cond` returns a tensor that contains boolean values. We add a `min`
# operator to checks if there is any false value. If so, this condition
# doesn't not hold.
cond = tvm.relay.op.min(self.cond)
return tvm.relay.If(cond, self.true_branch, self.false_branch)
def if_node(self):
"""Create an tvm.relay.If node if it hasn't been created yet."""
if self._if is None:
self._if = self._if_node()
return self._if
class Loop:
"""
A class contains the components that are used to build up a Relay
recursive call.
Parameters
----------
loop_vars : List[tvm.relay.Expr]
The loop variables that used in a while loop.
cond : tvm.relay.Expr
The condition of a while loop.
body : tvm.relay.Expr
The body of a matched while loop.
_loop : tvm.relay.Expr
An internal variable indicates where a recursive call is already created
for a matched TF while loop construct.
Examples
--------
The following is a vanilla loop from TensorFlow:
.. code-block:: python
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])
It will be converted to the following recursive call in Relay:
.. code-block:: python
fn (%while/Less/y: Tensor[(1,), int32],
%while/Add/y: Tensor[(1,), int32],
%Const: Tensor[(1,), int32]) {
%0 = fn(%loop_var0: Tensor[(1,), int32]) {
%1 = less(%loop_var0, %while/Less/y)
%2 = min(%1)
if (%2) {
%3 = add(%loop_var0, %while/Add/y)
free_var %while_loop
%4 = %while_loop(%3)
%4
} else {
%5 = (%loop_var0,)
%5
}
}
let %while_loop1 = %0
%6 = %while_loop1(%Const)
%6
}
"""
def __init__(self):
self.loop_vars = []
self.cond = None
self.body = []
self._loop = None
def _while_loop(self):
"""An internal API to create a Relay recurisve call for a matched TF
`while_loop` construct.
"""
wl = tvm.relay.var('while_loop')
sb = tvm.relay.scope_builder.ScopeBuilder()
loop_vars = []
bind_map = {}
for i, var in enumerate(self.loop_vars):
assert isinstance(var, _expr.Var), repr(var)
v = tvm.relay.var("loop_var" + str(i),
type_annotation=var.type_annotation)
loop_vars.append(v)
bind_map[var] = v
self.cond = tvm.relay.bind(self.cond, bind_map)
self.body = [tvm.relay.bind(b, bind_map) for b in self.body]
cond = tvm.relay.op.min(self.cond)
with sb.if_scope(cond):
sb.ret(wl(*self.body))
with sb.else_scope():
sb.ret(tvm.relay.Tuple(loop_vars))
loop_fn = tvm.relay.Function(loop_vars, sb.get())
sb = tvm.relay.scope_builder.ScopeBuilder()
sb.let(wl, loop_fn)
sb.ret(wl(*self.loop_vars))
return sb.get()
def while_loop(self):
"""Instantiate a while loop if it has not been created yet."""
if self._loop is None:
self._loop = self._while_loop()
return self._loop
return self._loop
class GraphProto(object):
""" A helper class for handling relay graph copying from Tensorflow GraphDef.
Definition:
......@@ -1284,6 +1499,8 @@ class GraphProto(object):
self._num_rnn_layer = False
self._outputs_are_0d = {}
self._input_shapes = {}
self._loops = {}
self._branches = {}
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
......@@ -1332,7 +1549,10 @@ class GraphProto(object):
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))
control_flow_node_map = defaultdict(set)
for node in graph.node:
node_name_prefix = node.name.rsplit('/', 1)[0]
control_flow_node_map[node_name_prefix].add(node.op)
if node.op == 'Placeholder':
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
......@@ -1447,12 +1667,17 @@ class GraphProto(object):
# This means the node is 1d in Relay and 0d in TF.
# See `_expand_dims_0d_aware`.
if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
input_0d_mismatch.add(in_sym)
input_0d_mismatch.add(in_sym[0])
attr['_input_shapes'] = input_shapes
attr['_input_0d_mismatch'] = input_0d_mismatch
op = self._convert_operator(node.op, inputs, attr, graph)
if node.op in _control_flow_nodes:
op = self._convert_control_flow_operator(node, inputs,
attr,
control_flow_node_map)
else:
op = self._convert_operator(node.op, inputs, attr, graph)
# Check if op is converted to param
if isinstance(op, np.ndarray):
......@@ -1493,7 +1718,10 @@ class GraphProto(object):
out = []
if outputs is None:
out = op
if node.op == "Exit":
out = [op[0].tuple_value]
else:
out = op
else:
for out_name in outputs:
if ":" in out_name:
......@@ -1529,7 +1757,9 @@ class GraphProto(object):
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
if any([node.op in t for t in [_identity_list, _convert_map,
_convert_map_rnn,
_control_flow_nodes]]):
pass
else:
missing_operators.add(node.op)
......@@ -1656,6 +1886,89 @@ class GraphProto(object):
sym = self.rnn.process_op(op_name, inputs, attrs, params)
return sym
def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_map):
"""
Convert the Relay control flow primitive into corresponding component
of a Relay control flow construct, i.e. `tf.cond` and `tf.while_loop`
are converted in Relay `If` and recusrive call, respectively.
Parameters
----------
node: TensorFlow graph node object.
A TensorFlow graph node object.
inputs : List[tvm.relay.Expr]
List of input symbols.
attrs : Dict[tvm.Attrs]
Dict of operator attributes.
control_flow_node_map : Dict[str, Set[str]]
A dictionary contains the execution frame name to primitives
mapping.
Returns
-------
op : tvm.relay.Expr
Converted relay expression.
"""
node_name_prefix = node.name.rsplit('/', 1)[0]
if node.op == "Merge":
if _in_while_loop(control_flow_node_map, node_name_prefix):
op = self._nodes[node.input[0]]
self._loops[node_name_prefix] = Loop()
else:
if len(self._branches) == 0:
raise RuntimeError("Cannot find a created "
"conditional for merge node")
branch = self._branches[node_name_prefix]
false_br = self._nodes[node.input[0]]
true_br = self._nodes[node.input[1]]
assert len(true_br) == 1
assert len(false_br) == 1
branch.true_branch = true_br[0]
branch.false_branch = false_br[0]
op = [branch.if_node()]
elif node.op == "Exit":
loop = self._loops[node_name_prefix]
exit_name = node.name.split('/')[-1]
assert str.startswith(exit_name, 'Exit')
# TensorFlow has differen naming convention on different
# versions.
if '_' in exit_name:
exit_number = int("0" + exit_name[5:])
else:
exit_number = int("0" + exit_name[4:])
expr = loop.while_loop()
op = _expr.TupleGetItem(expr, exit_number)
elif node.op == "Enter":
op = self._nodes[node.input[0]]
elif node.op == "LoopCond":
op = self._nodes[node.input[0]]
assert len(op) == 1
self._loops[node_name_prefix].cond = op[0]
elif node.op == "Switch":
op = self._nodes[node.input[0]]
assert len(op) == 1
if _in_while_loop(control_flow_node_map, node_name_prefix):
self._loops[node_name_prefix].loop_vars.append(op[0])
else:
if node_name_prefix not in self._branches:
self._branches[node_name_prefix] = Branch()
self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0])
elif node.op == "NextIteration":
op = self._nodes[node.input[0]]
assert len(op) == 1
self._loops[node_name_prefix].body.append(op[0])
else:
raise Exception("Cannot identify control flow operator: " +
"{}".format(node.op))
return op
def _convert_operator(self, op_name, inputs, attrs,
graph, identity_list=None, convert_map=None):
"""Convert from Tensorflow operator to relay operator.
......
......@@ -270,16 +270,30 @@ class Interpreter :
return TupleValueNode::make(values);
}
Value VisitExpr_(const FunctionNode* func_node) final {
auto func = GetRef<Function>(func_node);
// TODO(@jroesch): this doesn't support mutual letrec.
Value MakeClosure(const Function& func, const Var& letrec_name = Var()) {
tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func);
for (const auto& var : free_vars) {
captured_mod.Set(var, Eval(var));
// Evaluate the free var (which could be a function call) if it hasn't
// shown up in a letting binding that has invoked the function.
if (!letrec_name.defined() || letrec_name != var) {
captured_mod.Set(var, Eval(var));
}
}
return ClosureNode::make(captured_mod, func);
// We must use mutation here to build a self referential closure.
auto closure = ClosureNode::make(captured_mod, func);
auto mut_closure =
static_cast<ClosureNode*>(const_cast<Node*>(closure.get()));
mut_closure->env.Set(letrec_name, closure);
return closure;
}
Value VisitExpr_(const FunctionNode* func_node) final {
auto func = GetRef<Function>(func_node);
return MakeClosure(func);
}
Value InvokePrimitiveOp(Function func,
......@@ -438,10 +452,16 @@ class Interpreter :
}
}
Value VisitExpr_(const LetNode* op) final {
auto value = Eval(op->value);
this->extend(op->var, value);
return Eval(op->body);
Value VisitExpr_(const LetNode* let) final {
if (auto func = let->value.as<FunctionNode>()) {
auto clo = MakeClosure(GetRef<Function>(func), let->var);
this->extend(let->var, clo);
} else {
auto value = Eval(let->value);
this->extend(let->var, value);
}
return Eval(let->body);
}
Value VisitExpr_(const TupleGetItemNode* op) final {
......
"""Unit tests for converting TensorFlow control flow op to Relay."""
import tensorflow as tf
import numpy as np
from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow
def check_equal(graph, tf_out):
expr, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
ex = relay.create_executor('debug')
relay_out = ex.evaluate(expr)(**params)
if isinstance(relay_out, relay.backend.interpreter.TensorValue):
np.testing.assert_allclose(tf_out, relay_out.asnumpy())
else:
if not isinstance(tf_out, list):
tf_out = [tf_out]
for x, y in zip(tf_out, [r.asnumpy() for r in relay_out]):
np.testing.assert_allclose(x, y)
def test_vanilla_loop():
graph = tf.Graph()
with graph.as_default():
i = tf.constant(0)
def c(i): return tf.less(i, 10)
def b(i): return tf.add(i, 1)
r = tf.while_loop(c, b, [i])
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
def test_loop_2_vars():
graph = tf.Graph()
with graph.as_default():
i0 = tf.constant(0)
j0 = tf.ones([2, 2])
def c(i, j): return i < 10
def b(i, j): return [tf.add(i, 1), j]
i1, i2 = tf.while_loop(c, b, loop_vars=[i0, j0])
i1 += tf.constant(1337)
with tf.Session() as sess:
tf_out = sess.run(i1)
check_equal(graph, tf_out)
def test_loop_3_vars():
graph = tf.Graph()
with graph.as_default():
i0 = tf.constant(1)
j0 = tf.constant(2)
k0 = tf.constant(4)
def c(i, j, k): return i < 10
def b(i, j, k): return [i+1, j * k, k + i]
r = tf.while_loop(c, b, loop_vars=[i0, j0, k0])
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
def test_loop_conditions():
graph = tf.Graph()
with graph.as_default():
i = tf.constant(1)
j = tf.constant(1)
k = tf.constant(5)
def c(i, j, k): return \
tf.equal(tf.not_equal(tf.less(i + j, 10),
tf.less(j * k, 100)),
tf.greater_equal(k, i + j))
def b(i, j, k): return [i+j, j+k, k+1]
r = tf.while_loop(c, b, loop_vars=[i, j, k])
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
def test_loop_bodies():
graph = tf.Graph()
with graph.as_default():
def body(x):
a = tf.constant(np.array([[5, 6], [7, 8]]), dtype=tf.int32)
b = tf.constant(np.array([[1, 2], [3, 4]]), dtype=tf.int32)
c = a + b
return tf.nn.relu(x + c)
def condition(x):
return tf.reduce_sum(x) < 100
x = tf.constant(0, shape=[2, 2])
r = tf.while_loop(condition, body, [x])
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
def test_nested_loop():
graph = tf.Graph()
with graph.as_default():
def body(x):
def nest_body(c):
return tf.multiply(c, 2)
def cd(c): return tf.less(c, 10)
c = tf.constant(2)
res = tf.while_loop(cd, nest_body, loop_vars=[c])
return tf.nn.relu(x + res)
def condition(x):
return tf.greater(x, 100)
x = tf.constant(3)
r = tf.while_loop(condition, body, loop_vars=[x])
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
def test_vanilla_cond():
graph = tf.Graph()
with graph.as_default():
i = tf.constant(1)
j = tf.constant(4)
def f1():
return tf.multiply(1, 17)
def f2():
return tf.add(4, 23)
r = tf.cond(tf.less(i, j), f1, f2)
with tf.Session(graph=graph) as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
def test_multiple_cond_vars():
graph = tf.Graph()
with graph.as_default():
x1 = tf.constant(7)
x2 = tf.constant(12)
z = tf.constant(20)
r = tf.cond(tf.less(tf.add(x1, x2), 10),
lambda: tf.add(10, 2), lambda: tf.square(5))
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
def test_cond_fn_parameters():
graph = tf.Graph()
with graph.as_default():
def fn1(x, y):
return tf.multiply(5, 6)
def fn2(x, y):
return tf.add(3, 4)
i = tf.constant(1)
j = tf.constant(2)
k = tf.constant(3)
r = tf.cond(tf.less(i, j), lambda: fn1(i, k), lambda: fn2(j, k))
with tf.Session() as sess:
tf_out = sess.run(r, feed_dict={i: 1, j: 2, k: 3})
check_equal(graph, tf_out)
def test_nested_cond():
graph = tf.Graph()
with graph.as_default():
def fn1(a, b):
def nest_fn1():
return tf.add(1, 2)
def nest_fn2():
return tf.subtract(10, 5)
res = tf.cond(tf.less(1, 2), nest_fn1, nest_fn2)
return tf.multiply(tf.add(87, res), 10)
def fn2(a, b):
return tf.add(10, 10)
x = tf.constant(5)
y = tf.constant(6)
z = tf.constant(7)
pred = tf.less(x, y)
r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))
with tf.Session() as sess:
tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True})
check_equal(graph, tf_out)
def test_loop_in_cond():
graph = tf.Graph()
with graph.as_default():
def fn1(a, b):
i = tf.constant(0)
def cd(i): return tf.less(i, 10)
def bd(i): return tf.add(i, 1)
res = tf.while_loop(cd, bd, [i])
return tf.multiply(tf.add(20, res), 10)
def fn2(a, b):
return tf.add(10, 20)
x = tf.constant(7)
y = tf.constant(20)
z = tf.constant(10)
pred = tf.less(x, y)
r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))
with tf.Session() as sess:
tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True})
check_equal(graph, tf_out)
def test_cond_in_loop():
graph = tf.Graph()
with graph.as_default():
def body(x):
x = tf.constant(7)
z = tf.constant(20)
res = tf.cond(tf.less(x, 10), lambda: tf.add(
10, 20), lambda: tf.square(10))
return tf.multiply(res, x)
x = tf.constant(21)
def condition(x):
return tf.less(x, 100)
r = tf.while_loop(condition, body, loop_vars=[x])
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
if __name__ == "__main__":
# tf.while_loop
test_vanilla_loop()
test_loop_2_vars()
test_loop_3_vars()
test_loop_conditions()
test_loop_bodies()
# tf.cond
test_vanilla_cond()
test_multiple_cond_vars()
test_cond_fn_parameters()
# nested cases
test_nested_loop()
test_nested_cond()
test_loop_in_cond()
test_cond_in_loop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment