Commit f9281241 by Zhen Zhang Committed by Tianqi Chen

Check iter_type in vectorize (#1921)

parent f1a4b94b
...@@ -352,6 +352,13 @@ inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) ...@@ -352,6 +352,13 @@ inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type)
} }
Stage& Stage::vectorize(IterVar var) { // NOLINT(*) Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
CHECK(var->iter_type == kDataPar ||
var->iter_type == kOpaque ||
var->iter_type == kUnrolled ||
var->iter_type == kVectorized ||
var->iter_type == kTensorized ||
var->iter_type == kParallelized)
<< "Cannot vectorize on " << IterVarType2String(var->iter_type);
SetAttrIterType(operator->(), var, kVectorized); SetAttrIterType(operator->(), var, kVectorized);
return *this; return *this;
} }
......
from nose.tools import raises
import tvm import tvm
import pickle as pkl import pickle as pkl
...@@ -112,6 +113,13 @@ def test_vectorize(): ...@@ -112,6 +113,13 @@ def test_vectorize():
assert s[T].iter_var_attrs[xi].iter_type == UNROLL assert s[T].iter_var_attrs[xi].iter_type == UNROLL
assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE
@raises(Exception)
def test_vectorize_commreduce():
V = tvm.placeholder((128,), name='V')
ax = tvm.reduce_axis((0, 128), name='ax')
O = tvm.compute((1,), lambda _: tvm.sum(V[ax], axis=[ax]))
s = tvm.create_schedule(O.op)
s[O].vectorize(ax) # should throw here
def test_pragma(): def test_pragma():
m = 100 m = 100
...@@ -197,3 +205,4 @@ if __name__ == "__main__": ...@@ -197,3 +205,4 @@ if __name__ == "__main__":
test_split() test_split()
test_fuse() test_fuse()
test_vectorize() test_vectorize()
test_vectorize_commreduce()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment