Unverified Commit fdc8b0dd by Mahesh Ambule Committed by GitHub

[Relay, Topi] [TF, MXNet] Unravel Index operator (#5082)

* first cut unravel_index

* merge fixes

* change rates to dilations

* unravel_index op relay, topi, mxnet, tf

* doc changes

* small changes

* remove empty unravel and argwhere attrs

* remove empty unravel and argwhere attrs
parent 50b5adaa
......@@ -47,6 +47,7 @@ List of operators
topi.strided_slice
topi.expand_dims
topi.reshape
topi.unravel_index
topi.squeeze
topi.concatenate
topi.split
......@@ -147,6 +148,7 @@ topi
.. autofunction:: topi.strided_slice
.. autofunction:: topi.expand_dims
.. autofunction:: topi.reshape
.. autofunction:: topi.unravel_index
.. autofunction:: topi.squeeze
.. autofunction:: topi.concatenate
.. autofunction:: topi.split
......
......@@ -242,5 +242,6 @@ Supported Ops
- Transpose
- TruncateMod
- Unpack
- UnravelIndex
- Where
- ZerosLike
......@@ -124,6 +124,7 @@ This level enables additional math and transform operators.
tvm.relay.repeat
tvm.relay.tile
tvm.relay.reverse
tvm.relay.unravel_index
**Level 4: Broadcast and Reductions**
......@@ -217,4 +218,4 @@ This level supports dialect operators.
:nosignatures:
tvm.relay.qnn.op.requantize
tvm.relay.qnn.op.conv2d
tvm.relay.qnn.op.conv2d
\ No newline at end of file
......@@ -315,12 +315,6 @@ struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
}
}; // struct OneHotAttrs
/*! \brief Attributes for ArgWhere operator */
struct ArgWhereAttrs : public tvm::AttrsNode<ArgWhereAttrs> {
TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") {
}
}; // struct ArgWhereAttrs
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
......@@ -120,6 +120,13 @@ def _mx_compare(new_op, wrapper):
return impl
def _mx_unravel_index(inputs, attrs):
assert len(inputs) == 1
shape = attrs.get_int_tuple("shape")
shape_expr = _expr.const(list(shape))
return _op.unravel_index(inputs[0], shape_expr)
def _mx_zeros(inputs, attrs):
assert len(inputs) == 0
shape = attrs.get_int_tuple("shape")
......@@ -1826,6 +1833,7 @@ _convert_map = {
"Embedding" : _mx_embedding,
"argsort" : _mx_argsort,
"topk" : _mx_topk,
"_unravel_index": _mx_unravel_index,
"SequenceMask" : _mx_sequence_mask,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
......
......@@ -627,6 +627,11 @@ def _decode_image():
return inputs[0]
return _impl
def _unravel_index():
def _impl(inputs, attr, params):
return _op.unravel_index(inputs[0], inputs[1])
return _impl
def _crop_and_resize():
def _impl(inputs, attr, params):
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
......@@ -1744,6 +1749,7 @@ _convert_map = {
'Transpose' : _transpose(),
'TruncateMod' : _elemwise('mod'),
'Unpack' : _unpack(),
'UnravelIndex' : _unravel_index(),
'Where' : _where(),
'ZerosLike' : AttrCvt('zeros_like'),
......@@ -2517,9 +2523,7 @@ class GraphProto(object):
array_ndim = len(np_array.shape)
if array_ndim == 0:
new_array = np.empty([1], dtype=np_array.dtype)
new_array[0] = np_array
self._nodes[name] = [tvm.relay.const(new_array)]
self._nodes[name] = [tvm.relay.const(np_array)]
else:
self._params[name] = tvm.nd.array(np_array)
self._nodes[name] = [_expr.var(name,
......
......@@ -54,6 +54,7 @@ _reg.register_injective_schedule("gather_nd")
_reg.register_injective_schedule("sequence_mask")
_reg.register_injective_schedule("one_hot")
_reg.register_reduce_schedule("collapse_sum_like")
_reg.register_injective_schedule("unravel_index")
# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
......
......@@ -861,3 +861,26 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
[0, 0, 1]]
"""
return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)
def unravel_index(indices, shape):
"""Convert a flat index or array of flat indices into a tuple of coordinate arrays.
Example::
- unravel_index([22, 41, 37], [7, 6]) = [[3, 6, 6],[4, 5, 1]]
Parameters
----------
indices : relay.Expr
An integer array containing indices.
shape : relay.Expr
The shape of the array.
Returns
-------
result : relay.Expr
The tuple of coordinate arrays.
"""
return _make.unravel_index(indices, shape)
......@@ -806,15 +806,13 @@ bool ArgWhereRel(const Array<Type>& types,
TVM_REGISTER_GLOBAL("relay.op._make.argwhere")
.set_body_typed([](Expr data) {
static const Op& op = Op::Get("argwhere");
auto attrs = make_object<ArgWhereAttrs>();
return CallNode::make(op, {data}, Attrs(attrs), {});
return CallNode::make(op, {data}, Attrs(), {});
});
RELAY_REGISTER_OP("argwhere")
.describe(R"doc(Find the indices of elements of a tensor that are
non-zero)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type<ArgWhereAttrs>()
.add_argument("condition", "Tensor", "The input condition tensor.")
.add_type_rel("ArgWhere", ArgWhereRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
......@@ -2662,5 +2660,73 @@ RELAY_REGISTER_OP("one_hot")
.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
/* relay.unravel_index */
bool UnRavelIndexRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* indices = types[0].as<TensorTypeNode>();
if (indices == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "unravel_index: expect input type to be TensorType but get "
<< types[0];
return false;
}
CHECK(indices->dtype.is_int())
<< "indices of unravel_index must be tensor of integer";
const auto* shape = types[1].as<TensorTypeNode>();
if (shape == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "unravel_index: expect input type to be TensorType but get "
<< types[1];
return false;
}
CHECK(indices->dtype.is_int())
<< "shape of unravel_index must be tensor of integer";
Array<IndexExpr> indices_shape;
Array<IndexExpr> shape_shape;
indices_shape = indices->shape;
shape_shape = shape->shape;
Array<IndexExpr> oshape;
oshape.push_back(shape_shape[0]);
if (indices_shape.size() != 0) {
oshape.push_back(indices_shape[0]);
}
reporter->Assign(types[2], TensorType(oshape, indices->dtype));
return true;
}
Array<te::Tensor> UnRavelIndexCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type) {
return Array<te::Tensor>{topi::unravel_index(inputs[0], inputs[1])};
}
Expr MakeUnRavelIndex(Expr data,
Expr shape) {
static const Op& op = Op::Get("unravel_index");
return CallNode::make(op, {data, shape}, Attrs(), {});
}
TVM_REGISTER_GLOBAL("relay.op._make.unravel_index")
.set_body_typed(MakeUnRavelIndex);
RELAY_REGISTER_OP("unravel_index")
.describe(R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays.
Example::
- unravel_index([22, 41, 37], (7, 6)) = [[3, 6, 6], [4, 5, 1]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.set_support_level(3)
.add_type_rel("UnRavelIndexRel", UnRavelIndexRel)
.set_attr<FTVMCompute>("FTVMCompute", UnRavelIndexCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay
} // namespace tvm
......@@ -949,6 +949,32 @@ def test_forward_cond():
verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))
def test_forward_unravel_index():
def verify(x, shape, dtype):
a_np = np.array(x).astype(dtype)
mx_sym = _mx_symbol(mx.sym, 'unravel_index', [mx.sym.var('a'), shape])
ref_res = _mx_symbol(mx.nd, 'unravel_index', [mx.nd.array(a_np), shape])
shapes = {'a': a_np.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
for target, ctx in ctx_list():
for kind in ["graph", "vm", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(a_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
for dtype in ["int32", "int64"]:
verify([0, 1, 2, 3], [2, 2], dtype)
verify([144, 13, 45], [6, 7, 10, 2], dtype)
verify([456], [6, 7, 10, 2], dtype)
# In below example, 5 is out of bound for array of size 4.
# MXNet implementation provides different result than TVM
# TVM implementation is inline with Tensorflow
# Ideally error should be thrown just like Numpy
# verify([0, 1, 2, 5], [2, 2], dtype)
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
......@@ -1003,4 +1029,5 @@ if __name__ == '__main__':
test_forward_convolution()
test_forward_deconvolution()
test_forward_cond()
test_forward_make_loss()
\ No newline at end of file
test_forward_make_loss()
test_forward_unravel_index()
\ No newline at end of file
......@@ -3057,6 +3057,57 @@ def test_forward_add_n():
_test_forward_add_n(in5)
#######################################################################
# Unravel Index
# ----------------------
def _test_forward_unravel_index(inputs):
tf.reset_default_graph()
with tf.Graph().as_default():
temp = []
for each in inputs:
temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype))
output = tf.unravel_index(temp[0], temp[1])
compare_tf_with_tvm([each for each in inputs], [
each.name for each in temp], output.name)
def _test_forward_unravel_index_scalar(x, y, dtype="int32"):
tf.reset_default_graph()
with tf.Graph().as_default():
indices_1 = constant_op.constant(x, dtype=dtype)
dims_1 = constant_op.constant(y, dtype=dtype)
out_1 = array_ops.unravel_index(indices_1, dims_1)
compare_tf_with_tvm([], [], out_1.name)
def test_forward_unravel_index():
x = np.array([0, 1, 2, 3])
y = np.array([2, 2])
_test_forward_unravel_index([x, y])
x = np.array([0, 1, 2, 5])
y = np.array([2, 2])
_test_forward_unravel_index([x, y])
x = np.array([0, 1, 2, 5])
y = np.array([2])
_test_forward_unravel_index([x, y])
x = np.array([102, 300, 16])
y = np.array([10, 10, 9, 6])
_test_forward_unravel_index([x, y])
x = np.array([100])
y = np.array([10, 10, 9, 6])
_test_forward_unravel_index([x, y])
# Test scalar input
_test_forward_unravel_index_scalar(13, [1, 4, 5, 2])
#######################################################################
# Dilation2d
# ----------------------
def _test_dilation2d(tensor_in_sizes, filter_in_sizes,
strides, dilations, padding):
""" One iteration of dilation2d with given shapes and attributes """
......@@ -3173,6 +3224,7 @@ if __name__ == '__main__':
test_forward_squared_difference()
test_forward_add_n()
test_forward_floormod()
test_forward_unravel_index()
# Reductions
test_forward_argminmax()
......
......@@ -683,6 +683,44 @@ def test_gather_nd():
verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])
def test_unravel_index():
def verify_unravel_index(indices, shape, dtype):
x_data = np.array(indices).astype(dtype)
y_data = np.array(shape).astype(dtype)
x = relay.var("x", relay.TensorType(x_data.shape, dtype))
y = relay.var("y", relay.TensorType(y_data.shape, dtype))
z = relay.unravel_index(x, y)
zz = run_infer_type(z)
if len(x_data.shape) == 1:
out_shape = [y_data.shape[0], x_data.shape[0]]
else:
out_shape = [y_data.shape[0]]
assert zz.checked_type == relay.ty.TensorType(out_shape, dtype)
func = relay.Function([x, y], z)
ref_res = np.unravel_index(x_data, y_data)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
for dtype in ["int64", "int32"]:
verify_unravel_index([0, 1, 2, 3], [2, 2], dtype)
verify_unravel_index([144], [5, 5, 5, 2], dtype)
verify_unravel_index(144, [5, 5, 5, 2], dtype)
verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype)
# In below example, 5 is out of bound for array of size 4.
# Numpy implementation throws error for it
# TVM implementation does not throw error instead it produces
# output which is inline with Tensorflow
# verify_unravel_index([0, 1, 2, 5], [2, 2], dtype)
if __name__ == "__main__":
test_arange()
test_cast()
......@@ -713,3 +751,4 @@ if __name__ == "__main__":
test_tile()
test_repeat()
test_gather_nd()
test_unravel_index()
......@@ -233,6 +233,54 @@ inline Tensor reshape(const Tensor& x,
}
/*!
* \brief Converts a flat index or array of flat indices into a tuple of coordinate arrays
*
* \param x The input tensor having indices.
* \param shape The shape tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor of coordinate arrays.
*/
inline Tensor unravel_index(const Tensor& x,
const Tensor& shape,
std::string name = "T_unravel",
std::string tag = kInjective) {
auto x_shape = x->shape;
auto shape_shape = shape->shape;
Array<PrimExpr> oshape;
oshape.push_back(shape_shape[0]);
if (x_shape.size() != 0) {
oshape.push_back(x_shape[0]);
}
auto func = [&](const Array<Var>& indices) {
auto i = indices[0];
std::vector<PrimExpr> indices_divs;
PrimExpr ret = 0;
PrimExpr cur_val = 0;
PrimExpr index_val = 0;
if (x_shape.size() != 0) {
index_val = x[indices[1]];
} else {
index_val = x();
}
indices_divs.push_back(index_val);
for (int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) {
ret = tvm::if_then_else(i == v, indexmod(indices_divs.back(), shape[v]), ret);
cur_val = indexdiv(indices_divs.back(), shape[v]);
indices_divs.push_back(cur_val);
}
return ret;
};
return compute(oshape, func, name, tag);
}
/*!
* \brief Remove size 1 dimensions from the shape of a tensor.
* The removed dimensions must have a constant size of 1.
*
......
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,consider-using-enumerate
# pylint: disable=invalid-name,consider-using-enumerate,redefined-outer-name
"""Injective transformation operators"""
from __future__ import absolute_import as _abs
import tvm
......@@ -653,3 +653,26 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
[0, 0, 1]]
"""
return cpp.one_hot(indices, on_value, off_value, depth, axis, dtype)
def unravel_index(indices, shape):
"""Convert a flat index or array of flat indices into a tuple of coordinate arrays.
Example::
- unravel_index([22, 41, 37], [7, 6]) = [[3, 6, 6], [4, 5, 1]]
Parameters
----------
indices : relay.Expr
An integer array containing indices.
shape : relay.Expr
The shape of the array.
Returns
-------
result : relay.Expr
The tuple of coordinate arrays.
"""
return cpp.unravel_index(indices, shape)
......@@ -435,6 +435,11 @@ TVM_REGISTER_GLOBAL("topi.gather_nd")
*rv = gather_nd(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.unravel_index")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = unravel_index(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.matmul")
.set_body([](TVMArgs args, TVMRetValue *rv) {
switch ( args.size() ) {
......
......@@ -562,6 +562,40 @@ def verify_one_hot(indices_shape, depth, on_value, off_value, axis, dtype):
for device in get_all_backend():
check_device(device)
def verify_unravel_index(indices, shape, dtype):
x_data = np.array(indices).astype(dtype)
y_data = np.array(shape).astype(dtype)
if len(x_data.shape) == 1:
dst_shape = [y_data.shape[0], x_data.shape[0]]
else:
dst_shape = [y_data.shape[0]]
X = te.placeholder(shape=x_data.shape, dtype=dtype, name="X")
Y = te.placeholder(shape=y_data.shape, dtype=dtype, name="Y")
Z = topi.unravel_index(X, Y)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.testing.get_injective_schedule(device)(Z)
foo = tvm.build(s, [X, Y, Z], device, name="unravel_index")
out_npy = np.unravel_index(x_data, y_data)
datax_nd = tvm.nd.array(x_data, ctx)
datay_nd = tvm.nd.array(y_data, ctx)
out_nd = tvm.nd.empty(dst_shape, ctx=ctx, dtype=Z.dtype)
foo(datax_nd, datay_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in get_all_backend():
check_device(device)
def test_strided_slice():
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
......@@ -882,6 +916,15 @@ def test_one_hot():
verify_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
verify_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
def test_unravel_index():
for dtype in ["int32", "int64"]:
verify_unravel_index([0, 1, 2, 3], [2, 2], dtype)
verify_unravel_index([144], [5, 5, 5, 2], dtype)
verify_unravel_index(144, [5, 5, 5, 2], dtype)
verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype)
if __name__ == "__main__":
test_strided_slice()
test_concatenate()
......@@ -905,3 +948,4 @@ if __name__ == "__main__":
test_ndarray_size()
test_where_fusion()
test_one_hot()
test_unravel_index()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment