Commit 10b7757a by Albin Joy Committed by Tianqi Chen

[NNVM][TENSORFLOW] Fixed variable ops shape parsing issue (#1381)

parent 2fa0eca1
......@@ -593,11 +593,18 @@ class GraphProto(object):
raise NotImplementedError( \
"Const {} couldn't be converted to Param.".format(node.name))
try:
attr = self._parse_attr(node.attr)
#Variable converted to Const will not have only value attr
if 'value' in attr:
tensor_value = attr['value']
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in self._parse_attr(node.attr)['_output_shapes']]
except KeyError:
[tensor_util.TensorShapeProtoToList( \
tensor_value.tensor_shape)]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in self._parse_attr(node.attr)['_output_shapes']]
else:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
else:
......
......@@ -14,6 +14,8 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.core.framework import graph_pb2
import nnvm.testing.tf
......@@ -393,6 +395,44 @@ def test_forward_sigmoid():
_test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32'))
#######################################################################
# Variable
# --------
def _test_variable(data):
tf.reset_default_graph()
input_op = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
input_tensor = array_ops.reshape(input_op, data.shape)
size = input_tensor.shape.dims[1]
with variable_scope.variable_scope("linear", reuse=None):
w = variable_scope.get_variable(
"w", shape=[size, size], dtype=input_tensor.dtype)
# pylint: disable=unused-variable
output_op = math_ops.matmul(input_tensor, w)
# pylint: enable=unused-variable
with tf.Session() as sess:
sess.run(variables.global_variables_initializer())
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['MatMul'],
)
tf_output = run_tf_graph(sess, data, 'Placeholder:0', 'MatMul:0')
tvm_output = run_tvm_graph(final_graph_def, data,
"Placeholder", tf_output.shape, data.dtype)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
sess.close()
def test_forward_variable():
"""Variable type op test"""
_test_variable(np.random.uniform(size=(32, 100)).astype('float32'))
#######################################################################
# Multi Input to graph
# --------------------
......@@ -503,3 +543,4 @@ if __name__ == '__main__':
test_forward_inception_v3()
test_forward_inception_v1()
test_forward_mobilenet()
test_forward_variable()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment