Commit 2bcf3f2c by Ziheng Jiang Committed by Tianqi Chen

fix Stage.fuse (#33)

parent e42cc112
......@@ -216,7 +216,7 @@ TVM_REGISTER_API(_StageFuse)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar fused;
args[0].operator Stage()
.split(args[1], args[2], &fused);
.fuse(args[1], args[2], &fused);
*ret = fused;
});
......
......@@ -117,6 +117,7 @@ Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor
Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*)
IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused");
*p_target = fused;
StageNode* self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
......@@ -129,7 +130,7 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT
CHECK_EQ(pos_inner, pos_outer + 1)
<< "Can only fuse iterations that are consecutive between each other";
leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
leaf_vars->data.begin() + pos_inner);
leaf_vars->data.begin() + pos_inner + 1);
leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
fused.node_);
return *this;
......
......@@ -63,8 +63,23 @@ def test_tile():
xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
assert tuple(s[T].leaf_iter_vars) == (xo, yo, xi, yi)
def test_fuse():
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
T = tvm.compute((m, n), lambda i, j: A[i, j])
s = tvm.Schedule(T.op)
xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
fused = s[T].fuse(yo, xo)
assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)
if __name__ == "__main__":
test_schedule_create()
test_reorder()
test_tile()
test_split()
test_fuse()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment