Commit 16b009b2 by Haichen Shen Committed by Yizhi Liu

[Relay/TOPI][OP] Add arange op in Relay and TOPI (#2621)

* Add arange op

* Update docs

* Fix bug

* add sanity check in relay and mxnet frontend mapping

* lint

* nits

* pylint

* don't allow empty output from arange

* Remove empty test for arange

* Fix bug and update doc
parent e3e8645b
...@@ -67,6 +67,7 @@ List of operators ...@@ -67,6 +67,7 @@ List of operators
topi.not_equal topi.not_equal
topi.greater_equal topi.greater_equal
topi.less_equal topi.less_equal
topi.arange
topi.image.resize topi.image.resize
...@@ -123,6 +124,7 @@ topi ...@@ -123,6 +124,7 @@ topi
.. autofunction:: topi.power .. autofunction:: topi.power
.. autofunction:: topi.greater .. autofunction:: topi.greater
.. autofunction:: topi.less .. autofunction:: topi.less
.. autofunction:: topi.arange
topi.nn topi.nn
~~~~~~~ ~~~~~~~
......
...@@ -95,6 +95,7 @@ This level enables additional math and transform operators. ...@@ -95,6 +95,7 @@ This level enables additional math and transform operators.
tvm.relay.full_like tvm.relay.full_like
tvm.relay.cast tvm.relay.cast
tvm.relay.split tvm.relay.split
tvm.relay.arange
**Level 4: Broadcast and Reductions** **Level 4: Broadcast and Reductions**
...@@ -216,6 +217,7 @@ Level 3 Definitions ...@@ -216,6 +217,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.full_like .. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.cast .. autofunction:: tvm.relay.cast
.. autofunction:: tvm.relay.split .. autofunction:: tvm.relay.split
.. autofunction:: tvm.relay.arange
Level 4 Definitions Level 4 Definitions
......
...@@ -96,6 +96,25 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> { ...@@ -96,6 +96,25 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
} }
}; // struct InitOpAttrs }; // struct InitOpAttrs
/*! \brief Attributes used in arange operators */
struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {
tvm::Expr start;
tvm::Expr stop;
tvm::Expr step;
DataType dtype;
TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") {
TVM_ATTR_FIELD(start).set_default(make_const(Float(32), 0))
.describe("Start of interval. The interval includes this value.");
TVM_ATTR_FIELD(stop)
.describe("Stop of interval. The interval does not include this value.");
TVM_ATTR_FIELD(step).set_default(make_const(Float(32), 1))
.describe("Spacing between values.");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
.describe("Target data type.");
}
}; // struct ArangeAttrs
/*! \brief Attributes used in squeeze operators */ /*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> { struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
// use axis to make the name numpy compatible. // use axis to make the name numpy compatible.
......
...@@ -268,6 +268,18 @@ def _mx_multibox_detection(inputs, attrs): ...@@ -268,6 +268,18 @@ def _mx_multibox_detection(inputs, attrs):
return _op.vision.nms(ret[0], ret[1], **new_attrs1) return _op.vision.nms(ret[0], ret[1], **new_attrs1)
def _mx_arange(inputs, attrs):
assert len(inputs) == 0
if attrs.get_int("repeat", 1) != 1:
raise RuntimeError("arange doesn't support repeat")
new_attrs = {}
new_attrs["start"] = attrs.get_float("start", 0)
new_attrs["stop"] = attrs.get_float("stop")
new_attrs["step"] = attrs.get_float("step", 1)
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.arange(**new_attrs)
def _mx_roi_align(inputs, attrs): def _mx_roi_align(inputs, attrs):
new_attrs = {} new_attrs = {}
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size") new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
...@@ -362,6 +374,7 @@ _convert_map = { ...@@ -362,6 +374,7 @@ _convert_map = {
"Concat" : _mx_concat, "Concat" : _mx_concat,
"concat" : _mx_concat, "concat" : _mx_concat,
"LeakyReLU" : _mx_leaky_relu, "LeakyReLU" : _mx_leaky_relu,
"_arange" : _mx_arange,
"SoftmaxOutput" : _mx_softmax_output, "SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation, "SoftmaxActivation" : _mx_softmax_activation,
# vision # vision
......
...@@ -19,6 +19,7 @@ _reg.register_schedule("reshape", schedule_injective) ...@@ -19,6 +19,7 @@ _reg.register_schedule("reshape", schedule_injective)
_reg.register_schedule("reshape_like", schedule_injective) _reg.register_schedule("reshape_like", schedule_injective)
_reg.register_schedule("full", schedule_injective) _reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective) _reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("arange", schedule_injective)
_reg.register_schedule("cast", schedule_injective) _reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective) _reg.register_schedule("slice_like", schedule_injective)
......
...@@ -166,8 +166,9 @@ def reshape_like(data, shape_like): ...@@ -166,8 +166,9 @@ def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array. """Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array. the input array into an output array with the same shape as the second input array.
.. note:: .. note::
Sizes for both array should be compatible. Sizes for both array should be compatible.
Parameters Parameters
---------- ----------
...@@ -249,10 +250,57 @@ def full_like(data, fill_value): ...@@ -249,10 +250,57 @@ def full_like(data, fill_value):
return _make.full_like(data, fill_value) return _make.full_like(data, fill_value)
def arange(start, stop=None, step=1, dtype="float32"):
"""Return evenly spaced values within a given interval.
.. note::
Similar to ``numpy.arange``, when only one argument is given, it is used
as `stop` instead of `start` while `start` takes default value 0.
Warning: Undefined behavior when dtype is incompatible with start/stop/step.
It could lead to different results compared to numpy, MXNet, pytorch, etc.
Parameters
----------
start : tvm.Expr, optional
Start of interval. The interval includes this value. The default start
value is 0.
stop : tvm.Expr
Stop of interval. The interval does not include this value.
step : tvm.Expr, optional
Spacing between values. The default step size is 1.
dtype : str, optional
The target data type.
Returns
-------
result : relay.Expr
The resulting tensor.
Examples
--------
.. code-block:: python
relay.arange(5) = [0, 1, 2, 3, 4]
relay.arange(1, 5) = [1, 2, 3, 4]
relay.arange(1, 5, 1.5) = [1, 2.5, 4]
"""
if stop is None:
stop = start
start = 0
return _make.arange(start, stop, step, dtype)
def where(condition, x, y): def where(condition, x, y):
"""Selecting elements from either x or y depending on the value of the """Selecting elements from either x or y depending on the value of the
condition. condition.
.. note::
The shape of condition, x, and y needs to be the same.
Parameters Parameters
---------- ----------
condition : relay.Expr condition : relay.Expr
...@@ -282,8 +330,6 @@ def where(condition, x, y): ...@@ -282,8 +330,6 @@ def where(condition, x, y):
condition = [1, 0] condition = [1, 0]
relay.where(conditon, x, y) = [[1, 2], [7, 8]] relay.where(conditon, x, y) = [[1, 2], [7, 8]]
Note that the shape of condition, x, and y needs to be the same.
""" """
return _make.where(condition, x, y) return _make.where(condition, x, y)
......
...@@ -880,6 +880,63 @@ and type as the input array. ...@@ -880,6 +880,63 @@ and type as the input array.
.set_attr<FTVMCompute>("FTVMCompute", FullLikeCompute) .set_attr<FTVMCompute>("FTVMCompute", FullLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise); .set_attr<TOpPattern>("TOpPattern", kElemWise);
// arange operator
TVM_REGISTER_NODE_TYPE(ArangeAttrs);
bool ArangeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 1);
const ArangeAttrs* param = attrs.as<ArangeAttrs>();
IndexExpr num_elem = tvm::cast(tvm::Int(32), tvm::ceil(
tvm::cast(tvm::Float(32), param->stop - param->start) / param->step));
if (const tvm::ir::IntImm* val = num_elem.as<tvm::ir::IntImm>()) {
CHECK_GT(val->value, 0)
<< "Invalid arange attributes (start, stop, step): " << param->start
<< ", " << param->stop << ", " << param->step;
}
reporter->Assign(types[0], TensorTypeNode::make({num_elem}, param->dtype));
return true;
}
Array<Tensor> ArangeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const ArangeAttrs* param = attrs.as<ArangeAttrs>();
return { topi::arange(param->start, param->stop, param->step, param->dtype) };
}
Expr MakeArange(tvm::Expr start,
tvm::Expr stop,
tvm::Expr step,
DataType dtype) {
auto attrs = make_node<ArangeAttrs>();
attrs->start = std::move(start);
attrs->stop = std::move(stop);
attrs->step = std::move(step);
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("arange");
return CallNode::make(op, {}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.arange")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeArange, args, rv);
});
RELAY_REGISTER_OP("arange")
.describe(R"code(Returns evenly spaced values within a given interval.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ArangeAttrs")
.set_num_inputs(0)
.set_support_level(3)
.add_type_rel("Arange", ArangeRel)
.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
// where operator // where operator
bool WhereRel(const Array<Type>& types, bool WhereRel(const Array<Type>& types,
int num_inputs, int num_inputs,
......
...@@ -203,30 +203,51 @@ def test_forward_where(): ...@@ -203,30 +203,51 @@ def test_forward_where():
mx_cond = mx.nd.array(np_cond) mx_cond = mx.nd.array(np_cond)
mx_x = mx.nd.array(np_x) mx_x = mx.nd.array(np_x)
mx_y = mx.nd.array(np_y) mx_y = mx.nd.array(np_y)
shapes = {'cond': dshape, 'x': dshape, 'y': dshape}
mod = mx.mod.Module(mx_sym, label_names=None, data_names=['cond', 'x', 'y']) mod = mx.mod.Module(mx_sym, label_names=None, data_names=['cond', 'x', 'y'])
mod.bind(data_shapes=[('cond', dshape), ('x', dshape), ('y', dshape)], for_training=False) mod.bind(data_shapes=shapes.items(), for_training=False)
mod.init_params() mod.init_params()
args, auxs = mod.get_params() args, auxs = mod.get_params()
mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy() mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy()
out_shape = dshape
shape_dict = {'cond': dshape, 'x': dshape, 'y': dshape} new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, args, auxs)
new_sym, params = relay.frontend.from_mxnet(mx_sym,
shape_dict,
arg_params=args,
aux_params=auxs)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
with relay.build_config(opt_level=3): for kind in ["graph", "debug"]:
graph, lib, params = relay.build(new_sym, target, params=params) intrp = relay.create_executor(kind, ctx=ctx, target=target)
m = graph_runtime.create(graph, lib, ctx) op_res = intrp.evaluate(new_sym)(np_cond, np_x, np_y)
# set inputs tvm.testing.assert_allclose(op_res.asnumpy(), mx_out)
m.set_input("cond", tvm.nd.array(np_cond))
m.set_input("x", tvm.nd.array(np_x))
m.set_input("y", tvm.nd.array(np_y)) def test_forward_arange():
m.set_input(**params) def _mx_symbol(F, start, stop, step):
m.run() if start is None and step is None:
# get outputs sym = F.arange(stop)
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy() elif start is None:
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5) sym = F.arange(stop, step=step)
elif step is None:
sym = F.arange(start, stop)
else:
sym = F.arange(start, stop, step)
return sym
def verify(start, stop, step):
ref_res = _mx_symbol(mx.nd, start, stop, step).asnumpy()
mx_sym = _mx_symbol(mx.sym, start, stop, step)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)()
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
verify(0, 20, None)
verify(0, 20, 2)
verify(1, 20, None)
verify(1, 20, 2)
verify(1, 20, 1.5)
verify(1, 20.5, None)
verify(1, 20, 3)
verify(20, 1, -1)
verify(20, 1, -1.5)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -251,3 +272,4 @@ if __name__ == '__main__': ...@@ -251,3 +272,4 @@ if __name__ == '__main__':
test_forward_argmax() test_forward_argmax()
test_forward_argmin() test_forward_argmin()
test_forward_where() test_forward_where()
test_forward_arange()
...@@ -457,6 +457,40 @@ def test_infer_type_prelu(): ...@@ -457,6 +457,40 @@ def test_infer_type_prelu():
verify_infer_type_prelu((1, 3, 2, 2), None, 1, (1, 3, 2, 2)) verify_infer_type_prelu((1, 3, 2, 2), None, 1, (1, 3, 2, 2))
verify_infer_type_prelu((1, 2, 2, 3), None, 3, (1, 2, 2, 3)) verify_infer_type_prelu((1, 2, 2, 3), None, 3, (1, 2, 2, 3))
def test_arange():
def verify_arange(start, stop, step):
dtype = "float32"
if start is None and step is None:
x = relay.arange(stop)
ref_res = np.arange(stop)
elif start is None:
x = relay.arange(stop, step=step)
ref_res = np.arange(stop, step=step)
elif step is None:
x = relay.arange(start, stop)
ref_res = np.arange(start, stop)
else:
x = relay.arange(start, stop, step)
ref_res = np.arange(start, stop, step)
func = relay.Function([], x)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)()
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_arange(None, 20, None)
verify_arange(None, 20, 2)
verify_arange(1, 20, None)
verify_arange(1, 20, 2)
verify_arange(1, 20, 1.5)
verify_arange(1, 20.5, None)
verify_arange(1, 20, 3)
verify_arange(20, 1, -1)
verify_arange(20, 1, -1.5)
if __name__ == "__main__": if __name__ == "__main__":
test_cast() test_cast()
test_zeros_ones() test_zeros_ones()
...@@ -480,3 +514,4 @@ if __name__ == "__main__": ...@@ -480,3 +514,4 @@ if __name__ == "__main__":
test_squeeze_infer_type() test_squeeze_infer_type()
test_squeeze_bad_axes_infer_type() test_squeeze_bad_axes_infer_type()
test_split_infer_type() test_split_infer_type()
test_arange()
...@@ -868,6 +868,19 @@ inline Tensor tensordot(const Tensor& A, ...@@ -868,6 +868,19 @@ inline Tensor tensordot(const Tensor& A,
return compute(output_shape, func, name, tag); return compute(output_shape, func, name, tag);
} }
inline Tensor arange(const Expr start,
const Expr stop,
const Expr step,
Type dtype,
std::string name = "tensor",
std::string tag = kInjective) {
Expr num_elem = tvm::cast(tvm::Int(32), tvm::ceil(
tvm::cast(tvm::Float(32), stop - start) / step));
Array<Expr> shape;
return compute({num_elem}, [&](const Array<Var>& indices) {
return tvm::cast(dtype, start + step * indices[0]);
}, name, tag);
}
} // namespace topi } // namespace topi
#endif // TOPI_TRANSFORM_H_ #endif // TOPI_TRANSFORM_H_
...@@ -289,3 +289,32 @@ def tensordot(a, b, axes): ...@@ -289,3 +289,32 @@ def tensordot(a, b, axes):
if isinstance(axes[0], int): if isinstance(axes[0], int):
return cpp.tensordot(a, b, (axes[0],), (axes[1],)) return cpp.tensordot(a, b, (axes[0],), (axes[1],))
return cpp.tensordot(a, b, axes[0], axes[1]) return cpp.tensordot(a, b, axes[0], axes[1])
def arange(start, stop=None, step=1, dtype="float32"):
"""Creates a tensor with evenly spaced values within a given interval.
Parameters
----------
start : tvm.Expr, optional
Start of interval. The interval includes this value. The default start
value is 0.
stop : tvm.Expr
Stop of interval. The interval does not include this value.
step : tvm.Expr, optional
Spacing between values. The default step size is 1.
dtype : str, optional
The target data type.
Returns
-------
result : tvm.Tensor
The resulting tensor.
"""
if stop is None:
stop = start
start = 0
return cpp.arange(start, stop, step, dtype)
...@@ -290,6 +290,11 @@ TVM_REGISTER_GLOBAL("topi.where") ...@@ -290,6 +290,11 @@ TVM_REGISTER_GLOBAL("topi.where")
*rv = where(args[0], args[1], args[2]); *rv = where(args[0], args[1], args[2]);
}); });
TVM_REGISTER_GLOBAL("topi.arange")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = arange(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("topi.gather_nd") TVM_REGISTER_GLOBAL("topi.gather_nd")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = gather_nd(args[0], args[1]); *rv = gather_nd(args[0], args[1]);
......
...@@ -304,6 +304,36 @@ def verify_gather_nd(src_shape, indices_src, indices_dtype): ...@@ -304,6 +304,36 @@ def verify_gather_nd(src_shape, indices_src, indices_dtype):
for device in get_all_backend(): for device in get_all_backend():
check_device(device) check_device(device)
def verify_arange(start, stop, step):
if start is None and step is None:
A = topi.arange(stop)
a_np = np.arange(stop)
elif start is None:
A = topi.arange(stop, step=step)
a_np = np.arange(stop, step=step)
elif step is None:
A = topi.arange(start, stop)
a_np = np.arange(start, stop)
else:
A = topi.arange(start, stop, step)
a_np = np.arange(start, stop, step)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(A)
f = tvm.build(s, [A], device, name="arange")
a_nd = tvm.nd.empty(a_np.shape, dtype='float32', ctx=ctx)
f(a_nd)
tvm.testing.assert_allclose(a_nd.asnumpy(), a_np)
for device in get_all_backend():
check_device(device)
def test_strided_slice(): def test_strided_slice():
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
...@@ -407,6 +437,18 @@ def test_gather_nd(): ...@@ -407,6 +437,18 @@ def test_gather_nd():
verify_gather_nd((2, 3, 4, 5), [[1, 0], [2, 1], [3, 2], [4, 2]], verify_gather_nd((2, 3, 4, 5), [[1, 0], [2, 1], [3, 2], [4, 2]],
indices_dtype) indices_dtype)
def test_arange():
verify_arange(None, 20, None)
verify_arange(None, 20, 2)
verify_arange(1, 20, None)
verify_arange(1, 20, 2)
verify_arange(1, 20, 1.5)
verify_arange(1, 20.5, None)
verify_arange(1, 20, 3)
verify_arange(20, 1, -1)
verify_arange(20, 1, -1.5)
if __name__ == "__main__": if __name__ == "__main__":
test_strided_slice() test_strided_slice()
test_concatenate() test_concatenate()
...@@ -419,3 +461,4 @@ if __name__ == "__main__": ...@@ -419,3 +461,4 @@ if __name__ == "__main__":
test_expand_like() test_expand_like()
test_take() test_take()
test_gather_nd() test_gather_nd()
test_arange()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment