Commit 4d09fc4e by Leyuan Wang Committed by Haichen Shen

[Relay][Frontend] Add reverse op to relay (#2800)

* start adding reverse

* reverse updated

* reverse uses topi::flip

* typo fixed

* comment addressed

* exp simplified
parent a2b45887
......@@ -99,6 +99,7 @@ This level enables additional math and transform operators.
tvm.relay.stack
tvm.relay.repeat
tvm.relay.tile
tvm.relay.reverse
**Level 4: Broadcast and Reductions**
......@@ -229,6 +230,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.stack
.. autofunction:: tvm.relay.repeat
.. autofunction:: tvm.relay.tile
.. autofunction:: tvm.relay.reverse
Level 4 Definitions
......
......@@ -146,6 +146,15 @@ struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
}
}; // struct TileAttrs
/*! \brief Attributes used in reverse operators */
struct ReverseAttrs : public tvm::AttrsNode<ReverseAttrs> {
Integer axis;
TVM_DECLARE_ATTRS(ReverseAttrs, "relay.attrs.ReverseAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
.describe("The axis along which to reverse elements.");
}
}; // struct ReverseAttrs
/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
// use axis to make the name numpy compatible.
......
......@@ -422,6 +422,13 @@ def _mx_tile(inputs, attrs):
return _op.tile(inputs[0], **new_attrs)
def _mx_reverse(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
new_attrs["axis"] = attrs.get_int("axis")
return _op.reverse(inputs[0], **new_attrs)
def _mx_roi_align(inputs, attrs):
new_attrs = {}
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
......@@ -612,6 +619,7 @@ _convert_map = {
"_arange" : _mx_arange,
"repeat" : _mx_repeat,
"tile" : _mx_tile,
"reverse" : _mx_reverse,
"BlockGrad" : _mx_BlockGrad,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
......
......@@ -19,6 +19,7 @@ _reg.register_schedule("reshape_like", schedule_injective)
_reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("arange", schedule_injective)
_reg.register_schedule("reverse", schedule_injective)
_reg.register_schedule("repeat", schedule_broadcast)
_reg.register_schedule("tile", schedule_broadcast)
_reg.register_schedule("cast", schedule_injective)
......
......@@ -385,6 +385,35 @@ def tile(data, reps):
return _make.tile(data, reps)
def reverse(data, axis):
"""Reverses the order of elements along given axis while preserving array shape.
By default, repeat flattens the input array into 1-D and then repeats the elements.
Parameters
----------
data : relay.Expr
The input data to the operator.
axis: int
The axis along which to reverse elements.
Returns
-------
ret : relay.Expr
The computed result.
Examples
--------
.. code-block:: python
x = [[1., 2.], [3., 4.]]
relay.reverse(x, axis=0) = [[3., 4.], [1., 2.]]
relay.reverse(x, axis=1) = [[2., 1.], [4., 3.]]
"""
return _make.reverse(data, axis)
def where(condition, x, y):
"""Selecting elements from either x or y depending on the value of the
condition.
......
......@@ -1086,8 +1086,8 @@ Array<Tensor> RepeatCompute(const Attrs& attrs,
}
Expr MakeRepeat(Expr data,
int repeats,
int axis) {
int repeats,
int axis) {
auto attrs = make_node<RepeatAttrs>();
attrs->repeats = repeats;
attrs->axis = axis;
......@@ -1204,6 +1204,69 @@ RELAY_REGISTER_OP("tile")
.set_attr<FTVMCompute>("FTVMCompute", TileCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
// reverse operator
TVM_REGISTER_NODE_TYPE(ReverseAttrs);
bool ReverseRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "reverse: expect input type to be TensorType but get "
<< types[0];
return false;
}
const auto* param = attrs.as<ReverseAttrs>();
const int ndim = static_cast<int>(data->shape.size());
const int axis = param->axis;
CHECK(-ndim <= axis && axis < ndim)
<< "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]"
<< ", but got axis = " << axis
<< ", and data.ndim = " << ndim;
reporter->Assign(types[1], types[0]);
return true;
}
Array<Tensor> ReverseCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const ReverseAttrs *param = attrs.as<ReverseAttrs>();
CHECK(param != nullptr);
return { topi::flip(inputs[0], param->axis) };
}
Expr MakeReverse(Expr data,
int axis) {
auto attrs = make_node<ReverseAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("reverse");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.reverse")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReverse, args, rv);
});
RELAY_REGISTER_OP("reverse")
.describe(R"code(Reverses the order of elements along given `axis` while preserving array shape.
- **data**: The input data to the operator.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.Reverse")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Reverse", ReverseRel)
.set_attr<FTVMCompute>("FTVMCompute", ReverseCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
// where operator
bool WhereRel(const Array<Type>& types,
int num_inputs,
......
......@@ -491,6 +491,25 @@ def test_arange():
verify_arange(20, 1, -1.5)
def test_reverse():
def verify_reverse(dshape, axis):
x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.reverse(x, axis=axis)
zz = relay.ir_pass.infer_type(z)
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
ref_res = np.flip(x_data, axis)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_reverse((2, 3, 4), 1)
verify_reverse((4, 7), 0)
verify_reverse((2, 3, 4), -1)
if __name__ == "__main__":
test_cast()
test_zeros_ones()
......@@ -515,3 +534,4 @@ if __name__ == "__main__":
test_squeeze_bad_axes_infer_type()
test_split_infer_type()
test_arange()
test_reverse()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment