Unverified Commit 4b27cd14 by Yao Wang Committed by GitHub

[Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array (#5243)

* Support TF Frontend Static TensorArray

* Fix pylint

* Fix lint

* Move get_tensor_array_shape into prelude

* Fix lint

* Fix common
parent b72dd9d9
......@@ -456,22 +456,20 @@ def get_name(node):
def infer_type(node, mod=None):
"""A method to infer the type of an intermediate node in the relay graph."""
new_mod = IRModule.from_expr(node)
if mod is not None:
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
return entry if isinstance(node, _function.Function) else entry.body
if isinstance(mod, IRModule):
mod["main"] = _function.Function([], node)
mod = _transform.InferType()(mod)
entry = mod["main"]
ret = entry.body
else:
new_mod = IRModule.from_expr(node)
if mod is not None:
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
ret = entry if isinstance(node, _function.Function) else entry.body
def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
if hasattr(checked_type, 'shape'):
# Regular operator that outputs tensors
return get_const_tuple(out_type.checked_type.shape)
# The return type is not a tensor, for example List
return checked_type
return ret
def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
......@@ -483,6 +481,17 @@ def infer_channels(inputs, transpose=False):
return channels
def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
if hasattr(checked_type, 'shape'):
# Regular operator that outputs tensors
return get_const_tuple(checked_type.shape)
# The return type is not a tensor, for example List
return checked_type
def infer_value(input_val, params, mod=None):
"""A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that
......@@ -505,7 +514,7 @@ def infer_value(input_val, params, mod=None):
return m.get_output(0)
except Exception:
if isinstance(mod, IRModule):
mod["main"] = _expr.Function(analysis.free_vars(input_val), input_val)
mod["main"] = _function.Function(analysis.free_vars(input_val), input_val)
else:
mod = IRModule.from_expr(input_val)
exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
......
......@@ -16,7 +16,7 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions."""
from tvm.ir import IRModule
from tvm.ir import IRModule, TypeCall
from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, GlobalVar, If, const
......@@ -24,8 +24,51 @@ from .function import Function
from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
from . import op
from . import op, transform
def get_tensor_array_shape(expr, dtype, prelude):
"""Get the static shape of a tensor array if it has fixed rank shape.
By design, static ADT tensor in TVM has type name in the format
of static_tensor_dim0_dim1_..._dimN_t.
Parameters
----------
expr : Relay Expr
Input expression.
dtype : str
Data type.
prelude : Prelude
Tensor array prelude
Returns
-------
shape : tuple of (int, Any) or None
The output shape. None if input tensor array
has dynamic shape.
"""
mod = prelude.mod
mod["main"] = Function([], expr)
mod = transform.InferType()(mod)
checked_type = mod["main"].body.checked_type
assert isinstance(checked_type, TypeCall), "Input must be a tensor array."
ta_type_str = checked_type.args[0].func.name_hint
static_ta_ty_start = "static_tensor_{}".format(dtype)
if ta_type_str.startswith(static_ta_ty_start):
shape_str = ta_type_str.replace("{}_".format(static_ta_ty_start), '') \
.replace("_t", '')
shape = []
if "scalar" not in shape_str:
for dim_str in shape_str.split("_"):
if dim_str == "?":
shape.append(Any())
else:
shape.append(int(dim_str))
return tuple(shape)
return None
def _get_name_static(canonical, dtype, shape):
"""Get name for static shape tensor array op corresponding
......
......@@ -839,63 +839,75 @@ def test_forward_squeeze():
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1])
def test_tensor_array_constructor():
def run(dtype_str):
#######################################################################
# TensorArray
# -----------
def test_tensor_array_write_read():
def run(dtype_str, infer_shape, element_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
ta2 = ta1.write(0, t)
np_data = np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str)
in_data = [np_data, np_data]
t1 = tf.constant(np_data, dtype=dtype)
t2 = tf.constant(np_data, dtype=dtype)
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape,
element_shape=element_shape)
ta2 = ta1.write(0, t1)
ta3 = ta2.write(1, t2)
out = ta3.read(0)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
for dtype in tf_dtypes.keys():
run(dtype)
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='vm')
for dtype in ["float32", "int8"]:
run(dtype, False, None)
run(dtype, False, tf.TensorShape([None, 2]))
run(dtype, True, None)
def test_tensor_array_scatter():
def run(dtype_str):
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
indices = tf.constant([2, 1, 0])
ta1 = tf.TensorArray(dtype=dtype, size=3,
infer_shape=False, dynamic_size=False)
infer_shape=infer_shape)
ta2 = ta1.scatter(indices, t)
out0 = ta2.read(0)
out1 = ta2.read(1)
out2 = ta2.read(2)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
for dtype in tf_dtypes.keys():
run(dtype)
# TODO(wweic): Fix gather issue with PartialEvaluate
# def test_tensor_array_gather():
# with tf.Graph().as_default():
# dtype = 'float32'
# t = tf.constant([[1.0], [2.0], [3.0]])
# scatter_indices = tf.constant([2, 1, 0])
# gather_indices = tf.constant([1, 2])
# ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False)
# ta2 = ta1.scatter(scatter_indices, t)
# t1 = ta2.gather(gather_indices)
# g = tf.get_default_graph()
# compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='vm')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='vm')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='vm')
for dtype in ["float32", "int8"]:
run(dtype, False)
run(dtype, True)
def test_tensor_array_gather():
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str))
scatter_indices = tf.constant([2, 1, 0])
gather_indices = tf.constant([1, 2])
ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=infer_shape)
ta2 = ta1.scatter(scatter_indices, t)
t1 = ta2.gather(gather_indices)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='vm')
for dtype in ["float32", "int8"]:
run(dtype, True)
def test_tensor_array_split():
def run(dtype_str):
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4,
infer_shape=False, dynamic_size=False)
ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=infer_shape)
ta2 = ta1.split(t, split_length)
out0 = ta2.read(0)
out1 = ta2.read(1)
......@@ -906,56 +918,76 @@ def test_tensor_array_split():
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug')
for dtype in tf_dtypes.keys():
run(dtype)
for dtype in ["float32", "int8"]:
run(dtype, False)
run(dtype, True)
def test_tensor_array_concat():
def run(dtype_str):
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4,
infer_shape=False, dynamic_size=False)
infer_shape=infer_shape)
ta2 = ta1.split(t, split_length)
t = ta2.concat()
out = tf.identity(t)
compare_tf_with_tvm([], [], ['Identity:0'], mode='debug')
for dtype in tf_dtypes.keys():
run(dtype)
for dtype in ["float32", "int8"]:
run(dtype, False)
run(dtype, True)
def test_tensor_array_size():
def run(dtype_str):
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape)
out = ta1.size()
g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
for dtype in tf_dtypes.keys():
run(dtype)
for dtype in ["float32", "int8"]:
run(dtype, False)
run(dtype, True)
def test_tensor_array_stack():
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str))
scatter_indices = tf.constant([2, 1, 0])
ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=infer_shape)
ta2 = ta1.scatter(scatter_indices, t)
t1 = ta2.stack()
print(t1)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayStack/TensorArrayGatherV3:0'], mode='vm')
for dtype in ["float32", "int8"]:
run(dtype, True)
def test_tensor_array_unstack():
def run(dtype_str, input_shape):
def run(dtype_str, input_shape, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.random.choice([0, 1, 2, 3],
size=input_shape).astype(dtype.name))
ta1 = tf.TensorArray(dtype=dtype, infer_shape=False, size=input_shape[0])
ta1 = tf.TensorArray(dtype=dtype, infer_shape=infer_shape, size=input_shape[0])
ta2 = ta1.unstack(t)
out0 = ta2.size()
out1 = ta2.read(0)
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
for dtype in tf_dtypes.keys():
run(dtype, (5,))
run(dtype, (5, 5))
run(dtype, (5, 5, 5))
run(dtype, (5, 5, 5, 5))
run(dtype, (5, 5, 5, 5, 5))
run(dtype, (5, 5, 5, 5, 5, 5))
for dtype in ["float32", "int8"]:
run(dtype, (5,), False)
run(dtype, (5, 5), True)
run(dtype, (5, 5, 5), False)
run(dtype, (5, 5, 5, 5), True)
#######################################################################
# ConcatV2
......@@ -3241,6 +3273,16 @@ if __name__ == '__main__':
test_forward_reduce()
test_forward_mean()
# TensorArray
test_tensor_array_write_read()
test_tensor_array_concat()
test_tensor_array_scatter()
test_tensor_array_gather()
test_tensor_array_size()
test_tensor_array_split()
test_tensor_array_stack()
test_tensor_array_unstack()
# General
test_forward_multi_input()
test_forward_multi_output()
......
......@@ -166,12 +166,14 @@ def get_const_tuple(in_tuple):
"""
ret = []
for elem in in_tuple:
if isinstance(elem, tvm.tir.Var):
if isinstance(elem, (tvm.tir.Var, tvm.tir.expr.Any)):
ret.append(elem)
elif not isinstance(elem, (tvm.tir.IntImm, int)):
elem = tvm.tir.ir_pass.Simplify(elem)
if not isinstance(elem, tvm.tir.IntImm):
ret.append(elem)
else:
ret.append(get_const_int(elem))
else:
ret.append(get_const_int(elem))
return tuple(ret)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment