Commit f6c3f997 by Alexey Romanov Committed by Siva

[FRONTEND][TENSORFLOW] Use input shapes directly instead of 1-element lists (#2242)

parent 6d1f4c0b
...@@ -120,7 +120,7 @@ def _pooling(name): ...@@ -120,7 +120,7 @@ def _pooling(name):
attr['data_format'] = attr['data_format'].decode("utf-8") attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False flip_layout = False
input_shape = attr['_input_shapes'][inputs[0]][0] input_shape = attr['_input_shapes'][inputs[0]]
if attr['data_format'] == 'NHWC': if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2]) attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
...@@ -132,7 +132,7 @@ def _pooling(name): ...@@ -132,7 +132,7 @@ def _pooling(name):
raise TypeError("Unsupported data_format type : {}".format(attr['data_format'])) raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
tmp_shape = attr['_input_shapes'][inputs[0]][0] tmp_shape = attr['_input_shapes'][inputs[0]]
input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2)) inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
attr['data_format'] = "NCHW" attr['data_format'] = "NCHW"
...@@ -185,13 +185,13 @@ def _conv(opname): ...@@ -185,13 +185,13 @@ def _conv(opname):
# NCHW Layout require weights transpose # NCHW Layout require weights transpose
if attr['data_format'] == 'NCHW': if attr['data_format'] == 'NCHW':
tmp_shape = attr['_input_shapes'][inputs[1]][0] tmp_shape = attr['_input_shapes'][inputs[1]]
tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)] tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1)) inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1))
attr['_input_shapes'][inputs[1]] = [tmp_shape] attr['_input_shapes'][inputs[1]] = tmp_shape
input_shape = attr['_input_shapes'][inputs[0]][0] input_shape = attr['_input_shapes'][inputs[0]]
weights_shape = attr['_input_shapes'][inputs[1]][0] weights_shape = attr['_input_shapes'][inputs[1]]
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
...@@ -484,7 +484,7 @@ def _relu6(): ...@@ -484,7 +484,7 @@ def _relu6():
def _shape(): def _shape():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
return np.array(attr['_input_shapes'][inputs[0]][0], dtype='int32') return np.array(attr['_input_shapes'][inputs[0]], dtype='int32')
return _impl return _impl
def _fill(): def _fill():
...@@ -565,7 +565,7 @@ def _stridedSlice(): ...@@ -565,7 +565,7 @@ def _stridedSlice():
new_axis_mask = int(attr.get('new_axis_mask', 0)) new_axis_mask = int(attr.get('new_axis_mask', 0))
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0)) shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
data_shape = attr['_input_shapes'][inputs[0]] data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape[0]) data_dim = len(data_shape)
stride_dim = len(stride) stride_dim = len(stride)
def _transform_mask(stride_dim, ellipsis_mask): def _transform_mask(stride_dim, ellipsis_mask):
...@@ -596,7 +596,7 @@ def _stridedSlice(): ...@@ -596,7 +596,7 @@ def _stridedSlice():
+ new_axes_after_ellipsis), data_dim) + new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index): for i in range(final_index, to_index):
m_begin[final_index] = 0 m_begin[final_index] = 0
m_end[final_index] = data_shape[0][final_index] m_end[final_index] = data_shape[final_index]
m_stride[final_index] = 1 m_stride[final_index] = 1
fshape_indices.append(final_index) fshape_indices.append(final_index)
final_index += 1 final_index += 1
...@@ -606,19 +606,19 @@ def _stridedSlice(): ...@@ -606,19 +606,19 @@ def _stridedSlice():
if final_index == len(m_begin): if final_index == len(m_begin):
break break
if mask & begin_mask: if mask & begin_mask:
m_begin[final_index] = data_shape[0][final_index] \ m_begin[final_index] = data_shape[final_index] \
if stride[index] < 0 else 0 if stride[index] < 0 else 0
elif begin[index]: elif begin[index]:
m_begin[final_index] = begin[index] m_begin[final_index] = begin[index]
if mask & end_mask: if mask & end_mask:
m_end[final_index] = 0 if stride[index] < 0 \ m_end[final_index] = 0 if stride[index] < 0 \
else data_shape[0][final_index] else data_shape[final_index]
elif end[index]: elif end[index]:
m_end[final_index] = end[index] m_end[final_index] = end[index]
m_stride[final_index] = stride[index] m_stride[final_index] = stride[index]
if mask & shrink_axis_mask: if mask & shrink_axis_mask:
#Tensorflow make axis with shrink_axis_mask as dimension 1 #Tensorflow make axis with shrink_axis_mask as dimension 1
m_begin[final_index] = data_shape[0][final_index] + begin[index] \ m_begin[final_index] = data_shape[final_index] + begin[index] \
if begin[index] < 0 else begin[index] if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1 m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1 m_stride[final_index] = 1
...@@ -684,8 +684,8 @@ def _LSTMBlockCell(): ...@@ -684,8 +684,8 @@ def _LSTMBlockCell():
forget_bias = attr.pop('forget_bias') forget_bias = attr.pop('forget_bias')
input_shape = attr['_input_shapes'][inputs[0]] input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]] weight_shape = attr['_input_shapes'][inputs[3]]
batch_size, input_size = input_shape[0][0], input_shape[0][1] batch_size, input_size = input_shape[0], input_shape[1]
num_hidden_layers = weight_shape[0][1] num_hidden_layers = weight_shape[1]
num_hidden = num_hidden_layers // 4 num_hidden = num_hidden_layers // 4
in_data = _sym.reshape(in_data, in_data = _sym.reshape(in_data,
...@@ -741,11 +741,10 @@ def _transpose(): ...@@ -741,11 +741,10 @@ def _transpose():
def _rank(): def _rank():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
input_shapes = attr['_input_shapes'][inputs[0]] input_shape = attr['_input_shapes'][inputs[0]]
assert len(inputs) == 1
name = attr["_node_name"] name = attr["_node_name"]
params[name] = tvm.nd.array([len(input_shapes[0])]) params[name] = tvm.nd.array([len(input_shape)])
return _sym.Variable(name=name, shape=params[name].shape) return _sym.Variable(name=name, shape=params[name].shape)
return _impl return _impl
...@@ -829,7 +828,7 @@ def _unpack(): ...@@ -829,7 +828,7 @@ def _unpack():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
input_node = inputs[0] input_node = inputs[0]
axis = attr['axis'] axis = attr['axis']
input_shape = attr['_input_shapes'][input_node][0] input_shape = attr['_input_shapes'][input_node]
axis_length = input_shape[axis] axis_length = input_shape[axis]
if axis_length < 0: if axis_length < 0:
raise TypeError("Unstack with unknown axis length") raise TypeError("Unstack with unknown axis length")
...@@ -1018,8 +1017,8 @@ class RecurrentNetworks(object): ...@@ -1018,8 +1017,8 @@ class RecurrentNetworks(object):
"""LSTM cell warapper to prepare the inputs""" """LSTM cell warapper to prepare the inputs"""
input_shape = attr['_input_shapes'][inputs[0]] input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]] weight_shape = attr['_input_shapes'][inputs[3]]
batch_size = input_shape[0][0] batch_size = input_shape[0]
num_hidden = weight_shape[0][1] // 4 num_hidden = weight_shape[1] // 4
if layer == 0: if layer == 0:
#Create initial states placeholder in case of first layer #Create initial states placeholder in case of first layer
...@@ -1240,7 +1239,7 @@ class GraphProto(object): ...@@ -1240,7 +1239,7 @@ class GraphProto(object):
tensor_slot = 0 tensor_slot = 0
input_shape = self._output_shapes[node_name][0] input_shape = self._output_shapes[node_name][0]
inputs.append(in_sym) inputs.append(in_sym)
input_shapes[in_sym] = [input_shape] input_shapes[in_sym] = input_shape
# This means the node is 1d in NNVM and 0d in TF. # This means the node is 1d in NNVM and 0d in TF.
# See `_expand_dims_0d_aware`. # See `_expand_dims_0d_aware`.
if self._outputs_are_0d[node_name][tensor_slot] and input_shape: if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment