Commit 554df211 by Jon Soifer Committed by Haichen Shen

[TOPI][Relay][TensorFlow] Add OneHot operator (#3781)

* Add one-hot to Relay

* topi implementation

* Working

* add topi test

* Add TF test

* Fix check

* fix linting issues

* fix documentation

* Fix documentation

* Add support for on_value, off_value, axis, dtype

* Add full support for axis

* Fix compute and update test_forward

* Move on_value and off_value to inputs

* Add topi test

* Update tests

* Update docs

* Fix style

* re-enable tests

* Add one_hot to mxnet converter
parent 7264cb6a
...@@ -104,6 +104,7 @@ List of operators ...@@ -104,6 +104,7 @@ List of operators
topi.argsort topi.argsort
topi.topk topi.topk
topi.sequence_mask topi.sequence_mask
topi.one_hot
List of schedules List of schedules
...@@ -173,6 +174,7 @@ topi ...@@ -173,6 +174,7 @@ topi
.. autofunction:: topi.argsort .. autofunction:: topi.argsort
.. autofunction:: topi.topk .. autofunction:: topi.topk
.. autofunction:: topi.sequence_mask .. autofunction:: topi.sequence_mask
.. autofunction:: topi.one_hot
topi.nn topi.nn
~~~~~~~ ~~~~~~~
......
...@@ -200,6 +200,7 @@ This level support backpropagation of broadcast operators. It is temporary. ...@@ -200,6 +200,7 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.nn.batch_matmul tvm.relay.nn.batch_matmul
tvm.relay.contrib.adaptive_max_pool2d tvm.relay.contrib.adaptive_max_pool2d
tvm.relay.contrib.adaptive_avg_pool2d tvm.relay.contrib.adaptive_avg_pool2d
tvm.relay.one_hot
**Level 11: Dialect Operators** **Level 11: Dialect Operators**
...@@ -350,6 +351,7 @@ Level 10 Definitions ...@@ -350,6 +351,7 @@ Level 10 Definitions
.. autofunction:: tvm.relay.nn.batch_matmul .. autofunction:: tvm.relay.nn.batch_matmul
.. autofunction:: tvm.relay.contrib.adaptive_max_pool2d .. autofunction:: tvm.relay.contrib.adaptive_max_pool2d
.. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d .. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d
.. autofunction:: tvm.relay.one_hot
Level 11 Definitions Level 11 Definitions
......
...@@ -298,6 +298,22 @@ struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> { ...@@ -298,6 +298,22 @@ struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
} }
}; };
/*! \brief Attributes used in one-hot operator */
struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
int depth;
int axis;
DataType dtype;
TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") {
TVM_ATTR_FIELD(depth).set_default(1)
.describe("Depth of the one hot dimension.");
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("Axis to fill.");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
.describe("Output data type.");
}
}; // struct OneHotAttrs
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_ #endif // TVM_RELAY_ATTRS_TRANSFORM_H_
...@@ -896,6 +896,14 @@ def _mx_rnn_layer(inputs, attrs): ...@@ -896,6 +896,14 @@ def _mx_rnn_layer(inputs, attrs):
ret.append(_op.stack(inputs, axis=0)) ret.append(_op.stack(inputs, axis=0))
return ret return ret
def _mx_one_hot(inputs, attrs):
indices = inputs[0].astype('int32')
depth = attrs.get_int('depth', 0)
dtype = attrs.get_str('dtype', 'int32')
on_value = tvm.relay.const(attrs.get_float('on_value', 1.0), dtype)
off_value = tvm.relay.const(attrs.get_float('off_value', 0.0), dtype)
return _op.one_hot(indices, on_value, off_value, depth, -1, dtype)
# Note: due to attribute conversion constraint # Note: due to attribute conversion constraint
# ops in the identity set must be attribute free # ops in the identity set must be attribute free
...@@ -1041,6 +1049,7 @@ _convert_map = { ...@@ -1041,6 +1049,7 @@ _convert_map = {
"LinearRegressionOutput" : _mx_linear_regression_output, "LinearRegressionOutput" : _mx_linear_regression_output,
"smooth_l1" : _mx_smooth_l1, "smooth_l1" : _mx_smooth_l1,
"_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim, "_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
"one_hot" : _mx_one_hot,
# vision # vision
"_contrib_BilinearResize2D" : _mx_resize, "_contrib_BilinearResize2D" : _mx_resize,
"_contrib_MultiBoxPrior" : _mx_multibox_prior, "_contrib_MultiBoxPrior" : _mx_multibox_prior,
......
...@@ -1212,6 +1212,21 @@ def _log1p(): ...@@ -1212,6 +1212,21 @@ def _log1p():
return get_relay_op('log')(add_out) return get_relay_op('log')(add_out)
return _impl return _impl
def _one_hot():
def _impl(inputs, attr, params):
depth = int(_get_num_param(params, inputs[1]))
dtype = attr['T'].name
on_value = _get_num_param(params, inputs[2])
off_value = _get_num_param(params, inputs[3])
new_inputs = [inputs[0], \
tvm.relay.const(on_value, dtype), \
tvm.relay.const(off_value, dtype)]
return AttrCvt('one_hot',
ignores=['TI'],
extras={'depth' : depth, 'dtype' : dtype})(new_inputs, attr)
return _impl
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -1284,6 +1299,7 @@ _convert_map = { ...@@ -1284,6 +1299,7 @@ _convert_map = {
'Mul' : _elemwise('multiply'), 'Mul' : _elemwise('multiply'),
'Neg' : AttrCvt('negative'), 'Neg' : AttrCvt('negative'),
'NotEqual' : _broadcast('not_equal'), 'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(),
'Pack' : _pack(), 'Pack' : _pack(),
'Pad' : _pad('Pad'), 'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'), 'PadV2' : _pad('PadV2'),
......
...@@ -52,6 +52,7 @@ _reg.register_schedule("concatenate", schedule_concatenate) ...@@ -52,6 +52,7 @@ _reg.register_schedule("concatenate", schedule_concatenate)
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective) _reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
_reg.register_schedule("gather_nd", schedule_injective) _reg.register_schedule("gather_nd", schedule_injective)
_reg.register_schedule("sequence_mask", schedule_injective) _reg.register_schedule("sequence_mask", schedule_injective)
_reg.register_schedule("one_hot", schedule_injective)
# layout_transform # layout_transform
......
...@@ -748,3 +748,47 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0): ...@@ -748,3 +748,47 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0):
[[ 0.1, 0.1, 0.1], [ 16., 17., 18.]]] [[ 0.1, 0.1, 0.1], [ 16., 17., 18.]]]
""" """
return _make.sequence_mask(data, valid_length, mask_value, axis) return _make.sequence_mask(data, valid_length, mask_value, axis)
def one_hot(indices, on_value, off_value, depth, axis, dtype):
"""
Returns a one-hot tensor where the locations repsented by indices take value on_value,
other locations take value off_value.
Final dimension is <indices outer dimensions> x depth x <indices inner dimensions>.
Parameters
----------
indices : relay.Expr
Locations to set to on_value.
on_value : relay.Expr
Value to fill at indices.
off_value : relay.Expr
Value to fill at all other positions besides indices.
depth : int
Depth of the one-hot dimension.
axis : int
Axis to fill.
dtype : str
Data type of the output tensor.
Returns
-------
ret : relay.Expr
The one-hot tensor.
Examples
--------
.. code-block:: python
indices = [0, 1, 2]
relay.one_hot(indices, 3) =
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
"""
return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)
...@@ -2482,5 +2482,94 @@ Examples:: ...@@ -2482,5 +2482,94 @@ Examples::
.set_attr<FTVMCompute>("FTVMCompute", SequenceMaskCompute) .set_attr<FTVMCompute>("FTVMCompute", SequenceMaskCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective);
// relay.one_hot
TVM_REGISTER_NODE_TYPE(OneHotAttrs);
bool OneHotRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [indices, on_value, off_value, result]
CHECK_EQ(types.size(), 4);
const auto* indices = types[0].as<TensorTypeNode>();
CHECK(indices);
const auto param = attrs.as<OneHotAttrs>();
CHECK_GT(param->depth, 0);
Array<IndexExpr> oshape;
int ndim = indices->shape.size() + 1;
int indices_index = 0;
int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis;
for (int i = 0; i < ndim; i++) {
if (i == true_axis) {
oshape.push_back(Integer(param->depth));
} else {
oshape.push_back(indices->shape[indices_index++]);
}
}
reporter->Assign(types[3], TensorTypeNode::make(oshape, param->dtype));
return true;
}
Array<Tensor> OneHotCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<OneHotAttrs>();
CHECK(param != nullptr);
return Array<Tensor> {
topi::one_hot(inputs[0],
inputs[1](),
inputs[2](),
param->depth,
param->axis,
param->dtype)
};
}
Expr MakeOneHot(Expr indices,
Expr on_value,
Expr off_value,
int depth,
int axis,
DataType dtype) {
auto attrs = make_node<OneHotAttrs>();
attrs->depth = std::move(depth);
attrs->axis = axis;
attrs->dtype = dtype;
static const Op& op = Op::Get("one_hot");
return CallNode::make(op, {indices, on_value, off_value}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.one_hot")
.set_body_typed(MakeOneHot);
RELAY_REGISTER_OP("one_hot")
.describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1,
other locations take value 0. Final dimension is <indices dimensions> x depth.
**indices** Locations to set to 1.
**on_value** Value to fill at indices.
**off_value** Value to fill at all other positions besides indices.
**depth** Depth of the one-hot dimension.
**axis** Axis to fill.
**dtype**)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.OneHotAttrs")
.set_num_inputs(3)
.add_argument("indices", "Tensor", "Locations to set to on_value.")
.add_argument("on_value", "Expr", "Value to fill at indices.")
.add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.")
.set_support_level(10)
.add_type_rel("OneHot", OneHotRel)
.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -778,6 +778,25 @@ def test_forward_layer_norm(): ...@@ -778,6 +778,25 @@ def test_forward_layer_norm():
verify((2, 5), axis=0) verify((2, 5), axis=0)
verify((2, 5, 6)) verify((2, 5, 6))
def test_forward_one_hot():
def verify(indices_shape, depth, on_value, off_value, dtype):
x = np.random.randint(0, 5, size=indices_shape)
ref_res = mx.nd.one_hot(mx.nd.array(x), depth, on_value, off_value, dtype)
mx_sym = mx.sym.one_hot(mx.sym.var("x"), depth, on_value, off_value, dtype)
shape_dict = {"x": x.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x.astype("float32"))
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
verify((3,), 3, 1, 0, "int32")
verify((3,), 3, 1.0, 0.0, "float32")
verify((2, 2), 5, 2, -2, "int32")
verify((2, 2), 5, 0.5, -0.5, "float32")
verify((3, 2, 4, 5), 6, 1, 0, "int32")
verify((3, 2, 4, 5), 6, 1.0, 0.0, "float32")
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
test_forward_vgg() test_forward_vgg()
...@@ -825,3 +844,4 @@ if __name__ == '__main__': ...@@ -825,3 +844,4 @@ if __name__ == '__main__':
test_forward_contrib_div_sqrt_dim() test_forward_contrib_div_sqrt_dim()
test_forward_batch_norm() test_forward_batch_norm()
test_forward_layer_norm() test_forward_layer_norm()
test_forward_one_hot()
...@@ -2158,6 +2158,24 @@ def test_placeholder(): ...@@ -2158,6 +2158,24 @@ def test_placeholder():
compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0', compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0',
init_global_variables=True) init_global_variables=True)
#######################################################################
# OneHot
# ----------------------
def _test_forward_one_hot(indices_shape, depth, on_value, off_value, axis, out_dtype):
inp_array1 = np.random.randint(0, 5, size=indices_shape)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype)
out = tf.one_hot(in1, depth, on_value, off_value, axis, dtype=out_dtype)
compare_tf_with_tvm(inp_array1, in1.name, out.name)
def test_forward_one_hot():
_test_forward_one_hot((3,), 3, 1, 0, -1, "int32")
_test_forward_one_hot((3,), 3, 1.0, 0.0, -1, "float32")
_test_forward_one_hot((2, 2), 5, 2, -2, 0, "int32")
_test_forward_one_hot((2, 2), 5, 0.5, -0.5, 1, "float32")
_test_forward_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
_test_forward_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
####################################################################### #######################################################################
# Main # Main
...@@ -2193,6 +2211,7 @@ if __name__ == '__main__': ...@@ -2193,6 +2211,7 @@ if __name__ == '__main__':
test_forward_right_shift() test_forward_right_shift()
test_forward_left_shift() test_forward_left_shift()
test_forward_truncatemod() test_forward_truncatemod()
test_forward_one_hot()
# Activations # Activations
test_forward_sigmoid() test_forward_sigmoid()
......
...@@ -296,6 +296,45 @@ def test_sequence_mask(): ...@@ -296,6 +296,45 @@ def test_sequence_mask():
_verify((2, 3, 5, 3), 0.0, 0, 'float32', 'int64') _verify((2, 3, 5, 3), 0.0, 0, 'float32', 'int64')
_verify((5, 8, 3), 0.1, 1, 'float64', 'float32') _verify((5, 8, 3), 0.1, 1, 'float64', 'float32')
def test_one_hot():
def _get_oshape(indices_shape, depth, axis):
oshape = []
true_axis = len(indices_shape) if axis == -1 else axis
ndim = len(indices_shape) + 1
indices_index = 0
for i in range(0, ndim):
if i == true_axis:
oshape.append(depth)
else:
oshape.append(indices_shape[indices_index])
indices_index += 1
return oshape
def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
indices = relay.var("indices", relay.TensorType(indices_shape, "int32"))
on_value_const = relay.const(on_value)
off_value_const = relay.const(off_value)
out = relay.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype)
checked = run_infer_type(out)
assert checked.checked_type == relay.ty.TensorType(_get_oshape(indices_shape, depth, axis), dtype)
func = relay.Function([indices], out)
indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32")
out_np = topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
out_relay = intrp.evaluate(func)(indices_np)
tvm.testing.assert_allclose(out_relay.asnumpy(), out_np)
_verify((3,), 3, 1, 0, -1, "int32")
_verify((3,), 3, 1.0, 0.0, -1, "float32")
_verify((2, 2), 5, 2, -2, 0, "int32")
_verify((2, 2), 5, 0.5, -0.5, 1, "float32")
_verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
_verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
if __name__ == "__main__": if __name__ == "__main__":
test_adaptive_pool2d() test_adaptive_pool2d()
test_collapse_sum_like() test_collapse_sum_like()
...@@ -306,4 +345,5 @@ if __name__ == "__main__": ...@@ -306,4 +345,5 @@ if __name__ == "__main__":
test_shape_of() test_shape_of()
test_sequence_mask() test_sequence_mask()
test_ndarray_size() test_ndarray_size()
test_one_hot()
...@@ -1247,5 +1247,55 @@ inline Tensor ndarray_size(const Tensor& src, ...@@ -1247,5 +1247,55 @@ inline Tensor ndarray_size(const Tensor& src,
}, name, tag); }, name, tag);
} }
/*!
* \brief Returns a one-hot tensor where the locations repsented by indices take value on_value,
other locations take value off_value.
* \param indices locations to set to on_value.
* \param on_value value that locations represented by indices take on.
* \param off_value value that other locations take on.
* \param depth depth of the one-hot dimension.
* \param axis axis to fill.
* \param dtype data type of the output tensor.
* \param name output tensor name.
* \param tag output tensor tag.
* \return one-hot tensor.
*/
inline Tensor one_hot(const Tensor& indices,
const Expr on_value,
const Expr off_value,
int depth,
int axis,
const Type& dtype,
const std::string name = "T_one_hot",
const std::string tag = kInjective) {
Array<Expr> oshape;
int ndim = indices->shape.size() + 1;
int indices_index = 0;
int true_axis = (axis == -1) ? indices->shape.size() : axis;
for (int i = 0; i < ndim; i++) {
if (i == true_axis) {
oshape.push_back(Integer(depth));
} else {
oshape.push_back(indices->shape[indices_index++]);
}
}
Expr on_value_cast = cast(dtype, on_value);
Expr off_value_cast = cast(dtype, off_value);
return compute(oshape, [&](const Array<Var>& iter_vars) {
Array<Var> indices_indices;
for (size_t i = 0; i < iter_vars.size(); i++) {
if (static_cast<int>(i) == true_axis) {
continue;
}
indices_indices.push_back(iter_vars[i]);
}
auto idx = iter_vars[true_axis];
return ir::Select::make(indices(indices_indices) == idx, on_value_cast, off_value_cast);
}, name, tag);
}
} // namespace topi } // namespace topi
#endif // TOPI_TRANSFORM_H_ #endif // TOPI_TRANSFORM_H_
...@@ -25,3 +25,4 @@ from .batch_matmul import batch_matmul ...@@ -25,3 +25,4 @@ from .batch_matmul import batch_matmul
from .slice_axis_python import slice_axis_python from .slice_axis_python import slice_axis_python
from .sequence_mask_python import sequence_mask from .sequence_mask_python import sequence_mask
from .pool_grad_python import pool_grad_nchw from .pool_grad_python import pool_grad_nchw
from .one_hot import one_hot
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""OneHot in python"""
import numpy as np
def one_hot(indices, on_value, off_value, depth, axis, dtype):
"""one_hot operator implemented in numpy.
Returns a one-hot tensor where the locations repsented by indices take value on_value,
other locations take value off_value.
Final dimension is <indices outer dimensions> x depth x <indices inner dimensions>.
Parameters
----------
indices : numpy.ndarray
Locations to set to on_value.
on_value : int/float
Value to fill at indices.
off_value : int/float
Value to fill at all other positions besides indices.
depth : int
Depth of the one-hot dimension.
axis : int
Axis to fill.
dtype : str
Data type of the output tensor.
Returns
-------
ret : relay.Expr
The one-hot tensor.
"""
oshape = []
true_axis = len(indices.shape) if axis == -1 else axis
ndim = len(indices.shape) + 1
indices_index = 0
for i in range(0, ndim):
if i == true_axis:
oshape.append(depth)
else:
oshape.append(indices.shape[indices_index])
indices_index += 1
out = np.empty(oshape)
output_indices = [index for index in np.ndindex(out.shape)]
for output_index in output_indices:
indices_indices = []
for i, out_idx in enumerate(output_index):
if i == true_axis:
continue
indices_indices.append(out_idx)
index = output_index[true_axis]
if indices[tuple(indices_indices)] == index:
out[output_index] = on_value
else:
out[output_index] = off_value
return out.astype(dtype)
...@@ -518,3 +518,47 @@ def where(condition, x, y): ...@@ -518,3 +518,47 @@ def where(condition, x, y):
A Tensor selected from x or y depending on condition. A Tensor selected from x or y depending on condition.
""" """
return cpp.where(condition, x, y) return cpp.where(condition, x, y)
def one_hot(indices, on_value, off_value, depth, axis, dtype):
"""
Returns a one-hot tensor where the locations repsented by indices take value on_value,
other locations take value off_value.
Final dimension is <indices outer dimensions> x depth x <indices inner dimensions>.
Parameters
----------
indices : tvm.Tensor
Locations to set to on_value.
on_value : tvm.Tensor
Value to fill at indices.
off_value : tvm.Tensor
Value to fill at all other positions besides indices.
depth : int
Depth of the one-hot dimension.
axis : int
Axis to fill.
dtype : relay.DataType
Data type of the output tensor.
Returns
-------
ret : relay.Expr
The one-hot tensor.
Examples
--------
.. code-block:: python
indices = [0, 1, 2]
relay.one_hot(indices, 3) =
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
"""
return cpp.one_hot(indices, on_value, off_value, depth, axis, dtype)
...@@ -417,6 +417,14 @@ TVM_REGISTER_GLOBAL("topi.strided_slice") ...@@ -417,6 +417,14 @@ TVM_REGISTER_GLOBAL("topi.strided_slice")
*rv = strided_slice(args[0], args[1], args[2], args[3]); *rv = strided_slice(args[0], args[1], args[2], args[3]);
}); });
TVM_REGISTER_GLOBAL("topi.one_hot")
.set_body([](TVMArgs args, TVMRetValue *rv) {
int depth = args[3];
int axis = args[4];
DataType dtype = args[5];
*rv = one_hot(args[0], args[1], args[2], depth, axis, dtype);
});
/* Ops from nn/upsampling.h */ /* Ops from nn/upsampling.h */
TVM_REGISTER_GLOBAL("topi.nn.upsampling") TVM_REGISTER_GLOBAL("topi.nn.upsampling")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
......
...@@ -473,6 +473,31 @@ def verify_where(in_shape): ...@@ -473,6 +473,31 @@ def verify_where(in_shape):
for device in get_all_backend(): for device in get_all_backend():
check_device(device) check_device(device)
def verify_one_hot(indices_shape, depth, on_value, off_value, axis, dtype):
indices = tvm.placeholder(shape=indices_shape, name="indices", dtype="int32")
on_value_const = tvm.const(on_value, dtype)
off_value_const = tvm.const(off_value, dtype)
one_hot_result = topi.transform.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(one_hot_result)
fn = tvm.build(s, [indices, one_hot_result], device, name="one_hot")
indices_npy = np.random.randint(0, depth, size=indices_shape).astype(indices.dtype)
out_npy = topi.testing.one_hot(indices_npy, on_value, off_value, depth, axis, dtype)
indices_nd = tvm.nd.array(indices_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(one_hot_result.dtype), ctx)
fn(indices_nd, out_nd)
out_topi = out_nd.asnumpy()
tvm.testing.assert_allclose(out_topi, out_npy)
for device in get_all_backend():
check_device(device)
def test_strided_slice(): def test_strided_slice():
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
...@@ -770,6 +795,13 @@ def test_where_fusion(): ...@@ -770,6 +795,13 @@ def test_where_fusion():
for backend in get_all_backend(): for backend in get_all_backend():
check_device(backend) check_device(backend)
def test_one_hot():
verify_one_hot((3,), 3, 1, 0, -1, "int32")
verify_one_hot((3,), 3, 1.0, 0.0, -1, "float32")
verify_one_hot((2, 2), 5, 2, -2, 0, "int32")
verify_one_hot((2, 2), 5, 0.5, -0.5, 1, "float32")
verify_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
verify_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
if __name__ == "__main__": if __name__ == "__main__":
test_strided_slice() test_strided_slice()
...@@ -793,3 +825,4 @@ if __name__ == "__main__": ...@@ -793,3 +825,4 @@ if __name__ == "__main__":
test_sequence_mask() test_sequence_mask()
test_ndarray_size() test_ndarray_size()
test_where_fusion() test_where_fusion()
test_one_hot()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment