Commit 51e2e31f by Yong Wu Committed by MORITA Kazutaka

[Frontend][TF] Fix Placeholder issue (#2834)

* [Frontend][TF] Fix Placeholder issue

* Add test cases
parent 7e34988e
...@@ -126,7 +126,7 @@ def _argx(func, func_name): ...@@ -126,7 +126,7 @@ def _argx(func, func_name):
def _elemwise(name): def _elemwise(name):
def _impl(inputs, attr, *args): def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs)) assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs))
op_name = _math_name_picker(name)(attr) op_name = _math_name_picker(name)(attr)
return get_nnvm_op(op_name)(*inputs) return get_nnvm_op(op_name)(*inputs)
return _impl return _impl
...@@ -1217,16 +1217,24 @@ class GraphProto(object): ...@@ -1217,16 +1217,24 @@ class GraphProto(object):
for node in graph.node: for node in graph.node:
if node.op == 'Placeholder': if node.op == 'Placeholder':
# Give priority to user argument.
if shape and node.name in shape: if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name]) self._input_shapes[node.name] = list(shape[node.name])
continue else:
self._input_shapes[node.name] = \ self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape) tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]): for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0: if dim < 0:
self._input_shapes[node.name][idx] = 1 self._input_shapes[node.name][idx] = 1
warnings.warn("Use 1 instead of -1 in shape of operator %s." warnings.warn("Use 1 instead of -1 in shape of operator %s."
% node.name) % node.name)
self._nodes[node.name] = _sym.Variable(name=node.name,
shape=self._input_shapes[node.name])
self._output_shapes[node.name] = [self._input_shapes[node.name]]
self._outputs_are_0d[node.name] = [ \
not tshape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]]
# Ignore user's input shape for Non placeholder # Ignore user's input shape for Non placeholder
elif node.op == 'Const': elif node.op == 'Const':
...@@ -1250,11 +1258,6 @@ class GraphProto(object): ...@@ -1250,11 +1258,6 @@ class GraphProto(object):
# Variable converted to Const will not have only value attr # Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const': if 'value' in attr and node.op == 'Const':
self._output_shapes[node.name] = [self._input_shapes[node.name]] self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif shape and node.name in shape:
# Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]]
elif node.op == 'Placeholder':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif '_output_shapes' in attr: elif '_output_shapes' in attr:
self._output_shapes[node.name] = \ self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \ [tensor_util.TensorShapeProtoToList(tshape) \
...@@ -1269,11 +1272,7 @@ class GraphProto(object): ...@@ -1269,11 +1272,7 @@ class GraphProto(object):
not tshape if isinstance(tshape, list) else False \ not tshape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]] for tshape in self._output_shapes[node.name]]
if node.op == "Placeholder": if node.op == "Const":
self._nodes[node.name] = _sym.Variable(name=node.name,
shape=self._input_shapes[node.name])
elif node.op == "Const":
# All Const nodes are Param nodes, lets parse # All Const nodes are Param nodes, lets parse
self._num_param += 1 self._num_param += 1
for key, value in node.attr.items(): for key, value in node.attr.items():
...@@ -1284,7 +1283,7 @@ class GraphProto(object): ...@@ -1284,7 +1283,7 @@ class GraphProto(object):
attr = self._parse_attr(node.attr) attr = self._parse_attr(node.attr)
else: elif node.op != "Placeholder":
# Pass the parsed shapes instead # Pass the parsed shapes instead
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]
......
...@@ -941,6 +941,29 @@ def test_forward_resnetv2(): ...@@ -941,6 +941,29 @@ def test_forward_resnetv2():
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# Placeholder
# -----------
def test_forward_placeholder():
'''test a simple pb with Placeholder node in the end of GraphDef'''
with tf.Graph().as_default():
graph_def = tf_testing.get_workload("Custom/placeholder.pb")
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
out_node = 'mul'
with tf.Session() as sess:
# Add shapes to the graph.
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'Placeholder')
print("tf_output is {}\ntvm_output is {}".format(tf_output, tvm_output))
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
#######################################################################
# PTB # PTB
# --- # ---
dir(tf.contrib) dir(tf.contrib)
...@@ -1261,6 +1284,7 @@ if __name__ == '__main__': ...@@ -1261,6 +1284,7 @@ if __name__ == '__main__':
test_forward_inception_v1() test_forward_inception_v1()
test_forward_mobilenet() test_forward_mobilenet()
test_forward_resnetv2() test_forward_resnetv2()
test_forward_placeholder()
test_forward_ptb() test_forward_ptb()
# RNN # RNN
......
...@@ -239,7 +239,7 @@ def _argx(func, func_name): ...@@ -239,7 +239,7 @@ def _argx(func, func_name):
def _elemwise(name): def _elemwise(name):
def _impl(inputs, attr, *args): def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs)) assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs))
return _get_relay_op(name)(*inputs) return _get_relay_op(name)(*inputs)
return _impl return _impl
...@@ -1704,16 +1704,23 @@ class GraphProto(object): ...@@ -1704,16 +1704,23 @@ class GraphProto(object):
node_name_prefix = node.name.rsplit('/', 1)[0] node_name_prefix = node.name.rsplit('/', 1)[0]
control_flow_node_map[node_name_prefix].add(node.op) control_flow_node_map[node_name_prefix].add(node.op)
if node.op == 'Placeholder': if node.op == 'Placeholder':
# Give priority to user argument.
if shape and node.name in shape: if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name]) self._input_shapes[node.name] = list(shape[node.name])
continue else:
self._input_shapes[node.name] = \ self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape) tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]): for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0: if dim < 0:
self._input_shapes[node.name][idx] = 1 self._input_shapes[node.name][idx] = 1
warnings.warn("Use 1 instead of -1 in shape of operator %s." warnings.warn("Use 1 instead of -1 in shape of operator %s."
% node.name) % node.name)
self._output_shapes[node.name] = [self._input_shapes[node.name]]
attr = self._parse_attr(node.attr)
self._nodes[node.name] = [_expr.var(node.name,
shape=self._input_shapes[node.name],
dtype=attr['dtype'].name)]
# Ignore user's input shape for Non placeholder # Ignore user's input shape for Non placeholder
elif node.op == 'Const': elif node.op == 'Const':
...@@ -1736,11 +1743,6 @@ class GraphProto(object): ...@@ -1736,11 +1743,6 @@ class GraphProto(object):
# Variable converted to Const will not have only value attr # Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const': if 'value' in attr and node.op == 'Const':
self._output_shapes[node.name] = [self._input_shapes[node.name]] self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif shape and node.name in shape:
# Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]]
elif node.op == 'Placeholder':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif '_output_shapes' in attr: elif '_output_shapes' in attr:
self._output_shapes[node.name] = \ self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \ [tensor_util.TensorShapeProtoToList(tshape) \
...@@ -1755,13 +1757,7 @@ class GraphProto(object): ...@@ -1755,13 +1757,7 @@ class GraphProto(object):
not shape if isinstance(tshape, list) else False \ not shape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]] for tshape in self._output_shapes[node.name]]
if node.op == "Placeholder": if node.op == "Const":
self._output_shapes[node.name] = [self._input_shapes[node.name]]
self._nodes[node.name] = [_expr.var(node.name,
shape=self._input_shapes[node.name],
dtype=attr['dtype'].name)]
elif node.op == "Const":
# All Const nodes are Param nodes, lets parse # All Const nodes are Param nodes, lets parse
self._num_param += 1 self._num_param += 1
for key, value in node.attr.items(): for key, value in node.attr.items():
...@@ -1772,7 +1768,7 @@ class GraphProto(object): ...@@ -1772,7 +1768,7 @@ class GraphProto(object):
attr = self._parse_attr(node.attr) attr = self._parse_attr(node.attr)
else: elif node.op != "Placeholder":
# Pass the parsed shapes instead # Pass the parsed shapes instead
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]
...@@ -1816,7 +1812,8 @@ class GraphProto(object): ...@@ -1816,7 +1812,8 @@ class GraphProto(object):
input_shapes[in_sym[0]] = input_shape input_shapes[in_sym[0]] = input_shape
# This means the node is 1d in Relay and 0d in TF. # This means the node is 1d in Relay and 0d in TF.
# See `_expand_dims_0d_aware`. # See `_expand_dims_0d_aware`.
if self._outputs_are_0d[node_name][tensor_slot] and input_shape: if node_name in self._outputs_are_0d \
and self._outputs_are_0d[node_name][tensor_slot] and input_shape:
input_0d_mismatch.add(in_sym[0]) input_0d_mismatch.add(in_sym[0])
attr['_input_shapes'] = input_shapes attr['_input_shapes'] = input_shapes
......
...@@ -1134,6 +1134,27 @@ def test_forward_resnetv2(): ...@@ -1134,6 +1134,27 @@ def test_forward_resnetv2():
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# Placeholder
# -----------
def test_forward_placeholder():
'''test a simple pb with Placeholder node in the end of GraphDef'''
with tf.Graph().as_default():
graph_def = tf_testing.get_workload("Custom/placeholder.pb")
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
out_node = 'mul'
with tf.Session() as sess:
# Add shapes to the graph.
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'Placeholder')
print("tf_output is {}\ntvm_output is {}".format(tf_output, tvm_output))
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
#######################################################################
# PTB # PTB
# --- # ---
dir(tf.contrib) dir(tf.contrib)
...@@ -1514,6 +1535,7 @@ if __name__ == '__main__': ...@@ -1514,6 +1535,7 @@ if __name__ == '__main__':
test_forward_inception_v1() test_forward_inception_v1()
test_forward_mobilenet() test_forward_mobilenet()
test_forward_resnetv2() test_forward_resnetv2()
test_forward_placeholder()
test_forward_ptb() test_forward_ptb()
# RNN # RNN
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment