Commit e9debc9b by ziheng Committed by Tianqi Chen

[PASS] Use likely tag & enable LoopPartition by default (#132)

* [PASS] Use likely tag & enable LoopPartition by default

* [PASS] Support thread_axis partition

* Take IfThenElse branch method

* [PASS] Insert branch at the innermost thread scope

* [PASS] Select candidates before trying to partition & add test for select

* [PASS] Clean code

* Fix

* Remove print & assert vectorize happens
parent c42e0f1e
......@@ -67,6 +67,7 @@ def lower(sch,
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.LoopPartition(stmt)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
......
......@@ -9,6 +9,7 @@ from . import ir_pass as _pass
from . import collections as _collections
from ._ffi.base import string_types
from ._ffi.node import NodeGeneric
from .expr import Call as _Call
class WithScope(object):
"""Auxiliary scope with"""
......@@ -308,6 +309,19 @@ class IRBuilder(object):
"""
return BufferVar(self, buf.data, buf.dtype)
def likely(self, expr):
"""Add likely tag for expression.
Parameters
----------
expr : Expr
The expression. Usually a condition expression.
Returns
-------
expr : Expr
The expression will likely tag.
"""
return _make.Call(expr.dtype, "likely", [expr], _Call.PureIntrinsic, None, 0)
def get(self):
"""Return the builded IR.
......
......@@ -311,9 +311,10 @@ Stmt ComputeOpNode::BuildProvide(
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
nest.push_back(op::MakeIfNest(op::MakeBoundCheck(
stage, dom_map, false,
std::unordered_set<IterVar>(), value_map)));
auto preds = op::MakeBoundCheck(stage, dom_map, false,
std::unordered_set<IterVar>(), value_map);
for (auto& e : preds) e = likely(e);
nest.push_back(op::MakeIfNest(preds));
if (stage->store_predicate.defined()) {
nest.emplace_back(op::MakeIfNest({stage->store_predicate}));
}
......@@ -352,9 +353,9 @@ Stmt ComputeOpNode::BuildProvide(
auto init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true,
skip_iter, &init_value_map);
init_nest.push_back(
op::MakeIfNest(
op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map)));
auto preds = op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map);
for (auto& e : preds) e = likely(e);
init_nest.push_back(op::MakeIfNest(preds));
init = Substitute(init, init_value_map);
init = MergeNest(init_nest, init);
// common nest
......
......@@ -64,12 +64,12 @@ from tvm.contrib import nvcc_compiler
@tvm.register_func
def tvm_callback_cuda_compile(code):
print(code)
ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_52"])
ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_35"])
return ptx
def test_add():
# graph
n = tvm.convert(1024)
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
bias = tvm.var("bias", dtype="float32")
......
......@@ -22,6 +22,7 @@ def test_add_pipeline():
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
Cb = tvm.decl_buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0)
......
......@@ -17,8 +17,8 @@ def test_basic():
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first))
print(stmt)
def test_multi_loop():
ib = tvm.ir_builder.create()
......@@ -27,41 +27,40 @@ def test_multi_loop():
with ib.for_range(0, 4, "i") as i:
with ib.for_range(0, n, "j") as j:
with ib.for_range(0, m, "k") as k:
with ib.if_scope(i*m+j+k < n):
with ib.if_scope(ib.likely(i*m+j+k < n)):
ib.emit(tvm.make.Evaluate(m))
with ib.else_scope():
ib.emit(tvm.make.Evaluate(n))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
assert(not any(collect_visit(stmt.body.first,
lambda x: isinstance(x, tvm.stmt.IfThenElse))))
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.body.first, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_multi_if():
i = tvm.var('i')
j = tvm.var('j')
k = tvm.var('k')
ib = tvm.ir_builder.create()
m = tvm.var('m')
n = tvm.var('n')
stmt = tvm.make.For(
i, 0, 4, 0, 0,
tvm.make.For(
j, 0, n, 0, 0,
tvm.make.For(
k, 0, m, 0, 0,
tvm.make.Block(
tvm.make.IfThenElse((i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)),
tvm.make.IfThenElse((i*m+j-k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n))
))))
with ib.for_range(0, 4, 'i') as i:
with ib.for_range(0, n, 'j') as j:
with ib.for_range(0, m, 'k') as k:
with ib.if_scope(ib.likely(i*m+j+k < n)):
ib.emit(tvm.make.Evaluate(m))
with ib.else_scope():
ib.emit(tvm.make.Evaluate(n))
with ib.if_scope(ib.likely(i*m+j-k < n)):
ib.emit(tvm.make.Evaluate(m))
with ib.else_scope():
ib.emit(tvm.make.Evaluate(n))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.first))
print(stmt)
def test_thread_axis():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B')
s = tvm.create_schedule(B.op)
s[B].set_scope("shared")
......@@ -72,12 +71,67 @@ def test_thread_axis():
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt_ = tvm.ir_pass.LoopPartition(stmt)
assert('if' not in str(stmt_.body.body.body.first))
print(stmt_)
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first))
def test_vectorize():
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
bias = tvm.var("bias", dtype="float32")
scale = tvm.var("scale", dtype="float32")
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name='C')
# schedule
s = tvm.create_schedule(C.op)
# create iter var and assign them tags.
num_thread = 32
bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
tx, x = s[C].split(x, nparts=num_thread)
_, x = s[C].split(x, factor=4)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
s[C].vectorize(x)
stmt = tvm.lower(s, [A, B], name='ewise_add', with_api_wrapper=False)
body = stmt.body.body.body.body.body
assert(x.var.name not in str(body.condition))
assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.expr.Ramp))))
def test_select():
ib = tvm.ir_builder.create()
m = tvm.var('m')
n = tvm.var('n')
with ib.for_range(0, ((n+3)/4), 'i') as i:
with ib.for_range(0, 4, 'j') as j:
ib.emit(tvm.make.Evaluate(
tvm.make.Select(ib.likely(i*4+j<n), m, n)))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
def test_thread_axis2():
n = tvm.convert(4096)
m = tvm.var('m')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
s = tvm.create_schedule(C.op)
num_thread = 32
bx, x = s[C].split(C.op.axis[0], factor=32)
tx, x = s[C].split(x, nparts=num_thread)
_, x = s[C].split(x, factor=m)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
stmt = tvm.lower(s, [A, B], name='ewise_add', with_api_wrapper=False)
for_body = stmt.body.body.body.body.body.first
assert('threadIdx' not in str(for_body.extent))
if __name__ == "__main__":
test_multi_loop()
test_basic()
test_multi_loop()
test_multi_if()
test_thread_axis()
test_vectorize()
test_select()
test_thread_axis2()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment