Commit c1157ecf by Siva Committed by Tianqi Chen

[FRONTEND][TENSORFLOW] Enable strided_slice with fix. (#2002)

parent 77869913
......@@ -569,6 +569,7 @@ def _stridedSlice():
m_begin = [0] * data_dim
m_end = [0] * data_dim
m_stride = [0] * data_dim
fshape_indices = []
#Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
ellipsis_seen = False
new_axes_after_ellipsis = 0
......@@ -593,7 +594,10 @@ def _stridedSlice():
m_begin[final_index] = 0
m_end[final_index] = data_shape[0][final_index]
m_stride[final_index] = 1
fshape_indices.append(final_index)
final_index += 1
elif mask &new_axis_mask:
fshape_indices.append(-1)
elif not mask & new_axis_mask:
if final_index == len(m_begin):
break
......@@ -614,28 +618,30 @@ def _stridedSlice():
if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1
fshape_indices.append(-2)
else:
fshape_indices.append(final_index)
final_index += 1
return m_begin, m_end, m_stride
return m_begin, m_end, m_stride, fshape_indices
fshape_indices = None
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride = _transform_mask(stride_dim, ellipsis_mask)
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _sym.strided_slice(inputs[0], begin=begin, end=end, stride=stride)
out_shape = _infer_out_shapes(out, params)[0]
if not fshape_indices:
fshape_indices = range(len(out_shape))
#Create final output shape.
final_output = []
out_index = 0
index = 0
while out_index != len(out_shape):
#axis with shrink_axis_mask dimension=1 and it is ignored.
mask = 1 << index
if (new_axis_mask & mask) and not ellipsis_mask & mask:
for gather_index in fshape_indices:
if gather_index == -1:
final_output.append(1)
elif (not mask & shrink_axis_mask) or index >= stride_dim:
#Shrink is considered till stride_dim
final_output.append(out_shape[out_index])
out_index += 1
index += 1
elif gather_index == -2:
pass
else:
final_output.append(out_shape[gather_index])
return _sym.reshape(out, shape=tuple(final_output))
return _impl
......
......@@ -435,11 +435,15 @@ def _test_stridedslice(ip_shape, begin, end, stride, dtype,
def test_forward_stridedslice():
'''test StridedSlice'''
return
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], 'float32', ellipsis_mask=8)
_test_stridedslice((3, 4, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2)
_test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2)
_test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [2, 1, 1], 'float32', ellipsis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], 'float32', new_axis_mask=5)
_test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=4)
_test_stridedslice((6, 4, 5), [1, 1, 1], [6, 3, 4], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=5)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=4, new_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3)
......@@ -1056,7 +1060,7 @@ if __name__ == '__main__':
test_forward_resize_bilinear()
test_forward_pad()
test_forward_gather()
#test_forward_stridedslice()
test_forward_stridedslice()
# Activations
test_forward_sigmoid()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment