Commit 36a96773 by Wei Chen Committed by Zhi

[Relay][Frontend][TF] Add tensor array ops (#3798)

* [Relay][Frontend][TF] Add tensor array ops

* rename

* delete test

* Move utility function

* Refactor

* fix tensor array ops

* fix test

* fix rebase

* Fix serializer bug

* Improve tf convert name lookup to use prelude api

* Fix lint

* Fix test
parent 4052de6d
......@@ -22,10 +22,14 @@ from __future__ import print_function
import warnings
from collections import defaultdict
# Numpy support
import numpy as np
import tvm
from tvm.relay.prelude import Prelude
from .. import analysis
from .. import expr as _expr
from .. import op as _op
......@@ -508,6 +512,69 @@ def _pack():
return _op.concatenate(inputs_reshaped, axis)
return _impl
def _tensor_array():
def _impl(inputs, attr, params, prelude):
dtype_str = attr.get('dtype').name
tensor_array_constructor = prelude.get_var('tensor_array', dtype_str)
return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0)))
return _impl
def _tensor_array_scatter():
def _impl(inputs, attr, params, prelude):
dtype_str = attr.get('T').name
values_rank = len(inputs[2].type_annotation.shape)
unstack_name = "tensor_array_unstack_tensor{}".format(values_rank)
unstack_function = prelude.get_var(unstack_name, dtype_str)
values = unstack_function(inputs[2])
tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str)
return tensor_array_scatter_func(inputs[0], inputs[1], values)
return _impl
def _tensor_array_gather():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_gather(inputs[2], inputs[1])
return _impl
def _tensor_array_size():
def _impl(inputs, attr, params, prelude):
return prelude.length(inputs[0])
return _impl
def _tensor_array_write():
def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[2].type_annotation.shape)
dtype = attr.get('T').name
tensor_name = 'tensor{}'.format(input_rank)
tensor_func = prelude.get_var(tensor_name, dtype)
v = tensor_func(inputs[2])
write_func = prelude.get_var('tensor_array_write', dtype)
return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v)
return _impl
def _tensor_array_read():
def _impl(inputs, attr, params, prelude):
read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name)
return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
return _impl
def _tensor_array_split():
def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[1].type_annotation.shape)
dtype_str = attr.get('T').name
v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1])
lengths = _op.cast(inputs[2], 'int32')
split_var = prelude.get_var('tensor_array_split', dtype_str)
return split_var(inputs[0], v, lengths)
return _impl
def _tensor_array_concat():
def _impl(inputs, attr, params, prelude):
concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name)
return concat_func(inputs[1])
return _impl
def _tile():
def _impl(inputs, attr, params):
reps = _get_list_param(params, inputs.pop())
......@@ -1313,6 +1380,14 @@ _convert_map = {
'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(),
'Pack' : _pack(),
'TensorArrayV3' : _tensor_array(),
'TensorArrayScatterV3' : _tensor_array_scatter(),
'TensorArrayGatherV3' : _tensor_array_gather(),
'TensorArraySizeV3' : _tensor_array_size(),
'TensorArrayWriteV3' : _tensor_array_write(),
'TensorArrayReadV3' : _tensor_array_read(),
'TensorArraySplitV3' : _tensor_array_split(),
'TensorArrayConcatV3' : _tensor_array_concat(),
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
'Pow' : _elemwise('power'),
......@@ -1860,6 +1935,7 @@ class GraphProto(object):
self._loops = {}
self._branches = {}
self._mod = _module.Module({})
self._prelude = Prelude(self._mod)
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
......@@ -2335,7 +2411,11 @@ class GraphProto(object):
if op_name in identity_list:
sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
if 'TensorArray' in op_name:
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
else:
sym = convert_map[op_name](inputs, attrs, self._params)
elif op_name in convert_map_rnn:
sym = self._convert_rnn_operator(op_name, inputs, attrs,
self._params, graph,
......
......@@ -108,6 +108,29 @@ def clip_compute(attrs, inputs, output_type, target):
register_schedule("clip", schedule_elemwise)
@script
def _cast_shape_function(x):
out_ndim = len(x)
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = x[i]
return out
def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]
@script
def _expand_dims_shape_func(x):
ndim = len(x.shape)
out = output_tensor((ndim+1,), "int64")
out[0] = int64(1)
for i in const_range(0, ndim):
out[i+1] = int64(x.shape[i])
return out
def expand_dims_shape_func(attrs, inputs, out_ndims):
return [_expand_dims_shape_func(*inputs)]
# shape func
@script
def _broadcast_shape_func(x, y, ndim):
......@@ -140,6 +163,9 @@ def _broadcast_shape_func(x, y, ndim):
def broadcast_shape_func(attrs, inputs, out_ndims):
return [_broadcast_shape_func(*inputs, out_ndims[0])]
register_shape_func("expand_dims", False, expand_dims_shape_func)
register_shape_func("cast", False, cast_shape_func)
register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
register_shape_func("multiply", False, broadcast_shape_func)
......
......@@ -203,8 +203,12 @@ class PythonConverter(ExprFunctor):
for var, func in self.mod.functions.items():
# optimize the definition so any operators used are lowered
opt_func = self.optimize(func)
converted_func, _ = self.convert_func_node(opt_func, var)
defs.append(converted_func)
try:
converted_func, _ = self.convert_func_node(opt_func, var)
defs.append(converted_func)
except TypeError:
# TODO(wweic): fix conversion for Any
pass
return defs
......
......@@ -309,7 +309,9 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
fields.push_back(instr.alloc_tensor_reg.shape_register);
// Save `DLDataType` and the dst register.
const auto& dtype = instr.alloc_tensor.dtype;
fields.assign({dtype.code, dtype.bits, dtype.lanes});
fields.push_back(dtype.code);
fields.push_back(dtype.bits);
fields.push_back(dtype.lanes);
fields.push_back(instr.dst);
break;
}
......
......@@ -60,13 +60,19 @@ def vmobj_to_list(o):
result.append(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'cons':
if o.constructor.name_hint == 'Cons':
tl = vmobj_to_list(o.fields[1])
hd = vmobj_to_list(o.fields[0])
hd.extend(tl)
return hd
elif o.constructor.name_hint == 'nil':
elif o.constructor.name_hint == 'Nil':
return []
elif 'tensor_nil' in o.constructor.name_hint:
return [0]
elif 'tensor' in o.constructor.name_hint:
return [o.fields[0].asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy()]
else:
......@@ -77,14 +83,11 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
""" Generic function to compile on relay and execute on tvm """
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
layout = None
if target == "cuda":
layout = "NCHW"
target_host = None
shape_dict = {e: i.shape for e, i in zip(input_node, input_data)}
mod, params = relay.frontend.from_tensorflow(graph_def,
layout=layout,
shape=shape_dict,
......@@ -581,6 +584,111 @@ def test_forward_squeeze():
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1])
def test_tensor_array_constructor():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
ta2 = ta1.write(0, t)
ta3 = ta2.write(1, t2)
out = ta3.read(0)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
run('float32')
run('int32')
def test_tensor_array_scatter():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
indices = tf.constant([2, 1, 0])
ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=False, dynamic_size=False)
ta2 = ta1.scatter(indices, t)
out0 = ta2.read(0)
out1 = ta2.read(1)
out2 = ta2.read(2)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
run('float32')
run('int32')
# TODO(wweic): Fix gather issue with PartialEvaluate
# def test_tensor_array_gather():
# with tf.Graph().as_default():
# dtype = 'float32'
# t = tf.constant([[1.0], [2.0], [3.0]])
# scatter_indices = tf.constant([2, 1, 0])
# gather_indices = tf.constant([1, 2])
# ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False)
# ta2 = ta1.scatter(scatter_indices, t)
# t1 = ta2.gather(gather_indices)
# g = tf.get_default_graph()
# compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug')
def test_tensor_array_split():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False)
ta2 = ta1.split(t, split_length)
out0 = ta2.read(0)
out1 = ta2.read(1)
out2 = ta2.read(2)
out3 = ta2.read(3)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug')
run('float32')
run('int32')
def test_tensor_array_concat():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False)
ta2 = ta1.split(t, split_length)
t = ta2.concat()
compare_tf_with_tvm([], [], ['TensorArrayConcatV3:0'], mode='debug')
run('float32')
run('int32')
def test_tensor_array_size():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
out = ta1.size()
g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
run('float32')
run('int32')
#######################################################################
# ConcatV2
# --------
......
......@@ -21,6 +21,8 @@ from tvm.relay import create_executor
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr
import numpy as np
mod = relay.Module()
p = Prelude(mod)
add_nat_definitions(p)
......@@ -683,6 +685,146 @@ def test_iterate():
res = intrp.evaluate(relay.Function([], expr)())
assert count(res) == 12
def test_tensor_expand_dims():
def run(dtype):
x = relay.var('x')
mod = relay.Module()
p = Prelude(mod)
expand_dims_func = p.get_var('tensor_expand_dims', dtype)
tensor1 = p.get_var('tensor1', dtype)
mod["main"] = relay.Function([x], expand_dims_func(tensor1(x)))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
x_np = np.random.uniform(size=(1,)).astype(dtype)
result = ex.evaluate()(x_np)
got = vmobj_to_list(result)
expected = [np.expand_dims(x_np, axis=0)]
tvm.testing.assert_allclose(expected, got)
run('float32')
run('int32')
def test_tensor_array_constructor():
def run(dtype):
x = relay.var('x')
mod = relay.Module()
p = Prelude(mod)
tensor_array = p.get_var('tensor_array', dtype)
mod["main"] = relay.Function([x], tensor_array(x))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(5)
got = vmobj_to_list(result)
expected = np.array([0, 0, 0, 0, 0])
tvm.testing.assert_allclose(expected, got)
run('float32')
run('int32')
def test_tensor_array_read():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
l = relay.var('l')
i = relay.var('i')
read_func = p.get_var('tensor_array_read', dtype)
tensor_array = p.get_var('tensor_array', dtype)
mod["main"] = relay.Function([l, i], read_func(tensor_array(l), i))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(10, 5)
got = vmobj_to_list(result)
expected = [0]
tvm.testing.assert_allclose(expected, got)
run('float32')
run('int32')
def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.Tensor):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.asnumpy()]
elif isinstance(o, tvm.relay.backend.vmobj.Datatype):
result = []
for f in o:
result.extend(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'Cons':
tl = vmobj_to_list(o.fields[1])
hd = vmobj_to_list(o.fields[0])
hd.extend(tl)
return hd
elif o.constructor.name_hint == 'Nil':
return []
elif 'tensor_nil' in o.constructor.name_hint:
return [0]
elif 'tensor' in o.constructor.name_hint:
return [o.fields[0].asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
else:
raise RuntimeError("Unknown object type: %s" % type(o))
def test_tensor_array_stack():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
tensor_array = p.get_var('tensor_array', dtype)
tensor1 = p.get_var('tensor1', dtype)
write = p.get_var('tensor_array_write', dtype)
stack = p.get_var('tensor_array_stack', dtype)
l = relay.var('l')
v = relay.var('v')
init_tensor_array = tensor_array(relay.const(3))
tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v))
tensor_array2 = write(tensor_array1, relay.const(1), tensor1(v))
tensor_array3 = write(tensor_array2, relay.const(2), tensor1(v))
tensor_array4 = stack(tensor_array3)
mod["main"] = relay.Function([v], tensor_array4)
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
t = np.random.uniform(size=(1,)).astype(dtype)
result = ex.evaluate()(t)
res = vmobj_to_list(result)
expected = [np.stack([t, t, t])]
tvm.testing.assert_allclose(expected, res)
run('float32')
run('int32')
def test_tensor_array_unstack():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
unstack_tensor1 = p.get_var('tensor_array_unstack_tensor1', dtype)
v = relay.var('v')
mod["main"] = relay.Function([v], unstack_tensor1(v))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
t = np.random.uniform(size=(1,)).astype(dtype)
result = ex.evaluate()(t)
res = vmobj_to_list(result)
tvm.testing.assert_allclose(t, res)
run('float32')
run('int32')
def test_tensor_take():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
take = p.get_var('tensor_take', dtype)
tensor2 = p.get_var('tensor2', dtype)
v = relay.var('v')
lower = relay.var('lower')
upper = relay.var('upper')
mod["main"] = relay.Function([v, lower, upper], take(tensor2(v), lower, upper))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
t = np.random.uniform(size=(10, 10)).astype(dtype)
result = ex.evaluate()(t, 2, 5)
res = vmobj_to_list(result)
expected = [np.take(t, range(2, 5), axis=0)]
tvm.testing.assert_allclose(expected, res)
run('float32')
run('int32')
if __name__ == "__main__":
test_nat_constructor()
......@@ -707,3 +849,9 @@ if __name__ == "__main__":
test_size()
test_compose()
test_iterate()
test_tensor_expand_dims()
test_tensor_array_constructor()
test_tensor_array_read()
test_tensor_array_stack()
test_tensor_array_unstack()
......@@ -38,7 +38,8 @@ def test_prelude():
Feature.fLet,
Feature.fIf,
Feature.fConstructor,
Feature.fMatch
Feature.fMatch,
Feature.fGraph
])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment