Commit 10b77ef3 by Wei Chen Committed by Haichen Shen

[TF][Relay][Op] Pass module when infer shape (#4287)

* [TF][Relay][Op] Pass module when infer shape

* Fix lint

* Improve style

* Add test
parent f823c577
...@@ -451,20 +451,24 @@ def get_name(node): ...@@ -451,20 +451,24 @@ def get_name(node):
return name return name
def infer_type(node): def infer_type(node, mod=None):
"""A method to infer the type of an intermediate node in the relay graph.""" """A method to infer the type of an intermediate node in the relay graph."""
mod = node if isinstance(node, _module.Module) else _module.Module.from_expr(node) new_mod = _module.Module.from_expr(node)
mod = _transform.InferType()(mod) if mod is not None:
entry = mod["main"] new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body return entry if isinstance(node, _expr.Function) else entry.body
def infer_shape(inputs, mod=None):
def infer_shape(inputs): """A method to get the output type of an intermediate node in the graph."""
"""A method to get the output shape of an intermediate node in the graph.""" out_type = infer_type(inputs, mod=mod)
out_type = infer_type(inputs) checked_type = out_type.checked_type
out_shapes = get_const_tuple(out_type.checked_type.shape) if hasattr(checked_type, 'shape'):
return out_shapes # Regular operator that outputs tensors
return get_const_tuple(out_type.checked_type.shape)
# The return type is not a tensor, for example List
return checked_type
def infer_channels(inputs, transpose=False): def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide """A hack for getting 'channels' or 'units' since caffe2 does not provide
......
...@@ -90,6 +90,12 @@ def _get_list_param(params, input_node): ...@@ -90,6 +90,12 @@ def _get_list_param(params, input_node):
def _get_tuple_param(params, input_node): def _get_tuple_param(params, input_node):
return tuple(_get_param(params, input_node)) return tuple(_get_param(params, input_node))
def _need_module_for_shape_inference(op):
return op in ['StridedSlice']
def _need_prelude_for_shape_inference(op):
return "TensorArray" in op
def _rsqrt(): def _rsqrt():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
inputs.append(tvm.relay.const(-0.5, attr['T'].name)) inputs.append(tvm.relay.const(-0.5, attr['T'].name))
...@@ -893,7 +899,7 @@ def _gather_nd(): ...@@ -893,7 +899,7 @@ def _gather_nd():
return _impl return _impl
def _stridedSlice(): def _stridedSlice():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
"""Strided Slice. """Strided Slice.
Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/ Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
...@@ -976,7 +982,7 @@ def _stridedSlice(): ...@@ -976,7 +982,7 @@ def _stridedSlice():
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
out_shape = _infer_shape(out) out_shape = _infer_shape(out, mod=mod)
if not fshape_indices: if not fshape_indices:
fshape_indices = range(len(out_shape)) fshape_indices = range(len(out_shape))
...@@ -2169,7 +2175,8 @@ class GraphProto(object): ...@@ -2169,7 +2175,8 @@ class GraphProto(object):
# Infer shapes even without specifying "add_shapes=True" # Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]: if output_shapes == [None]:
out_shapes = [_infer_shape(node_item) for node_item in self._nodes[node.name]] out_shapes = [_infer_shape(node_item, self._mod)
for node_item in self._nodes[node.name]]
self._output_shapes[node.name] = out_shapes self._output_shapes[node.name] = out_shapes
if self._output_shapes[node.name] and shape and node.name in shape: if self._output_shapes[node.name] and shape and node.name in shape:
...@@ -2179,7 +2186,7 @@ class GraphProto(object): ...@@ -2179,7 +2186,7 @@ class GraphProto(object):
node_output = self._nodes[node.name] node_output = self._nodes[node.name]
if shape and (not self._output_shapes[node.name][0] if shape and (not self._output_shapes[node.name][0]
or -1 in self._output_shapes[node.name][0]): or -1 in self._output_shapes[node.name][0]):
out_shapes = [_infer_shape(node_item) for node_item in node_output] out_shapes = [_infer_shape(node_item, self._mod) for node_item in node_output]
self._output_shapes[node.name] = out_shapes self._output_shapes[node.name] = out_shapes
out = [] out = []
...@@ -2470,8 +2477,10 @@ class GraphProto(object): ...@@ -2470,8 +2477,10 @@ class GraphProto(object):
if op_name in identity_list: if op_name in identity_list:
sym = get_relay_op(op_name)(*inputs, **attrs) sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map: elif op_name in convert_map:
if 'TensorArray' in op_name: if _need_prelude_for_shape_inference(op_name):
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude) sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
elif _need_module_for_shape_inference(op_name):
sym = convert_map[op_name](inputs, attrs, self._params, self._mod)
else: else:
sym = convert_map[op_name](inputs, attrs, self._params) sym = convert_map[op_name](inputs, attrs, self._params)
......
...@@ -746,7 +746,8 @@ def test_tensor_array_concat(): ...@@ -746,7 +746,8 @@ def test_tensor_array_concat():
infer_shape=False, dynamic_size=False) infer_shape=False, dynamic_size=False)
ta2 = ta1.split(t, split_length) ta2 = ta1.split(t, split_length)
t = ta2.concat() t = ta2.concat()
compare_tf_with_tvm([], [], ['TensorArrayConcatV3:0'], mode='debug') out = tf.identity(t)
compare_tf_with_tvm([], [], ['Identity:0'], mode='debug')
for dtype in tf_dtypes.keys(): for dtype in tf_dtypes.keys():
run(dtype) run(dtype)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment