Commit e9debc9b by ziheng Committed by Tianqi Chen

[PASS] Use likely tag & enable LoopPartition by default (#132)

* [PASS] Use likely tag & enable LoopPartition by default

* [PASS] Support thread_axis partition

* Take IfThenElse branch method

* [PASS] Insert branch at the innermost thread scope

* [PASS] Select candidates before trying to partition & add test for select

* [PASS] Clean code

* Fix

* Remove print & assert vectorize happens
parent c42e0f1e
...@@ -67,6 +67,7 @@ def lower(sch, ...@@ -67,6 +67,7 @@ def lower(sch,
sch = sch.normalize() sch = sch.normalize()
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.LoopPartition(stmt)
stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt) stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.VectorizeLoop(stmt)
......
...@@ -9,6 +9,7 @@ from . import ir_pass as _pass ...@@ -9,6 +9,7 @@ from . import ir_pass as _pass
from . import collections as _collections from . import collections as _collections
from ._ffi.base import string_types from ._ffi.base import string_types
from ._ffi.node import NodeGeneric from ._ffi.node import NodeGeneric
from .expr import Call as _Call
class WithScope(object): class WithScope(object):
"""Auxiliary scope with""" """Auxiliary scope with"""
...@@ -308,6 +309,19 @@ class IRBuilder(object): ...@@ -308,6 +309,19 @@ class IRBuilder(object):
""" """
return BufferVar(self, buf.data, buf.dtype) return BufferVar(self, buf.data, buf.dtype)
def likely(self, expr):
"""Add likely tag for expression.
Parameters
----------
expr : Expr
The expression. Usually a condition expression.
Returns
-------
expr : Expr
The expression will likely tag.
"""
return _make.Call(expr.dtype, "likely", [expr], _Call.PureIntrinsic, None, 0)
def get(self): def get(self):
"""Return the builded IR. """Return the builded IR.
......
...@@ -311,9 +311,10 @@ Stmt ComputeOpNode::BuildProvide( ...@@ -311,9 +311,10 @@ Stmt ComputeOpNode::BuildProvide(
std::unordered_map<IterVar, Expr> value_map; std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest( auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map); stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
nest.push_back(op::MakeIfNest(op::MakeBoundCheck( auto preds = op::MakeBoundCheck(stage, dom_map, false,
stage, dom_map, false, std::unordered_set<IterVar>(), value_map);
std::unordered_set<IterVar>(), value_map))); for (auto& e : preds) e = likely(e);
nest.push_back(op::MakeIfNest(preds));
if (stage->store_predicate.defined()) { if (stage->store_predicate.defined()) {
nest.emplace_back(op::MakeIfNest({stage->store_predicate})); nest.emplace_back(op::MakeIfNest({stage->store_predicate}));
} }
...@@ -352,9 +353,9 @@ Stmt ComputeOpNode::BuildProvide( ...@@ -352,9 +353,9 @@ Stmt ComputeOpNode::BuildProvide(
auto init_nest = op::MakeLoopNest( auto init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true, stage, dom_map, begin_loop, true,
skip_iter, &init_value_map); skip_iter, &init_value_map);
init_nest.push_back( auto preds = op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map);
op::MakeIfNest( for (auto& e : preds) e = likely(e);
op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map))); init_nest.push_back(op::MakeIfNest(preds));
init = Substitute(init, init_value_map); init = Substitute(init, init_value_map);
init = MergeNest(init_nest, init); init = MergeNest(init_nest, init);
// common nest // common nest
......
...@@ -64,12 +64,12 @@ from tvm.contrib import nvcc_compiler ...@@ -64,12 +64,12 @@ from tvm.contrib import nvcc_compiler
@tvm.register_func @tvm.register_func
def tvm_callback_cuda_compile(code): def tvm_callback_cuda_compile(code):
print(code) print(code)
ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_52"]) ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_35"])
return ptx return ptx
def test_add(): def test_add():
# graph # graph
n = tvm.convert(1024) n = tvm.var('n')
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
bias = tvm.var("bias", dtype="float32") bias = tvm.var("bias", dtype="float32")
......
...@@ -22,6 +22,7 @@ def test_add_pipeline(): ...@@ -22,6 +22,7 @@ def test_add_pipeline():
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
Cb = tvm.decl_buffer(C.shape, C.dtype, name='C') Cb = tvm.decl_buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0)
......
...@@ -17,8 +17,8 @@ def test_basic(): ...@@ -17,8 +17,8 @@ def test_basic():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first)) assert('if' not in str(stmt.body.body.body.first))
print(stmt)
def test_multi_loop(): def test_multi_loop():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
...@@ -27,41 +27,40 @@ def test_multi_loop(): ...@@ -27,41 +27,40 @@ def test_multi_loop():
with ib.for_range(0, 4, "i") as i: with ib.for_range(0, 4, "i") as i:
with ib.for_range(0, n, "j") as j: with ib.for_range(0, n, "j") as j:
with ib.for_range(0, m, "k") as k: with ib.for_range(0, m, "k") as k:
with ib.if_scope(i*m+j+k < n): with ib.if_scope(ib.likely(i*m+j+k < n)):
ib.emit(tvm.make.Evaluate(m)) ib.emit(tvm.make.Evaluate(m))
with ib.else_scope(): with ib.else_scope():
ib.emit(tvm.make.Evaluate(n)) ib.emit(tvm.make.Evaluate(n))
stmt = ib.get() stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt)
assert(not any(collect_visit(stmt.body.first, stmt = tvm.ir_pass.Simplify(stmt)
lambda x: isinstance(x, tvm.stmt.IfThenElse)))) assert(not any(collect_visit(stmt.body.first, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_multi_if(): def test_multi_if():
i = tvm.var('i') ib = tvm.ir_builder.create()
j = tvm.var('j')
k = tvm.var('k')
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
stmt = tvm.make.For( with ib.for_range(0, 4, 'i') as i:
i, 0, 4, 0, 0, with ib.for_range(0, n, 'j') as j:
tvm.make.For( with ib.for_range(0, m, 'k') as k:
j, 0, n, 0, 0, with ib.if_scope(ib.likely(i*m+j+k < n)):
tvm.make.For( ib.emit(tvm.make.Evaluate(m))
k, 0, m, 0, 0, with ib.else_scope():
tvm.make.Block( ib.emit(tvm.make.Evaluate(n))
tvm.make.IfThenElse((i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)), with ib.if_scope(ib.likely(i*m+j-k < n)):
tvm.make.IfThenElse((i*m+j-k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)) ib.emit(tvm.make.Evaluate(m))
)))) with ib.else_scope():
ib.emit(tvm.make.Evaluate(n))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.first)) assert('if' not in str(stmt.body.first))
print(stmt)
def test_thread_axis(): def test_thread_axis():
m = tvm.var('m') m = tvm.var('m')
l = tvm.var('l') l = tvm.var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B') B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
s[B].set_scope("shared") s[B].set_scope("shared")
...@@ -72,12 +71,67 @@ def test_thread_axis(): ...@@ -72,12 +71,67 @@ def test_thread_axis():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt_ = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt)
assert('if' not in str(stmt_.body.body.body.first)) stmt = tvm.ir_pass.Simplify(stmt)
print(stmt_) assert('if' not in str(stmt.body.body.body.first))
def test_vectorize():
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
bias = tvm.var("bias", dtype="float32")
scale = tvm.var("scale", dtype="float32")
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name='C')
# schedule
s = tvm.create_schedule(C.op)
# create iter var and assign them tags.
num_thread = 32
bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
tx, x = s[C].split(x, nparts=num_thread)
_, x = s[C].split(x, factor=4)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
s[C].vectorize(x)
stmt = tvm.lower(s, [A, B], name='ewise_add', with_api_wrapper=False)
body = stmt.body.body.body.body.body
assert(x.var.name not in str(body.condition))
assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.expr.Ramp))))
def test_select():
ib = tvm.ir_builder.create()
m = tvm.var('m')
n = tvm.var('n')
with ib.for_range(0, ((n+3)/4), 'i') as i:
with ib.for_range(0, 4, 'j') as j:
ib.emit(tvm.make.Evaluate(
tvm.make.Select(ib.likely(i*4+j<n), m, n)))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
def test_thread_axis2():
n = tvm.convert(4096)
m = tvm.var('m')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
s = tvm.create_schedule(C.op)
num_thread = 32
bx, x = s[C].split(C.op.axis[0], factor=32)
tx, x = s[C].split(x, nparts=num_thread)
_, x = s[C].split(x, factor=m)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
stmt = tvm.lower(s, [A, B], name='ewise_add', with_api_wrapper=False)
for_body = stmt.body.body.body.body.body.first
assert('threadIdx' not in str(for_body.extent))
if __name__ == "__main__": if __name__ == "__main__":
test_multi_loop()
test_basic() test_basic()
test_multi_loop()
test_multi_if() test_multi_if()
test_thread_axis() test_thread_axis()
test_vectorize()
test_select()
test_thread_axis2()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment