Commit 36a96773 by Wei Chen Committed by Zhi

[Relay][Frontend][TF] Add tensor array ops (#3798)

* [Relay][Frontend][TF] Add tensor array ops

* rename

* delete test

* Move utility function

* Refactor

* fix tensor array ops

* fix test

* fix rebase

* Fix serializer bug

* Improve tf convert name lookup to use prelude api

* Fix lint

* Fix test
parent 4052de6d
......@@ -22,10 +22,14 @@ from __future__ import print_function
import warnings
from collections import defaultdict
# Numpy support
import numpy as np
import tvm
from tvm.relay.prelude import Prelude
from .. import analysis
from .. import expr as _expr
from .. import op as _op
......@@ -508,6 +512,69 @@ def _pack():
return _op.concatenate(inputs_reshaped, axis)
return _impl
def _tensor_array():
def _impl(inputs, attr, params, prelude):
dtype_str = attr.get('dtype').name
tensor_array_constructor = prelude.get_var('tensor_array', dtype_str)
return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0)))
return _impl
def _tensor_array_scatter():
def _impl(inputs, attr, params, prelude):
dtype_str = attr.get('T').name
values_rank = len(inputs[2].type_annotation.shape)
unstack_name = "tensor_array_unstack_tensor{}".format(values_rank)
unstack_function = prelude.get_var(unstack_name, dtype_str)
values = unstack_function(inputs[2])
tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str)
return tensor_array_scatter_func(inputs[0], inputs[1], values)
return _impl
def _tensor_array_gather():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_gather(inputs[2], inputs[1])
return _impl
def _tensor_array_size():
def _impl(inputs, attr, params, prelude):
return prelude.length(inputs[0])
return _impl
def _tensor_array_write():
def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[2].type_annotation.shape)
dtype = attr.get('T').name
tensor_name = 'tensor{}'.format(input_rank)
tensor_func = prelude.get_var(tensor_name, dtype)
v = tensor_func(inputs[2])
write_func = prelude.get_var('tensor_array_write', dtype)
return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v)
return _impl
def _tensor_array_read():
def _impl(inputs, attr, params, prelude):
read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name)
return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
return _impl
def _tensor_array_split():
def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[1].type_annotation.shape)
dtype_str = attr.get('T').name
v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1])
lengths = _op.cast(inputs[2], 'int32')
split_var = prelude.get_var('tensor_array_split', dtype_str)
return split_var(inputs[0], v, lengths)
return _impl
def _tensor_array_concat():
def _impl(inputs, attr, params, prelude):
concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name)
return concat_func(inputs[1])
return _impl
def _tile():
def _impl(inputs, attr, params):
reps = _get_list_param(params, inputs.pop())
......@@ -1313,6 +1380,14 @@ _convert_map = {
'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(),
'Pack' : _pack(),
'TensorArrayV3' : _tensor_array(),
'TensorArrayScatterV3' : _tensor_array_scatter(),
'TensorArrayGatherV3' : _tensor_array_gather(),
'TensorArraySizeV3' : _tensor_array_size(),
'TensorArrayWriteV3' : _tensor_array_write(),
'TensorArrayReadV3' : _tensor_array_read(),
'TensorArraySplitV3' : _tensor_array_split(),
'TensorArrayConcatV3' : _tensor_array_concat(),
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
'Pow' : _elemwise('power'),
......@@ -1860,6 +1935,7 @@ class GraphProto(object):
self._loops = {}
self._branches = {}
self._mod = _module.Module({})
self._prelude = Prelude(self._mod)
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
......@@ -2335,7 +2411,11 @@ class GraphProto(object):
if op_name in identity_list:
sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
if 'TensorArray' in op_name:
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
else:
sym = convert_map[op_name](inputs, attrs, self._params)
elif op_name in convert_map_rnn:
sym = self._convert_rnn_operator(op_name, inputs, attrs,
self._params, graph,
......
......@@ -108,6 +108,29 @@ def clip_compute(attrs, inputs, output_type, target):
register_schedule("clip", schedule_elemwise)
@script
def _cast_shape_function(x):
out_ndim = len(x)
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = x[i]
return out
def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]
@script
def _expand_dims_shape_func(x):
ndim = len(x.shape)
out = output_tensor((ndim+1,), "int64")
out[0] = int64(1)
for i in const_range(0, ndim):
out[i+1] = int64(x.shape[i])
return out
def expand_dims_shape_func(attrs, inputs, out_ndims):
return [_expand_dims_shape_func(*inputs)]
# shape func
@script
def _broadcast_shape_func(x, y, ndim):
......@@ -140,6 +163,9 @@ def _broadcast_shape_func(x, y, ndim):
def broadcast_shape_func(attrs, inputs, out_ndims):
return [_broadcast_shape_func(*inputs, out_ndims[0])]
register_shape_func("expand_dims", False, expand_dims_shape_func)
register_shape_func("cast", False, cast_shape_func)
register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
register_shape_func("multiply", False, broadcast_shape_func)
......
......@@ -16,8 +16,513 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions."""
from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, Function, GlobalVar, If, const
from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
from . import op
from .module import Module
class TensorArrayOps(object):
"""Contains tensor array related ops"""
def __init__(self, prelude, dtype):
"""Create tensor array ops registry"""
self.prelude = prelude
self.dtype = dtype
def get_name(self, canonical):
"""Get name corresponding to the caninical name"""
return self.prelude.get_name(canonical, self.dtype)
def get_var(self, canonical):
"""Get var corresponding to the caninical name"""
return self.prelude.get_var(canonical, self.dtype)
def define_tensor_adt(self):
"""Defines the dynamic tensor ADT, which is the container for tensors
with variable shapes."""
tensor_type_name = self.get_name('tensor_t')
tensor_type_var = GlobalTypeVar(tensor_type_name)
setattr(self.prelude, tensor_type_name, tensor_type_var)
tensor0_type = TensorType([], self.dtype)
tensor1_type = TensorType([Any()], self.dtype)
tensor2_type = TensorType([Any(), Any()], self.dtype)
tensor3_type = TensorType([Any(), Any(), Any()], self.dtype)
tensor4_type = TensorType([Any(), Any(), Any(), Any()], self.dtype)
tensor5_type = TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype)
tensor6_type = TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype)
tensor_nil_name = self.get_name('tensor_nil')
tensor0_name = self.get_name('tensor0')
tensor1_name = self.get_name('tensor1')
tensor2_name = self.get_name('tensor2')
tensor3_name = self.get_name('tensor3')
tensor4_name = self.get_name('tensor4')
tensor5_name = self.get_name('tensor5')
tensor6_name = self.get_name('tensor6')
tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var)
tensor0_case = Constructor(tensor0_name, [tensor0_type], tensor_type_var)
tensor1_case = Constructor(tensor1_name, [tensor1_type], tensor_type_var)
tensor2_case = Constructor(tensor2_name, [tensor2_type], tensor_type_var)
tensor3_case = Constructor(tensor3_name, [tensor3_type], tensor_type_var)
tensor4_case = Constructor(tensor4_name, [tensor4_type], tensor_type_var)
tensor5_case = Constructor(tensor5_name, [tensor5_type], tensor_type_var)
tensor6_case = Constructor(tensor6_name, [tensor6_type], tensor_type_var)
setattr(self.prelude, tensor_nil_name, tensor_nil_case)
setattr(self.prelude, tensor0_name, tensor0_case)
setattr(self.prelude, tensor1_name, tensor1_case)
setattr(self.prelude, tensor2_name, tensor2_case)
setattr(self.prelude, tensor3_name, tensor3_case)
setattr(self.prelude, tensor4_name, tensor4_case)
setattr(self.prelude, tensor5_name, tensor5_case)
setattr(self.prelude, tensor6_name, tensor6_case)
self.prelude.mod[tensor_type_var] = TypeData(tensor_type_var, [], [tensor_nil_case,
tensor0_case,
tensor1_case,
tensor2_case,
tensor3_case,
tensor4_case,
tensor5_case,
tensor6_case])
def define_tensor_take(self):
"""Defines a function to return a range of tensor_t on axis 0.
tensor_take(t, lower, upper) :
tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t
"""
take_name = self.get_name("tensor_take")
take_var = GlobalVar(take_name)
setattr(self.prelude, take_name, take_var)
tensor_t = self.get_var('tensor_t')
tensor1_var = self.get_var('tensor1')
tensor2_var = self.get_var('tensor2')
tensor3_var = self.get_var('tensor3')
tensor4_var = self.get_var('tensor4')
tensor5_var = self.get_var('tensor5')
tensor6_var = self.get_var('tensor6')
t = Var('tensor', tensor_t())
lower = Var('lower', scalar_type('int32'))
upper = Var('upper', scalar_type('int32'))
t1 = Var('t1')
t2 = Var('t2')
t3 = Var('t3')
t4 = Var('t4')
t5 = Var('t5')
t6 = Var('t6')
tensor1_case =\
Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]),
tensor1_var(op.take(t1, op.arange(lower, upper, dtype='int32'))))
tensor2_case =\
Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]),
tensor2_var(op.take(t2, op.arange(lower, upper, dtype='int32'), axis=0)))
tensor3_case =\
Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]),
tensor3_var(op.take(t3, op.arange(lower, upper, dtype='int32'), axis=0)))
tensor4_case =\
Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]),
tensor4_var(op.take(t4, op.arange(lower, upper, dtype='int32'), axis=0)))
tensor5_case =\
Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]),
tensor5_var(op.take(t5, op.arange(lower, upper, dtype='int32'), axis=0)))
tensor6_case =\
Clause(PatternConstructor(tensor6_var, [PatternVar(t6)]),
tensor6_var(op.take(t6, op.arange(lower, upper, dtype='int32'), axis=0)))
self.prelude.mod[take_var] =\
Function([t, lower, upper],
Match(t, [tensor1_case,
tensor2_case,
tensor3_case,
tensor4_case,
tensor5_case,
tensor6_case], False),
tensor_t(), [])
def define_tensor_expand_dims(self):
"""Defines a function to grow a tensor_t's rank by adding one dimension in front
of the original tensor_t.
tensor_expand_dims(t) : tensor_t -> tensor_t
"""
expand_dims_name = self.get_name("tensor_expand_dims")
expand_dims_var = GlobalVar(expand_dims_name)
setattr(self.prelude, expand_dims_name, expand_dims_var)
tensor_type_var = self.get_var('tensor_t')
x = Var("x", tensor_type_var())
t0 = Var("t0")
t1 = Var("t1")
t2 = Var("t2")
t3 = Var("t3")
t4 = Var("t4")
t5 = Var("t5")
tensor0_var = self.get_var('tensor0')
tensor1_var = self.get_var('tensor1')
tensor2_var = self.get_var('tensor2')
tensor3_var = self.get_var('tensor3')
tensor4_var = self.get_var('tensor4')
tensor5_var = self.get_var('tensor5')
tensor6_var = self.get_var('tensor6')
tensor0_case = Clause(PatternConstructor(tensor0_var, [PatternVar(t0)]),
tensor1_var(op.expand_dims(t0, 0, 1)))
tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]),
tensor2_var(op.expand_dims(t1, 0, 1)))
tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]),
tensor3_var(op.expand_dims(t2, 0, 1)))
tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]),
tensor4_var(op.expand_dims(t3, 0, 1)))
tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]),
tensor5_var(op.expand_dims(t4, 0, 1)))
tensor5_case = Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]),
tensor6_var(op.expand_dims(t5, 0, 1)))
self.prelude.mod[expand_dims_var] =\
Function([x],
Match(x, [tensor0_case,
tensor1_case,
tensor2_case,
tensor3_case,
tensor4_case,
tensor5_case], False))
def define_tensor_concat(self):
"""Defines a function to concatenate two tensor_t on the first axis
tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t
"""
concat_name = self.get_name("tensor_concatenate")
concat_var = GlobalVar(concat_name)
setattr(self.prelude, concat_name, concat_var)
tensor_type_var = self.get_var('tensor_t')
x = Var("x", tensor_type_var())
y = Var("y", tensor_type_var())
tensor1_var = self.get_var('tensor1')
tensor2_var = self.get_var('tensor2')
tensor3_var = self.get_var('tensor3')
tensor4_var = self.get_var('tensor4')
t11 = Var("t11")
t12 = Var("t12")
t21 = Var("t21")
t22 = Var("t22")
t31 = Var("t31")
t32 = Var("t32")
t41 = Var("t41")
t42 = Var("t42")
tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t11)]),
Match(y, [Clause(PatternConstructor(tensor1_var, [PatternVar(t12)]),
tensor1_var(op.concatenate([t11, t12], axis=0)))],
False))
tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t21)]),
Match(y, [Clause(PatternConstructor(tensor2_var, [PatternVar(t22)]),
tensor2_var(op.concatenate([t21, t22], axis=0)))],
False))
tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t31)]),
Match(y, [Clause(PatternConstructor(tensor3_var, [PatternVar(t32)]),
tensor3_var(op.concatenate([t31, t32], axis=0)))],
False))
tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t41)]),
Match(y, [Clause(PatternConstructor(tensor4_var, [PatternVar(t42)]),
tensor4_var(op.concatenate([t41, t42], axis=0)))],
False))
# op.concatenate does not support tensor with rank higher than 4
self.prelude.mod[concat_var] =\
Function([x, y], Match(x, [tensor1_case,
tensor2_case,
tensor3_case,
tensor4_case], False))
def define_tensor_array(self):
"""Defines a function to create a tensor array with size n.
tensor_array(n) : Tensor[(), int32] -> list[tensor_t]
"""
tensor_array_constructor_name = self.get_name("tensor_array")
tensor_array_constructor_var = GlobalVar(tensor_array_constructor_name)
setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var)
tensor_nil_var = self.get_var('tensor_nil')
tensor_type_var = self.get_var('tensor_t')
n = Var("x", scalar_type('int32'))
body = If(equal(n, const(0)),
self.prelude.nil(),
self.prelude.cons(tensor_nil_var(),
tensor_array_constructor_var(subtract(n, const(1)))))
self.prelude.mod[tensor_array_constructor_var] = \
Function([n], body, self.prelude.l(tensor_type_var()), [])
def define_tensor_array_read(self):
"""Defines a function to get the head of a list. Assume the list has at least one
element.
tensor_array_read(ta, n) : list[tensor_t] -> Tensor[(), int32] -> tensor_t
"""
read_name = self.get_name("tensor_array_read")
read_var = GlobalVar(read_name)
setattr(self.prelude, read_name, read_var)
tensor_type_var = self.get_var('tensor_t')
tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
n = Var("x", scalar_type('int32'))
self.prelude.mod[read_var] =\
Function([tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), [])
def define_tensor_array_write(self):
"""Defines a function to update a tensor array at index n with value v.
tensor_array_write(ta, n, v) :
list[tensor_t] -> Tensor[(), int32] -> tensor_t -> list[tensor_t]
"""
write_name = self.get_name("tensor_array_write")
write_var = GlobalVar(write_name)
setattr(self.prelude, write_name, write_var)
tensor_type_var = self.get_var('tensor_t')
tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
n = Var("x", scalar_type('int32'))
v = Var("v", tensor_type_var())
self.prelude.mod[write_var] =\
Function([tensor_array, n, v], self.prelude.update(tensor_array, n, v),
self.prelude.l(tensor_type_var()), [])
def define_tensor_array_unstack_tensor1(self):
"""Defines a function to unstack the values of a tensor_t with rank 1 in a tensor array.
tensor_array_unstack_tensor1(t) : tensor_t -> list[tensor_t]
"""
helper_name = self.get_name("tensor_array_unstack_tensor1_helper")
helper_var = GlobalVar(helper_name)
setattr(self.prelude, helper_name, helper_var)
tensor = Var("t", TensorType([Any()], self.dtype))
up = Var("up", scalar_type('int32'))
i = Var("i", scalar_type('int32'))
tensor_type_var = self.get_var('tensor_t')
tensor0_var = self.get_var('tensor0')
helper_body =\
If(equal(i, up),
self.prelude.nil(),
self.prelude.cons(tensor0_var(op.take(tensor, i)),
helper_var(add(i, const(1)), up, tensor)))
self.prelude.mod[helper_var] =\
Function([i, up, tensor], helper_body, self.prelude.l(tensor_type_var()), [])
unstack_name = self.get_name("tensor_array_unstack_tensor1")
unstack_var = GlobalVar(unstack_name)
setattr(self.prelude, unstack_name, unstack_var)
tensor1 = Var("tensor", TensorType([Any()], self.dtype))
shape = op.shape_of(tensor1)
ndim = op.take(shape, const(0))
self.prelude.mod[unstack_var] =\
Function([tensor1], helper_var(const(0), ndim, tensor1),
self.prelude.l(tensor_type_var()), [])
def define_tensor_array_unstack_tensor2(self):
"""Defines a function to unstack the values of a tensor_t with rank 2 in a tensor array.
tensor_array_unstack_tensor2(t) : tensor_t -> list[tensor_t]
"""
helper_name = self.get_name("tensor_array_unstack_tensor2_helper")
helper_var = GlobalVar(helper_name)
setattr(self.prelude, helper_name, helper_var)
tensor = Var("t", TensorType([Any(), Any()], self.dtype))
up = Var("up", scalar_type('int32'))
i = Var("i", scalar_type('int32'))
helper_body = If(equal(i, up),
self.prelude.nil(),
self.prelude.cons(self.get_var('tensor1')(op.take(tensor, i, axis=0)),
helper_var(add(i, const(1)), up, tensor)))
self.prelude.mod[helper_var] =\
Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), [])
tensor_array_unstack_tensor2_name = self.get_name("tensor_array_unstack_tensor2")
tensor_array_unstack_tensor2_var = GlobalVar(tensor_array_unstack_tensor2_name)
setattr(self.prelude, tensor_array_unstack_tensor2_name, tensor_array_unstack_tensor2_var)
tensor2 = Var("tensor", TensorType([Any(), Any()], self.dtype))
shape = op.shape_of(tensor2)
ndim = op.take(shape, const(0))
self.prelude.mod[tensor_array_unstack_tensor2_var] =\
Function([tensor2], helper_var(const(0), ndim, tensor2),
self.prelude.l(self.get_var('tensor_t')()), [])
def define_tensor_array_scatter(self):
"""Defines a function to scatter the values of a tensor_t in indices of a tensor array.
tensor_array_scatter(ta, indices, value) :
list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t]
"""
tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper")
tensor_array_scatter_helper_var = GlobalVar(tensor_array_scatter_helper_name)
tensor_t = self.get_var('tensor_t')
ta = Var("ta", self.prelude.l(tensor_t()))
current = Var("current", scalar_type('int32'))
limit = Var("limit", scalar_type('int32'))
indices_ = Var('indices_', TensorType([Any()], 'int32'))
values_ = Var('values_', self.prelude.l(tensor_t()))
write_var = self.get_var('tensor_array_write')
read_var = self.get_var('tensor_array_read')
helper_body = If(equal(current, limit),
ta,
tensor_array_scatter_helper_var(
write_var(ta, op.take(indices_, current),
read_var(values_, current)),
add(current, const(1)),
limit, indices_, values_))
self.prelude.mod[tensor_array_scatter_helper_var] =\
Function([ta, current, limit, indices_, values_],
helper_body, self.prelude.l(tensor_t()), [])
tensor_array_scatter_name = self.get_name("tensor_array_scatter")
tensor_array_scatter_var = GlobalVar(tensor_array_scatter_name)
setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var)
tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
indices = Var('indices', TensorType([Any()], 'int32'))
values = Var('values', self.prelude.l(tensor_t()))
indices_shape = op.shape_of(indices)
limit = op.take(indices_shape, const(0))
body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values)
self.prelude.mod[tensor_array_scatter_var] =\
Function([tensor_array, indices, values], body, self.prelude.l(tensor_t()), [])
def define_tensor_array_split(self):
"""Defines a function to split the values of a tensor_t into a tensor array.
tensor_array_split(ta, value, lengths) :
list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t]
"""
tensor_t = self.get_var('tensor_t')
tensor_array_split_helper_name = self.get_name("ta_split_helper")
tensor_array_split_helper_var = GlobalVar(tensor_array_split_helper_name)
setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var)
ta1 = Var("tensor_array", self.prelude.l(tensor_t()))
value1 = Var('value1', tensor_t())
offset1 = Var('offset1', scalar_type('int32'))
current1 = Var('current1', scalar_type('int32'))
limit1 = Var('limit1', scalar_type('int32'))
lengths1 = Var('lengths', TensorType([Any()], 'int32'))
write_var = self.get_var('tensor_array_write')
take_var = self.get_var('tensor_take')
helper1_body = If(equal(current1, limit1),
ta1,
write_var(
tensor_array_split_helper_var(
ta1,
value1,
add(offset1, op.take(lengths1, current1)),
add(current1, const(1)),
limit1,
lengths1
),
current1,
take_var(value1,
offset1,
add(op.take(lengths1, current1), offset1))))
self.prelude.mod[tensor_array_split_helper_var] = \
Function([ta1, value1, offset1, current1, limit1, lengths1],
helper1_body, self.prelude.l(tensor_t()), [])
split_name = self.get_name("tensor_array_split")
split_var = GlobalVar(split_name)
setattr(self.prelude, split_name, split_var)
tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
value = Var('value', tensor_t())
lengths = Var('lengths', TensorType([Any()], 'int32'))
lengths_shape = op.shape_of(lengths)
lengths_limit = op.take(lengths_shape, const(0))
body = tensor_array_split_helper_var(
tensor_array,
value,
const(0),
const(0),
lengths_limit,
lengths)
self.prelude.mod[split_var] =\
Function([tensor_array, value, lengths], body, self.prelude.l(tensor_t()), [])
def define_tensor_array_concat(self):
"""Defines a function to return the values in the tensor array as concatenated tensor_t.
tensor_array_concat(ta) : list[tensor_t] -> tensor_t
"""
concat_name = self.get_name("tensor_array_concat")
concat_var = GlobalVar(concat_name)
setattr(self.prelude, concat_name, concat_var)
tensor_concat_var = self.get_var('tensor_concatenate')
tensor_t = self.get_var('tensor_t')
tensor_nil_var = self.get_var('tensor_nil')
tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
hd = Var("hd")
tl = Var("tl")
nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var())
cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]),
Match(tl, [
Clause(PatternConstructor(self.prelude.nil), hd),
Clause(PatternWildcard(),
tensor_concat_var(hd, concat_var(tl)))
], False))
self.prelude.mod[concat_var] =\
Function([tensor_array],
Match(tensor_array, [nil_case, cons_case], False), tensor_t(), [])
def define_tensor_array_gather(self):
"""Defines a function to return the selected values in a tensor array as tensor_t.
tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t
"""
helper_name = self.get_name("tensor_array_gather_helper")
helper_var = GlobalVar(helper_name)
setattr(self.prelude, helper_name, helper_var)
tensor_type_var = self.get_var('tensor_t')
stack_var = self.get_var('tensor_array_stack')
read_var = self.get_var('tensor_array_read')
ta = Var("ta", self.prelude.l(tensor_type_var()))
accu = Var("accu", self.prelude.l(tensor_type_var()))
current = Var("current", scalar_type('int32'))
limit = Var("limit", scalar_type('int32'))
indices_ = Var('indices_', TensorType([Any()], 'int32'))
helper_body =\
If(equal(current, const(0)),
stack_var(accu),
helper_var(
ta,
self.prelude.cons(
read_var(
ta, op.take(indices_, subtract(current, const(1)))), accu),
subtract(current, const(1)),
limit, indices_))
self.prelude.mod[helper_var] = \
Function([ta, accu, current, limit, indices_], helper_body, tensor_type_var(), [])
gather_name = self.get_name("tensor_array_gather")
gather_var = GlobalVar(gather_name)
setattr(self.prelude, gather_name, gather_var)
tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
indices = Var('indices', TensorType([Any()], 'int32'))
indices_shape = op.shape_of(indices)
limit = op.take(indices_shape, const(0))
body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices)
self.prelude.mod[gather_var] =\
Function([tensor_array, indices], body, tensor_type_var(), [])
def define_tensor_array_stack(self):
"""Defines a function to get the values in the tensor array as a stack tensor_t.
tensor_array_stack(l) : list[tensor_t] -> tensor_t
"""
stack_name = self.get_name("tensor_array_stack")
stack_var = GlobalVar(stack_name)
setattr(self.prelude, stack_name, stack_var)
tensor_type_var = self.get_var('tensor_t')
tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
expand_dims_var = self.get_var('tensor_expand_dims')
concat_var = self.get_var('tensor_concatenate')
tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array)
tensors = self.prelude.foldl(concat_var,
self.prelude.hd(tensor_array_expand_dims),
self.prelude.tl(tensor_array_expand_dims))
self.prelude.mod[stack_var] = Function([tensor_array], tensors, tensor_type_var(), [])
def register(self):
"""Register all tensor array ops in Prelude"""
self.define_tensor_adt()
self.define_tensor_take()
self.define_tensor_expand_dims()
self.define_tensor_concat()
self.define_tensor_array()
self.define_tensor_array_read()
self.define_tensor_array_write()
self.define_tensor_array_unstack_tensor1()
self.define_tensor_array_unstack_tensor2()
self.define_tensor_array_scatter()
self.define_tensor_array_split()
self.define_tensor_array_concat()
self.define_tensor_array_stack()
# TODO(wweic): Gather fails in PartialEvaluate
# self.define_tensor_array_gather()
class Prelude:
"""Contains standard definitions."""
......@@ -27,6 +532,17 @@ class Prelude:
self.mod = mod
self.load_prelude()
def get_name(self, canonical, dtype):
"""Get name corresponding to the canonical name"""
if canonical == 'tensor_t':
return 'tensor_{}_t'.format(dtype)
return "{}_{}".format(canonical, dtype)
def get_var(self, canonical, dtype):
"""Get var corresponding to the canonical name"""
name = self.get_name(canonical, dtype)
return getattr(self, name)
def load_prelude(self):
"""Parses the Prelude from Relay's text format into a module."""
# TODO(@jroesch): we should remove this helper when we port over prelude
......@@ -74,3 +590,7 @@ class Prelude:
]
for global_def in GLOBAL_DEFS:
setattr(self, global_def, self.mod.get_global_var(global_def))
for dtype in ['float32', 'int32']:
tensor_array_ops = TensorArrayOps(self, dtype)
tensor_array_ops.register()
......@@ -203,8 +203,12 @@ class PythonConverter(ExprFunctor):
for var, func in self.mod.functions.items():
# optimize the definition so any operators used are lowered
opt_func = self.optimize(func)
converted_func, _ = self.convert_func_node(opt_func, var)
defs.append(converted_func)
try:
converted_func, _ = self.convert_func_node(opt_func, var)
defs.append(converted_func)
except TypeError:
# TODO(wweic): fix conversion for Any
pass
return defs
......
......@@ -309,7 +309,9 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
fields.push_back(instr.alloc_tensor_reg.shape_register);
// Save `DLDataType` and the dst register.
const auto& dtype = instr.alloc_tensor.dtype;
fields.assign({dtype.code, dtype.bits, dtype.lanes});
fields.push_back(dtype.code);
fields.push_back(dtype.bits);
fields.push_back(dtype.lanes);
fields.push_back(instr.dst);
break;
}
......
......@@ -60,13 +60,19 @@ def vmobj_to_list(o):
result.append(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'cons':
if o.constructor.name_hint == 'Cons':
tl = vmobj_to_list(o.fields[1])
hd = vmobj_to_list(o.fields[0])
hd.extend(tl)
return hd
elif o.constructor.name_hint == 'nil':
elif o.constructor.name_hint == 'Nil':
return []
elif 'tensor_nil' in o.constructor.name_hint:
return [0]
elif 'tensor' in o.constructor.name_hint:
return [o.fields[0].asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy()]
else:
......@@ -77,14 +83,11 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
""" Generic function to compile on relay and execute on tvm """
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
layout = None
if target == "cuda":
layout = "NCHW"
target_host = None
shape_dict = {e: i.shape for e, i in zip(input_node, input_data)}
mod, params = relay.frontend.from_tensorflow(graph_def,
layout=layout,
shape=shape_dict,
......@@ -581,6 +584,111 @@ def test_forward_squeeze():
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1])
def test_tensor_array_constructor():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
ta2 = ta1.write(0, t)
ta3 = ta2.write(1, t2)
out = ta3.read(0)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
run('float32')
run('int32')
def test_tensor_array_scatter():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
indices = tf.constant([2, 1, 0])
ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=False, dynamic_size=False)
ta2 = ta1.scatter(indices, t)
out0 = ta2.read(0)
out1 = ta2.read(1)
out2 = ta2.read(2)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
run('float32')
run('int32')
# TODO(wweic): Fix gather issue with PartialEvaluate
# def test_tensor_array_gather():
# with tf.Graph().as_default():
# dtype = 'float32'
# t = tf.constant([[1.0], [2.0], [3.0]])
# scatter_indices = tf.constant([2, 1, 0])
# gather_indices = tf.constant([1, 2])
# ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False)
# ta2 = ta1.scatter(scatter_indices, t)
# t1 = ta2.gather(gather_indices)
# g = tf.get_default_graph()
# compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug')
def test_tensor_array_split():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False)
ta2 = ta1.split(t, split_length)
out0 = ta2.read(0)
out1 = ta2.read(1)
out2 = ta2.read(2)
out3 = ta2.read(3)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug')
run('float32')
run('int32')
def test_tensor_array_concat():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False)
ta2 = ta1.split(t, split_length)
t = ta2.concat()
compare_tf_with_tvm([], [], ['TensorArrayConcatV3:0'], mode='debug')
run('float32')
run('int32')
def test_tensor_array_size():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
out = ta1.size()
g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
run('float32')
run('int32')
#######################################################################
# ConcatV2
# --------
......
......@@ -21,6 +21,8 @@ from tvm.relay import create_executor
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr
import numpy as np
mod = relay.Module()
p = Prelude(mod)
add_nat_definitions(p)
......@@ -683,6 +685,146 @@ def test_iterate():
res = intrp.evaluate(relay.Function([], expr)())
assert count(res) == 12
def test_tensor_expand_dims():
def run(dtype):
x = relay.var('x')
mod = relay.Module()
p = Prelude(mod)
expand_dims_func = p.get_var('tensor_expand_dims', dtype)
tensor1 = p.get_var('tensor1', dtype)
mod["main"] = relay.Function([x], expand_dims_func(tensor1(x)))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
x_np = np.random.uniform(size=(1,)).astype(dtype)
result = ex.evaluate()(x_np)
got = vmobj_to_list(result)
expected = [np.expand_dims(x_np, axis=0)]
tvm.testing.assert_allclose(expected, got)
run('float32')
run('int32')
def test_tensor_array_constructor():
def run(dtype):
x = relay.var('x')
mod = relay.Module()
p = Prelude(mod)
tensor_array = p.get_var('tensor_array', dtype)
mod["main"] = relay.Function([x], tensor_array(x))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(5)
got = vmobj_to_list(result)
expected = np.array([0, 0, 0, 0, 0])
tvm.testing.assert_allclose(expected, got)
run('float32')
run('int32')
def test_tensor_array_read():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
l = relay.var('l')
i = relay.var('i')
read_func = p.get_var('tensor_array_read', dtype)
tensor_array = p.get_var('tensor_array', dtype)
mod["main"] = relay.Function([l, i], read_func(tensor_array(l), i))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(10, 5)
got = vmobj_to_list(result)
expected = [0]
tvm.testing.assert_allclose(expected, got)
run('float32')
run('int32')
def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.Tensor):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.asnumpy()]
elif isinstance(o, tvm.relay.backend.vmobj.Datatype):
result = []
for f in o:
result.extend(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'Cons':
tl = vmobj_to_list(o.fields[1])
hd = vmobj_to_list(o.fields[0])
hd.extend(tl)
return hd
elif o.constructor.name_hint == 'Nil':
return []
elif 'tensor_nil' in o.constructor.name_hint:
return [0]
elif 'tensor' in o.constructor.name_hint:
return [o.fields[0].asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
else:
raise RuntimeError("Unknown object type: %s" % type(o))
def test_tensor_array_stack():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
tensor_array = p.get_var('tensor_array', dtype)
tensor1 = p.get_var('tensor1', dtype)
write = p.get_var('tensor_array_write', dtype)
stack = p.get_var('tensor_array_stack', dtype)
l = relay.var('l')
v = relay.var('v')
init_tensor_array = tensor_array(relay.const(3))
tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v))
tensor_array2 = write(tensor_array1, relay.const(1), tensor1(v))
tensor_array3 = write(tensor_array2, relay.const(2), tensor1(v))
tensor_array4 = stack(tensor_array3)
mod["main"] = relay.Function([v], tensor_array4)
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
t = np.random.uniform(size=(1,)).astype(dtype)
result = ex.evaluate()(t)
res = vmobj_to_list(result)
expected = [np.stack([t, t, t])]
tvm.testing.assert_allclose(expected, res)
run('float32')
run('int32')
def test_tensor_array_unstack():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
unstack_tensor1 = p.get_var('tensor_array_unstack_tensor1', dtype)
v = relay.var('v')
mod["main"] = relay.Function([v], unstack_tensor1(v))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
t = np.random.uniform(size=(1,)).astype(dtype)
result = ex.evaluate()(t)
res = vmobj_to_list(result)
tvm.testing.assert_allclose(t, res)
run('float32')
run('int32')
def test_tensor_take():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
take = p.get_var('tensor_take', dtype)
tensor2 = p.get_var('tensor2', dtype)
v = relay.var('v')
lower = relay.var('lower')
upper = relay.var('upper')
mod["main"] = relay.Function([v, lower, upper], take(tensor2(v), lower, upper))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
t = np.random.uniform(size=(10, 10)).astype(dtype)
result = ex.evaluate()(t, 2, 5)
res = vmobj_to_list(result)
expected = [np.take(t, range(2, 5), axis=0)]
tvm.testing.assert_allclose(expected, res)
run('float32')
run('int32')
if __name__ == "__main__":
test_nat_constructor()
......@@ -707,3 +849,9 @@ if __name__ == "__main__":
test_size()
test_compose()
test_iterate()
test_tensor_expand_dims()
test_tensor_array_constructor()
test_tensor_array_read()
test_tensor_array_stack()
test_tensor_array_unstack()
......@@ -38,7 +38,8 @@ def test_prelude():
Feature.fLet,
Feature.fIf,
Feature.fConstructor,
Feature.fMatch
Feature.fMatch,
Feature.fGraph
])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment