Commit 16b009b2 by Haichen Shen Committed by Yizhi Liu

[Relay/TOPI][OP] Add arange op in Relay and TOPI (#2621)

* Add arange op

* Update docs

* Fix bug

* add sanity check in relay and mxnet frontend mapping

* lint

* nits

* pylint

* don't allow empty output from arange

* Remove empty test for arange

* Fix bug and update doc
parent e3e8645b
......@@ -67,6 +67,7 @@ List of operators
topi.not_equal
topi.greater_equal
topi.less_equal
topi.arange
topi.image.resize
......@@ -123,6 +124,7 @@ topi
.. autofunction:: topi.power
.. autofunction:: topi.greater
.. autofunction:: topi.less
.. autofunction:: topi.arange
topi.nn
~~~~~~~
......
......@@ -95,6 +95,7 @@ This level enables additional math and transform operators.
tvm.relay.full_like
tvm.relay.cast
tvm.relay.split
tvm.relay.arange
**Level 4: Broadcast and Reductions**
......@@ -216,6 +217,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.cast
.. autofunction:: tvm.relay.split
.. autofunction:: tvm.relay.arange
Level 4 Definitions
......
......@@ -96,6 +96,25 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
}
}; // struct InitOpAttrs
/*! \brief Attributes used in arange operators */
struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {
tvm::Expr start;
tvm::Expr stop;
tvm::Expr step;
DataType dtype;
TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") {
TVM_ATTR_FIELD(start).set_default(make_const(Float(32), 0))
.describe("Start of interval. The interval includes this value.");
TVM_ATTR_FIELD(stop)
.describe("Stop of interval. The interval does not include this value.");
TVM_ATTR_FIELD(step).set_default(make_const(Float(32), 1))
.describe("Spacing between values.");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
.describe("Target data type.");
}
}; // struct ArangeAttrs
/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
// use axis to make the name numpy compatible.
......
......@@ -268,6 +268,18 @@ def _mx_multibox_detection(inputs, attrs):
return _op.vision.nms(ret[0], ret[1], **new_attrs1)
def _mx_arange(inputs, attrs):
assert len(inputs) == 0
if attrs.get_int("repeat", 1) != 1:
raise RuntimeError("arange doesn't support repeat")
new_attrs = {}
new_attrs["start"] = attrs.get_float("start", 0)
new_attrs["stop"] = attrs.get_float("stop")
new_attrs["step"] = attrs.get_float("step", 1)
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.arange(**new_attrs)
def _mx_roi_align(inputs, attrs):
new_attrs = {}
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
......@@ -362,6 +374,7 @@ _convert_map = {
"Concat" : _mx_concat,
"concat" : _mx_concat,
"LeakyReLU" : _mx_leaky_relu,
"_arange" : _mx_arange,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
# vision
......
......@@ -19,6 +19,7 @@ _reg.register_schedule("reshape", schedule_injective)
_reg.register_schedule("reshape_like", schedule_injective)
_reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("arange", schedule_injective)
_reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
......
......@@ -166,8 +166,9 @@ def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.
Sizes for both array should be compatible.
Parameters
----------
......@@ -249,10 +250,57 @@ def full_like(data, fill_value):
return _make.full_like(data, fill_value)
def arange(start, stop=None, step=1, dtype="float32"):
"""Return evenly spaced values within a given interval.
.. note::
Similar to ``numpy.arange``, when only one argument is given, it is used
as `stop` instead of `start` while `start` takes default value 0.
Warning: Undefined behavior when dtype is incompatible with start/stop/step.
It could lead to different results compared to numpy, MXNet, pytorch, etc.
Parameters
----------
start : tvm.Expr, optional
Start of interval. The interval includes this value. The default start
value is 0.
stop : tvm.Expr
Stop of interval. The interval does not include this value.
step : tvm.Expr, optional
Spacing between values. The default step size is 1.
dtype : str, optional
The target data type.
Returns
-------
result : relay.Expr
The resulting tensor.
Examples
--------
.. code-block:: python
relay.arange(5) = [0, 1, 2, 3, 4]
relay.arange(1, 5) = [1, 2, 3, 4]
relay.arange(1, 5, 1.5) = [1, 2.5, 4]
"""
if stop is None:
stop = start
start = 0
return _make.arange(start, stop, step, dtype)
def where(condition, x, y):
"""Selecting elements from either x or y depending on the value of the
condition.
.. note::
The shape of condition, x, and y needs to be the same.
Parameters
----------
condition : relay.Expr
......@@ -282,8 +330,6 @@ def where(condition, x, y):
condition = [1, 0]
relay.where(conditon, x, y) = [[1, 2], [7, 8]]
Note that the shape of condition, x, and y needs to be the same.
"""
return _make.where(condition, x, y)
......
......@@ -880,6 +880,63 @@ and type as the input array.
.set_attr<FTVMCompute>("FTVMCompute", FullLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise);
// arange operator
TVM_REGISTER_NODE_TYPE(ArangeAttrs);
bool ArangeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 1);
const ArangeAttrs* param = attrs.as<ArangeAttrs>();
IndexExpr num_elem = tvm::cast(tvm::Int(32), tvm::ceil(
tvm::cast(tvm::Float(32), param->stop - param->start) / param->step));
if (const tvm::ir::IntImm* val = num_elem.as<tvm::ir::IntImm>()) {
CHECK_GT(val->value, 0)
<< "Invalid arange attributes (start, stop, step): " << param->start
<< ", " << param->stop << ", " << param->step;
}
reporter->Assign(types[0], TensorTypeNode::make({num_elem}, param->dtype));
return true;
}
Array<Tensor> ArangeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const ArangeAttrs* param = attrs.as<ArangeAttrs>();
return { topi::arange(param->start, param->stop, param->step, param->dtype) };
}
Expr MakeArange(tvm::Expr start,
tvm::Expr stop,
tvm::Expr step,
DataType dtype) {
auto attrs = make_node<ArangeAttrs>();
attrs->start = std::move(start);
attrs->stop = std::move(stop);
attrs->step = std::move(step);
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("arange");
return CallNode::make(op, {}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.arange")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeArange, args, rv);
});
RELAY_REGISTER_OP("arange")
.describe(R"code(Returns evenly spaced values within a given interval.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ArangeAttrs")
.set_num_inputs(0)
.set_support_level(3)
.add_type_rel("Arange", ArangeRel)
.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
// where operator
bool WhereRel(const Array<Type>& types,
int num_inputs,
......
......@@ -203,30 +203,51 @@ def test_forward_where():
mx_cond = mx.nd.array(np_cond)
mx_x = mx.nd.array(np_x)
mx_y = mx.nd.array(np_y)
shapes = {'cond': dshape, 'x': dshape, 'y': dshape}
mod = mx.mod.Module(mx_sym, label_names=None, data_names=['cond', 'x', 'y'])
mod.bind(data_shapes=[('cond', dshape), ('x', dshape), ('y', dshape)], for_training=False)
mod.bind(data_shapes=shapes.items(), for_training=False)
mod.init_params()
args, auxs = mod.get_params()
mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy()
out_shape = dshape
shape_dict = {'cond': dshape, 'x': dshape, 'y': dshape}
new_sym, params = relay.frontend.from_mxnet(mx_sym,
shape_dict,
arg_params=args,
aux_params=auxs)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, args, auxs)
for target, ctx in ctx_list():
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(new_sym, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input("cond", tvm.nd.array(np_cond))
m.set_input("x", tvm.nd.array(np_x))
m.set_input("y", tvm.nd.array(np_y))
m.set_input(**params)
m.run()
# get outputs
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(np_cond, np_x, np_y)
tvm.testing.assert_allclose(op_res.asnumpy(), mx_out)
def test_forward_arange():
def _mx_symbol(F, start, stop, step):
if start is None and step is None:
sym = F.arange(stop)
elif start is None:
sym = F.arange(stop, step=step)
elif step is None:
sym = F.arange(start, stop)
else:
sym = F.arange(start, stop, step)
return sym
def verify(start, stop, step):
ref_res = _mx_symbol(mx.nd, start, stop, step).asnumpy()
mx_sym = _mx_symbol(mx.sym, start, stop, step)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)()
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
verify(0, 20, None)
verify(0, 20, 2)
verify(1, 20, None)
verify(1, 20, 2)
verify(1, 20, 1.5)
verify(1, 20.5, None)
verify(1, 20, 3)
verify(20, 1, -1)
verify(20, 1, -1.5)
if __name__ == '__main__':
......@@ -251,3 +272,4 @@ if __name__ == '__main__':
test_forward_argmax()
test_forward_argmin()
test_forward_where()
test_forward_arange()
......@@ -457,6 +457,40 @@ def test_infer_type_prelu():
verify_infer_type_prelu((1, 3, 2, 2), None, 1, (1, 3, 2, 2))
verify_infer_type_prelu((1, 2, 2, 3), None, 3, (1, 2, 2, 3))
def test_arange():
def verify_arange(start, stop, step):
dtype = "float32"
if start is None and step is None:
x = relay.arange(stop)
ref_res = np.arange(stop)
elif start is None:
x = relay.arange(stop, step=step)
ref_res = np.arange(stop, step=step)
elif step is None:
x = relay.arange(start, stop)
ref_res = np.arange(start, stop)
else:
x = relay.arange(start, stop, step)
ref_res = np.arange(start, stop, step)
func = relay.Function([], x)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)()
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_arange(None, 20, None)
verify_arange(None, 20, 2)
verify_arange(1, 20, None)
verify_arange(1, 20, 2)
verify_arange(1, 20, 1.5)
verify_arange(1, 20.5, None)
verify_arange(1, 20, 3)
verify_arange(20, 1, -1)
verify_arange(20, 1, -1.5)
if __name__ == "__main__":
test_cast()
test_zeros_ones()
......@@ -480,3 +514,4 @@ if __name__ == "__main__":
test_squeeze_infer_type()
test_squeeze_bad_axes_infer_type()
test_split_infer_type()
test_arange()
......@@ -868,6 +868,19 @@ inline Tensor tensordot(const Tensor& A,
return compute(output_shape, func, name, tag);
}
inline Tensor arange(const Expr start,
const Expr stop,
const Expr step,
Type dtype,
std::string name = "tensor",
std::string tag = kInjective) {
Expr num_elem = tvm::cast(tvm::Int(32), tvm::ceil(
tvm::cast(tvm::Float(32), stop - start) / step));
Array<Expr> shape;
return compute({num_elem}, [&](const Array<Var>& indices) {
return tvm::cast(dtype, start + step * indices[0]);
}, name, tag);
}
} // namespace topi
#endif // TOPI_TRANSFORM_H_
......@@ -289,3 +289,32 @@ def tensordot(a, b, axes):
if isinstance(axes[0], int):
return cpp.tensordot(a, b, (axes[0],), (axes[1],))
return cpp.tensordot(a, b, axes[0], axes[1])
def arange(start, stop=None, step=1, dtype="float32"):
"""Creates a tensor with evenly spaced values within a given interval.
Parameters
----------
start : tvm.Expr, optional
Start of interval. The interval includes this value. The default start
value is 0.
stop : tvm.Expr
Stop of interval. The interval does not include this value.
step : tvm.Expr, optional
Spacing between values. The default step size is 1.
dtype : str, optional
The target data type.
Returns
-------
result : tvm.Tensor
The resulting tensor.
"""
if stop is None:
stop = start
start = 0
return cpp.arange(start, stop, step, dtype)
......@@ -290,6 +290,11 @@ TVM_REGISTER_GLOBAL("topi.where")
*rv = where(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.arange")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = arange(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("topi.gather_nd")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = gather_nd(args[0], args[1]);
......
......@@ -304,6 +304,36 @@ def verify_gather_nd(src_shape, indices_src, indices_dtype):
for device in get_all_backend():
check_device(device)
def verify_arange(start, stop, step):
if start is None and step is None:
A = topi.arange(stop)
a_np = np.arange(stop)
elif start is None:
A = topi.arange(stop, step=step)
a_np = np.arange(stop, step=step)
elif step is None:
A = topi.arange(start, stop)
a_np = np.arange(start, stop)
else:
A = topi.arange(start, stop, step)
a_np = np.arange(start, stop, step)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(A)
f = tvm.build(s, [A], device, name="arange")
a_nd = tvm.nd.empty(a_np.shape, dtype='float32', ctx=ctx)
f(a_nd)
tvm.testing.assert_allclose(a_nd.asnumpy(), a_np)
for device in get_all_backend():
check_device(device)
def test_strided_slice():
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
......@@ -407,6 +437,18 @@ def test_gather_nd():
verify_gather_nd((2, 3, 4, 5), [[1, 0], [2, 1], [3, 2], [4, 2]],
indices_dtype)
def test_arange():
verify_arange(None, 20, None)
verify_arange(None, 20, 2)
verify_arange(1, 20, None)
verify_arange(1, 20, 2)
verify_arange(1, 20, 1.5)
verify_arange(1, 20.5, None)
verify_arange(1, 20, 3)
verify_arange(20, 1, -1)
verify_arange(20, 1, -1.5)
if __name__ == "__main__":
test_strided_slice()
test_concatenate()
......@@ -419,3 +461,4 @@ if __name__ == "__main__":
test_expand_like()
test_take()
test_gather_nd()
test_arange()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment