Commit 2bcf3f2c by Ziheng Jiang Committed by Tianqi Chen

fix Stage.fuse (#33)

parent e42cc112
...@@ -216,7 +216,7 @@ TVM_REGISTER_API(_StageFuse) ...@@ -216,7 +216,7 @@ TVM_REGISTER_API(_StageFuse)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar fused; IterVar fused;
args[0].operator Stage() args[0].operator Stage()
.split(args[1], args[2], &fused); .fuse(args[1], args[2], &fused);
*ret = fused; *ret = fused;
}); });
......
...@@ -117,6 +117,7 @@ Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor ...@@ -117,6 +117,7 @@ Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor
Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*) Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*)
IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused"); IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused");
*p_target = fused;
StageNode* self = operator->(); StageNode* self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
...@@ -129,7 +130,7 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT ...@@ -129,7 +130,7 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT
CHECK_EQ(pos_inner, pos_outer + 1) CHECK_EQ(pos_inner, pos_outer + 1)
<< "Can only fuse iterations that are consecutive between each other"; << "Can only fuse iterations that are consecutive between each other";
leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer, leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
leaf_vars->data.begin() + pos_inner); leaf_vars->data.begin() + pos_inner + 1);
leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer, leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
fused.node_); fused.node_);
return *this; return *this;
......
...@@ -63,8 +63,23 @@ def test_tile(): ...@@ -63,8 +63,23 @@ def test_tile():
xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5) xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
assert tuple(s[T].leaf_iter_vars) == (xo, yo, xi, yi) assert tuple(s[T].leaf_iter_vars) == (xo, yo, xi, yi)
def test_fuse():
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
T = tvm.compute((m, n), lambda i, j: A[i, j])
s = tvm.Schedule(T.op)
xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
fused = s[T].fuse(yo, xo)
assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_create() test_schedule_create()
test_reorder() test_reorder()
test_tile() test_tile()
test_split() test_split()
test_fuse()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment