Commit 1791b121 by Wei Chen Committed by Tianqi Chen

[SCHEDULE] Detect duplicate IterVar in reorder (#575)

parent 326edd76
......@@ -259,6 +259,7 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT
}
Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
std::unordered_set<IterVar> seen_var;
StageNode* self = operator->();
for (IterVar iv : order) {
CHECK(iv->iter_type == kDataPar ||
......@@ -266,6 +267,10 @@ Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
iv->iter_type == kThreadIndex)
<< "Cannot reorder IterVar("
<< IterVarType2String(iv->iter_type) << ")";
CHECK_EQ(seen_var.count(iv), 0)
<< "Same axis can not appear more than once " << iv;
seen_var.insert(iv);
}
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
......
......@@ -43,6 +43,13 @@ def test_reorder():
assert tuple(s[T].leaf_iter_vars) != order
s[T].reorder(*order)
assert tuple(s[T].leaf_iter_vars) == order
try:
# pass duplicate IterVar
# must raise an error
s[T].reorder(xi2, xi1, xi2)
assert False
except tvm.TVMError:
pass
def test_split():
m = tvm.var('m')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment