Unverified Commit 8502691b by maheshambule Committed by GitHub

[Frontend] [Tensorflow] ReadVariableOp operator support (#4952)

* tf frontend read variable op

* pylint fix

* tf frontend freezed graph pruned ops
parent 0fb48360
......@@ -1500,6 +1500,12 @@ def _add_n():
# compatible operators that do NOT require any conversion.
_identity_list = []
# Operators that get pruned away when the complete graph is frozen.
# These operators are not needed for inference.
_freezed_graph_pruned_op_list = ['ReadVariableOp', 'ResourceGather', 'Variable',
'VariableV2', 'VarHandleOp', 'Assign', 'AssignVariableOp']
# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
......@@ -2187,6 +2193,11 @@ class GraphProto(object):
missing_operators = self._parse_import_prerequisites(graph)
if missing_operators:
freezed_ops = [op for op in missing_operators if op in _freezed_graph_pruned_op_list]
if freezed_ops:
raise Exception("Graph is not frozen. Provide a frozen graph. "
"Found operators {}".format(freezed_ops))
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))
......
......@@ -22,6 +22,7 @@ This article is a test script to test tensorflow operator with Relay.
"""
from __future__ import print_function
import numpy as np
import pytest
import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import graph_util
......@@ -1061,6 +1062,62 @@ def test_forward_variable():
_test_variable(np.random.uniform(size=(32, 100)).astype('float32'))
def test_read_variable_op():
""" Read Variable op test """
tf.reset_default_graph()
data = np.random.uniform(size=(32, 100)).astype('float32')
input_tensor = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
size = input_tensor.shape.dims[1]
var_data = np.random.uniform(-5, 5, size=[size, size]).astype(np.float32)
input_var = tf.Variable(var_data, name='var1', use_resource=True)
math_ops.matmul(input_tensor, input_var)
out_name = ['MatMul:0']
out_node = ['MatMul']
in_name = ['Placeholder:0']
in_node = ['Placeholder']
in_data = [data]
with tf.Session() as sess:
sess.run(variables.global_variables_initializer())
final_graph_def = sess.graph.as_graph_def(add_shapes=True)
tf_output = run_tf_graph(sess, in_data, in_name, out_name)
shape_dict = {e: i.shape for e, i in zip(in_name, in_data)}
with pytest.raises(Exception) as exexcinfo:
mod, params = relay.frontend.from_tensorflow(final_graph_def,
layout=None,
shape=shape_dict,
outputs=None)
assert exexcinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph.")
# Now convert the variables to constant and run inference on the converted graph
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
out_node,
)
for device in ["llvm", "cuda"]:
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
target=device, out_names=out_name,
num_output=len(out_name))
for i in range(len(tf_output)):
tvm.testing.assert_allclose(
tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
sess.close()
#######################################################################
# MatMul, BatchMatMul, BatchMatMulV2
# ----------------------------------
......@@ -3038,3 +3095,6 @@ if __name__ == '__main__':
test_forward_where()
test_forward_matmul()
test_forward_batch_matmul()
# Internal misc. ops
test_read_variable_op()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment