Unverified Commit 4b27cd14 by Yao Wang Committed by GitHub

[Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array (#5243)

* Support TF Frontend Static TensorArray

* Fix pylint

* Fix lint

* Move get_tensor_array_shape into prelude

* Fix lint

* Fix common
parent b72dd9d9
......@@ -456,22 +456,20 @@ def get_name(node):
def infer_type(node, mod=None):
"""A method to infer the type of an intermediate node in the relay graph."""
new_mod = IRModule.from_expr(node)
if mod is not None:
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
return entry if isinstance(node, _function.Function) else entry.body
if isinstance(mod, IRModule):
mod["main"] = _function.Function([], node)
mod = _transform.InferType()(mod)
entry = mod["main"]
ret = entry.body
else:
new_mod = IRModule.from_expr(node)
if mod is not None:
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
ret = entry if isinstance(node, _function.Function) else entry.body
def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
if hasattr(checked_type, 'shape'):
# Regular operator that outputs tensors
return get_const_tuple(out_type.checked_type.shape)
# The return type is not a tensor, for example List
return checked_type
return ret
def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
......@@ -483,6 +481,17 @@ def infer_channels(inputs, transpose=False):
return channels
def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
out_type = infer_type(inputs, mod=mod)
checked_type = out_type.checked_type
if hasattr(checked_type, 'shape'):
# Regular operator that outputs tensors
return get_const_tuple(checked_type.shape)
# The return type is not a tensor, for example List
return checked_type
def infer_value(input_val, params, mod=None):
"""A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that
......@@ -505,7 +514,7 @@ def infer_value(input_val, params, mod=None):
return m.get_output(0)
except Exception:
if isinstance(mod, IRModule):
mod["main"] = _expr.Function(analysis.free_vars(input_val), input_val)
mod["main"] = _function.Function(analysis.free_vars(input_val), input_val)
else:
mod = IRModule.from_expr(input_val)
exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
......
......@@ -26,13 +26,14 @@ import numpy as np
import tvm
from tvm.ir import IRModule
from tvm.relay.prelude import Prelude
from tvm.relay.prelude import Prelude, StaticTensorArrayOps, get_tensor_array_shape
from tvm.ir import structural_hash as s_hash
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from ..ty import Any
from ..expr_functor import ExprMutator, ExprVisitor
from .common import AttrCvt, get_relay_op
from .common import infer_type as _infer_type
......@@ -259,8 +260,6 @@ def _conv(opname):
if opname == 'conv_transpose' and attr['data_format'] == 'NHWC':
# transform to NCHW for TVM backend compatible and set 'flip_layout'
# to have output flip back to NHWC
tmp_shape = _infer_shape(inputs[2], mod)
tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2))
attr['strides'][1], attr['strides'][2], attr['strides'][3] = \
attr['strides'][3], attr['strides'][1], attr['strides'][2]
......@@ -789,25 +788,152 @@ def _pack():
def _tensor_array():
def _impl(inputs, attr, params, prelude):
try:
from tensorflow.python.framework import tensor_util
except ImportError as e:
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))
dtype_str = attr.get('dtype').name
tensor_array_constructor = prelude.get_var('tensor_array', dtype_str)
return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0)))
assert not attr["dynamic_size"], "Dynamic size tensor array is " \
"not supported in TVM yet."
raw_elem_shape = tensor_util.TensorShapeProtoToList(attr['element_shape'])
elem_shape = []
for dim in raw_elem_shape:
if dim < 0:
elem_shape.append(Any())
else:
elem_shape.append(dim)
if elem_shape:
# Element shape is specified.
# Directly create static tensor array with given shape.
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
elem_shape)
static_tensor_array_ops.register()
tensor_array_constructor = prelude.get_var_static('tensor_array',
dtype_str,
elem_shape)
tensor_array = tensor_array_constructor(inputs[0])
_static_tensor_array_map[tensor_array] = tensor_array
elif attr['identical_element_shapes']:
# identical_element_shapes is set but element shape is not given.
# We create a static tensor array with dummy shape and record it in
# _static_tensor_array_map. Later when creating other tensor array ops
# which uses this tensor array, we reconstruct this tensor array with
# actual shape.
dummy_shape = ()
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
dummy_shape)
static_tensor_array_ops.register()
tensor_array_constructor = prelude.get_var_static('tensor_array',
dtype_str,
dummy_shape)
tensor_array = tensor_array_constructor(inputs[0])
_static_tensor_array_map[tensor_array] = None
else:
tensor_array_constructor = prelude.get_var('tensor_array', dtype_str)
tensor_array = tensor_array_constructor(inputs[0])
return tensor_array
return _impl
def _tensor_array_scatter():
def _impl(inputs, attr, params, prelude):
dtype_str = attr.get('T').name
values_rank = len(inputs[2].type_annotation.shape)
unstack_name = "tensor_array_unstack_tensor{}".format(values_rank)
unstack_function = prelude.get_var(unstack_name, dtype_str)
values = unstack_function(inputs[2])
tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str)
return tensor_array_scatter_func(inputs[0], inputs[1], values)
input_ta = inputs[0]
input_shape = get_tensor_array_shape(input_ta, dtype_str, prelude)
values_shape = _infer_shape(inputs[2], prelude.mod)
input_t_shape = values_shape[1:]
indices_shape = _infer_shape(inputs[1], prelude.mod)
if input_shape is None:
values_rank = len(values_shape)
unstack_name = "tensor_array_unstack_tensor{}".format(values_rank)
unstack_function = prelude.get_var(unstack_name, dtype_str)
values = unstack_function(inputs[2])
tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str)
else:
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_t_shape)
static_tensor_array_ops.register()
# For scatter operation, it is possible to write to a newly create
# tensor array. We need to check and recreate its input tensor array.
if input_ta in _static_tensor_array_map and \
_static_tensor_array_map[input_ta] is None:
ta_constructor = prelude.get_var_static('tensor_array',
dtype_str,
input_t_shape)
new_ta = ta_constructor(input_ta.args[0])
_static_tensor_array_map[input_ta] = new_ta
input_ta = new_ta
# Register static indices shape
if isinstance(indices_shape[0], int):
static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True)
tensor_array_scatter_func = prelude.get_var_static('tensor_array_scatter',
dtype_str,
input_t_shape)
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
values_shape)
static_tensor_array_ops.register()
unstack_function = prelude.get_var_static('tensor_array_unstack',
dtype_str,
values_shape)
values = unstack_function(inputs[2])
ret = tensor_array_scatter_func(input_ta, inputs[1], values)
return ret
return _impl
def _tensor_array_gather():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_gather(inputs[2], inputs[1])
dtype_str = attr.get('dtype').name
input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude)
indices_shape = _infer_shape(inputs[1], prelude.mod)
if input_shape is None:
gather_func = prelude.get_var('tensor_array_gather', dtype_str)
out = gather_func(inputs[2], inputs[1])
else:
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_shape)
static_tensor_array_ops.register()
if not isinstance(indices_shape[0], int):
gather_function = prelude.get_var_static('tensor_array_gather',
dtype_str,
input_shape)
out_tensor_t = gather_function(inputs[2], inputs[1])
# Output shape is (indices_shape[0],) + input_shape
static_tensor_array_ops.define_tensor_get_data((indices_shape[0],) + input_shape)
get_data_func = prelude.get_var_static('tensor_get_data',
dtype_str,
input_shape)
out = get_data_func(out_tensor_t)
else:
# For fixed length indices, directly generate static shape output
read_func = prelude.get_var_static('tensor_array_read',
dtype_str,
input_shape)
static_tensor_array_ops.define_tensor_get_data(input_shape)
get_data_func = prelude.get_var_static('tensor_get_data',
dtype_str,
input_shape)
tensor_list = []
for i in range(indices_shape[0]):
index = _op.take(inputs[1], tvm.relay.const(i))
out_tensor = get_data_func(read_func(inputs[2], index))
tensor_list.append(_op.expand_dims(out_tensor, axis=0))
out = _op.concatenate(tensor_list, axis=0)
return out
return _impl
def _tensor_array_size():
......@@ -817,37 +943,163 @@ def _tensor_array_size():
def _tensor_array_write():
def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[2].type_annotation.shape)
dtype = attr.get('T').name
dtype_str = attr.get('T').name
input_ta = inputs[3]
input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude)
input_t_shape = _infer_shape(inputs[2], prelude.mod)
input_rank = len(input_t_shape)
if input_ta_shape is None:
tensor_name = 'tensor{}'.format(input_rank)
tensor_func = prelude.get_var(tensor_name, dtype_str)
v = tensor_func(inputs[2])
write_func = prelude.get_var('tensor_array_write', dtype_str)
else:
# For write operation, it is possible to write to a newly create
# tensor array. We need to check and recreate its input tensor array.
if input_ta in _static_tensor_array_map and \
_static_tensor_array_map[input_ta] is None:
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_t_shape)
static_tensor_array_ops.register()
ta_constructor = prelude.get_var_static('tensor_array',
dtype_str,
input_t_shape)
new_ta = ta_constructor(input_ta.args[0])
_static_tensor_array_map[input_ta] = new_ta
input_ta = new_ta
input_ta_shape = input_t_shape
else:
input_ta_rank = len(input_ta_shape)
assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \
format(input_ta_rank, input_rank)
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_ta_shape)
static_tensor_array_ops.register()
tensor_name = 'tensor{}'.format(input_rank)
tensor_func = prelude.get_var(tensor_name, dtype)
v = tensor_func(inputs[2])
write_func = prelude.get_var('tensor_array_write', dtype)
tensor_func = prelude.get_var_static("tensor_constructor",
dtype_str,
input_ta_shape)
v = tensor_func(inputs[2])
write_func = prelude.get_var_static('tensor_array_write',
dtype_str,
input_ta_shape)
return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v)
return write_func(input_ta, _op.take(inputs[1], tvm.relay.const(0)), v)
return _impl
def _tensor_array_read():
def _impl(inputs, attr, params, prelude):
read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name)
return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
dtype_str = attr['dtype'].name
input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude)
if input_shape is None:
read_func = prelude.get_var('tensor_array_read', dtype_str)
out = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
else:
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_shape)
static_tensor_array_ops.register()
static_tensor_array_ops.define_tensor_get_data(input_shape)
read_func = prelude.get_var_static("tensor_array_read", dtype_str, input_shape)
out_tensor = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
get_data_func = prelude.get_var_static('tensor_get_data',
dtype_str,
input_shape)
out = get_data_func(out_tensor)
return out
return _impl
def _tensor_array_split():
def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[1].type_annotation.shape)
dtype_str = attr.get('T').name
v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1])
input_ta = inputs[0]
input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude)
input_t_shape = _infer_shape(inputs[1], prelude.mod)
input_rank = len(input_t_shape)
lengths = _op.cast(inputs[2], 'int32')
split_var = prelude.get_var('tensor_array_split', dtype_str)
return split_var(inputs[0], v, lengths)
lengths_shape = _infer_shape(lengths, prelude.mod)
value_shape = _infer_shape(inputs[1], prelude.mod)
if input_ta_shape is None:
v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1])
split_func = prelude.get_var('tensor_array_split', dtype_str)
else:
# For split operation, it is possible to write to a newly create
# tensor array. We need to check and recreate its input tensor array.
if input_ta in _static_tensor_array_map and \
_static_tensor_array_map[input_ta] is None:
input_ta_shape = (Any(),) + input_t_shape[1:]
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_ta_shape)
static_tensor_array_ops.register()
ta_constructor = prelude.get_var_static('tensor_array',
dtype_str,
input_ta_shape)
new_ta = ta_constructor(input_ta.args[0])
_static_tensor_array_map[input_ta] = new_ta
input_ta = new_ta
else:
input_ta_rank = len(input_ta_shape)
assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \
format(input_ta_rank, input_rank)
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_ta_shape)
static_tensor_array_ops.register()
# Check static value/indices shape
if isinstance(value_shape[0], int) or isinstance(lengths_shape[0], int):
static_tensor_array_ops.define_tensor_array_split(value_shape,
lengths_shape,
True)
tensor_func_name = prelude.get_name_static("tensor_constructor",
dtype_str,
value_shape)
if not hasattr(prelude, tensor_func_name):
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
value_shape)
static_tensor_array_ops.register()
tensor_func = prelude.get_var_static("tensor_constructor",
dtype_str,
value_shape)
v = tensor_func(inputs[1])
split_func = prelude.get_var_static('tensor_array_split',
dtype_str,
input_ta_shape)
return split_func(input_ta, v, lengths)
return _impl
def _tensor_array_concat():
def _impl(inputs, attr, params, prelude):
concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name)
return concat_func(inputs[1])
dtype_str = attr['dtype'].name
input_shape = get_tensor_array_shape(inputs[1], dtype_str, prelude)
if input_shape is None:
concat_func = prelude.get_var('tensor_array_concat', dtype_str)
out = concat_func(inputs[1])
else:
static_tensor_array_ops = StaticTensorArrayOps(prelude,
dtype_str,
input_shape)
static_tensor_array_ops.register()
concat_func = prelude.get_var_static("tensor_array_concat", dtype_str, input_shape)
out_tensor = concat_func(inputs[1])
static_tensor_array_ops.define_tensor_get_data((Any(),) + input_shape[1:])
get_data_func = prelude.get_var_static('tensor_get_data',
dtype_str,
input_shape)
out = get_data_func(out_tensor)
return out
return _impl
def _tile():
......@@ -1370,7 +1622,7 @@ def _range():
return AttrCvt(
op_name="arange",
ignores=['Tidx'],
ignores=['Tidx', '_class'],
extras={'start': start,
'stop': limit,
'step': delta,
......@@ -2084,6 +2336,9 @@ class RecurrentNetworks(object):
# 1.x.
_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']
# A map to record tensor array with fixed rank shape
_static_tensor_array_map = {}
class RewriteSubgraph(ExprMutator):
"""
A helper class to rewrite expr in while loop function to variable
......
......@@ -16,7 +16,7 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions."""
from tvm.ir import IRModule
from tvm.ir import IRModule, TypeCall
from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, GlobalVar, If, const
......@@ -24,8 +24,51 @@ from .function import Function
from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
from . import op
from . import op, transform
def get_tensor_array_shape(expr, dtype, prelude):
"""Get the static shape of a tensor array if it has fixed rank shape.
By design, static ADT tensor in TVM has type name in the format
of static_tensor_dim0_dim1_..._dimN_t.
Parameters
----------
expr : Relay Expr
Input expression.
dtype : str
Data type.
prelude : Prelude
Tensor array prelude
Returns
-------
shape : tuple of (int, Any) or None
The output shape. None if input tensor array
has dynamic shape.
"""
mod = prelude.mod
mod["main"] = Function([], expr)
mod = transform.InferType()(mod)
checked_type = mod["main"].body.checked_type
assert isinstance(checked_type, TypeCall), "Input must be a tensor array."
ta_type_str = checked_type.args[0].func.name_hint
static_ta_ty_start = "static_tensor_{}".format(dtype)
if ta_type_str.startswith(static_ta_ty_start):
shape_str = ta_type_str.replace("{}_".format(static_ta_ty_start), '') \
.replace("_t", '')
shape = []
if "scalar" not in shape_str:
for dim_str in shape_str.split("_"):
if dim_str == "?":
shape.append(Any())
else:
shape.append(int(dim_str))
return tuple(shape)
return None
def _get_name_static(canonical, dtype, shape):
"""Get name for static shape tensor array op corresponding
......
......@@ -839,63 +839,75 @@ def test_forward_squeeze():
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1])
def test_tensor_array_constructor():
def run(dtype_str):
#######################################################################
# TensorArray
# -----------
def test_tensor_array_write_read():
def run(dtype_str, infer_shape, element_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
ta2 = ta1.write(0, t)
np_data = np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str)
in_data = [np_data, np_data]
t1 = tf.constant(np_data, dtype=dtype)
t2 = tf.constant(np_data, dtype=dtype)
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape,
element_shape=element_shape)
ta2 = ta1.write(0, t1)
ta3 = ta2.write(1, t2)
out = ta3.read(0)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
for dtype in tf_dtypes.keys():
run(dtype)
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='vm')
for dtype in ["float32", "int8"]:
run(dtype, False, None)
run(dtype, False, tf.TensorShape([None, 2]))
run(dtype, True, None)
def test_tensor_array_scatter():
def run(dtype_str):
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
indices = tf.constant([2, 1, 0])
ta1 = tf.TensorArray(dtype=dtype, size=3,
infer_shape=False, dynamic_size=False)
infer_shape=infer_shape)
ta2 = ta1.scatter(indices, t)
out0 = ta2.read(0)
out1 = ta2.read(1)
out2 = ta2.read(2)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
for dtype in tf_dtypes.keys():
run(dtype)
# TODO(wweic): Fix gather issue with PartialEvaluate
# def test_tensor_array_gather():
# with tf.Graph().as_default():
# dtype = 'float32'
# t = tf.constant([[1.0], [2.0], [3.0]])
# scatter_indices = tf.constant([2, 1, 0])
# gather_indices = tf.constant([1, 2])
# ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False)
# ta2 = ta1.scatter(scatter_indices, t)
# t1 = ta2.gather(gather_indices)
# g = tf.get_default_graph()
# compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='vm')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='vm')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='vm')
for dtype in ["float32", "int8"]:
run(dtype, False)
run(dtype, True)
def test_tensor_array_gather():
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str))
scatter_indices = tf.constant([2, 1, 0])
gather_indices = tf.constant([1, 2])
ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=infer_shape)
ta2 = ta1.scatter(scatter_indices, t)
t1 = ta2.gather(gather_indices)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='vm')
for dtype in ["float32", "int8"]:
run(dtype, True)
def test_tensor_array_split():
def run(dtype_str):
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4,
infer_shape=False, dynamic_size=False)
ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=infer_shape)
ta2 = ta1.split(t, split_length)
out0 = ta2.read(0)
out1 = ta2.read(1)
......@@ -906,56 +918,76 @@ def test_tensor_array_split():
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug')
for dtype in tf_dtypes.keys():
run(dtype)
for dtype in ["float32", "int8"]:
run(dtype, False)
run(dtype, True)
def test_tensor_array_concat():
def run(dtype_str):
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4,
infer_shape=False, dynamic_size=False)
infer_shape=infer_shape)
ta2 = ta1.split(t, split_length)
t = ta2.concat()
out = tf.identity(t)
compare_tf_with_tvm([], [], ['Identity:0'], mode='debug')
for dtype in tf_dtypes.keys():
run(dtype)
for dtype in ["float32", "int8"]:
run(dtype, False)
run(dtype, True)
def test_tensor_array_size():
def run(dtype_str):
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape)
out = ta1.size()
g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
for dtype in tf_dtypes.keys():
run(dtype)
for dtype in ["float32", "int8"]:
run(dtype, False)
run(dtype, True)
def test_tensor_array_stack():
def run(dtype_str, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str))
scatter_indices = tf.constant([2, 1, 0])
ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=infer_shape)
ta2 = ta1.scatter(scatter_indices, t)
t1 = ta2.stack()
print(t1)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayStack/TensorArrayGatherV3:0'], mode='vm')
for dtype in ["float32", "int8"]:
run(dtype, True)
def test_tensor_array_unstack():
def run(dtype_str, input_shape):
def run(dtype_str, input_shape, infer_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.random.choice([0, 1, 2, 3],
size=input_shape).astype(dtype.name))
ta1 = tf.TensorArray(dtype=dtype, infer_shape=False, size=input_shape[0])
ta1 = tf.TensorArray(dtype=dtype, infer_shape=infer_shape, size=input_shape[0])
ta2 = ta1.unstack(t)
out0 = ta2.size()
out1 = ta2.read(0)
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
for dtype in tf_dtypes.keys():
run(dtype, (5,))
run(dtype, (5, 5))
run(dtype, (5, 5, 5))
run(dtype, (5, 5, 5, 5))
run(dtype, (5, 5, 5, 5, 5))
run(dtype, (5, 5, 5, 5, 5, 5))
for dtype in ["float32", "int8"]:
run(dtype, (5,), False)
run(dtype, (5, 5), True)
run(dtype, (5, 5, 5), False)
run(dtype, (5, 5, 5, 5), True)
#######################################################################
# ConcatV2
......@@ -3241,6 +3273,16 @@ if __name__ == '__main__':
test_forward_reduce()
test_forward_mean()
# TensorArray
test_tensor_array_write_read()
test_tensor_array_concat()
test_tensor_array_scatter()
test_tensor_array_gather()
test_tensor_array_size()
test_tensor_array_split()
test_tensor_array_stack()
test_tensor_array_unstack()
# General
test_forward_multi_input()
test_forward_multi_output()
......
......@@ -166,12 +166,14 @@ def get_const_tuple(in_tuple):
"""
ret = []
for elem in in_tuple:
if isinstance(elem, tvm.tir.Var):
if isinstance(elem, (tvm.tir.Var, tvm.tir.expr.Any)):
ret.append(elem)
elif not isinstance(elem, (tvm.tir.IntImm, int)):
elem = tvm.tir.ir_pass.Simplify(elem)
if not isinstance(elem, tvm.tir.IntImm):
ret.append(elem)
else:
ret.append(get_const_int(elem))
else:
ret.append(get_const_int(elem))
return tuple(ret)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment