Commit 7e82eb61 by Tianqi Chen Committed by GitHub

[SCHEDULE][REFACTOR] Default Fuse to outer inner, consistent to split (#289)

* [SCHEDULE] Fix fuse node order

* Make fuse order consistent with split
parent e9744431
......@@ -123,12 +123,12 @@ class Stage : public NodeRef {
Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
/*!
* \brief Fuse the inner outer domain to the target
* \param inner The inner domain to be fused
* \param outer The outer domain to be fused.
* \param inner The inner domain to be fused
* \param p_target The result target domain.
* \return reference to self.
*/
Stage& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*)
Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
/*!
* \brief Reorder the iteration
* \param order The order of iteration variable.
......
......@@ -281,7 +281,7 @@ class Stage(NodeBase):
outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)
return outer, inner
def fuse(self, inner, outer):
def fuse(self, outer, inner):
"""Fuse inner and outer to a single iteration variable.
Parameters
......@@ -294,10 +294,10 @@ class Stage(NodeBase):
Returns
-------
inner : IterVar
fused : IterVar
The fused variable of iteration.
"""
return _api_internal._StageFuse(self, inner, outer)
return _api_internal._StageFuse(self, outer, inner)
def set_scope(self, scope):
"""Set the thread scope of this stage
......
......@@ -218,7 +218,7 @@ Stage& Stage::split_by_nparts(
return *this;
}
Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*)
Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT(*)
StageNode* self = operator->();
CHECK(outer->iter_type == kDataPar ||
outer->iter_type == kCommReduce ||
......@@ -227,7 +227,7 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT
CHECK(inner->iter_type == kDataPar ||
inner->iter_type == kCommReduce ||
inner->iter_type == kOrdered)
<< "Cannot fuse " << IterVarType2String(outer->iter_type);
<< "Cannot fuse " << IterVarType2String(inner->iter_type);
IterVarType iter_type = outer->iter_type;
if (inner->iter_type > iter_type) iter_type = inner->iter_type;
......@@ -241,11 +241,14 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
self->relations.push_back(FuseNode::make(inner, outer, fused));
all_vars->data.push_back(fused.node_);
size_t pos_inner = FindLeafVar(all_vars, leaf_vars, inner);
size_t pos_outer = FindLeafVar(all_vars, leaf_vars, outer);
if (pos_inner + 1 == pos_outer) {
std::swap(outer, inner);
std::swap(pos_inner, pos_outer);
}
self->relations.push_back(FuseNode::make(outer, inner, fused));
all_vars->data.push_back(fused.node_);
CHECK_EQ(pos_inner, pos_outer + 1)
<< "Can only fuse iterations that are consecutive between each other";
leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
......
......@@ -72,7 +72,7 @@ def test_fuse():
s = tvm.create_schedule(T.op)
xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
fused = s[T].fuse(yo, xo)
fused = s[T].fuse(xo, yo)
assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)
......
......@@ -38,7 +38,7 @@ def _schedule_conv2d_hwcn(op, sch):
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
hi, wi, fi, ni = sch[Out].op.axis
bz = sch[Out].fuse(wi, hi)
bz = sch[Out].fuse(hi, wi)
by, fi = sch[Out].split(fi, factor=block_factor)
bx, ni = sch[Out].split(ni, factor=block_factor)
tyz, fi = sch[Out].split(fi, nparts=vthread)
......@@ -60,7 +60,7 @@ def _schedule_conv2d_hwcn(op, sch):
ry, rx, rc = sch[BL].op.reduce_axis
rco, rci = sch[BL].split(rc, factor=step)
sch[BL].reorder(rco, ry, rx, rci, fi, ni)
fuse_index = sch[BL].fuse(rx, ry)
fuse_index = sch[BL].fuse(ry, rx)
fuse_index = sch[BL].fuse(fuse_index, rco)
rx = fuse_index
......
......@@ -61,7 +61,7 @@ def schedule_depthwise_conv2d_map(op):
# split and bind
bx, bxi = s[Output].split(Output.op.axis[1], factor=channel_multiplier)
s[Output].reorder(Output.op.axis[2], Output.op.axis[3], bxi)
bx = s[Output].fuse(bx, Output.op.axis[0])
bx = s[Output].fuse(Output.op.axis[0], bx)
s[Output].bind(bx, block_x)
by1, y1i = s[Output].split(Output.op.axis[2], factor=blocking_h)
tvx, vxi = s[Output].split(y1i, nparts=num_vthread_x)
......@@ -70,7 +70,7 @@ def schedule_depthwise_conv2d_map(op):
tvy, vyi = s[Output].split(y2i, nparts=num_vthread_y)
ty, yi = s[Output].split(vyi, nparts=num_thread)
s[Output].reorder(by1, by2, tvx, tvy, tx, ty, xi, yi)
by = s[Output].fuse(by2, by1)
by = s[Output].fuse(by1, by2)
s[Output].bind(tvx, thread_vx)
s[Output].bind(tvy, thread_vy)
s[Output].bind(tx, thread_x)
......
......@@ -94,7 +94,7 @@ s = tvm.create_schedule(B.op)
# tile to four axises first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
fused = s[B].fuse(yi, xi)
fused = s[B].fuse(xi, yi)
print(tvm.lower(s, [A, B], simple_mode=True))
######################################################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment