Commit e3695cad by Tianqi Chen Committed by GitHub

[BUGFIX/PASS] Fix Vectorize with If condition (#135)

parent e9debc9b
......@@ -144,7 +144,7 @@ class IRBuilder(object):
value = _make.StringImm(value)
self.emit(lambda x: _make.AttrStmt(node, attr_key, value, x))
def for_range(self, begin, end, name="i", dtype="int32"):
def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
"""Create a for iteration scope.
Parameters
......@@ -161,6 +161,9 @@ class IRBuilder(object):
dtype : str, optional
The data type of iteration variable.
for_type : str, optional
The special tag on the for loop.
Returns
-------
loop_scope : With.Scope of Var
......@@ -179,8 +182,18 @@ class IRBuilder(object):
loop_var = _api.var(name, dtype=dtype)
extent = end if begin == 0 else _pass.Simplify(end - begin)
def _exit_cb():
if for_type == "serial":
for_type_id = 0
elif for_type == "parallel":
for_type_id = 1
elif for_type == "vectorize":
for_type_id = 2
elif for_type == "unroll":
for_type_id = 3
else:
raise ValueError("Unknown for_type")
self.emit(_make.For(
loop_var, begin, extent, 0, 0, self._pop_seq()))
loop_var, begin, extent, for_type_id, 0, self._pop_seq()))
return WithScope(loop_var, _exit_cb)
def if_scope(self, cond):
......
......@@ -252,7 +252,7 @@ class Vectorizer : public IRMutator {
}
Stmt then_case = this->Mutate(op->then_case);
Stmt else_case;
if (else_case.defined()) {
if (op->else_case.defined()) {
else_case = this->Mutate(op->else_case);
}
if (condition.same_as(op->condition) &&
......
......@@ -3,22 +3,38 @@ import tvm
def test_vectorize_loop():
dtype = 'int64'
n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i')
j = tvm.var('j')
VECTORIZE = 2
# for i in 0 to n-1:
stmt = tvm.make.For(
i, n, 2, 0, 0,
tvm.make.For(j, 0, 4, VECTORIZE, 0,
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1,
j + 1)))
ib = tvm.ir_builder.create()
A = ib.pointer("float32", name="A")
with ib.for_range(0, n) as i:
with ib.for_range(0, 4, for_type="vectorize") as j:
A[j + 1] = A[i] + 1
stmt = ib.get()
assert isinstance(stmt.body, tvm.stmt.For)
stmt = tvm.ir_pass.VectorizeLoop(stmt)
assert isinstance(stmt, tvm.stmt.For)
assert not isinstance(stmt.body, tvm.stmt.For)
print(stmt)
assert isinstance(stmt.body.index, tvm.expr.Ramp)
assert isinstance(stmt.body.value, tvm.expr.Broadcast)
def test_vectorize_with_if():
n = tvm.var('n')
x = tvm.var('x')
ib = tvm.ir_builder.create()
A = ib.pointer("float32", name="A")
with ib.for_range(0, 4, for_type="vectorize") as i:
with ib.if_scope(x < n):
A[i] = A[i] + 1
with ib.else_scope():
with ib.if_scope(i < n):
A[i] = 2.0
stmt = ib.get()
stmt = tvm.ir_pass.VectorizeLoop(stmt)
assert isinstance(stmt, tvm.stmt.IfThenElse)
assert isinstance(stmt.then_case.index, tvm.expr.Ramp)
assert isinstance(stmt.then_case.value, tvm.expr.Add)
assert stmt.then_case.value.dtype == "float32x4"
assert isinstance(stmt.else_case, tvm.stmt.For)
if __name__ == "__main__":
test_vectorize_with_if()
test_vectorize_loop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment