Unverified Commit 4b27cd14 by Yao Wang Committed by GitHub

[Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array (#5243)

* Support TF Frontend Static TensorArray

* Fix pylint

* Fix lint

* Move get_tensor_array_shape into prelude

* Fix lint

* Fix common
parent b72dd9d9
...@@ -456,22 +456,20 @@ def get_name(node): ...@@ -456,22 +456,20 @@ def get_name(node):
def infer_type(node, mod=None): def infer_type(node, mod=None):
"""A method to infer the type of an intermediate node in the relay graph.""" """A method to infer the type of an intermediate node in the relay graph."""
new_mod = IRModule.from_expr(node) if isinstance(mod, IRModule):
if mod is not None: mod["main"] = _function.Function([], node)
new_mod.update(mod) mod = _transform.InferType()(mod)
new_mod = _transform.InferType()(new_mod) entry = mod["main"]
entry = new_mod["main"] ret = entry.body
return entry if isinstance(node, _function.Function) else entry.body else:
new_mod = IRModule.from_expr(node)
if mod is not None:
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
ret = entry if isinstance(node, _function.Function) else entry.body
def infer_shape(inputs, mod=None): return ret
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
if hasattr(checked_type, 'shape'):
# Regular operator that outputs tensors
return get_const_tuple(out_type.checked_type.shape)
# The return type is not a tensor, for example List
return checked_type
def infer_channels(inputs, transpose=False): def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide """A hack for getting 'channels' or 'units' since caffe2 does not provide
...@@ -483,6 +481,17 @@ def infer_channels(inputs, transpose=False): ...@@ -483,6 +481,17 @@ def infer_channels(inputs, transpose=False):
return channels return channels
def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
if hasattr(checked_type, 'shape'):
# Regular operator that outputs tensors
return get_const_tuple(checked_type.shape)
# The return type is not a tensor, for example List
return checked_type
def infer_value(input_val, params, mod=None): def infer_value(input_val, params, mod=None):
"""A hack for getting the value of an expression by evaluating a """A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that portion of the relay graph. This is often needed for functions that
...@@ -505,7 +514,7 @@ def infer_value(input_val, params, mod=None): ...@@ -505,7 +514,7 @@ def infer_value(input_val, params, mod=None):
return m.get_output(0) return m.get_output(0)
except Exception: except Exception:
if isinstance(mod, IRModule): if isinstance(mod, IRModule):
mod["main"] = _expr.Function(analysis.free_vars(input_val), input_val) mod["main"] = _function.Function(analysis.free_vars(input_val), input_val)
else: else:
mod = IRModule.from_expr(input_val) mod = IRModule.from_expr(input_val)
exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm") exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
......
...@@ -26,13 +26,14 @@ import numpy as np ...@@ -26,13 +26,14 @@ import numpy as np
import tvm import tvm
from tvm.ir import IRModule from tvm.ir import IRModule
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude, StaticTensorArrayOps, get_tensor_array_shape
from tvm.ir import structural_hash as s_hash from tvm.ir import structural_hash as s_hash
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import function as _function from .. import function as _function
from .. import op as _op from .. import op as _op
from ..ty import Any
from ..expr_functor import ExprMutator, ExprVisitor from ..expr_functor import ExprMutator, ExprVisitor
from .common import AttrCvt, get_relay_op from .common import AttrCvt, get_relay_op
from .common import infer_type as _infer_type from .common import infer_type as _infer_type
...@@ -259,8 +260,6 @@ def _conv(opname): ...@@ -259,8 +260,6 @@ def _conv(opname):
if opname == 'conv_transpose' and attr['data_format'] == 'NHWC': if opname == 'conv_transpose' and attr['data_format'] == 'NHWC':
# transform to NCHW for TVM backend compatible and set 'flip_layout' # transform to NCHW for TVM backend compatible and set 'flip_layout'
# to have output flip back to NHWC # to have output flip back to NHWC
tmp_shape = _infer_shape(inputs[2], mod)
tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2)) inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2))
attr['strides'][1], attr['strides'][2], attr['strides'][3] = \ attr['strides'][1], attr['strides'][2], attr['strides'][3] = \
attr['strides'][3], attr['strides'][1], attr['strides'][2] attr['strides'][3], attr['strides'][1], attr['strides'][2]
...@@ -789,25 +788,152 @@ def _pack(): ...@@ -789,25 +788,152 @@ def _pack():
def _tensor_array(): def _tensor_array():
def _impl(inputs, attr, params, prelude): def _impl(inputs, attr, params, prelude):
try:
from tensorflow.python.framework import tensor_util
except ImportError as e:
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))
dtype_str = attr.get('dtype').name dtype_str = attr.get('dtype').name
tensor_array_constructor = prelude.get_var('tensor_array', dtype_str) assert not attr["dynamic_size"], "Dynamic size tensor array is " \
return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0))) "not supported in TVM yet."
raw_elem_shape = tensor_util.TensorShapeProtoToList(attr['element_shape'])
elem_shape = []
for dim in raw_elem_shape:
if dim < 0:
elem_shape.append(Any())
else:
elem_shape.append(dim)
if elem_shape:
# Element shape is specified.
# Directly create static tensor array with given shape.
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
elem_shape)
static_tensor_array_ops.register()
tensor_array_constructor = prelude.get_var_static('tensor_array',
dtype_str,
elem_shape)
tensor_array = tensor_array_constructor(inputs[0])
_static_tensor_array_map[tensor_array] = tensor_array
elif attr['identical_element_shapes']:
# identical_element_shapes is set but element shape is not given.
# We create a static tensor array with dummy shape and record it in
# _static_tensor_array_map. Later when creating other tensor array ops
# which uses this tensor array, we reconstruct this tensor array with
# actual shape.
dummy_shape = ()
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
dummy_shape)
static_tensor_array_ops.register()
tensor_array_constructor = prelude.get_var_static('tensor_array',
dtype_str,
dummy_shape)
tensor_array = tensor_array_constructor(inputs[0])
_static_tensor_array_map[tensor_array] = None
else:
tensor_array_constructor = prelude.get_var('tensor_array', dtype_str)
tensor_array = tensor_array_constructor(inputs[0])
return tensor_array
return _impl return _impl
def _tensor_array_scatter(): def _tensor_array_scatter():
def _impl(inputs, attr, params, prelude): def _impl(inputs, attr, params, prelude):
dtype_str = attr.get('T').name dtype_str = attr.get('T').name
values_rank = len(inputs[2].type_annotation.shape) input_ta = inputs[0]
unstack_name = "tensor_array_unstack_tensor{}".format(values_rank) input_shape = get_tensor_array_shape(input_ta, dtype_str, prelude)
unstack_function = prelude.get_var(unstack_name, dtype_str) values_shape = _infer_shape(inputs[2], prelude.mod)
values = unstack_function(inputs[2]) input_t_shape = values_shape[1:]
tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str) indices_shape = _infer_shape(inputs[1], prelude.mod)
return tensor_array_scatter_func(inputs[0], inputs[1], values)
if input_shape is None:
values_rank = len(values_shape)
unstack_name = "tensor_array_unstack_tensor{}".format(values_rank)
unstack_function = prelude.get_var(unstack_name, dtype_str)
values = unstack_function(inputs[2])
tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str)
else:
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_t_shape)
static_tensor_array_ops.register()
# For scatter operation, it is possible to write to a newly create
# tensor array. We need to check and recreate its input tensor array.
if input_ta in _static_tensor_array_map and \
_static_tensor_array_map[input_ta] is None:
ta_constructor = prelude.get_var_static('tensor_array',
dtype_str,
input_t_shape)
new_ta = ta_constructor(input_ta.args[0])
_static_tensor_array_map[input_ta] = new_ta
input_ta = new_ta
# Register static indices shape
if isinstance(indices_shape[0], int):
static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True)
tensor_array_scatter_func = prelude.get_var_static('tensor_array_scatter',
dtype_str,
input_t_shape)
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
values_shape)
static_tensor_array_ops.register()
unstack_function = prelude.get_var_static('tensor_array_unstack',
dtype_str,
values_shape)
values = unstack_function(inputs[2])
ret = tensor_array_scatter_func(input_ta, inputs[1], values)
return ret
return _impl return _impl
def _tensor_array_gather(): def _tensor_array_gather():
def _impl(inputs, attr, params, prelude): def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_gather(inputs[2], inputs[1]) dtype_str = attr.get('dtype').name
input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude)
indices_shape = _infer_shape(inputs[1], prelude.mod)
if input_shape is None:
gather_func = prelude.get_var('tensor_array_gather', dtype_str)
out = gather_func(inputs[2], inputs[1])
else:
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_shape)
static_tensor_array_ops.register()
if not isinstance(indices_shape[0], int):
gather_function = prelude.get_var_static('tensor_array_gather',
dtype_str,
input_shape)
out_tensor_t = gather_function(inputs[2], inputs[1])
# Output shape is (indices_shape[0],) + input_shape
static_tensor_array_ops.define_tensor_get_data((indices_shape[0],) + input_shape)
get_data_func = prelude.get_var_static('tensor_get_data',
dtype_str,
input_shape)
out = get_data_func(out_tensor_t)
else:
# For fixed length indices, directly generate static shape output
read_func = prelude.get_var_static('tensor_array_read',
dtype_str,
input_shape)
static_tensor_array_ops.define_tensor_get_data(input_shape)
get_data_func = prelude.get_var_static('tensor_get_data',
dtype_str,
input_shape)
tensor_list = []
for i in range(indices_shape[0]):
index = _op.take(inputs[1], tvm.relay.const(i))
out_tensor = get_data_func(read_func(inputs[2], index))
tensor_list.append(_op.expand_dims(out_tensor, axis=0))
out = _op.concatenate(tensor_list, axis=0)
return out
return _impl return _impl
def _tensor_array_size(): def _tensor_array_size():
...@@ -817,37 +943,163 @@ def _tensor_array_size(): ...@@ -817,37 +943,163 @@ def _tensor_array_size():
def _tensor_array_write(): def _tensor_array_write():
def _impl(inputs, attr, params, prelude): def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[2].type_annotation.shape) dtype_str = attr.get('T').name
dtype = attr.get('T').name input_ta = inputs[3]
input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude)
input_t_shape = _infer_shape(inputs[2], prelude.mod)
input_rank = len(input_t_shape)
if input_ta_shape is None:
tensor_name = 'tensor{}'.format(input_rank)
tensor_func = prelude.get_var(tensor_name, dtype_str)
v = tensor_func(inputs[2])
write_func = prelude.get_var('tensor_array_write', dtype_str)
else:
# For write operation, it is possible to write to a newly create
# tensor array. We need to check and recreate its input tensor array.
if input_ta in _static_tensor_array_map and \
_static_tensor_array_map[input_ta] is None:
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_t_shape)
static_tensor_array_ops.register()
ta_constructor = prelude.get_var_static('tensor_array',
dtype_str,
input_t_shape)
new_ta = ta_constructor(input_ta.args[0])
_static_tensor_array_map[input_ta] = new_ta
input_ta = new_ta
input_ta_shape = input_t_shape
else:
input_ta_rank = len(input_ta_shape)
assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \
format(input_ta_rank, input_rank)
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_ta_shape)
static_tensor_array_ops.register()
tensor_name = 'tensor{}'.format(input_rank) tensor_func = prelude.get_var_static("tensor_constructor",
tensor_func = prelude.get_var(tensor_name, dtype) dtype_str,
v = tensor_func(inputs[2]) input_ta_shape)
write_func = prelude.get_var('tensor_array_write', dtype) v = tensor_func(inputs[2])
write_func = prelude.get_var_static('tensor_array_write',
dtype_str,
input_ta_shape)
return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v) return write_func(input_ta, _op.take(inputs[1], tvm.relay.const(0)), v)
return _impl return _impl
def _tensor_array_read(): def _tensor_array_read():
def _impl(inputs, attr, params, prelude): def _impl(inputs, attr, params, prelude):
read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name) dtype_str = attr['dtype'].name
return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude)
if input_shape is None:
read_func = prelude.get_var('tensor_array_read', dtype_str)
out = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
else:
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_shape)
static_tensor_array_ops.register()
static_tensor_array_ops.define_tensor_get_data(input_shape)
read_func = prelude.get_var_static("tensor_array_read", dtype_str, input_shape)
out_tensor = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
get_data_func = prelude.get_var_static('tensor_get_data',
dtype_str,
input_shape)
out = get_data_func(out_tensor)
return out
return _impl return _impl
def _tensor_array_split(): def _tensor_array_split():
def _impl(inputs, attr, params, prelude): def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[1].type_annotation.shape)
dtype_str = attr.get('T').name dtype_str = attr.get('T').name
v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1]) input_ta = inputs[0]
input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude)
input_t_shape = _infer_shape(inputs[1], prelude.mod)
input_rank = len(input_t_shape)
lengths = _op.cast(inputs[2], 'int32') lengths = _op.cast(inputs[2], 'int32')
split_var = prelude.get_var('tensor_array_split', dtype_str) lengths_shape = _infer_shape(lengths, prelude.mod)
return split_var(inputs[0], v, lengths) value_shape = _infer_shape(inputs[1], prelude.mod)
if input_ta_shape is None:
v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1])
split_func = prelude.get_var('tensor_array_split', dtype_str)
else:
# For split operation, it is possible to write to a newly create
# tensor array. We need to check and recreate its input tensor array.
if input_ta in _static_tensor_array_map and \
_static_tensor_array_map[input_ta] is None:
input_ta_shape = (Any(),) + input_t_shape[1:]
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_ta_shape)
static_tensor_array_ops.register()
ta_constructor = prelude.get_var_static('tensor_array',
dtype_str,
input_ta_shape)
new_ta = ta_constructor(input_ta.args[0])
_static_tensor_array_map[input_ta] = new_ta
input_ta = new_ta
else:
input_ta_rank = len(input_ta_shape)
assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \
format(input_ta_rank, input_rank)
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_ta_shape)
static_tensor_array_ops.register()
# Check static value/indices shape
if isinstance(value_shape[0], int) or isinstance(lengths_shape[0], int):
static_tensor_array_ops.define_tensor_array_split(value_shape,
lengths_shape,
True)
tensor_func_name = prelude.get_name_static("tensor_constructor",
dtype_str,
value_shape)
if not hasattr(prelude, tensor_func_name):
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
value_shape)
static_tensor_array_ops.register()
tensor_func = prelude.get_var_static("tensor_constructor",
dtype_str,
value_shape)
v = tensor_func(inputs[1])
split_func = prelude.get_var_static('tensor_array_split',
dtype_str,
input_ta_shape)
return split_func(input_ta, v, lengths)
return _impl return _impl
def _tensor_array_concat(): def _tensor_array_concat():
def _impl(inputs, attr, params, prelude): def _impl(inputs, attr, params, prelude):
concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name) dtype_str = attr['dtype'].name
return concat_func(inputs[1]) input_shape = get_tensor_array_shape(inputs[1], dtype_str, prelude)
if input_shape is None:
concat_func = prelude.get_var('tensor_array_concat', dtype_str)
out = concat_func(inputs[1])
else:
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_shape)
static_tensor_array_ops.register()
concat_func = prelude.get_var_static("tensor_array_concat", dtype_str, input_shape)
out_tensor = concat_func(inputs[1])
static_tensor_array_ops.define_tensor_get_data((Any(),) + input_shape[1:])
get_data_func = prelude.get_var_static('tensor_get_data',
dtype_str,
input_shape)
out = get_data_func(out_tensor)
return out
return _impl return _impl
def _tile(): def _tile():
...@@ -1370,7 +1622,7 @@ def _range(): ...@@ -1370,7 +1622,7 @@ def _range():
return AttrCvt( return AttrCvt(
op_name="arange", op_name="arange",
ignores=['Tidx'], ignores=['Tidx', '_class'],
extras={'start': start, extras={'start': start,
'stop': limit, 'stop': limit,
'step': delta, 'step': delta,
...@@ -2084,6 +2336,9 @@ class RecurrentNetworks(object): ...@@ -2084,6 +2336,9 @@ class RecurrentNetworks(object):
# 1.x. # 1.x.
_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond'] _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']
# A map to record tensor array with fixed rank shape
_static_tensor_array_map = {}
class RewriteSubgraph(ExprMutator): class RewriteSubgraph(ExprMutator):
""" """
A helper class to rewrite expr in while loop function to variable A helper class to rewrite expr in while loop function to variable
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions.""" """A prelude containing useful global functions and ADT definitions."""
from tvm.ir import IRModule from tvm.ir import IRModule, TypeCall
from .ty import GlobalTypeVar, TensorType, Any, scalar_type from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, GlobalVar, If, const from .expr import Var, GlobalVar, If, const
...@@ -24,8 +24,51 @@ from .function import Function ...@@ -24,8 +24,51 @@ from .function import Function
from .op.tensor import add, subtract, equal from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard from .adt import PatternConstructor, PatternVar, PatternWildcard
from . import op from . import op, transform
def get_tensor_array_shape(expr, dtype, prelude):
"""Get the static shape of a tensor array if it has fixed rank shape.
By design, static ADT tensor in TVM has type name in the format
of static_tensor_dim0_dim1_..._dimN_t.
Parameters
----------
expr : Relay Expr
Input expression.
dtype : str
Data type.
prelude : Prelude
Tensor array prelude
Returns
-------
shape : tuple of (int, Any) or None
The output shape. None if input tensor array
has dynamic shape.
"""
mod = prelude.mod
mod["main"] = Function([], expr)
mod = transform.InferType()(mod)
checked_type = mod["main"].body.checked_type
assert isinstance(checked_type, TypeCall), "Input must be a tensor array."
ta_type_str = checked_type.args[0].func.name_hint
static_ta_ty_start = "static_tensor_{}".format(dtype)
if ta_type_str.startswith(static_ta_ty_start):
shape_str = ta_type_str.replace("{}_".format(static_ta_ty_start), '') \
.replace("_t", '')
shape = []
if "scalar" not in shape_str:
for dim_str in shape_str.split("_"):
if dim_str == "?":
shape.append(Any())
else:
shape.append(int(dim_str))
return tuple(shape)
return None
def _get_name_static(canonical, dtype, shape): def _get_name_static(canonical, dtype, shape):
"""Get name for static shape tensor array op corresponding """Get name for static shape tensor array op corresponding
......
...@@ -839,63 +839,75 @@ def test_forward_squeeze(): ...@@ -839,63 +839,75 @@ def test_forward_squeeze():
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1]) _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1])
def test_tensor_array_constructor(): #######################################################################
def run(dtype_str): # TensorArray
# -----------
def test_tensor_array_write_read():
def run(dtype_str, infer_shape, element_shape):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str] dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype) np_data = np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str)
t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype) in_data = [np_data, np_data]
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) t1 = tf.constant(np_data, dtype=dtype)
ta2 = ta1.write(0, t) t2 = tf.constant(np_data, dtype=dtype)
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape,
element_shape=element_shape)
ta2 = ta1.write(0, t1)
ta3 = ta2.write(1, t2) ta3 = ta2.write(1, t2)
out = ta3.read(0) out = ta3.read(0)
g = tf.get_default_graph() g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug') compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='vm')
for dtype in tf_dtypes.keys():
run(dtype) for dtype in ["float32", "int8"]:
run(dtype, False, None)
run(dtype, False, tf.TensorShape([None, 2]))
run(dtype, True, None)
def test_tensor_array_scatter(): def test_tensor_array_scatter():
def run(dtype_str): def run(dtype_str, infer_shape):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str] dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype) t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
indices = tf.constant([2, 1, 0]) indices = tf.constant([2, 1, 0])
ta1 = tf.TensorArray(dtype=dtype, size=3, ta1 = tf.TensorArray(dtype=dtype, size=3,
infer_shape=False, dynamic_size=False) infer_shape=infer_shape)
ta2 = ta1.scatter(indices, t) ta2 = ta1.scatter(indices, t)
out0 = ta2.read(0) out0 = ta2.read(0)
out1 = ta2.read(1) out1 = ta2.read(1)
out2 = ta2.read(2) out2 = ta2.read(2)
g = tf.get_default_graph() g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='vm')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='vm')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='vm')
for dtype in tf_dtypes.keys(): for dtype in ["float32", "int8"]:
run(dtype) run(dtype, False)
run(dtype, True)
# TODO(wweic): Fix gather issue with PartialEvaluate
# def test_tensor_array_gather():
# with tf.Graph().as_default(): def test_tensor_array_gather():
# dtype = 'float32' def run(dtype_str, infer_shape):
# t = tf.constant([[1.0], [2.0], [3.0]]) with tf.Graph().as_default():
# scatter_indices = tf.constant([2, 1, 0]) dtype = tf_dtypes[dtype_str]
# gather_indices = tf.constant([1, 2]) t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str))
# ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False) scatter_indices = tf.constant([2, 1, 0])
# ta2 = ta1.scatter(scatter_indices, t) gather_indices = tf.constant([1, 2])
# t1 = ta2.gather(gather_indices) ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=infer_shape)
# g = tf.get_default_graph() ta2 = ta1.scatter(scatter_indices, t)
# compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug') t1 = ta2.gather(gather_indices)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='vm')
for dtype in ["float32", "int8"]:
run(dtype, True)
def test_tensor_array_split(): def test_tensor_array_split():
def run(dtype_str): def run(dtype_str, infer_shape):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str] dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype) t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4, ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=infer_shape)
infer_shape=False, dynamic_size=False)
ta2 = ta1.split(t, split_length) ta2 = ta1.split(t, split_length)
out0 = ta2.read(0) out0 = ta2.read(0)
out1 = ta2.read(1) out1 = ta2.read(1)
...@@ -906,56 +918,76 @@ def test_tensor_array_split(): ...@@ -906,56 +918,76 @@ def test_tensor_array_split():
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug')
for dtype in tf_dtypes.keys(): for dtype in ["float32", "int8"]:
run(dtype) run(dtype, False)
run(dtype, True)
def test_tensor_array_concat(): def test_tensor_array_concat():
def run(dtype_str): def run(dtype_str, infer_shape):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str] dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype) t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4, ta1 = tf.TensorArray(dtype=dtype, size=4,
infer_shape=False, dynamic_size=False) infer_shape=infer_shape)
ta2 = ta1.split(t, split_length) ta2 = ta1.split(t, split_length)
t = ta2.concat() t = ta2.concat()
out = tf.identity(t) out = tf.identity(t)
compare_tf_with_tvm([], [], ['Identity:0'], mode='debug') compare_tf_with_tvm([], [], ['Identity:0'], mode='debug')
for dtype in tf_dtypes.keys(): for dtype in ["float32", "int8"]:
run(dtype) run(dtype, False)
run(dtype, True)
def test_tensor_array_size(): def test_tensor_array_size():
def run(dtype_str): def run(dtype_str, infer_shape):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str] dtype = tf_dtypes[dtype_str]
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape)
out = ta1.size() out = ta1.size()
g = tf.get_default_graph() g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
for dtype in tf_dtypes.keys(): for dtype in ["float32", "int8"]:
run(dtype) run(dtype, False)
run(dtype, True)
def test_tensor_array_stack():
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str))
scatter_indices = tf.constant([2, 1, 0])
ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=infer_shape)
ta2 = ta1.scatter(scatter_indices, t)
t1 = ta2.stack()
print(t1)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayStack/TensorArrayGatherV3:0'], mode='vm')
for dtype in ["float32", "int8"]:
run(dtype, True)
def test_tensor_array_unstack(): def test_tensor_array_unstack():
def run(dtype_str, input_shape): def run(dtype_str, input_shape, infer_shape):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str] dtype = tf_dtypes[dtype_str]
t = tf.constant(np.random.choice([0, 1, 2, 3], t = tf.constant(np.random.choice([0, 1, 2, 3],
size=input_shape).astype(dtype.name)) size=input_shape).astype(dtype.name))
ta1 = tf.TensorArray(dtype=dtype, infer_shape=False, size=input_shape[0]) ta1 = tf.TensorArray(dtype=dtype, infer_shape=infer_shape, size=input_shape[0])
ta2 = ta1.unstack(t) ta2 = ta1.unstack(t)
out0 = ta2.size() out0 = ta2.size()
out1 = ta2.read(0) out1 = ta2.read(0)
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug') compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
for dtype in tf_dtypes.keys(): for dtype in ["float32", "int8"]:
run(dtype, (5,)) run(dtype, (5,), False)
run(dtype, (5, 5)) run(dtype, (5, 5), True)
run(dtype, (5, 5, 5)) run(dtype, (5, 5, 5), False)
run(dtype, (5, 5, 5, 5)) run(dtype, (5, 5, 5, 5), True)
run(dtype, (5, 5, 5, 5, 5))
run(dtype, (5, 5, 5, 5, 5, 5))
####################################################################### #######################################################################
# ConcatV2 # ConcatV2
...@@ -3241,6 +3273,16 @@ if __name__ == '__main__': ...@@ -3241,6 +3273,16 @@ if __name__ == '__main__':
test_forward_reduce() test_forward_reduce()
test_forward_mean() test_forward_mean()
# TensorArray
test_tensor_array_write_read()
test_tensor_array_concat()
test_tensor_array_scatter()
test_tensor_array_gather()
test_tensor_array_size()
test_tensor_array_split()
test_tensor_array_stack()
test_tensor_array_unstack()
# General # General
test_forward_multi_input() test_forward_multi_input()
test_forward_multi_output() test_forward_multi_output()
......
...@@ -166,12 +166,14 @@ def get_const_tuple(in_tuple): ...@@ -166,12 +166,14 @@ def get_const_tuple(in_tuple):
""" """
ret = [] ret = []
for elem in in_tuple: for elem in in_tuple:
if isinstance(elem, tvm.tir.Var): if isinstance(elem, (tvm.tir.Var, tvm.tir.expr.Any)):
ret.append(elem) ret.append(elem)
elif not isinstance(elem, (tvm.tir.IntImm, int)): elif not isinstance(elem, (tvm.tir.IntImm, int)):
elem = tvm.tir.ir_pass.Simplify(elem) elem = tvm.tir.ir_pass.Simplify(elem)
if not isinstance(elem, tvm.tir.IntImm): if not isinstance(elem, tvm.tir.IntImm):
ret.append(elem) ret.append(elem)
else:
ret.append(get_const_int(elem))
else: else:
ret.append(get_const_int(elem)) ret.append(get_const_int(elem))
return tuple(ret) return tuple(ret)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment