Commit e9039d04 by Yao Wang Committed by Zhi

Support reshape for dynamic shape in tf converter (#4185)

* Support reshape for dynamic shape in tf converter

* Only allow reshape directly after shape function for symbolic input shape

* Fix lint
parent 9a3d2ec9
......@@ -612,6 +612,16 @@ def _slice():
def _reshape():
def _impl(inputs, attr, params):
pop_node = inputs.pop(1)
# We use reshape_like directly to deal with dynamic shape.
if isinstance(pop_node, tvm.relay.expr.Call):
if "shape_of" not in str(pop_node.op):
raise RuntimeError("If shape operator is used in reshape to "
"express reshape_like, shape_of must be "
"the direct ancestor of reshape when input "
"shape is symbolic.")
return _op.reshape_like(inputs[0], pop_node.args[0])
try:
shape_arg = _get_tuple_param(params, pop_node)
except AttributeError:
......@@ -788,7 +798,18 @@ def _relu6():
def _shape():
def _impl(inputs, attr, params):
return np.array(attr['_input_shapes'][inputs[0]], dtype='int32')
is_symbolic_shape = False
for axis in attr['_input_shapes'][inputs[0]]:
if not isinstance(axis, (int, tvm.expr.IntImm, tvm.expr.UIntImm)):
is_symbolic_shape = True
break
if is_symbolic_shape:
ret = _op.shape_of(inputs[0], dtype='int32')
else:
ret = np.array(attr['_input_shapes'][inputs[0]], dtype='int32')
return ret
return _impl
def _fill():
......
......@@ -543,12 +543,23 @@ def _test_reshape(data, out_shape):
compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0')
def _test_reshape_like(data, shape_like):
""" A special case for reshape. """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
in_shape_like = array_ops.placeholder(shape=shape_like.shape, dtype=data.dtype)
out_shape = array_ops.shape(in_shape_like)
array_ops.reshape(in_data, out_shape)
compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0')
def test_forward_reshape():
_test_reshape(np.arange(6.0), [2, 3])
_test_reshape(np.arange(6), [-1, 2])
_test_reshape(np.arange(6), [3, -1])
_test_reshape(np.arange(6), [-1])
_test_reshape_like(np.zeros((3, 6)), np.zeros((9, 2)))
#######################################################################
# DepthToSpace
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment