Commit 03a29da7 by Wei Chen Committed by Zhi

[Relay][Op][TF] Complete tensor array unstack with all ranks support (#4309)

parent e6806115
......@@ -40,6 +40,7 @@ from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape
from .common import infer_channels as _infer_channels
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
__all__ = ['from_tensorflow']
......@@ -1079,9 +1080,13 @@ def _rank():
def _range():
def _impl(inputs, attr, params):
start = _get_param(params, inputs[0])[0]
limit = _get_param(params, inputs[1])[0] \
if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \
else params.pop('Rank').asnumpy()[0]
if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant):
limit = _get_param(params, inputs[1])[0]
else:
if any(['Rank' in param for param in params]):
limit = params.pop('Rank').asnumpy()[0]
else:
limit = _infer_value_simulated(inputs[1], params).asnumpy()[0]
delta = _get_param(params, inputs[2])[0]
dtype = attr['Tidx'].name if 'Tidx' in attr else str(start.dtype)
return AttrCvt(
......
......@@ -336,6 +336,122 @@ class TensorArrayOps(object):
Function([tensor2], helper_var(const(0), ndim, tensor2),
self.prelude.l(self.get_var('tensor_t')()), [])
def define_tensor_array_unstack_tensor3(self):
"""Defines a function to unstack the values of a tensor_t with rank 3 in a tensor array.
tensor_array_unstack_tensor3(t) : tensor_t -> list[tensor_t]
"""
helper_name = self.get_name("tensor_array_unstack_tensor3_helper")
helper_var = GlobalVar(helper_name)
setattr(self.prelude, helper_name, helper_var)
tensor = Var("t", TensorType([Any(), Any(), Any()], self.dtype))
up = Var("up", scalar_type('int32'))
i = Var("i", scalar_type('int32'))
helper_body = If(equal(i, up),
self.prelude.nil(),
self.prelude.cons(self.get_var('tensor2')(op.take(tensor, i, axis=0)),
helper_var(add(i, const(1)), up, tensor)))
self.prelude.mod[helper_var] =\
Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])
tensor_array_unstack_tensor3_name = self.get_name("tensor_array_unstack_tensor3")
tensor_array_unstack_tensor3_var = GlobalVar(tensor_array_unstack_tensor3_name)
setattr(self.prelude, tensor_array_unstack_tensor3_name, tensor_array_unstack_tensor3_var)
tensor3 = Var("tensor", TensorType([Any(), Any(), Any()], self.dtype))
shape = op.shape_of(tensor3)
ndim = op.take(shape, const(0))
self.prelude.mod[tensor_array_unstack_tensor3_var] =\
Function([tensor3], helper_var(const(0), ndim, tensor3),
self.prelude.l(self.get_var('tensor_t')()), [])
def define_tensor_array_unstack_tensor4(self):
"""Defines a function to unstack the values of a tensor_t with rank 4 in a tensor array.
tensor_array_unstack_tensor4(t) : tensor_t -> list[tensor_t]
"""
helper_name = self.get_name("tensor_array_unstack_tensor4_helper")
helper_var = GlobalVar(helper_name)
setattr(self.prelude, helper_name, helper_var)
tensor = Var("t", TensorType([Any(), Any(), Any(), Any()], self.dtype))
up = Var("up", scalar_type('int32'))
i = Var("i", scalar_type('int32'))
helper_body = If(equal(i, up),
self.prelude.nil(),
self.prelude.cons(self.get_var('tensor3')(op.take(tensor, i, axis=0)),
helper_var(add(i, const(1)), up, tensor)))
self.prelude.mod[helper_var] =\
Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])
tensor_array_unstack_tensor4_name = self.get_name("tensor_array_unstack_tensor4")
tensor_array_unstack_tensor4_var = GlobalVar(tensor_array_unstack_tensor4_name)
setattr(self.prelude, tensor_array_unstack_tensor4_name, tensor_array_unstack_tensor4_var)
tensor4 = Var("tensor", TensorType([Any(), Any(), Any(), Any()], self.dtype))
shape = op.shape_of(tensor4)
ndim = op.take(shape, const(0))
self.prelude.mod[tensor_array_unstack_tensor4_var] =\
Function([tensor4], helper_var(const(0), ndim, tensor4),
self.prelude.l(self.get_var('tensor_t')()), [])
def define_tensor_array_unstack_tensor5(self):
"""Defines a function to unstack the values of a tensor_t with rank 5 in a tensor array.
tensor_array_unstack_tensor5(t) : tensor_t -> list[tensor_t]
"""
helper_name = self.get_name("tensor_array_unstack_tensor5_helper")
helper_var = GlobalVar(helper_name)
setattr(self.prelude, helper_name, helper_var)
tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype))
up = Var("up", scalar_type('int32'))
i = Var("i", scalar_type('int32'))
helper_body = If(equal(i, up),
self.prelude.nil(),
self.prelude.cons(self.get_var('tensor4')(op.take(tensor, i, axis=0)),
helper_var(add(i, const(1)), up, tensor)))
self.prelude.mod[helper_var] =\
Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])
tensor_array_unstack_tensor5_name = self.get_name("tensor_array_unstack_tensor5")
tensor_array_unstack_tensor5_var = GlobalVar(tensor_array_unstack_tensor5_name)
setattr(self.prelude, tensor_array_unstack_tensor5_name, tensor_array_unstack_tensor5_var)
tensor5 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype))
shape = op.shape_of(tensor5)
ndim = op.take(shape, const(0))
self.prelude.mod[tensor_array_unstack_tensor5_var] =\
Function([tensor5], helper_var(const(0), ndim, tensor5),
self.prelude.l(self.get_var('tensor_t')()), [])
def define_tensor_array_unstack_tensor6(self):
"""Defines a function to unstack the values of a tensor_t with rank 6 in a tensor array.
tensor_array_unstack_tensor6(t) : tensor_t -> list[tensor_t]
"""
helper_name = self.get_name("tensor_array_unstack_tensor6_helper")
helper_var = GlobalVar(helper_name)
setattr(self.prelude, helper_name, helper_var)
tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype))
up = Var("up", scalar_type('int32'))
i = Var("i", scalar_type('int32'))
helper_body = If(equal(i, up),
self.prelude.nil(),
self.prelude.cons(self.get_var('tensor5')(op.take(tensor, i, axis=0)),
helper_var(add(i, const(1)), up, tensor)))
self.prelude.mod[helper_var] =\
Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])
tensor_array_unstack_tensor6_name = self.get_name("tensor_array_unstack_tensor6")
tensor_array_unstack_tensor6_var = GlobalVar(tensor_array_unstack_tensor6_name)
setattr(self.prelude, tensor_array_unstack_tensor6_name, tensor_array_unstack_tensor6_var)
tensor6 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype))
shape = op.shape_of(tensor6)
ndim = op.take(shape, const(0))
self.prelude.mod[tensor_array_unstack_tensor6_var] =\
Function([tensor6], helper_var(const(0), ndim, tensor6),
self.prelude.l(self.get_var('tensor_t')()), [])
def define_tensor_array_scatter(self):
"""Defines a function to scatter the values of a tensor_t in indices of a tensor array.
tensor_array_scatter(ta, indices, value) :
......@@ -516,6 +632,10 @@ class TensorArrayOps(object):
self.define_tensor_array_write()
self.define_tensor_array_unstack_tensor1()
self.define_tensor_array_unstack_tensor2()
self.define_tensor_array_unstack_tensor3()
self.define_tensor_array_unstack_tensor4()
self.define_tensor_array_unstack_tensor5()
self.define_tensor_array_unstack_tensor6()
self.define_tensor_array_scatter()
self.define_tensor_array_split()
self.define_tensor_array_concat()
......
......@@ -763,6 +763,26 @@ def test_tensor_array_size():
for dtype in tf_dtypes.keys():
run(dtype)
def test_tensor_array_unstack():
def run(dtype_str, input_shape):
with tf.Graph().as_default():
dtype = tf_dtypes[dtype_str]
t = tf.constant(np.random.choice([0, 1, 2, 3],
size=input_shape).astype(dtype.name))
ta1 = tf.TensorArray(dtype=dtype, infer_shape=False, size=input_shape[0])
ta2 = ta1.unstack(t)
out0 = ta2.size()
out1 = ta2.read(0)
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
for dtype in tf_dtypes.keys():
run(dtype, (5,))
run(dtype, (5, 5))
run(dtype, (5, 5, 5))
run(dtype, (5, 5, 5, 5))
run(dtype, (5, 5, 5, 5, 5))
run(dtype, (5, 5, 5, 5, 5, 5))
#######################################################################
# ConcatV2
# --------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment