Commit 36a96773 by Wei Chen Committed by Zhi

[Relay][Frontend][TF] Add tensor array ops (#3798)

* [Relay][Frontend][TF] Add tensor array ops

* rename

* delete test

* Move utility function

* Refactor

* fix tensor array ops

* fix test

* fix rebase

* Fix serializer bug

* Improve tf convert name lookup to use prelude api

* Fix lint

* Fix test
parent 4052de6d
...@@ -22,10 +22,14 @@ from __future__ import print_function ...@@ -22,10 +22,14 @@ from __future__ import print_function
import warnings import warnings
from collections import defaultdict from collections import defaultdict
# Numpy support # Numpy support
import numpy as np import numpy as np
import tvm import tvm
from tvm.relay.prelude import Prelude
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
...@@ -508,6 +512,69 @@ def _pack(): ...@@ -508,6 +512,69 @@ def _pack():
return _op.concatenate(inputs_reshaped, axis) return _op.concatenate(inputs_reshaped, axis)
return _impl return _impl
def _tensor_array():
def _impl(inputs, attr, params, prelude):
dtype_str = attr.get('dtype').name
tensor_array_constructor = prelude.get_var('tensor_array', dtype_str)
return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0)))
return _impl
def _tensor_array_scatter():
def _impl(inputs, attr, params, prelude):
dtype_str = attr.get('T').name
values_rank = len(inputs[2].type_annotation.shape)
unstack_name = "tensor_array_unstack_tensor{}".format(values_rank)
unstack_function = prelude.get_var(unstack_name, dtype_str)
values = unstack_function(inputs[2])
tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str)
return tensor_array_scatter_func(inputs[0], inputs[1], values)
return _impl
def _tensor_array_gather():
def _impl(inputs, attr, params, prelude):
return prelude.tensor_array_gather(inputs[2], inputs[1])
return _impl
def _tensor_array_size():
def _impl(inputs, attr, params, prelude):
return prelude.length(inputs[0])
return _impl
def _tensor_array_write():
def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[2].type_annotation.shape)
dtype = attr.get('T').name
tensor_name = 'tensor{}'.format(input_rank)
tensor_func = prelude.get_var(tensor_name, dtype)
v = tensor_func(inputs[2])
write_func = prelude.get_var('tensor_array_write', dtype)
return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v)
return _impl
def _tensor_array_read():
def _impl(inputs, attr, params, prelude):
read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name)
return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0)))
return _impl
def _tensor_array_split():
def _impl(inputs, attr, params, prelude):
input_rank = len(inputs[1].type_annotation.shape)
dtype_str = attr.get('T').name
v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1])
lengths = _op.cast(inputs[2], 'int32')
split_var = prelude.get_var('tensor_array_split', dtype_str)
return split_var(inputs[0], v, lengths)
return _impl
def _tensor_array_concat():
def _impl(inputs, attr, params, prelude):
concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name)
return concat_func(inputs[1])
return _impl
def _tile(): def _tile():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
reps = _get_list_param(params, inputs.pop()) reps = _get_list_param(params, inputs.pop())
...@@ -1313,6 +1380,14 @@ _convert_map = { ...@@ -1313,6 +1380,14 @@ _convert_map = {
'NotEqual' : _broadcast('not_equal'), 'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(), 'OneHot' : _one_hot(),
'Pack' : _pack(), 'Pack' : _pack(),
'TensorArrayV3' : _tensor_array(),
'TensorArrayScatterV3' : _tensor_array_scatter(),
'TensorArrayGatherV3' : _tensor_array_gather(),
'TensorArraySizeV3' : _tensor_array_size(),
'TensorArrayWriteV3' : _tensor_array_write(),
'TensorArrayReadV3' : _tensor_array_read(),
'TensorArraySplitV3' : _tensor_array_split(),
'TensorArrayConcatV3' : _tensor_array_concat(),
'Pad' : _pad('Pad'), 'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'), 'PadV2' : _pad('PadV2'),
'Pow' : _elemwise('power'), 'Pow' : _elemwise('power'),
...@@ -1860,6 +1935,7 @@ class GraphProto(object): ...@@ -1860,6 +1935,7 @@ class GraphProto(object):
self._loops = {} self._loops = {}
self._branches = {} self._branches = {}
self._mod = _module.Module({}) self._mod = _module.Module({})
self._prelude = Prelude(self._mod)
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef. """Construct relay nodes from tensorflow graph definition - GraphDef.
...@@ -2335,7 +2411,11 @@ class GraphProto(object): ...@@ -2335,7 +2411,11 @@ class GraphProto(object):
if op_name in identity_list: if op_name in identity_list:
sym = get_relay_op(op_name)(*inputs, **attrs) sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map: elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params) if 'TensorArray' in op_name:
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
else:
sym = convert_map[op_name](inputs, attrs, self._params)
elif op_name in convert_map_rnn: elif op_name in convert_map_rnn:
sym = self._convert_rnn_operator(op_name, inputs, attrs, sym = self._convert_rnn_operator(op_name, inputs, attrs,
self._params, graph, self._params, graph,
......
...@@ -108,6 +108,29 @@ def clip_compute(attrs, inputs, output_type, target): ...@@ -108,6 +108,29 @@ def clip_compute(attrs, inputs, output_type, target):
register_schedule("clip", schedule_elemwise) register_schedule("clip", schedule_elemwise)
@script
def _cast_shape_function(x):
out_ndim = len(x)
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = x[i]
return out
def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]
@script
def _expand_dims_shape_func(x):
ndim = len(x.shape)
out = output_tensor((ndim+1,), "int64")
out[0] = int64(1)
for i in const_range(0, ndim):
out[i+1] = int64(x.shape[i])
return out
def expand_dims_shape_func(attrs, inputs, out_ndims):
return [_expand_dims_shape_func(*inputs)]
# shape func # shape func
@script @script
def _broadcast_shape_func(x, y, ndim): def _broadcast_shape_func(x, y, ndim):
...@@ -140,6 +163,9 @@ def _broadcast_shape_func(x, y, ndim): ...@@ -140,6 +163,9 @@ def _broadcast_shape_func(x, y, ndim):
def broadcast_shape_func(attrs, inputs, out_ndims): def broadcast_shape_func(attrs, inputs, out_ndims):
return [_broadcast_shape_func(*inputs, out_ndims[0])] return [_broadcast_shape_func(*inputs, out_ndims[0])]
register_shape_func("expand_dims", False, expand_dims_shape_func)
register_shape_func("cast", False, cast_shape_func)
register_shape_func("add", False, broadcast_shape_func) register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func) register_shape_func("subtract", False, broadcast_shape_func)
register_shape_func("multiply", False, broadcast_shape_func) register_shape_func("multiply", False, broadcast_shape_func)
......
...@@ -203,8 +203,12 @@ class PythonConverter(ExprFunctor): ...@@ -203,8 +203,12 @@ class PythonConverter(ExprFunctor):
for var, func in self.mod.functions.items(): for var, func in self.mod.functions.items():
# optimize the definition so any operators used are lowered # optimize the definition so any operators used are lowered
opt_func = self.optimize(func) opt_func = self.optimize(func)
converted_func, _ = self.convert_func_node(opt_func, var) try:
defs.append(converted_func) converted_func, _ = self.convert_func_node(opt_func, var)
defs.append(converted_func)
except TypeError:
# TODO(wweic): fix conversion for Any
pass
return defs return defs
......
...@@ -309,7 +309,9 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { ...@@ -309,7 +309,9 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
fields.push_back(instr.alloc_tensor_reg.shape_register); fields.push_back(instr.alloc_tensor_reg.shape_register);
// Save `DLDataType` and the dst register. // Save `DLDataType` and the dst register.
const auto& dtype = instr.alloc_tensor.dtype; const auto& dtype = instr.alloc_tensor.dtype;
fields.assign({dtype.code, dtype.bits, dtype.lanes}); fields.push_back(dtype.code);
fields.push_back(dtype.bits);
fields.push_back(dtype.lanes);
fields.push_back(instr.dst); fields.push_back(instr.dst);
break; break;
} }
......
...@@ -60,13 +60,19 @@ def vmobj_to_list(o): ...@@ -60,13 +60,19 @@ def vmobj_to_list(o):
result.append(vmobj_to_list(f)) result.append(vmobj_to_list(f))
return result return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'cons': if o.constructor.name_hint == 'Cons':
tl = vmobj_to_list(o.fields[1]) tl = vmobj_to_list(o.fields[1])
hd = vmobj_to_list(o.fields[0]) hd = vmobj_to_list(o.fields[0])
hd.extend(tl) hd.extend(tl)
return hd return hd
elif o.constructor.name_hint == 'nil': elif o.constructor.name_hint == 'Nil':
return [] return []
elif 'tensor_nil' in o.constructor.name_hint:
return [0]
elif 'tensor' in o.constructor.name_hint:
return [o.fields[0].asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy()] return [o.data.asnumpy()]
else: else:
...@@ -77,14 +83,11 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, ...@@ -77,14 +83,11 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
""" Generic function to compile on relay and execute on tvm """ """ Generic function to compile on relay and execute on tvm """
input_data = convert_to_list(input_data) input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node) input_node = convert_to_list(input_node)
layout = None layout = None
if target == "cuda": if target == "cuda":
layout = "NCHW" layout = "NCHW"
target_host = None target_host = None
shape_dict = {e: i.shape for e, i in zip(input_node, input_data)} shape_dict = {e: i.shape for e, i in zip(input_node, input_data)}
mod, params = relay.frontend.from_tensorflow(graph_def, mod, params = relay.frontend.from_tensorflow(graph_def,
layout=layout, layout=layout,
shape=shape_dict, shape=shape_dict,
...@@ -581,6 +584,111 @@ def test_forward_squeeze(): ...@@ -581,6 +584,111 @@ def test_forward_squeeze():
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5]) _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5])
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1]) _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1])
def test_tensor_array_constructor():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
ta2 = ta1.write(0, t)
ta3 = ta2.write(1, t2)
out = ta3.read(0)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
run('float32')
run('int32')
def test_tensor_array_scatter():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
indices = tf.constant([2, 1, 0])
ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=False, dynamic_size=False)
ta2 = ta1.scatter(indices, t)
out0 = ta2.read(0)
out1 = ta2.read(1)
out2 = ta2.read(2)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
run('float32')
run('int32')
# TODO(wweic): Fix gather issue with PartialEvaluate
# def test_tensor_array_gather():
# with tf.Graph().as_default():
# dtype = 'float32'
# t = tf.constant([[1.0], [2.0], [3.0]])
# scatter_indices = tf.constant([2, 1, 0])
# gather_indices = tf.constant([1, 2])
# ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False)
# ta2 = ta1.scatter(scatter_indices, t)
# t1 = ta2.gather(gather_indices)
# g = tf.get_default_graph()
# compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug')
def test_tensor_array_split():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False)
ta2 = ta1.split(t, split_length)
out0 = ta2.read(0)
out1 = ta2.read(1)
out2 = ta2.read(2)
out3 = ta2.read(3)
g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug')
run('float32')
run('int32')
def test_tensor_array_concat():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False)
ta2 = ta1.split(t, split_length)
t = ta2.concat()
compare_tf_with_tvm([], [], ['TensorArrayConcatV3:0'], mode='debug')
run('float32')
run('int32')
def test_tensor_array_size():
def run(dtype_str):
with tf.Graph().as_default():
dtype = {
'float32': tf.float32,
'int32' : tf.int32
}[dtype_str]
ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
out = ta1.size()
g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
run('float32')
run('int32')
####################################################################### #######################################################################
# ConcatV2 # ConcatV2
# -------- # --------
......
...@@ -21,6 +21,8 @@ from tvm.relay import create_executor ...@@ -21,6 +21,8 @@ from tvm.relay import create_executor
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr
import numpy as np
mod = relay.Module() mod = relay.Module()
p = Prelude(mod) p = Prelude(mod)
add_nat_definitions(p) add_nat_definitions(p)
...@@ -683,6 +685,146 @@ def test_iterate(): ...@@ -683,6 +685,146 @@ def test_iterate():
res = intrp.evaluate(relay.Function([], expr)()) res = intrp.evaluate(relay.Function([], expr)())
assert count(res) == 12 assert count(res) == 12
def test_tensor_expand_dims():
def run(dtype):
x = relay.var('x')
mod = relay.Module()
p = Prelude(mod)
expand_dims_func = p.get_var('tensor_expand_dims', dtype)
tensor1 = p.get_var('tensor1', dtype)
mod["main"] = relay.Function([x], expand_dims_func(tensor1(x)))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
x_np = np.random.uniform(size=(1,)).astype(dtype)
result = ex.evaluate()(x_np)
got = vmobj_to_list(result)
expected = [np.expand_dims(x_np, axis=0)]
tvm.testing.assert_allclose(expected, got)
run('float32')
run('int32')
def test_tensor_array_constructor():
def run(dtype):
x = relay.var('x')
mod = relay.Module()
p = Prelude(mod)
tensor_array = p.get_var('tensor_array', dtype)
mod["main"] = relay.Function([x], tensor_array(x))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(5)
got = vmobj_to_list(result)
expected = np.array([0, 0, 0, 0, 0])
tvm.testing.assert_allclose(expected, got)
run('float32')
run('int32')
def test_tensor_array_read():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
l = relay.var('l')
i = relay.var('i')
read_func = p.get_var('tensor_array_read', dtype)
tensor_array = p.get_var('tensor_array', dtype)
mod["main"] = relay.Function([l, i], read_func(tensor_array(l), i))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(10, 5)
got = vmobj_to_list(result)
expected = [0]
tvm.testing.assert_allclose(expected, got)
run('float32')
run('int32')
def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.Tensor):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.asnumpy()]
elif isinstance(o, tvm.relay.backend.vmobj.Datatype):
result = []
for f in o:
result.extend(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'Cons':
tl = vmobj_to_list(o.fields[1])
hd = vmobj_to_list(o.fields[0])
hd.extend(tl)
return hd
elif o.constructor.name_hint == 'Nil':
return []
elif 'tensor_nil' in o.constructor.name_hint:
return [0]
elif 'tensor' in o.constructor.name_hint:
return [o.fields[0].asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
else:
raise RuntimeError("Unknown object type: %s" % type(o))
def test_tensor_array_stack():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
tensor_array = p.get_var('tensor_array', dtype)
tensor1 = p.get_var('tensor1', dtype)
write = p.get_var('tensor_array_write', dtype)
stack = p.get_var('tensor_array_stack', dtype)
l = relay.var('l')
v = relay.var('v')
init_tensor_array = tensor_array(relay.const(3))
tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v))
tensor_array2 = write(tensor_array1, relay.const(1), tensor1(v))
tensor_array3 = write(tensor_array2, relay.const(2), tensor1(v))
tensor_array4 = stack(tensor_array3)
mod["main"] = relay.Function([v], tensor_array4)
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
t = np.random.uniform(size=(1,)).astype(dtype)
result = ex.evaluate()(t)
res = vmobj_to_list(result)
expected = [np.stack([t, t, t])]
tvm.testing.assert_allclose(expected, res)
run('float32')
run('int32')
def test_tensor_array_unstack():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
unstack_tensor1 = p.get_var('tensor_array_unstack_tensor1', dtype)
v = relay.var('v')
mod["main"] = relay.Function([v], unstack_tensor1(v))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
t = np.random.uniform(size=(1,)).astype(dtype)
result = ex.evaluate()(t)
res = vmobj_to_list(result)
tvm.testing.assert_allclose(t, res)
run('float32')
run('int32')
def test_tensor_take():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
take = p.get_var('tensor_take', dtype)
tensor2 = p.get_var('tensor2', dtype)
v = relay.var('v')
lower = relay.var('lower')
upper = relay.var('upper')
mod["main"] = relay.Function([v, lower, upper], take(tensor2(v), lower, upper))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
t = np.random.uniform(size=(10, 10)).astype(dtype)
result = ex.evaluate()(t, 2, 5)
res = vmobj_to_list(result)
expected = [np.take(t, range(2, 5), axis=0)]
tvm.testing.assert_allclose(expected, res)
run('float32')
run('int32')
if __name__ == "__main__": if __name__ == "__main__":
test_nat_constructor() test_nat_constructor()
...@@ -707,3 +849,9 @@ if __name__ == "__main__": ...@@ -707,3 +849,9 @@ if __name__ == "__main__":
test_size() test_size()
test_compose() test_compose()
test_iterate() test_iterate()
test_tensor_expand_dims()
test_tensor_array_constructor()
test_tensor_array_read()
test_tensor_array_stack()
test_tensor_array_unstack()
...@@ -38,7 +38,8 @@ def test_prelude(): ...@@ -38,7 +38,8 @@ def test_prelude():
Feature.fLet, Feature.fLet,
Feature.fIf, Feature.fIf,
Feature.fConstructor, Feature.fConstructor,
Feature.fMatch Feature.fMatch,
Feature.fGraph
]) ])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment