Commit 10b77ef3 by Wei Chen Committed by Haichen Shen

[TF][Relay][Op] Pass module when infer shape (#4287)

* [TF][Relay][Op] Pass module when infer shape

* Fix lint

* Improve style

* Add test
parent f823c577
......@@ -451,20 +451,24 @@ def get_name(node):
return name
def infer_type(node):
def infer_type(node, mod=None):
"""A method to infer the type of an intermediate node in the relay graph."""
mod = node if isinstance(node, _module.Module) else _module.Module.from_expr(node)
mod = _transform.InferType()(mod)
entry = mod["main"]
new_mod = _module.Module.from_expr(node)
if mod is not None:
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body
def infer_shape(inputs):
"""A method to get the output shape of an intermediate node in the graph."""
out_type = infer_type(inputs)
out_shapes = get_const_tuple(out_type.checked_type.shape)
return out_shapes
def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
if hasattr(checked_type, 'shape'):
# Regular operator that outputs tensors
return get_const_tuple(out_type.checked_type.shape)
# The return type is not a tensor, for example List
return checked_type
def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
......
......@@ -90,6 +90,12 @@ def _get_list_param(params, input_node):
def _get_tuple_param(params, input_node):
return tuple(_get_param(params, input_node))
def _need_module_for_shape_inference(op):
return op in ['StridedSlice']
def _need_prelude_for_shape_inference(op):
return "TensorArray" in op
def _rsqrt():
def _impl(inputs, attr, params):
inputs.append(tvm.relay.const(-0.5, attr['T'].name))
......@@ -893,7 +899,7 @@ def _gather_nd():
return _impl
def _stridedSlice():
def _impl(inputs, attr, params):
def _impl(inputs, attr, params, mod):
"""Strided Slice.
Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
......@@ -976,7 +982,7 @@ def _stridedSlice():
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
out_shape = _infer_shape(out)
out_shape = _infer_shape(out, mod=mod)
if not fshape_indices:
fshape_indices = range(len(out_shape))
......@@ -2169,7 +2175,8 @@ class GraphProto(object):
# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
out_shapes = [_infer_shape(node_item) for node_item in self._nodes[node.name]]
out_shapes = [_infer_shape(node_item, self._mod)
for node_item in self._nodes[node.name]]
self._output_shapes[node.name] = out_shapes
if self._output_shapes[node.name] and shape and node.name in shape:
......@@ -2179,7 +2186,7 @@ class GraphProto(object):
node_output = self._nodes[node.name]
if shape and (not self._output_shapes[node.name][0]
or -1 in self._output_shapes[node.name][0]):
out_shapes = [_infer_shape(node_item) for node_item in node_output]
out_shapes = [_infer_shape(node_item, self._mod) for node_item in node_output]
self._output_shapes[node.name] = out_shapes
out = []
......@@ -2470,8 +2477,10 @@ class GraphProto(object):
if op_name in identity_list:
sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
if 'TensorArray' in op_name:
if _need_prelude_for_shape_inference(op_name):
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
elif _need_module_for_shape_inference(op_name):
sym = convert_map[op_name](inputs, attrs, self._params, self._mod)
else:
sym = convert_map[op_name](inputs, attrs, self._params)
......
......@@ -746,7 +746,8 @@ def test_tensor_array_concat():
infer_shape=False, dynamic_size=False)
ta2 = ta1.split(t, split_length)
t = ta2.concat()
compare_tf_with_tvm([], [], ['TensorArrayConcatV3:0'], mode='debug')
out = tf.identity(t)
compare_tf_with_tvm([], [], ['Identity:0'], mode='debug')
for dtype in tf_dtypes.keys():
run(dtype)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment