Commit 891b4e06 by Pariksheet Pinjari Committed by Tianqi Chen

Flip operator (#505)

parent 31edf3f7
......@@ -156,6 +156,14 @@ struct TransposeParam : public dmlc::Parameter<TransposeParam> {
}
};
struct FlipParam : public dmlc::Parameter<FlipParam> {
int axis;
DMLC_DECLARE_PARAMETER(FlipParam) {
DMLC_DECLARE_FIELD(axis).set_default(0)
.describe("the axis to be reveresed.");
}
};
struct BroadcastToParam : public dmlc::Parameter<BroadcastToParam> {
TShape shape;
......
......@@ -41,6 +41,10 @@ reg.register_schedule("reshape_like", _fschedule_injective)
reg.register_pattern("transpose", OpPattern.INJECTIVE)
reg.register_schedule("transpose", _fschedule_injective)
# flip
reg.register_pattern("flip", OpPattern.INJECTIVE)
reg.register_schedule("flip", _fschedule_injective)
# reshape
reg.register_pattern("reshape", OpPattern.INJECTIVE)
reg.register_schedule("reshape", _fschedule_injective)
......
......@@ -830,5 +830,54 @@ Examples::
};
});
// Flip
DMLC_REGISTER_PARAMETER(FlipParam);
NNVM_REGISTER_OP(flip)
.describe(R"code(Reverse the elements of an array.
Examples::
x = [[ 1, 2],
[ 3, 4]]
flip(x) = [[ 3., 4.],
[ 1., 2.]]
x = [[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]]]
flip(x) = [[[ 5., 6.],
[ 7., 8.]],
[[ 1., 2.],
[ 3., 4.]]]
flip(x, axis=1) = [[[ 3., 4.],
[ 1., 2.]],
[[ 7., 8.],
[ 5., 6.]]]
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Source input")
.add_arguments(FlipParam::__FIELDS__())
.set_attr_parser(ParamParser<FlipParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<FlipParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(4)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const FlipParam& param = nnvm::get<FlipParam>(attrs.parsed);
return Array<Tensor>{ topi::flip(inputs[0], param.axis) };
});
} // namespace top
} // namespace nnvm
......@@ -90,6 +90,28 @@ def test_reduce():
verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True)
verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2))
def verify_flip(ishape, axis):
x = sym.Variable("x")
y = sym.flip(x, axis=axis) + 1
dtype = "float32"
x_np = np.random.uniform(size=ishape).astype(dtype)
res = np.flip(x_np, axis) + 1
for target, ctx in ctx_list():
# set input
graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape})
m = graph_runtime.create(graph, lib, ctx)
m.run(x=x_np)
out = m.get_output(0, tvm.nd.empty(res.shape))
np.testing.assert_allclose(out.asnumpy(), res, atol=1e-5, rtol=1e-5)
def test_flip():
verify_flip((3, 4, 3), 1)
verify_flip((3, 4, 3), 0)
verify_flip((3, 4, 3), 2)
verify_flip((3, 4, 3), -1)
verify_flip((3, 4, 3), -3)
verify_flip((3, 4, 3), -2)
def verify_reshape(dshape, oshape):
x = sym.Variable("x")
......@@ -347,4 +369,5 @@ if __name__ == "__main__":
test_elemwise_sum()
test_block_grad()
test_full()
test_flip()
print(nnvm.compiler.engine.dump())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment