Commit 0e99e8aa by ZihengJiang

Add tile operation

parent 0c72ca97
......@@ -100,6 +100,9 @@ class Schedule : public NodeRef {
* \return reference to self.
*/
Schedule& reorder(const Array<IterVar>& order); // NOLINT(*)
Schedule& tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer,
IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor); // NOLINT(*)
};
/*!
......
......@@ -107,3 +107,8 @@ class Schedule(NodeBase):
The order to be ordered
"""
_function_internal._ScheduleReorder(self, args)
def tile(self, x_parent, y_parent, x_factor, y_factor):
x_outer, y_outer, x_inner, y_inner = _function_internal._ScheduleTile(
self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner
......@@ -151,5 +151,13 @@ TVM_REGISTER_API(_ScheduleReorder)
.reorder(args.at(1));
});
TVM_REGISTER_API(_ScheduleTile)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar x_outer, y_outer, x_inner, y_inner;
args.at(0).operator Schedule()
.tile(args.at(1), args.at(2), &x_outer, &y_outer,
&x_inner, &y_inner, args.at(3), args.at(4));
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
} // namespace tvm
......@@ -148,6 +148,16 @@ Schedule& Schedule::reorder(const Array<IterVar>& order) { // NOLINT(*)
return *this;
}
Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer,
IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor) { // NOLINT(*)
split(x_parent, p_x_outer, p_x_inner, x_factor);
split(y_parent, p_y_outer, p_y_inner, y_factor);
reorder(Array<IterVar>({*p_x_inner, *p_y_inner, *p_x_outer, *p_y_outer}));
return *this;
}
IterVarRelation SplitNode::make(
IterVar parent, IterVar outer,
IterVar inner, Expr factor) {
......
......@@ -34,8 +34,18 @@ def test_reorder():
sch_T.reorder(*order)
assert tuple(sch_T.leaf_iter_vars) == order
def test_tile():
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
T = tvm.compute((m, n), lambda i, j: A[i, j])
sch_T = tvm.Schedule(T.op, scope="shared")
xo, yo, xi, yi = sch_T.tile(T.op.dim_var[0], T.op.dim_var[1], x_factor=10, y_factor=5)
assert tuple(sch_T.leaf_iter_vars) == (xi, yi, xo, yo)
if __name__ == "__main__":
test_schedule_create()
test_reorder()
test_tile()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment