Commit 0bbdad4b by Siyuan Li Committed by Yao Wang

[Relay][Frontend][TF] Fix slice when begin or size is not Const (#4372)

* fix slice bug when input is param

* use _infer_value rather than _infer_value_simulated
parent 786d7998
...@@ -626,8 +626,14 @@ def _tile(): ...@@ -626,8 +626,14 @@ def _tile():
def _slice(): def _slice():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
begin = _get_list_param(params, inputs[1]) try:
size = _get_list_param(params, inputs[2]) begin = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
begin = _infer_value(inputs[1], params).asnumpy().tolist()[0]
try:
size = _get_list_param(params, inputs[2])
except (IndexError, KeyError, AttributeError):
size = _infer_value(inputs[2], params).asnumpy().tolist()[0]
data_shape = attr['_input_shapes'][inputs[0]] data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape) data_dim = len(data_shape)
end = size end = size
......
...@@ -2188,6 +2188,20 @@ def test_forward_transpose(): ...@@ -2188,6 +2188,20 @@ def test_forward_transpose():
_test_forward_tranapose_axes_input((2, 3, 4, 5), (3, 0, 1, 2)) _test_forward_tranapose_axes_input((2, 3, 4, 5), (3, 0, 1, 2))
def _test_forward_slice_operation_input(input_value, begin_value, size_value):
input_data = np.array(input_value, dtype=np.float32)
with tf.Graph().as_default():
input_tensor = tf.placeholder(
shape=input_data.shape, dtype=input_data.dtype, name="input")
begin_tensor = tf.expand_dims(begin_value, axis=0)
size_tensor = tf.expand_dims(size_value, axis=0)
slice_tensor = tf.slice(input_tensor, begin_tensor, size_tensor, name='slice_output')
compare_tf_with_tvm([input_data], ['input:0'], 'slice_output:0')
def test_forward_slice():
_test_forward_slice_operation_input([1, 1], 0, 2)
def test_forward_ceil(): def test_forward_ceil():
ishape = (1, 3, 10, 10) ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(size=ishape).astype(np.float32) inp_array = np.random.uniform(size=ishape).astype(np.float32)
...@@ -2760,8 +2774,8 @@ def test_forward_add_n(): ...@@ -2760,8 +2774,8 @@ def test_forward_add_n():
# Main # Main
# ---- # ----
if __name__ == '__main__': if __name__ == '__main__':
# Transforms # Transforms
test_forward_slice()
test_forward_transpose() test_forward_transpose()
test_forward_reshape() test_forward_reshape()
test_forward_depthtospace() test_forward_depthtospace()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment