Commit c1157ecf by Siva Committed by Tianqi Chen

[FRONTEND][TENSORFLOW] Enable strided_slice with fix. (#2002)

parent 77869913
...@@ -569,6 +569,7 @@ def _stridedSlice(): ...@@ -569,6 +569,7 @@ def _stridedSlice():
m_begin = [0] * data_dim m_begin = [0] * data_dim
m_end = [0] * data_dim m_end = [0] * data_dim
m_stride = [0] * data_dim m_stride = [0] * data_dim
fshape_indices = []
#Count new axis after ellipsis_mask, consider while applying ellipsis_mask. #Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
ellipsis_seen = False ellipsis_seen = False
new_axes_after_ellipsis = 0 new_axes_after_ellipsis = 0
...@@ -593,7 +594,10 @@ def _stridedSlice(): ...@@ -593,7 +594,10 @@ def _stridedSlice():
m_begin[final_index] = 0 m_begin[final_index] = 0
m_end[final_index] = data_shape[0][final_index] m_end[final_index] = data_shape[0][final_index]
m_stride[final_index] = 1 m_stride[final_index] = 1
fshape_indices.append(final_index)
final_index += 1 final_index += 1
elif mask &new_axis_mask:
fshape_indices.append(-1)
elif not mask & new_axis_mask: elif not mask & new_axis_mask:
if final_index == len(m_begin): if final_index == len(m_begin):
break break
...@@ -614,28 +618,30 @@ def _stridedSlice(): ...@@ -614,28 +618,30 @@ def _stridedSlice():
if begin[index] < 0 else begin[index] if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1 m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1 m_stride[final_index] = 1
fshape_indices.append(-2)
else:
fshape_indices.append(final_index)
final_index += 1 final_index += 1
return m_begin, m_end, m_stride return m_begin, m_end, m_stride, fshape_indices
fshape_indices = None
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride = _transform_mask(stride_dim, ellipsis_mask) begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _sym.strided_slice(inputs[0], begin=begin, end=end, stride=stride) out = _sym.strided_slice(inputs[0], begin=begin, end=end, stride=stride)
out_shape = _infer_out_shapes(out, params)[0] out_shape = _infer_out_shapes(out, params)[0]
if not fshape_indices:
fshape_indices = range(len(out_shape))
#Create final output shape. #Create final output shape.
final_output = [] final_output = []
out_index = 0 for gather_index in fshape_indices:
index = 0 if gather_index == -1:
while out_index != len(out_shape):
#axis with shrink_axis_mask dimension=1 and it is ignored.
mask = 1 << index
if (new_axis_mask & mask) and not ellipsis_mask & mask:
final_output.append(1) final_output.append(1)
elif (not mask & shrink_axis_mask) or index >= stride_dim: elif gather_index == -2:
#Shrink is considered till stride_dim pass
final_output.append(out_shape[out_index]) else:
out_index += 1 final_output.append(out_shape[gather_index])
index += 1
return _sym.reshape(out, shape=tuple(final_output)) return _sym.reshape(out, shape=tuple(final_output))
return _impl return _impl
......
...@@ -435,11 +435,15 @@ def _test_stridedslice(ip_shape, begin, end, stride, dtype, ...@@ -435,11 +435,15 @@ def _test_stridedslice(ip_shape, begin, end, stride, dtype,
def test_forward_stridedslice(): def test_forward_stridedslice():
'''test StridedSlice''' '''test StridedSlice'''
return
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32') _test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], 'float32', ellipsis_mask=8) _test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], 'float32', ellipsis_mask=8)
_test_stridedslice((3, 4, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2)
_test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2)
_test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [2, 1, 1], 'float32', ellipsis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], 'float32', new_axis_mask=5) _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], 'float32', new_axis_mask=5)
_test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=4) _test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=4)
_test_stridedslice((6, 4, 5), [1, 1, 1], [6, 3, 4], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=5)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=4, new_axis_mask=2) _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=4, new_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3) _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3) _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3)
...@@ -1056,7 +1060,7 @@ if __name__ == '__main__': ...@@ -1056,7 +1060,7 @@ if __name__ == '__main__':
test_forward_resize_bilinear() test_forward_resize_bilinear()
test_forward_pad() test_forward_pad()
test_forward_gather() test_forward_gather()
#test_forward_stridedslice() test_forward_stridedslice()
# Activations # Activations
test_forward_sigmoid() test_forward_sigmoid()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment