Commit 8ef22176 by Xingjian Shi Committed by Haichen Shen

[RELAY] [OP] [MXNet Frontend] Add sequence_mask (#3437)

* Add sequence_mask

use exactly the same arguments as mxnet

fix

* fix lint

* fix lint

* add mxnet conversion + relay

* update

* update doc

* fix pylint

* fix doc

* address comment

* try to address comments

* try to enable shape check for valid_length

* fix

* try to fix

* fix bug

* try to fix

* address comment

* address comment
parent be9275c9
...@@ -101,6 +101,7 @@ List of operators ...@@ -101,6 +101,7 @@ List of operators
topi.image.resize topi.image.resize
topi.argsort topi.argsort
topi.topk topi.topk
topi.sequence_mask
List of schedules List of schedules
...@@ -167,6 +168,7 @@ topi ...@@ -167,6 +168,7 @@ topi
.. autofunction:: topi.layout_transform .. autofunction:: topi.layout_transform
.. autofunction:: topi.argsort .. autofunction:: topi.argsort
.. autofunction:: topi.topk .. autofunction:: topi.topk
.. autofunction:: topi.sequence_mask
topi.nn topi.nn
~~~~~~~ ~~~~~~~
......
...@@ -190,6 +190,7 @@ This level support backpropagation of broadcast operators. It is temporary. ...@@ -190,6 +190,7 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.device_copy tvm.relay.device_copy
tvm.relay.annotation.on_device tvm.relay.annotation.on_device
tvm.relay.reverse_reshape tvm.relay.reverse_reshape
tvm.relay.sequence_mask
tvm.relay.nn.batch_matmul tvm.relay.nn.batch_matmul
tvm.relay.contrib.adaptive_max_pool2d tvm.relay.contrib.adaptive_max_pool2d
tvm.relay.contrib.adaptive_avg_pool2d tvm.relay.contrib.adaptive_avg_pool2d
...@@ -323,6 +324,7 @@ Level 10 Definitions ...@@ -323,6 +324,7 @@ Level 10 Definitions
.. autofunction:: tvm.relay.device_copy .. autofunction:: tvm.relay.device_copy
.. autofunction:: tvm.relay.annotation.on_device .. autofunction:: tvm.relay.annotation.on_device
.. autofunction:: tvm.relay.reverse_reshape .. autofunction:: tvm.relay.reverse_reshape
.. autofunction:: tvm.relay.sequence_mask
.. autofunction:: tvm.relay.nn.batch_matmul .. autofunction:: tvm.relay.nn.batch_matmul
.. autofunction:: tvm.relay.contrib.adaptive_max_pool2d .. autofunction:: tvm.relay.contrib.adaptive_max_pool2d
.. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d .. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d
...@@ -275,6 +275,18 @@ struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> { ...@@ -275,6 +275,18 @@ struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> {
} }
}; };
struct SequenceMaskAttrs : public tvm::AttrsNode<SequenceMaskAttrs> {
double mask_value;
int axis;
TVM_DECLARE_ATTRS(SequenceMaskAttrs, "relay.attrs.SequenceMaskAttrs") {
TVM_ATTR_FIELD(mask_value).set_default(0)
.describe("The masking value.");
TVM_ATTR_FIELD(axis).set_default(0)
.describe("The axis of the length dimension. Can only be 0 or 1.");
}
}; // struct SequenceMaskAttrs.
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_ #endif // TVM_RELAY_ATTRS_TRANSFORM_H_
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, import-self, len-as-condition # pylint: disable=invalid-name, import-self, len-as-condition, no-else-return
"""MXNet symbol frontend.""" """MXNet symbol frontend."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
...@@ -709,6 +709,18 @@ def _mx_topk(inputs, attrs): ...@@ -709,6 +709,18 @@ def _mx_topk(inputs, attrs):
return _op.topk(inputs[0], **new_attrs) return _op.topk(inputs[0], **new_attrs)
def _mx_SequenceMask(inputs, attrs):
assert len(inputs) == 1 or len(inputs) == 2
new_attrs = {}
use_sequence_length = attrs.get_bool('use_sequence_length', False)
new_attrs['mask_value'] = attrs.get_float('value', 0.0)
new_attrs['axis'] = attrs.get_int('axis', 0)
if use_sequence_length:
return _op.sequence_mask(*inputs, **new_attrs)
else:
return inputs[0]
def _mx_rnn_param_concat(inputs, _): def _mx_rnn_param_concat(inputs, _):
# We don't need to concatenate RNN params because we will unravel the RNN op # We don't need to concatenate RNN params because we will unravel the RNN op
return [inputs] return [inputs]
...@@ -994,6 +1006,7 @@ _convert_map = { ...@@ -994,6 +1006,7 @@ _convert_map = {
"Embedding" : _mx_embedding, "Embedding" : _mx_embedding,
"argsort" : _mx_argsort, "argsort" : _mx_argsort,
"topk" : _mx_topk, "topk" : _mx_topk,
"SequenceMask" : _mx_SequenceMask,
"SoftmaxOutput" : _mx_softmax_output, "SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation, "SoftmaxActivation" : _mx_softmax_activation,
"LinearRegressionOutput" : _mx_linear_regression_output, "LinearRegressionOutput" : _mx_linear_regression_output,
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from . import op as _reg from . import op as _reg
from ._reduce import _schedule_reduce from ._reduce import _schedule_reduce
from .op import schedule_injective, OpPattern from .op import OpPattern
schedule_injective = _reg.schedule_injective schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective
...@@ -50,6 +50,8 @@ _reg.register_schedule("stack", schedule_injective) ...@@ -50,6 +50,8 @@ _reg.register_schedule("stack", schedule_injective)
_reg.register_schedule("concatenate", schedule_concatenate) _reg.register_schedule("concatenate", schedule_concatenate)
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective) _reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
_reg.register_schedule("gather_nd", schedule_injective) _reg.register_schedule("gather_nd", schedule_injective)
_reg.register_schedule("sequence_mask", schedule_injective)
# layout_transform # layout_transform
_reg.register_schedule("layout_transform", schedule_injective) _reg.register_schedule("layout_transform", schedule_injective)
......
...@@ -678,3 +678,49 @@ def gather_nd(data, indices): ...@@ -678,3 +678,49 @@ def gather_nd(data, indices):
relay.gather_nd(data, indices) = [[3, 4], [5, 6]] relay.gather_nd(data, indices) = [[3, 4], [5, 6]]
""" """
return _make.gather_nd(data, indices) return _make.gather_nd(data, indices)
def sequence_mask(data, valid_length, mask_value=0, axis=0):
"""Sets all elements outside the expected length of the sequence to a constant value.
This function takes an n-dimensional input array of the form [MAX_LENGTH, batch_size, ...] or
[batch_size, MAX_LENGTH, ...] and returns an array of the same shape.
Parameters
----------
data : relay.Expr
The input data.
valid_length : relay.Expr
The expected (valid) length of each sequence in the tensor.
mask_value : float
The masking value.
axis : int
The axis of the length dimension.
Returns
-------
ret : relay.Expr
The computed result.
Examples
--------
.. code-block:: python
x = [[[ 1., 2., 3.], [ 4., 5., 6.]],
[[ 7., 8., 9.], [ 10., 11., 12.]],
[[ 13., 14., 15.], [ 16., 17., 18.]]]
relay.sequence_mask(x, valid_length=[1, 1]) =
[[[ 1., 2., 3.], [ 4., 5., 6.]],
[[ 0., 0., 0.], [ 0., 0., 0.]],
[[ 0., 0., 0.], [ 0., 0., 0.]]]
relay.sequence_mask(x, valid_length=[2, 3], mask_value=0.1) =
[[[ 1., 2., 3.], [ 4., 5., 6.]],
[[ 7., 8., 9.], [ 10., 11., 12.]],
[[ 0.1, 0.1, 0.1], [ 16., 17., 18.]]]
"""
return _make.sequence_mask(data, valid_length, mask_value, axis)
...@@ -805,7 +805,7 @@ Examples:: ...@@ -805,7 +805,7 @@ Examples::
.set_num_inputs(2) .set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.") .add_argument("indices", "Tensor", "The indices tensor.")
.set_support_level(2) .set_support_level(3)
.add_type_rel("Take", TakeRel) .add_type_rel("Take", TakeRel)
.set_attr<FTVMCompute>("FTVMCompute", TakeCompute) .set_attr<FTVMCompute>("FTVMCompute", TakeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective);
...@@ -2218,5 +2218,108 @@ output shape will simply be (Y_0, ..., Y_{K-1}). ...@@ -2218,5 +2218,108 @@ output shape will simply be (Y_0, ..., Y_{K-1}).
.set_attr<FTVMCompute>("FTVMCompute", GatherNDCompute) .set_attr<FTVMCompute>("FTVMCompute", GatherNDCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective);
// relay.sequence_mask
TVM_REGISTER_NODE_TYPE(SequenceMaskAttrs);
bool SequenceMaskRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, valid_length, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* valid_length = types[1].as<TensorTypeNode>();
CHECK(data);
CHECK(valid_length);
const auto param = attrs.as<SequenceMaskAttrs>();
Array<IndexExpr> valid_length_shape;
CHECK(param->axis == 0 || param->axis == 1);
valid_length_shape.push_back(data->shape[1 - param->axis]);
reporter->Assign(types[1], TensorTypeNode::make(valid_length_shape, valid_length->dtype));
reporter->Assign(types[2], types[0]);
return true;
}
Array<Tensor> SequenceMaskCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<SequenceMaskAttrs>();
CHECK(param != nullptr);
return Array<Tensor>{ topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis) };
}
Expr MakeSequenceMask(Expr data,
Expr valid_length,
double mask_value,
int axis) {
auto attrs = make_node<SequenceMaskAttrs>();
attrs->mask_value = std::move(mask_value);
attrs->axis = std::move(axis);
static const Op& op = Op::Get("sequence_mask");
return CallNode::make(op, {data, valid_length}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.sequence_mask")
.set_body_typed(MakeSequenceMask);
RELAY_REGISTER_OP("sequence_mask")
.describe(R"code(Sets all elements outside the expected length of the sequence to a constant value.
This function takes an n-dimensional input array of the form [MAX_LENGTH, batch_size, ...] or
[batch_size, MAX_LENGTH, ...] and returns an array of the same shape.
`axis` means the axis of the length dimension and can only be 0 or 1. If axis is 0,
the data must have shape [MAX_LENGTH, batch_size, ...]. Otherwise (axis=1), the data must have
shape [batch_size, MAX_LENGTH, ...].
`valid_length` gives the length of each sequence. `valid_length` should be
a 1D int array with positive ints and has dimension [batch_size,].
Examples::
x = [[[ 1., 2., 3.],
[ 4., 5., 6.]],
[[ 7., 8., 9.],
[ 10., 11., 12.]],
[[ 13., 14., 15.],
[ 16., 17., 18.]]]
// valid_length [1, 1] means only the first block of each batch will be kept
// and other blocks are masked with default mask value = 0
sequence_mask(x, valid_length=[1, 1]) =
[[[ 1., 2., 3.],
[ 4., 5., 6.]],
[[ 0., 0., 0.],
[ 0., 0., 0.]],
[[ 0., 0., 0.],
[ 0., 0., 0.]]]
// valid_length [2, 3] means the first 2 blocks of the 1st batch will be kept
// and the first 3 blocks of the 2nd batch will be kept
// the masked values are set to be the specified mask value = 0.1
sequence_mask(x, valid_length=[2, 3], mask_value=0.1) =
[[[ 1., 2., 3.],
[ 4., 5., 6.]],
[[ 7., 8., 9.],
[ 10., 11., 12.]],
[[ 0.1, 0.1, 0.1],
[ 16., 17., 18.]]]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SequenceMaskAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("valid_length", "Tensor", "The real (valid) length of each sequence.")
.set_support_level(10)
.add_type_rel("SequenceMask", SequenceMaskRel)
.set_attr<FTVMCompute>("FTVMCompute", SequenceMaskCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -666,6 +666,51 @@ def test_forward_topk(): ...@@ -666,6 +666,51 @@ def test_forward_topk():
verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True) verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True)
verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32") verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32")
def test_forward_sequence_mask():
def verify(shape, use_sequence_length, value, axis, dtype, itype):
data_np = np.random.uniform(size=shape).astype(dtype)
valid_length_np = np.random.randint(0, shape[axis], size=shape[1-axis]).astype(itype)
if use_sequence_length:
ref_res = mx.nd.SequenceMask(mx.nd.array(data_np, dtype=dtype),
sequence_length=mx.nd.array(valid_length_np, dtype=itype),
use_sequence_length=use_sequence_length,
value=value,
axis=axis)
mx_sym = mx.sym.SequenceMask(mx.sym.var('data'),
sequence_length=mx.sym.var('valid_length'),
use_sequence_length=use_sequence_length,
value=value,
axis=axis)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape,
'valid_length': valid_length_np.shape},
dtype={"data": dtype,
"valid_length": itype})
else:
ref_res = mx.nd.SequenceMask(mx.nd.array(data_np, dtype=dtype),
use_sequence_length=use_sequence_length,
value=value,
axis=axis)
mx_sym = mx.sym.SequenceMask(mx.sym.var('data'),
use_sequence_length=use_sequence_length,
value=value,
axis=axis)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape}, dtype={"data": dtype})
for target, ctx in ctx_list():
for kind in ['graph', 'debug']:
if use_sequence_length is False and kind == 'graph':
# Disable the test for 'graph' when it's identity.
continue
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
if use_sequence_length:
op_res = intrp.evaluate()(data_np, valid_length_np)
else:
op_res = intrp.evaluate()(data_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((5, 10), True, 0.0, 0, 'float32', 'float32')
verify((5, 4, 3), True, 1.0, 1, 'float32', 'float32')
verify((5, 4, 3), False, 1.0, 1, 'float64', 'float64')
verify((5, 4, 3, 2), True, 1.0, 0, 'float32', 'float32')
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
...@@ -710,3 +755,4 @@ if __name__ == '__main__': ...@@ -710,3 +755,4 @@ if __name__ == '__main__':
test_forward_Crop() test_forward_Crop()
test_forward_argsort() test_forward_argsort()
test_forward_topk() test_forward_topk()
test_forward_sequence_mask()
...@@ -249,6 +249,27 @@ def test_adaptive_pool2d(): ...@@ -249,6 +249,27 @@ def test_adaptive_pool2d():
verify_adaptive_pool2d((1, 14, 56, 78), (34, 13), "max") verify_adaptive_pool2d((1, 14, 56, 78), (34, 13), "max")
verify_adaptive_pool2d((1, 5, 46, 97), (4, 96), "avg") verify_adaptive_pool2d((1, 5, 46, 97), (4, 96), "avg")
def test_sequence_mask():
def _verify(data_shape, mask_value, axis, dtype, itype):
max_length = data_shape[axis]
nbatch = data_shape[1 - axis]
data = relay.var("data", relay.TensorType(data_shape, dtype))
valid_length = relay.var("valid_length", relay.TensorType((nbatch,), itype))
out = relay.sequence_mask(data, valid_length, mask_value, axis)
assert relay.ir_pass.infer_type(out).checked_type == relay.ty.TensorType(data_shape, dtype)
func = relay.Function([data, valid_length], out)
data_np = np.random.uniform(size=data_shape).astype(dtype)
valid_length_np = np.random.randint(0, max_length, size=nbatch).astype(itype)
gt_out_np = topi.testing.sequence_mask(data_np, valid_length_np, mask_value, axis)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
out_relay = intrp.evaluate(func)(data_np, valid_length_np)
tvm.testing.assert_allclose(out_relay.asnumpy(), gt_out_np)
_verify((5, 10), 0.0, 1, 'float32', 'int32')
_verify((2, 3, 5, 3), 0.0, 0, 'float32', 'int64')
_verify((5, 8, 3), 0.1, 1, 'float64', 'float32')
if __name__ == "__main__": if __name__ == "__main__":
test_adaptive_pool2d() test_adaptive_pool2d()
...@@ -258,3 +279,4 @@ if __name__ == "__main__": ...@@ -258,3 +279,4 @@ if __name__ == "__main__":
test_reverse_reshape() test_reverse_reshape()
test_batch_matmul() test_batch_matmul()
test_shape_of() test_shape_of()
test_sequence_mask()
...@@ -657,6 +657,43 @@ inline Tensor take(const Tensor& a, ...@@ -657,6 +657,43 @@ inline Tensor take(const Tensor& a,
} }
} }
/*!
* \brief Mask the out-of-boundary elements of each sequence.
*
* \param data The source array.
* \param valid_length The real length of each sequence.
* \param mask_value The masking value.
* \param axis The axis of the temporal dimension of the sequence
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the sequence_mask operation
*/
inline Tensor sequence_mask(const Tensor& data,
const Tensor& valid_length,
double mask_value,
int axis,
std::string name = "T_sequence_mask",
std::string tag = kInjective) {
CHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1";
CHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,).";
auto length_dim = data->shape[axis];
auto batch_dim = data->shape[1 - axis];
Array<Expr> out_shape = data->shape;
Tensor out = compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> len_index;
auto tid = out_index[axis];
auto bid = out_index[1 - axis];
len_index.push_back(bid);
Expr ret = tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
tvm::cast(data->dtype, Expr(mask_value)), data(out_index));
return ret;
}, name, tag);
return out;
}
/*! /*!
* \brief Take elements from an array along an axis. * \brief Take elements from an array along an axis.
* *
......
...@@ -23,3 +23,4 @@ from .gather_nd_python import gather_nd_python ...@@ -23,3 +23,4 @@ from .gather_nd_python import gather_nd_python
from .strided_slice_python import strided_slice_python from .strided_slice_python import strided_slice_python
from .batch_matmul import batch_matmul from .batch_matmul import batch_matmul
from .slice_axis_python import slice_axis_python from .slice_axis_python import slice_axis_python
from .sequence_mask_python import sequence_mask
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Sequence mask in python"""
import numpy as np
def sequence_mask(data, valid_length, mask_value, axis):
"""batch_matmul operator implemented in numpy.
Parameters
----------
data : numpy.ndarray
N-D with shape [batch_size, MAX_LENGTH, ...] or [MAX_LENGTH, batch_size, ...]
valid_length : numpy.ndarray
1-D with shape [batch_size,]
mask_value : float
Masking value
axis : int
The axis of the length dimension
Returns
-------
out : numpy.ndarray
N-D with shape same as data
"""
in_shape = data.shape
max_length = data.shape[axis]
val_len_expand_shape = [1 for _ in range(len(in_shape))]
val_len_expand_shape[1 - axis] = in_shape[1 - axis]
seq_len_expand_shape = [1 for _ in range(len(in_shape))]
seq_len_expand_shape[axis] = in_shape[axis]
mask = np.broadcast_to(np.arange(max_length).reshape(seq_len_expand_shape),
in_shape) >= valid_length.reshape(val_len_expand_shape)
out = data * (1 - mask) + mask_value * mask
return out
...@@ -436,3 +436,44 @@ def shape(array, dtype="int32"): ...@@ -436,3 +436,44 @@ def shape(array, dtype="int32"):
The resulting tensor. The resulting tensor.
""" """
return cpp.shape(array, dtype) return cpp.shape(array, dtype)
def sequence_mask(data, valid_length, mask_value=0, axis=0):
"""Sets all elements outside the expected length of the sequence to a constant value.
This function takes an n-dimensional input array of the form [MAX_LENGTH, batch_size, ...] or
[batch_size, MAX_LENGTH, ...] and returns an array of the same shape.
`axis` means the axis of the length dimension and can only be 0 or 1. If `axis` is 0,
the data must have shape [MAX_LENGTH, batch_size, ...]. Otherwise (axis=1), the data must have
shape [batch_size, MAX_LENGTH, ...].
`valid_length` gives the length of each sequence. `valid_length` should be
a 1D int array with positive ints and has dimension [batch_size,].
Parameters
----------
data : tvm.Tensor
N-D with shape [MAX_LENGTH, batch_size, ...] or [batch_size, MAX_LENGTH, ...]
depending on the value of `axis`.
valid_length : tvm.Tensor
1-D with shape [batch_size,]
mask_value : float, optional
The masking value, default 0
axis : int, optional
axis of the length dimension, must be 0 or 1, default 0
Returns
-------
output : tvm.Tensor
N-D with shape [MAX_LENGTH, batch_size, ...] or [batch_size, MAX_LENGTH, ...]
depending on the value of `axis`.
"""
assert len(data.shape) >= 2,\
"only support data.ndim >= 2, received data.shape = {}".format(data.shape)
assert axis == 0 or axis == 1, "only support axis = 0, 1, received axis = {}".format(axis)
return cpp.sequence_mask(data, valid_length, mask_value, axis)
...@@ -337,6 +337,14 @@ TVM_REGISTER_GLOBAL("topi.take") ...@@ -337,6 +337,14 @@ TVM_REGISTER_GLOBAL("topi.take")
} }
}); });
TVM_REGISTER_GLOBAL("topi.sequence_mask")
.set_body([](TVMArgs args, TVMRetValue *rv) {
double pad_val = args[2];
int axis = args[3];
*rv = sequence_mask(args[0], args[1], pad_val, axis);
});
TVM_REGISTER_GLOBAL("topi.where") TVM_REGISTER_GLOBAL("topi.where")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = where(args[0], args[1], args[2]); *rv = where(args[0], args[1], args[2]);
......
...@@ -619,6 +619,36 @@ def test_shape(): ...@@ -619,6 +619,36 @@ def test_shape():
check_device(backend) check_device(backend)
def test_sequence_mask():
for in_shape in (5, 10), (3, 4, 5, 4):
for axis in [0, 1]:
for mask_value in [0.0, 1.0]:
max_length = in_shape[axis]
batch_size = in_shape[1 - axis]
A = tvm.placeholder(shape=in_shape, dtype="float32", name="A")
B = tvm.placeholder(shape=(batch_size,), dtype="int32", name="B")
C = topi.sequence_mask(A, B, axis=axis, mask_value=mask_value)
A_data = np.random.normal(0, 1, in_shape).astype(np.float32)
B_data = np.random.randint(1, max_length, (batch_size,)).astype(np.int32)
C_gt_data = topi.testing.sequence_mask(A_data, B_data, mask_value, axis)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
tvm_A = tvm.nd.array(A_data, ctx)
tvm_B = tvm.nd.array(B_data, ctx)
tvm_C = tvm.nd.empty(in_shape, ctx=ctx, dtype="float32")
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(C)
f = tvm.build(s, [A, B, C], device, name="SequenceMask")
f(tvm_A, tvm_B, tvm_C)
tvm.testing.assert_allclose(tvm_C.asnumpy(), C_gt_data)
for backend in get_all_backend():
check_device(backend)
if __name__ == "__main__": if __name__ == "__main__":
test_strided_slice() test_strided_slice()
test_concatenate() test_concatenate()
...@@ -637,3 +667,4 @@ if __name__ == "__main__": ...@@ -637,3 +667,4 @@ if __name__ == "__main__":
test_repeat() test_repeat()
test_tile() test_tile()
test_shape() test_shape()
test_sequence_mask()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment