Commit 00d509d4 by Alexey Romanov Committed by Tianqi Chen

[FRONTEND][TENSORFLOW] Support Unstack and Split (#2105)

parent 4bbf96e4
...@@ -36,6 +36,7 @@ class AttrCvt(object): ...@@ -36,6 +36,7 @@ class AttrCvt(object):
self._ignores.append('_node_name') self._ignores.append('_node_name')
self._ignores.append('is_training') self._ignores.append('is_training')
self._ignores.append('_target_layout') self._ignores.append('_target_layout')
self._ignores.append('_input_0d_mismatch')
# Retain the names # Retain the names
try: try:
attrs['name'] = attrs['_node_name'] attrs['name'] = attrs['_node_name']
...@@ -319,8 +320,7 @@ def _expand_dims(): ...@@ -319,8 +320,7 @@ def _expand_dims():
dim_input = inputs.pop(1) dim_input = inputs.pop(1)
axis = params[dim_input.list_output_names()[0]] axis = params[dim_input.list_output_names()[0]]
params.pop(dim_input.list_output_names()[0]) params.pop(dim_input.list_output_names()[0])
return AttrCvt(op_name="expand_dims", ignores=['Tdim'], return _expand_dims_0d_aware(inputs[0], attr, axis=axis.asnumpy()[0])
extras={'axis': axis.asnumpy()[0]})(inputs, attr)
return _impl return _impl
def _resize_bilinear(): def _resize_bilinear():
...@@ -383,7 +383,7 @@ def _concat(): ...@@ -383,7 +383,7 @@ def _concat():
def _pack(): def _pack():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = int(attr["axis"]) axis = int(attr["axis"])
inputs_reshaped = [_sym.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs] inputs_reshaped = [_expand_dims_0d_aware(i, attr, axis=axis, num_newaxis=1) for i in inputs]
return _sym.concatenate(*inputs_reshaped, axis=axis, name=attr["_node_name"]) return _sym.concatenate(*inputs_reshaped, axis=axis, name=attr["_node_name"])
return _impl return _impl
...@@ -787,15 +787,64 @@ def _broadcast(name): ...@@ -787,15 +787,64 @@ def _broadcast(name):
)(inputs, attr) )(inputs, attr)
return _impl return _impl
def _split(): def _split(has_size_vector):
# TF documentation https://www.tensorflow.org/api_docs/python/tf/split
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = params.pop(inputs[0].list_output_names()[0]) try:
return AttrCvt( # order and number of inputs are different:
op_name="split", ignores=['T'], # if has_size_vector:
transforms={'num_split': 'indices_or_sections'}, # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split-v
extras={'axis': axis.asnumpy()[0]})(inputs[1], attr) # else:
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split
# in addition, `axis` and `num_or_size_splits` can be tensors in TensorFlow,
# we can only support constants
if has_size_vector:
input_node_index = 0
input_axis_index = 2
size_splits_input_name = inputs[1].list_output_names()[0]
size_splits = params[size_splits_input_name].asnumpy()
section_beginnings = np.cumsum(size_splits)[:-1]
indices_or_sections = tuple(section_beginnings)
else:
input_node_index = 1
input_axis_index = 0
indices_or_sections = attr['num_split']
input_node = inputs[input_node_index]
axis_input_name = inputs[input_axis_index].list_output_names()[0]
axis_input_value = params[axis_input_name].asnumpy()[0]
except (IndexError, KeyError):
raise TypeError( \
"Unsupported argument for split: `axis` and `num_or_size_splits` " \
"should be constants")
return _sym.split(input_node,
indices_or_sections=indices_or_sections,
axis=axis_input_value)
return _impl return _impl
def _unpack():
def _impl(inputs, attr, params):
input_node = inputs[0]
axis = attr['axis']
input_shape = attr['_input_shapes'][input_node][0]
axis_length = input_shape[axis]
if axis_length < 0:
raise TypeError("Unstack with unknown axis length")
splitted = _sym.split(input_node,
indices_or_sections=axis_length,
axis=axis,
name=attr.get('_node_name', 'unstack'))
return _sym.Group([_sym.squeeze(split_item, axis=axis) for split_item in splitted])
return _impl
def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
if data in attr['_input_0d_mismatch']:
return data if num_newaxis == 1 else \
_sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis-1)
return _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis)
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -863,7 +912,9 @@ _convert_map = { ...@@ -863,7 +912,9 @@ _convert_map = {
'GreaterEqual' : _broadcast('greater_equal'), 'GreaterEqual' : _broadcast('greater_equal'),
'Equal' : _broadcast('equal'), 'Equal' : _broadcast('equal'),
'NotEqual' : _broadcast('not_equal'), 'NotEqual' : _broadcast('not_equal'),
'Split' : _split(), 'Split' : _split(False),
'SplitV' : _split(True),
'Unpack' : _unpack(),
} }
# _convert_map_rnn defines maps of rnn operator name to # _convert_map_rnn defines maps of rnn operator name to
...@@ -1059,6 +1110,7 @@ class GraphProto(object): ...@@ -1059,6 +1110,7 @@ class GraphProto(object):
self._output_shapes = {} self._output_shapes = {}
self._num_param = 0 self._num_param = 0
self._num_rnn_layer = False self._num_rnn_layer = False
self._outputs_are_0d = {}
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef. """Construct nnvm nodes from tensorflow graph definition - GraphDef.
...@@ -1114,6 +1166,7 @@ class GraphProto(object): ...@@ -1114,6 +1166,7 @@ class GraphProto(object):
# Operator name 'Const' is treated as a parameter to build NNVM params dict. # Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes = {} input_shapes = {}
input_0d_mismatch = set()
attr = self._parse_attr(node.attr) attr = self._parse_attr(node.attr)
#Variable converted to Const will not have only value attr #Variable converted to Const will not have only value attr
...@@ -1133,6 +1186,9 @@ class GraphProto(object): ...@@ -1133,6 +1186,9 @@ class GraphProto(object):
else: else:
raise NotImplementedError( \ raise NotImplementedError( \
"Please freeze the graph with add_shapes=True") "Please freeze the graph with add_shapes=True")
self._outputs_are_0d[node.name] = [ \
not shape if isinstance(shape, list) else False \
for shape in self._output_shapes[node.name]]
if node.op == "Placeholder": if node.op == "Placeholder":
self._nodes[node.name] = _sym.Variable(name=node.name, self._nodes[node.name] = _sym.Variable(name=node.name,
...@@ -1162,11 +1218,13 @@ class GraphProto(object): ...@@ -1162,11 +1218,13 @@ class GraphProto(object):
# Fill shapes for all inputs in a list # Fill shapes for all inputs in a list
inputs = [] inputs = []
for i in node.input: for i in node.input:
#ToDo: Some of the tensorflow operators internaly maintain # Some TensorFlow operators internally maintain execution layers
#execution layers and its output name will the layer number along with # and their output name includes the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the # graph node name. E.g. the node name is 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case, # output tensor name is 'Model/RNN/cell_0/RnnCell:0'. In this case,
#the digit has to be ignored. # the number has to be ignored for single-output nodes.
# On the other hand, for multi-output nodes the number is the output index,
# and the lack of the number implies 0.
tensor_name = i.split(':') tensor_name = i.split(':')
node_name = tensor_name[0] node_name = tensor_name[0]
if node_name in self._nodes: if node_name in self._nodes:
...@@ -1174,12 +1232,18 @@ class GraphProto(object): ...@@ -1174,12 +1232,18 @@ class GraphProto(object):
if len(in_sym.list_output_names()) > 1: if len(in_sym.list_output_names()) > 1:
tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0 tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0
in_sym = in_sym[tensor_slot] in_sym = in_sym[tensor_slot]
input_shape = (self._output_shapes[node_name])[tensor_slot] input_shape = self._output_shapes[node_name][tensor_slot]
else: else:
tensor_slot = 0
input_shape = self._output_shapes[node_name][0] input_shape = self._output_shapes[node_name][0]
inputs.append(in_sym) inputs.append(in_sym)
input_shapes[in_sym] = [input_shape] input_shapes[in_sym] = [input_shape]
# This means the node is 1d in NNVM and 0d in TF.
# See `_expand_dims_0d_aware`.
if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
input_0d_mismatch.add(in_sym)
attr['_input_shapes'] = input_shapes attr['_input_shapes'] = input_shapes
attr['_input_0d_mismatch'] = input_0d_mismatch
inputs = self._fix_extranodes(node.op, attr, inputs) inputs = self._fix_extranodes(node.op, attr, inputs)
op = self._convert_operator(node.op, inputs, attr, graph) op = self._convert_operator(node.op, inputs, attr, graph)
...@@ -1207,7 +1271,13 @@ class GraphProto(object): ...@@ -1207,7 +1271,13 @@ class GraphProto(object):
if outputs is None: if outputs is None:
out.append(final_op) out.append(final_op)
else: else:
out = [self._nodes[out_name] for out_name in outputs] for out_name in outputs:
if ":" in out_name:
out_name, out_num = out_name.split(":")
out_num = int(out_num)
out.append(self._nodes[out_name][out_num])
else:
out.append(self._nodes[out_name])
#Add the RNN outputs also with 'head' nodes of the nnvm graph #Add the RNN outputs also with 'head' nodes of the nnvm graph
if self._num_rnn_layer: if self._num_rnn_layer:
...@@ -1215,7 +1285,7 @@ class GraphProto(object): ...@@ -1215,7 +1285,7 @@ class GraphProto(object):
out.append(out_rnn) out.append(out_rnn)
if isinstance(out, list): if isinstance(out, list):
out = _sym.Group(out) out = _sym.Group(out) if len(out) > 1 else out[0]
return out, self._params return out, self._params
......
...@@ -124,7 +124,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, ...@@ -124,7 +124,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
if no_gpu and device == 'cuda': if no_gpu and device == 'cuda':
continue continue
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device) tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
num_output=len(out_node), target=device, out_names=out_name)
# since the names from tensorflow and nnvm runs are not exactly same, # since the names from tensorflow and nnvm runs are not exactly same,
# first len(tf_output) will be compared # first len(tf_output) will be compared
for i in range(len(tf_output)): for i in range(len(tf_output)):
...@@ -506,14 +507,24 @@ def test_forward_gather(): ...@@ -506,14 +507,24 @@ def test_forward_gather():
# Split # Split
# ----- # -----
def _test_split(in_shape, axis, num_split, dtype): def _test_split(in_shape, axis, num_or_size_splits, dtype):
np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
""" One iteration of a Split """ """ One iteration of a Split """
tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data")
num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits
tf.split(in_data, num_or_size_splits, axis=axis)
with tf.Graph().as_default(): compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)])
in_data = tf.placeholder(dtype, in_shape, name="in_data")
tf.split(in_data, num_split, axis) # and now test together with concat
np_data = np.random.uniform(size=in_shape).astype(dtype) tf.reset_default_graph()
compare_tf_with_tvm(np_data, 'in_data:0', 'split:0') in_data = tf.placeholder(dtype, in_shape, name="in_data")
splitted = tf.split(in_data, num_or_size_splits, axis=axis)
tf.concat(splitted, axis)
compare_tf_with_tvm([np_data], 'in_data:0', 'concat:0')
def test_forward_split(): def test_forward_split():
'''test split layer''' '''test split layer'''
...@@ -523,11 +534,11 @@ def test_forward_split(): ...@@ -523,11 +534,11 @@ def test_forward_split():
_test_split((6,), 0, 3, 'float32') _test_split((6,), 0, 3, 'float32')
# rank 2 # rank 2
_test_split((6, 2), 0, 3, 'float32') _test_split((6, 2), 0, 3, 'float32')
_test_split((2, 6), 1, 3, 'float32') _test_split((2, 6), 1, 6, 'float32')
# rank 3 # rank 3
_test_split((6, 2, 4), 0, 3, 'float32') _test_split((6, 2, 4), 0, 2, 'int32')
_test_split((2, 6, 4), 1, 3, 'float32') _test_split((2, 6, 4), 1, 3, 'float32')
_test_split((2, 4, 6), 2, 3, 'float32') _test_split((2, 4, 6), 2, 1, 'float32')
# rank 4 # rank 4
_test_split((6, 1, 3, 5), 0, 3, 'float32') _test_split((6, 1, 3, 5), 0, 3, 'float32')
_test_split((1, 6, 3, 5), 1, 3, 'float32') _test_split((1, 6, 3, 5), 1, 3, 'float32')
...@@ -538,45 +549,37 @@ def test_forward_split(): ...@@ -538,45 +549,37 @@ def test_forward_split():
_test_split((1, 6, 3, 5), -3, 3, 'float32') _test_split((1, 6, 3, 5), -3, 3, 'float32')
_test_split((1, 3, 6, 5), -2, 3, 'float32') _test_split((1, 3, 6, 5), -2, 3, 'float32')
_test_split((1, 3, 5, 6), -1, 3, 'float32') _test_split((1, 3, 5, 6), -1, 3, 'float32')
# size_splits list
_test_split((6,), 0, [1, 2, 3], 'int32')
_test_split((3, 6, 4), -2, [1, 4, 1], 'float32')
####################################################################### #######################################################################
# Split followed by concat # Unstack
# ------------------------ # -------
def _test_split_concat(in_shape, axis, num_split, dtype): def _test_unstack(ip_shape, axis, dtype):
""" One iteration of a split_concat pair""" np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)
with tf.Graph().as_default(): tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data") in_data = tf.placeholder(dtype, ip_shape, name="in_data")
splitted = tf.split(in_data, num_split, axis) tf.unstack(in_data, axis=axis)
tf.concat(splitted, axis)
np_data = np.random.uniform(size=in_shape).astype(dtype) compare_tf_with_tvm([np_data], ['in_data:0'], [f'unstack:{n}' for n in range(ip_shape[axis])])
compare_tf_with_tvm(np_data, 'in_data:0', 'concat:0')
tf.reset_default_graph()
def test_forward_split_concat(): in_data = tf.placeholder(dtype, ip_shape, name="in_data")
'''test split followed by concat layers''' tf.stack(tf.unstack(in_data, axis=axis), axis=axis)
# rank 1
_test_split_concat((3,), 0, 1, 'float32') compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0')
_test_split_concat((3,), 0, 3, 'float32')
_test_split_concat((6,), 0, 3, 'float32') def test_forward_unstack():
# rank 2 '''test unstack layer'''
_test_split_concat((6, 2), 0, 3, 'float32') _test_unstack((6,), 0, 'int32')
_test_split_concat((2, 6), 1, 3, 'float32') _test_unstack((2,6), 1, 'float64')
# rank 3 # negative axis
_test_split_concat((6, 2, 4), 0, 3, 'float32') _test_unstack((1,4), -1, 'int32')
_test_split_concat((2, 6, 4), 1, 3, 'float32') _test_unstack((3,6,4), -2, 'float32')
_test_split_concat((2, 4, 6), 2, 3, 'float32')
# rank 4
_test_split((6, 1, 3, 5), 0, 3, 'float32')
_test_split((1, 6, 3, 5), 1, 3, 'float32')
_test_split((1, 3, 6, 5), 2, 3, 'float32')
_test_split((1, 3, 5, 6), 3, 3, 'float32')
# split along negative axis
_test_split((6, 1, 3, 5), -4, 3, 'float32')
_test_split((1, 6, 3, 5), -3, 3, 'float32')
_test_split((1, 3, 6, 5), -2, 3, 'float32')
_test_split((1, 3, 5, 6), -1, 3, 'float32')
####################################################################### #######################################################################
...@@ -1139,7 +1142,7 @@ if __name__ == '__main__': ...@@ -1139,7 +1142,7 @@ if __name__ == '__main__':
test_forward_gather() test_forward_gather()
test_forward_stridedslice() test_forward_stridedslice()
test_forward_split() test_forward_split()
test_forward_split_concat() test_forward_unstack()
# Activations # Activations
test_forward_sigmoid() test_forward_sigmoid()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment