Commit 036294c9 by lixiaoquan Committed by Yizhi Liu

[Relay][TensorFlow] Remove 'input_0d_mismatch' special handling (#3087)

* [Relay][TensorFlow] Remove 'input_0d_mismatch' special handling

* Add more tests.

* Cover the case that strided_slice outputs a scalar
parent 6a956fbc
......@@ -99,7 +99,6 @@ class AttrCvt(object):
self._ignores.append('_node_name')
self._ignores.append('is_training')
self._ignores.append('_target_layout')
self._ignores.append('_input_0d_mismatch')
# apply custom check
if self._custom_check:
......@@ -458,9 +457,9 @@ def _cast():
def _expand_dims():
def _impl(inputs, attr, params):
dim_input = inputs.pop(1)
axis = params[dim_input.name_hint]
params.pop(dim_input.name_hint)
return _expand_dims_0d_aware(inputs[0], attr, axis=axis.asnumpy()[0])
axis = params.pop(_get_name_hint(dim_input)).asnumpy()[0]
return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
extras={'axis': int(axis), 'num_newaxis': 1})(inputs, attr)
return _impl
def _resize_bilinear():
......@@ -528,7 +527,7 @@ def _concat():
def _pack():
def _impl(inputs, attr, params):
axis = int(attr["axis"])
inputs_reshaped = [_expand_dims_0d_aware(i, attr, axis=axis, num_newaxis=1) for i in inputs]
inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
return _op.concatenate(inputs_reshaped, axis)
return _impl
......@@ -820,9 +819,9 @@ def _stridedSlice():
pass
else:
final_output.append(out_shape[gather_index])
# Prevent 0-dim tensors which are not accepted by Relay
if not final_output:
final_output.append(1)
return out
return _op.reshape(out, newshape=tuple(final_output))
return _impl
......@@ -984,16 +983,6 @@ def _unpack():
for split_item in splitted]), len(splitted))
return _impl
def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
if data in attr['_input_0d_mismatch']:
return data if num_newaxis == 1 else \
AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
extras={'axis': int(axis), 'num_newaxis': int(num_newaxis-1)})([data], attr)
return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
extras={'axis': int(axis), 'num_newaxis': int(num_newaxis)})([data], attr)
def _softmax():
def _impl(inputs, attr, params):
return AttrCvt(op_name='softmax',
......@@ -1647,7 +1636,6 @@ class GraphProto(object):
self._output_shapes = {}
self._num_param = 0
self._num_rnn_layer = False
self._outputs_are_0d = {}
self._input_shapes = {}
self._loops = {}
self._branches = {}
......@@ -1737,7 +1725,6 @@ class GraphProto(object):
# Operator name 'Const' is treated as a parameter to build params dict.
input_shapes = {}
input_0d_mismatch = set()
attr = self._parse_attr(node.attr)
# Variable converted to Const will not have only value attr
......@@ -1753,10 +1740,6 @@ class GraphProto(object):
# Will infer shapes if the graph is not frozen with add_shapes=True
self._output_shapes[node.name] = [None]
self._outputs_are_0d[node.name] = [ \
not shape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]]
if node.op == "Const":
# All Const nodes are Param nodes, lets parse
self._num_param += 1
......@@ -1810,14 +1793,8 @@ class GraphProto(object):
input_shape = self._output_shapes[node_name][0]
inputs.append(in_sym[0])
input_shapes[in_sym[0]] = input_shape
# This means the node is 1d in Relay and 0d in TF.
# See `_expand_dims_0d_aware`.
if node_name in self._outputs_are_0d \
and self._outputs_are_0d[node_name][tensor_slot] and input_shape:
input_0d_mismatch.add(in_sym[0])
attr['_input_shapes'] = input_shapes
attr['_input_0d_mismatch'] = input_0d_mismatch
if node.op in _control_flow_nodes:
op = self._convert_control_flow_operator(node, inputs,
......
......@@ -580,6 +580,7 @@ def _test_stridedslice(ip_shape, begin, end, stride, dtype,
def test_forward_stridedslice():
'''test StridedSlice'''
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], 'float32', ellipsis_mask=8)
_test_stridedslice((3, 4, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2)
......@@ -1475,6 +1476,21 @@ def test_forward_rel_ops():
_test_forward_rel_op([t1, t2], math_ops.equal)
_test_forward_rel_op([t1, t2], math_ops.not_equal)
#######################################################################
# ExpandDims
# ----------
def _test_forward_expand_dims(data, axis):
in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name='in1')
out = tf.expand_dims(in1, axis)
compare_tf_with_tvm([data], [in1.name], out.name)
def test_forward_expand_dims():
_test_forward_expand_dims(np.int32(1), 0)
_test_forward_expand_dims(np.array([1]), 0)
_test_forward_expand_dims(np.array([1]), -1)
_test_forward_expand_dims(np.array([[1], [2]]), 0)
_test_forward_expand_dims(np.array([[1], [2]]), 1)
_test_forward_expand_dims(np.array([[1], [2]]), -1)
#######################################################################
# Main
......@@ -1509,6 +1525,7 @@ if __name__ == '__main__':
test_forward_reverse_v2()
test_forward_pow_exp()
test_forward_sign()
test_forward_expand_dims()
# Reductions
test_forward_argminmax()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment