Unverified Commit 4b27cd14 by Yao Wang Committed by GitHub

[Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array (#5243)

* Support TF Frontend Static TensorArray

* Fix pylint

* Fix lint

* Move get_tensor_array_shape into prelude

* Fix lint

* Fix common
parent b72dd9d9
...@@ -456,22 +456,20 @@ def get_name(node): ...@@ -456,22 +456,20 @@ def get_name(node):
def infer_type(node, mod=None): def infer_type(node, mod=None):
"""A method to infer the type of an intermediate node in the relay graph.""" """A method to infer the type of an intermediate node in the relay graph."""
new_mod = IRModule.from_expr(node) if isinstance(mod, IRModule):
if mod is not None: mod["main"] = _function.Function([], node)
new_mod.update(mod) mod = _transform.InferType()(mod)
new_mod = _transform.InferType()(new_mod) entry = mod["main"]
entry = new_mod["main"] ret = entry.body
return entry if isinstance(node, _function.Function) else entry.body else:
new_mod = IRModule.from_expr(node)
if mod is not None:
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
ret = entry if isinstance(node, _function.Function) else entry.body
def infer_shape(inputs, mod=None): return ret
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
if hasattr(checked_type, 'shape'):
# Regular operator that outputs tensors
return get_const_tuple(out_type.checked_type.shape)
# The return type is not a tensor, for example List
return checked_type
def infer_channels(inputs, transpose=False): def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide """A hack for getting 'channels' or 'units' since caffe2 does not provide
...@@ -483,6 +481,17 @@ def infer_channels(inputs, transpose=False): ...@@ -483,6 +481,17 @@ def infer_channels(inputs, transpose=False):
return channels return channels
def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
if hasattr(checked_type, 'shape'):
# Regular operator that outputs tensors
return get_const_tuple(checked_type.shape)
# The return type is not a tensor, for example List
return checked_type
def infer_value(input_val, params, mod=None): def infer_value(input_val, params, mod=None):
"""A hack for getting the value of an expression by evaluating a """A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that portion of the relay graph. This is often needed for functions that
...@@ -505,7 +514,7 @@ def infer_value(input_val, params, mod=None): ...@@ -505,7 +514,7 @@ def infer_value(input_val, params, mod=None):
return m.get_output(0) return m.get_output(0)
except Exception: except Exception:
if isinstance(mod, IRModule): if isinstance(mod, IRModule):
mod["main"] = _expr.Function(analysis.free_vars(input_val), input_val) mod["main"] = _function.Function(analysis.free_vars(input_val), input_val)
else: else:
mod = IRModule.from_expr(input_val) mod = IRModule.from_expr(input_val)
exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm") exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions.""" """A prelude containing useful global functions and ADT definitions."""
from tvm.ir import IRModule from tvm.ir import IRModule, TypeCall
from .ty import GlobalTypeVar, TensorType, Any, scalar_type from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, GlobalVar, If, const from .expr import Var, GlobalVar, If, const
...@@ -24,8 +24,51 @@ from .function import Function ...@@ -24,8 +24,51 @@ from .function import Function
from .op.tensor import add, subtract, equal from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard from .adt import PatternConstructor, PatternVar, PatternWildcard
from . import op from . import op, transform
def get_tensor_array_shape(expr, dtype, prelude):
"""Get the static shape of a tensor array if it has fixed rank shape.
By design, static ADT tensor in TVM has type name in the format
of static_tensor_dim0_dim1_..._dimN_t.
Parameters
----------
expr : Relay Expr
Input expression.
dtype : str
Data type.
prelude : Prelude
Tensor array prelude
Returns
-------
shape : tuple of (int, Any) or None
The output shape. None if input tensor array
has dynamic shape.
"""
mod = prelude.mod
mod["main"] = Function([], expr)
mod = transform.InferType()(mod)
checked_type = mod["main"].body.checked_type
assert isinstance(checked_type, TypeCall), "Input must be a tensor array."
ta_type_str = checked_type.args[0].func.name_hint
static_ta_ty_start = "static_tensor_{}".format(dtype)
if ta_type_str.startswith(static_ta_ty_start):
shape_str = ta_type_str.replace("{}_".format(static_ta_ty_start), '') \
.replace("_t", '')
shape = []
if "scalar" not in shape_str:
for dim_str in shape_str.split("_"):
if dim_str == "?":
shape.append(Any())
else:
shape.append(int(dim_str))
return tuple(shape)
return None
def _get_name_static(canonical, dtype, shape): def _get_name_static(canonical, dtype, shape):
"""Get name for static shape tensor array op corresponding """Get name for static shape tensor array op corresponding
......
...@@ -839,63 +839,75 @@ def test_forward_squeeze(): ...@@ -839,63 +839,75 @@ def test_forward_squeeze():
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1]) _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1])
def test_tensor_array_constructor(): #######################################################################
def run(dtype_str): # TensorArray
# -----------
def test_tensor_array_write_read():
def run(dtype_str, infer_shape, element_shape):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str] dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype) np_data = np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str)
t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype) in_data = [np_data, np_data]
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) t1 = tf.constant(np_data, dtype=dtype)
ta2 = ta1.write(0, t) t2 = tf.constant(np_data, dtype=dtype)
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape,
element_shape=element_shape)
ta2 = ta1.write(0, t1)
ta3 = ta2.write(1, t2) ta3 = ta2.write(1, t2)
out = ta3.read(0) out = ta3.read(0)
g = tf.get_default_graph() g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug') compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='vm')
for dtype in tf_dtypes.keys():
run(dtype) for dtype in ["float32", "int8"]:
run(dtype, False, None)
run(dtype, False, tf.TensorShape([None, 2]))
run(dtype, True, None)
def test_tensor_array_scatter(): def test_tensor_array_scatter():
def run(dtype_str): def run(dtype_str, infer_shape):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str] dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype) t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
indices = tf.constant([2, 1, 0]) indices = tf.constant([2, 1, 0])
ta1 = tf.TensorArray(dtype=dtype, size=3, ta1 = tf.TensorArray(dtype=dtype, size=3,
infer_shape=False, dynamic_size=False) infer_shape=infer_shape)
ta2 = ta1.scatter(indices, t) ta2 = ta1.scatter(indices, t)
out0 = ta2.read(0) out0 = ta2.read(0)
out1 = ta2.read(1) out1 = ta2.read(1)
out2 = ta2.read(2) out2 = ta2.read(2)
g = tf.get_default_graph() g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='vm')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='vm')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='vm')
for dtype in tf_dtypes.keys(): for dtype in ["float32", "int8"]:
run(dtype) run(dtype, False)
run(dtype, True)
# TODO(wweic): Fix gather issue with PartialEvaluate
# def test_tensor_array_gather():
# with tf.Graph().as_default(): def test_tensor_array_gather():
# dtype = 'float32' def run(dtype_str, infer_shape):
# t = tf.constant([[1.0], [2.0], [3.0]]) with tf.Graph().as_default():
# scatter_indices = tf.constant([2, 1, 0]) dtype = tf_dtypes[dtype_str]
# gather_indices = tf.constant([1, 2]) t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str))
# ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False) scatter_indices = tf.constant([2, 1, 0])
# ta2 = ta1.scatter(scatter_indices, t) gather_indices = tf.constant([1, 2])
# t1 = ta2.gather(gather_indices) ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=infer_shape)
# g = tf.get_default_graph() ta2 = ta1.scatter(scatter_indices, t)
# compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug') t1 = ta2.gather(gather_indices)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='vm')
for dtype in ["float32", "int8"]:
run(dtype, True)
def test_tensor_array_split(): def test_tensor_array_split():
def run(dtype_str): def run(dtype_str, infer_shape):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str] dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype) t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4, ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=infer_shape)
infer_shape=False, dynamic_size=False)
ta2 = ta1.split(t, split_length) ta2 = ta1.split(t, split_length)
out0 = ta2.read(0) out0 = ta2.read(0)
out1 = ta2.read(1) out1 = ta2.read(1)
...@@ -906,56 +918,76 @@ def test_tensor_array_split(): ...@@ -906,56 +918,76 @@ def test_tensor_array_split():
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug')
for dtype in tf_dtypes.keys(): for dtype in ["float32", "int8"]:
run(dtype) run(dtype, False)
run(dtype, True)
def test_tensor_array_concat(): def test_tensor_array_concat():
def run(dtype_str): def run(dtype_str, infer_shape):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str] dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype) t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4, ta1 = tf.TensorArray(dtype=dtype, size=4,
infer_shape=False, dynamic_size=False) infer_shape=infer_shape)
ta2 = ta1.split(t, split_length) ta2 = ta1.split(t, split_length)
t = ta2.concat() t = ta2.concat()
out = tf.identity(t) out = tf.identity(t)
compare_tf_with_tvm([], [], ['Identity:0'], mode='debug') compare_tf_with_tvm([], [], ['Identity:0'], mode='debug')
for dtype in tf_dtypes.keys(): for dtype in ["float32", "int8"]:
run(dtype) run(dtype, False)
run(dtype, True)
def test_tensor_array_size(): def test_tensor_array_size():
def run(dtype_str): def run(dtype_str, infer_shape):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str] dtype = tf_dtypes[dtype_str]
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape)
out = ta1.size() out = ta1.size()
g = tf.get_default_graph() g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
for dtype in tf_dtypes.keys(): for dtype in ["float32", "int8"]:
run(dtype) run(dtype, False)
run(dtype, True)
def test_tensor_array_stack():
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str))
scatter_indices = tf.constant([2, 1, 0])
ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=infer_shape)
ta2 = ta1.scatter(scatter_indices, t)
t1 = ta2.stack()
print(t1)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayStack/TensorArrayGatherV3:0'], mode='vm')
for dtype in ["float32", "int8"]:
run(dtype, True)
def test_tensor_array_unstack(): def test_tensor_array_unstack():
def run(dtype_str, input_shape): def run(dtype_str, input_shape, infer_shape):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str] dtype = tf_dtypes[dtype_str]
t = tf.constant(np.random.choice([0, 1, 2, 3], t = tf.constant(np.random.choice([0, 1, 2, 3],
size=input_shape).astype(dtype.name)) size=input_shape).astype(dtype.name))
ta1 = tf.TensorArray(dtype=dtype, infer_shape=False, size=input_shape[0]) ta1 = tf.TensorArray(dtype=dtype, infer_shape=infer_shape, size=input_shape[0])
ta2 = ta1.unstack(t) ta2 = ta1.unstack(t)
out0 = ta2.size() out0 = ta2.size()
out1 = ta2.read(0) out1 = ta2.read(0)
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug') compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
for dtype in tf_dtypes.keys(): for dtype in ["float32", "int8"]:
run(dtype, (5,)) run(dtype, (5,), False)
run(dtype, (5, 5)) run(dtype, (5, 5), True)
run(dtype, (5, 5, 5)) run(dtype, (5, 5, 5), False)
run(dtype, (5, 5, 5, 5)) run(dtype, (5, 5, 5, 5), True)
run(dtype, (5, 5, 5, 5, 5))
run(dtype, (5, 5, 5, 5, 5, 5))
####################################################################### #######################################################################
# ConcatV2 # ConcatV2
...@@ -3241,6 +3273,16 @@ if __name__ == '__main__': ...@@ -3241,6 +3273,16 @@ if __name__ == '__main__':
test_forward_reduce() test_forward_reduce()
test_forward_mean() test_forward_mean()
# TensorArray
test_tensor_array_write_read()
test_tensor_array_concat()
test_tensor_array_scatter()
test_tensor_array_gather()
test_tensor_array_size()
test_tensor_array_split()
test_tensor_array_stack()
test_tensor_array_unstack()
# General # General
test_forward_multi_input() test_forward_multi_input()
test_forward_multi_output() test_forward_multi_output()
......
...@@ -166,12 +166,14 @@ def get_const_tuple(in_tuple): ...@@ -166,12 +166,14 @@ def get_const_tuple(in_tuple):
""" """
ret = [] ret = []
for elem in in_tuple: for elem in in_tuple:
if isinstance(elem, tvm.tir.Var): if isinstance(elem, (tvm.tir.Var, tvm.tir.expr.Any)):
ret.append(elem) ret.append(elem)
elif not isinstance(elem, (tvm.tir.IntImm, int)): elif not isinstance(elem, (tvm.tir.IntImm, int)):
elem = tvm.tir.ir_pass.Simplify(elem) elem = tvm.tir.ir_pass.Simplify(elem)
if not isinstance(elem, tvm.tir.IntImm): if not isinstance(elem, tvm.tir.IntImm):
ret.append(elem) ret.append(elem)
else:
ret.append(get_const_int(elem))
else: else:
ret.append(get_const_int(elem)) ret.append(get_const_int(elem))
return tuple(ret) return tuple(ret)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment