Commit f6c3f997 by Alexey Romanov Committed by Siva

[FRONTEND][TENSORFLOW] Use input shapes directly instead of 1-element lists (#2242)

parent 6d1f4c0b
......@@ -120,7 +120,7 @@ def _pooling(name):
attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False
input_shape = attr['_input_shapes'][inputs[0]][0]
input_shape = attr['_input_shapes'][inputs[0]]
if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
......@@ -132,7 +132,7 @@ def _pooling(name):
raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
tmp_shape = attr['_input_shapes'][inputs[0]][0]
tmp_shape = attr['_input_shapes'][inputs[0]]
input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
attr['data_format'] = "NCHW"
......@@ -185,13 +185,13 @@ def _conv(opname):
# NCHW Layout require weights transpose
if attr['data_format'] == 'NCHW':
tmp_shape = attr['_input_shapes'][inputs[1]][0]
tmp_shape = attr['_input_shapes'][inputs[1]]
tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1))
attr['_input_shapes'][inputs[1]] = [tmp_shape]
attr['_input_shapes'][inputs[1]] = tmp_shape
input_shape = attr['_input_shapes'][inputs[0]][0]
weights_shape = attr['_input_shapes'][inputs[1]][0]
input_shape = attr['_input_shapes'][inputs[0]]
weights_shape = attr['_input_shapes'][inputs[1]]
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
......@@ -484,7 +484,7 @@ def _relu6():
def _shape():
def _impl(inputs, attr, params):
return np.array(attr['_input_shapes'][inputs[0]][0], dtype='int32')
return np.array(attr['_input_shapes'][inputs[0]], dtype='int32')
return _impl
def _fill():
......@@ -565,7 +565,7 @@ def _stridedSlice():
new_axis_mask = int(attr.get('new_axis_mask', 0))
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape[0])
data_dim = len(data_shape)
stride_dim = len(stride)
def _transform_mask(stride_dim, ellipsis_mask):
......@@ -596,7 +596,7 @@ def _stridedSlice():
+ new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index):
m_begin[final_index] = 0
m_end[final_index] = data_shape[0][final_index]
m_end[final_index] = data_shape[final_index]
m_stride[final_index] = 1
fshape_indices.append(final_index)
final_index += 1
......@@ -606,19 +606,19 @@ def _stridedSlice():
if final_index == len(m_begin):
break
if mask & begin_mask:
m_begin[final_index] = data_shape[0][final_index] \
m_begin[final_index] = data_shape[final_index] \
if stride[index] < 0 else 0
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
m_end[final_index] = 0 if stride[index] < 0 \
else data_shape[0][final_index]
else data_shape[final_index]
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
if mask & shrink_axis_mask:
#Tensorflow make axis with shrink_axis_mask as dimension 1
m_begin[final_index] = data_shape[0][final_index] + begin[index] \
m_begin[final_index] = data_shape[final_index] + begin[index] \
if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1
......@@ -684,8 +684,8 @@ def _LSTMBlockCell():
forget_bias = attr.pop('forget_bias')
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size, input_size = input_shape[0][0], input_shape[0][1]
num_hidden_layers = weight_shape[0][1]
batch_size, input_size = input_shape[0], input_shape[1]
num_hidden_layers = weight_shape[1]
num_hidden = num_hidden_layers // 4
in_data = _sym.reshape(in_data,
......@@ -741,11 +741,10 @@ def _transpose():
def _rank():
def _impl(inputs, attr, params):
input_shapes = attr['_input_shapes'][inputs[0]]
assert len(inputs) == 1
input_shape = attr['_input_shapes'][inputs[0]]
name = attr["_node_name"]
params[name] = tvm.nd.array([len(input_shapes[0])])
params[name] = tvm.nd.array([len(input_shape)])
return _sym.Variable(name=name, shape=params[name].shape)
return _impl
......@@ -829,7 +828,7 @@ def _unpack():
def _impl(inputs, attr, params):
input_node = inputs[0]
axis = attr['axis']
input_shape = attr['_input_shapes'][input_node][0]
input_shape = attr['_input_shapes'][input_node]
axis_length = input_shape[axis]
if axis_length < 0:
raise TypeError("Unstack with unknown axis length")
......@@ -1018,8 +1017,8 @@ class RecurrentNetworks(object):
"""LSTM cell warapper to prepare the inputs"""
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size = input_shape[0][0]
num_hidden = weight_shape[0][1] // 4
batch_size = input_shape[0]
num_hidden = weight_shape[1] // 4
if layer == 0:
#Create initial states placeholder in case of first layer
......@@ -1240,7 +1239,7 @@ class GraphProto(object):
tensor_slot = 0
input_shape = self._output_shapes[node_name][0]
inputs.append(in_sym)
input_shapes[in_sym] = [input_shape]
input_shapes[in_sym] = input_shape
# This means the node is 1d in NNVM and 0d in TF.
# See `_expand_dims_0d_aware`.
if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment