Commit 7e82eb61 by Tianqi Chen Committed by GitHub

[SCHEDULE][REFACTOR] Default Fuse to outer inner, consistent to split (#289)

* [SCHEDULE] Fix fuse node order

* Make fuse order consistent with split
parent e9744431
...@@ -123,12 +123,12 @@ class Stage : public NodeRef { ...@@ -123,12 +123,12 @@ class Stage : public NodeRef {
Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
/*! /*!
* \brief Fuse the inner outer domain to the target * \brief Fuse the inner outer domain to the target
* \param inner The inner domain to be fused
* \param outer The outer domain to be fused. * \param outer The outer domain to be fused.
* \param inner The inner domain to be fused
* \param p_target The result target domain. * \param p_target The result target domain.
* \return reference to self. * \return reference to self.
*/ */
Stage& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*) Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
/*! /*!
* \brief Reorder the iteration * \brief Reorder the iteration
* \param order The order of iteration variable. * \param order The order of iteration variable.
......
...@@ -281,7 +281,7 @@ class Stage(NodeBase): ...@@ -281,7 +281,7 @@ class Stage(NodeBase):
outer, inner = _api_internal._StageSplitByFactor(self, parent, factor) outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)
return outer, inner return outer, inner
def fuse(self, inner, outer): def fuse(self, outer, inner):
"""Fuse inner and outer to a single iteration variable. """Fuse inner and outer to a single iteration variable.
Parameters Parameters
...@@ -294,10 +294,10 @@ class Stage(NodeBase): ...@@ -294,10 +294,10 @@ class Stage(NodeBase):
Returns Returns
------- -------
inner : IterVar fused : IterVar
The fused variable of iteration. The fused variable of iteration.
""" """
return _api_internal._StageFuse(self, inner, outer) return _api_internal._StageFuse(self, outer, inner)
def set_scope(self, scope): def set_scope(self, scope):
"""Set the thread scope of this stage """Set the thread scope of this stage
......
...@@ -218,7 +218,7 @@ Stage& Stage::split_by_nparts( ...@@ -218,7 +218,7 @@ Stage& Stage::split_by_nparts(
return *this; return *this;
} }
Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*) Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT(*)
StageNode* self = operator->(); StageNode* self = operator->();
CHECK(outer->iter_type == kDataPar || CHECK(outer->iter_type == kDataPar ||
outer->iter_type == kCommReduce || outer->iter_type == kCommReduce ||
...@@ -227,7 +227,7 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT ...@@ -227,7 +227,7 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT
CHECK(inner->iter_type == kDataPar || CHECK(inner->iter_type == kDataPar ||
inner->iter_type == kCommReduce || inner->iter_type == kCommReduce ||
inner->iter_type == kOrdered) inner->iter_type == kOrdered)
<< "Cannot fuse " << IterVarType2String(outer->iter_type); << "Cannot fuse " << IterVarType2String(inner->iter_type);
IterVarType iter_type = outer->iter_type; IterVarType iter_type = outer->iter_type;
if (inner->iter_type > iter_type) iter_type = inner->iter_type; if (inner->iter_type > iter_type) iter_type = inner->iter_type;
...@@ -241,11 +241,14 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT ...@@ -241,11 +241,14 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
self->relations.push_back(FuseNode::make(inner, outer, fused));
all_vars->data.push_back(fused.node_);
size_t pos_inner = FindLeafVar(all_vars, leaf_vars, inner); size_t pos_inner = FindLeafVar(all_vars, leaf_vars, inner);
size_t pos_outer = FindLeafVar(all_vars, leaf_vars, outer); size_t pos_outer = FindLeafVar(all_vars, leaf_vars, outer);
if (pos_inner + 1 == pos_outer) {
std::swap(outer, inner);
std::swap(pos_inner, pos_outer);
}
self->relations.push_back(FuseNode::make(outer, inner, fused));
all_vars->data.push_back(fused.node_);
CHECK_EQ(pos_inner, pos_outer + 1) CHECK_EQ(pos_inner, pos_outer + 1)
<< "Can only fuse iterations that are consecutive between each other"; << "Can only fuse iterations that are consecutive between each other";
leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer, leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
......
...@@ -72,7 +72,7 @@ def test_fuse(): ...@@ -72,7 +72,7 @@ def test_fuse():
s = tvm.create_schedule(T.op) s = tvm.create_schedule(T.op)
xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5) xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
fused = s[T].fuse(yo, xo) fused = s[T].fuse(xo, yo)
assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations) assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi) assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)
......
...@@ -38,7 +38,7 @@ def _schedule_conv2d_hwcn(op, sch): ...@@ -38,7 +38,7 @@ def _schedule_conv2d_hwcn(op, sch):
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy") thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
hi, wi, fi, ni = sch[Out].op.axis hi, wi, fi, ni = sch[Out].op.axis
bz = sch[Out].fuse(wi, hi) bz = sch[Out].fuse(hi, wi)
by, fi = sch[Out].split(fi, factor=block_factor) by, fi = sch[Out].split(fi, factor=block_factor)
bx, ni = sch[Out].split(ni, factor=block_factor) bx, ni = sch[Out].split(ni, factor=block_factor)
tyz, fi = sch[Out].split(fi, nparts=vthread) tyz, fi = sch[Out].split(fi, nparts=vthread)
...@@ -60,7 +60,7 @@ def _schedule_conv2d_hwcn(op, sch): ...@@ -60,7 +60,7 @@ def _schedule_conv2d_hwcn(op, sch):
ry, rx, rc = sch[BL].op.reduce_axis ry, rx, rc = sch[BL].op.reduce_axis
rco, rci = sch[BL].split(rc, factor=step) rco, rci = sch[BL].split(rc, factor=step)
sch[BL].reorder(rco, ry, rx, rci, fi, ni) sch[BL].reorder(rco, ry, rx, rci, fi, ni)
fuse_index = sch[BL].fuse(rx, ry) fuse_index = sch[BL].fuse(ry, rx)
fuse_index = sch[BL].fuse(fuse_index, rco) fuse_index = sch[BL].fuse(fuse_index, rco)
rx = fuse_index rx = fuse_index
......
...@@ -61,7 +61,7 @@ def schedule_depthwise_conv2d_map(op): ...@@ -61,7 +61,7 @@ def schedule_depthwise_conv2d_map(op):
# split and bind # split and bind
bx, bxi = s[Output].split(Output.op.axis[1], factor=channel_multiplier) bx, bxi = s[Output].split(Output.op.axis[1], factor=channel_multiplier)
s[Output].reorder(Output.op.axis[2], Output.op.axis[3], bxi) s[Output].reorder(Output.op.axis[2], Output.op.axis[3], bxi)
bx = s[Output].fuse(bx, Output.op.axis[0]) bx = s[Output].fuse(Output.op.axis[0], bx)
s[Output].bind(bx, block_x) s[Output].bind(bx, block_x)
by1, y1i = s[Output].split(Output.op.axis[2], factor=blocking_h) by1, y1i = s[Output].split(Output.op.axis[2], factor=blocking_h)
tvx, vxi = s[Output].split(y1i, nparts=num_vthread_x) tvx, vxi = s[Output].split(y1i, nparts=num_vthread_x)
...@@ -70,7 +70,7 @@ def schedule_depthwise_conv2d_map(op): ...@@ -70,7 +70,7 @@ def schedule_depthwise_conv2d_map(op):
tvy, vyi = s[Output].split(y2i, nparts=num_vthread_y) tvy, vyi = s[Output].split(y2i, nparts=num_vthread_y)
ty, yi = s[Output].split(vyi, nparts=num_thread) ty, yi = s[Output].split(vyi, nparts=num_thread)
s[Output].reorder(by1, by2, tvx, tvy, tx, ty, xi, yi) s[Output].reorder(by1, by2, tvx, tvy, tx, ty, xi, yi)
by = s[Output].fuse(by2, by1) by = s[Output].fuse(by1, by2)
s[Output].bind(tvx, thread_vx) s[Output].bind(tvx, thread_vx)
s[Output].bind(tvy, thread_vy) s[Output].bind(tvy, thread_vy)
s[Output].bind(tx, thread_x) s[Output].bind(tx, thread_x)
......
...@@ -94,7 +94,7 @@ s = tvm.create_schedule(B.op) ...@@ -94,7 +94,7 @@ s = tvm.create_schedule(B.op)
# tile to four axises first: (i.outer, j.outer, i.inner, j.inner) # tile to four axises first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5) xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused) # then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
fused = s[B].fuse(yi, xi) fused = s[B].fuse(xi, yi)
print(tvm.lower(s, [A, B], simple_mode=True)) print(tvm.lower(s, [A, B], simple_mode=True))
###################################################################### ######################################################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment