Unverified Commit c41e0402 by ANSHUMAN TRIPATHY Committed by GitHub

Early checking added and new test cases added for schedule fuse (#5010)

* [1] New test case added for fuse

* [2] New test case added for fuse

* [3] New test case added for fuse

* [4] New test case added for fuse

* [5] Early check added
parent fd39c5c0
......@@ -263,10 +263,10 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT
std::swap(outer, inner);
std::swap(pos_inner, pos_outer);
}
self->relations.push_back(FuseNode::make(outer, inner, fused));
all_vars->data.push_back(fused);
CHECK_EQ(pos_inner, pos_outer + 1)
<< "Can only fuse iterations that are consecutive between each other";
self->relations.push_back(FuseNode::make(outer, inner, fused));
all_vars->data.push_back(fused);
leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
leaf_vars->data.begin() + pos_inner + 1);
leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
......
......@@ -102,6 +102,49 @@ def test_fuse():
assert any(isinstance(x, tvm.te.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)
def test_fuse_with_split():
m = te.size_var('m')
n = te.size_var('n')
A = te.placeholder((m, n), name='A')
T = te.compute((m, n), lambda i, j: A[i, j])
s = te.create_schedule(T.op)
y = T.op.axis[1]
xo, xi = s[T].split(T.op.axis[0], factor=10)
fused = s[T].fuse(xi, y)
assert any(isinstance(x, tvm.te.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (xo, fused)
@pytest.mark.xfail
def test_fuse_with_out_of_order_axis():
m = te.size_var('m')
n = te.size_var('n')
A = te.placeholder((m, n), name='A')
T = te.compute((m, n), lambda i, j: A[i, j])
s = te.create_schedule(T.op)
y = T.op.axis[1]
xo, xi = s[T].split(T.op.axis[0], factor=10)
fused = s[T].fuse(xo, y) # should throw here
@pytest.mark.xfail
def test_fuse_with_out_of_order_axis_with_reorder():
m = te.size_var('m')
n = te.size_var('n')
A = te.placeholder((m, n), name='A')
T = te.compute((m, n), lambda i, j: A[i, j])
s = te.create_schedule(T.op)
y = T.op.axis[1]
xo, xi = s[T].split(T.op.axis[0], factor=10)
s[T].reorder(y, xo, xi)
fused = s[T].fuse(y, xo) # should be ok
s = te.create_schedule(T.op)
y = T.op.axis[1]
xo, xi = s[T].split(T.op.axis[0], factor=10)
s[T].reorder(y, xo, xi)
fused = s[T].fuse(y, xi) # should throw here
def test_singleton():
print("test singleton")
......@@ -257,5 +300,8 @@ if __name__ == "__main__":
test_tile()
test_split()
test_fuse()
test_fuse_with_split()
test_fuse_with_out_of_order_axis()
test_fuse_with_out_of_order_axis_with_reorder()
test_vectorize()
test_vectorize_commreduce()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment