Commit c4439a80 by Siva Committed by Tianqi Chen

[TENSORLFOW] PlaceholderWithDefault (limited) implementation. (#3184)

parent 76ae2dc6
......@@ -1740,7 +1740,7 @@ class GraphProto(object):
for node in graph.node:
node_name_prefix = node.name.rsplit('/', 1)[0]
control_flow_node_map[node_name_prefix].add(node.op)
if node.op == 'Placeholder':
if node.op == 'Placeholder' or node.op == 'PlaceholderWithDefault':
# Give priority to user argument.
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
......@@ -1800,7 +1800,7 @@ class GraphProto(object):
attr = self._parse_attr(node.attr)
elif node.op != "Placeholder":
elif node.op != "Placeholder" and node.op != 'PlaceholderWithDefault':
# Pass the parsed shapes instead
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]
......@@ -1925,7 +1925,7 @@ class GraphProto(object):
"""
missing_operators = set()
for node in graph.node:
if node.op == "Placeholder":
if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault':
pass
elif node.op == "Const":
pass
......
......@@ -1541,6 +1541,24 @@ def test_forward_reduce_prod():
_test_forward_reduce_prod((5, 5), 0, True)
_test_forward_reduce_prod((5, 5), 1, True)
#######################################################################
# PlaceholderWithDefault
# ----------------------
def test_placeholder():
with tf.Graph().as_default():
in_data1 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32)
var1 = tf.Variable(in_data1, name='in1')
var2 = array_ops.placeholder_with_default(var1, None, name='place1')
in_data2 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32)
place1 = array_ops.placeholder(shape=in_data1.shape, dtype=in_data1.dtype, name='in2')
out1 = tf.math.add(var1, var2, name='out1')
out2 = tf.math.add(out1, place1, name='out2')
compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0', init_global_variables=True)
#######################################################################
# Main
# ----
......@@ -1590,6 +1608,7 @@ if __name__ == '__main__':
test_forward_multi_input()
test_forward_multi_output()
test_forward_variable()
test_placeholder()
# NN
test_forward_convolution()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment