Commit 18f8581b by Yao Wang Committed by Yizhi Liu

Fix tf reshape (#4285)

* Fix tf reshape

* Fix test

* Fix pylint

* Fix pylint
parent cff62bdb
......@@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
"""TF: Tensorflow frontend."""
from __future__ import absolute_import as _abs
from __future__ import print_function
......@@ -613,22 +613,24 @@ def _reshape():
def _impl(inputs, attr, params):
pop_node = inputs.pop(1)
# We use reshape_like directly to deal with dynamic shape.
if isinstance(pop_node, tvm.relay.expr.Call):
if "shape_of" not in str(pop_node.op):
raise RuntimeError("If shape operator is used in reshape to "
"express reshape_like, shape_of must be "
"the direct ancestor of reshape when input "
"shape is symbolic.")
return _op.reshape_like(inputs[0], pop_node.args[0])
try:
shape_arg = _get_tuple_param(params, pop_node)
except AttributeError:
# Shape operator is already pruned, hence
# try to infer shape by precompute prune if possible.
try:
params_new = _infer_value(pop_node, params)
shape_arg = tuple(params_new.asnumpy().astype('int64').flatten())
except Exception:
# Deal with symbolic shape case.
# Currently only shape_of can be the direct ancestor.
if not isinstance(pop_node, tvm.relay.expr.Call) or \
"shape_of" not in str(pop_node.op):
raise RuntimeError("If shape operator is used in reshape to "
"express reshape_like, shape_of must be "
"the direct ancestor of reshape when input "
"shape is symbolic.")
return _op.reshape_like(inputs[0], pop_node.args[0])
return AttrCvt(
op_name="reshape",
extras={'newshape': shape_arg},
......
......@@ -551,6 +551,17 @@ def _test_reshape(data, out_shape):
compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0')
def _test_reshape_with_call():
""" relay.expr.Call as shape """
data = np.zeros((6, 4, 2))
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out_shape = tf.constant([1, 2, 3], dtype="int32")
out_shape = tf.multiply(out_shape, 2)
array_ops.reshape(in_data, out_shape)
compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0')
def _test_reshape_like(data, shape_like):
""" A special case for reshape. """
......@@ -567,6 +578,7 @@ def test_forward_reshape():
_test_reshape(np.arange(6), [-1, 2])
_test_reshape(np.arange(6), [3, -1])
_test_reshape(np.arange(6), [-1])
_test_reshape_with_call()
_test_reshape_like(np.zeros((3, 6)), np.zeros((9, 2)))
#######################################################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment