Commit dd1558af by Animesh Jain Committed by Tianqi Chen

Conditional Loop Partitioning - Extending to remove if conditions (#1797)

parent d5103bbc
...@@ -239,11 +239,16 @@ class ThreadPartitionInserter : public IRMutator { ...@@ -239,11 +239,16 @@ class ThreadPartitionInserter : public IRMutator {
// Try to do partition at the candidate IRs // Try to do partition at the candidate IRs
class LoopPartitioner : public IRMutator { class LoopPartitioner : public IRMutator {
public: public:
explicit LoopPartitioner(std::unordered_set<const Node*> candidates) explicit LoopPartitioner(bool split_const_loop)
: candidates_(candidates) {} : selector(CandidateSelector(split_const_loop)) {}
Stmt VisitAndMutate(const Stmt& stmt) {
selector.Visit(stmt);
return Mutate(stmt);
}
Stmt Mutate_(const For* op, const Stmt& stmt) { Stmt Mutate_(const For* op, const Stmt& stmt) {
if (candidates_.count(op)) { if (selector.candidates.count(op)) {
Stmt s = TryPartition(op, stmt, op->loop_var, Stmt s = TryPartition(op, stmt, op->loop_var,
op->min, op->min + op->extent - 1, op->body, false); op->min, op->min + op->extent - 1, op->body, false);
if (s.defined()) return s; if (s.defined()) return s;
...@@ -266,7 +271,7 @@ class LoopPartitioner : public IRMutator { ...@@ -266,7 +271,7 @@ class LoopPartitioner : public IRMutator {
const IterVarNode *iv = op->node.as<IterVarNode>(); const IterVarNode *iv = op->node.as<IterVarNode>();
CHECK(iv); CHECK(iv);
Var var = iv->var; Var var = iv->var;
if (candidates_.count(op)) { if (selector.candidates.count(op)) {
Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true); Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true);
if (s.defined()) return s; if (s.defined()) return s;
} }
...@@ -295,9 +300,9 @@ class LoopPartitioner : public IRMutator { ...@@ -295,9 +300,9 @@ class LoopPartitioner : public IRMutator {
inline Stmt MakeFor(const Node* op, Expr extent, Stmt body); inline Stmt MakeFor(const Node* op, Expr extent, Stmt body);
/* Candidate IRs that may be partitioned potentially */ /* Candidate IRs that may be partitioned potentially */
std::unordered_set<const Node*> candidates_;
std::unordered_map<const Variable*, IntSet> hint_map_; std::unordered_map<const Variable*, IntSet> hint_map_;
std::unordered_map<const Variable*, IntSet> relax_map_; std::unordered_map<const Variable*, IntSet> relax_map_;
CandidateSelector selector;
}; };
Stmt LoopPartitioner::TryPartition(const Node* node, Stmt LoopPartitioner::TryPartition(const Node* node,
...@@ -322,7 +327,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -322,7 +327,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr body_begin; Expr body_begin;
Stmt pre_stmt; Stmt pre_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) { if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) {
body_begin = true_itrv.min(); body_begin = ir::Simplify(true_itrv.min());
if (!can_prove(body_begin == min)) { if (!can_prove(body_begin == min)) {
Expr cond = (body_begin - min >= 0); Expr cond = (body_begin - min >= 0);
if (!can_prove(cond)) { if (!can_prove(cond)) {
...@@ -343,7 +348,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -343,7 +348,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr post_doubt_begin; Expr post_doubt_begin;
Stmt post_stmt; Stmt post_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) { if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) {
post_doubt_begin = true_itrv.max() + 1; post_doubt_begin = ir::Simplify(true_itrv.max() + 1);
if (!can_prove(true_itrv.max() == max)) { if (!can_prove(true_itrv.max() == max)) {
// require the extent to be non-negative // require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0); Expr cond = (max - post_doubt_begin + 1 >= 0);
...@@ -354,8 +359,17 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -354,8 +359,17 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
} }
// [post_doubt_begin, max] // [post_doubt_begin, max]
if (!partition_thread_scope) { if (!partition_thread_scope) {
Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); Stmt post_body;
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); // If the loop is going from 0 to 1, replace the loop var with min value
if (as_const_int(max) && as_const_int(post_doubt_begin)) {
if (*as_const_int(max) == *as_const_int(post_doubt_begin)) {
post_body = Substitute(body, {{Var{var}, post_doubt_begin}});
post_stmt = post_body;
}
} else {
post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
}
} }
} }
} else { } else {
...@@ -368,8 +382,15 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -368,8 +382,15 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Stmt simplified_body = ConditionEliminator(partitions).Mutate(body); Stmt simplified_body = ConditionEliminator(partitions).Mutate(body);
Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
s = MakeFor(node, post_doubt_begin - body_begin, new_body); s = MakeFor(node, post_doubt_begin - body_begin, new_body);
if (pre_stmt.defined()) s = Block::make(pre_stmt, s);
if (post_stmt.defined()) s = Block::make(s, post_stmt); if (!(pre_stmt.defined() && post_stmt.defined())) s = VisitAndMutate(s);
if (pre_stmt.defined()) s = Block::make(pre_stmt, s);
if (post_stmt.defined()) {
if (as_const_int(max) && as_const_int(post_doubt_begin)) {
post_stmt = VisitAndMutate(post_stmt);
}
s = Block::make(s, post_stmt);
}
} else { } else {
Expr cond = const_true(); Expr cond = const_true();
if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin); if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin);
...@@ -402,9 +423,7 @@ class RemoveLikelyTags : public IRMutator { ...@@ -402,9 +423,7 @@ class RemoveLikelyTags : public IRMutator {
}; };
Stmt LoopPartition(Stmt stmt, bool split_const_loop) { Stmt LoopPartition(Stmt stmt, bool split_const_loop) {
CandidateSelector selector(split_const_loop); stmt = LoopPartitioner(split_const_loop).VisitAndMutate(stmt);
selector.Visit(stmt);
stmt = LoopPartitioner(selector.candidates).Mutate(stmt);
stmt = RemoveLikelyTags().Mutate(stmt); stmt = RemoveLikelyTags().Mutate(stmt);
return stmt; return stmt;
} }
......
...@@ -177,6 +177,157 @@ def test_everything_during_deduction(): ...@@ -177,6 +177,157 @@ def test_everything_during_deduction():
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert(isinstance(stmt.body.body, tvm.stmt.IfThenElse)) assert(isinstance(stmt.body.body, tvm.stmt.IfThenElse))
def test_single_likely():
n = 60
A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((n, ), name='B')
T = tvm.compute((n, ), lambda i: A[i]+B[i])
s = tvm.create_schedule(T.op)
x = T.op.axis[0]
xo, xi = s[T].split(x, factor=16)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_multi_likely():
n = 94
m = 62
A = tvm.placeholder((n, m), name='A')
B = tvm.placeholder((n, m), name='B')
T = tvm.compute((n, m), lambda i, j: A[i, j]+B[i, j])
s = tvm.create_schedule(T.op)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
x, y = T.op.axis
xo, xi = s[T].split(x, factor=16)
yo, yi = s[T].split(y, factor=16)
s[T].reorder(xo, yo, xi, yi)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_oneD_pool():
m = tvm.var('m')
ib = tvm.ir_builder.create()
#data = tvm.placeholder((16,), name = 'data')
data = ib.pointer("float32", name="A")
out = ib.pointer("float32", name="A")
with ib.for_range(0, 16, 'ow') as ow:
with ib.for_range(0, 3, 'kw') as kw:
with ib.if_scope(ib.likely(ow > 0)):
with ib.if_scope(ib.likely(ow < 15)):
out[ow] = tvm.max(out[ow], data[ow + kw - 1])
with ib.for_range(0, 16, 'ow') as ow:
with ib.for_range(0, 3, 'kw') as kw:
with ib.if_scope(ib.likely(ow < 1)):
with ib.if_scope(ib.likely(kw > 0)):
out[ow] = tvm.max(out[ow], data[ow + kw - 1])
with ib.for_range(0, 16, 'ow') as ow:
with ib.for_range(0, 3, 'kw') as kw:
with ib.if_scope(ib.likely(ow > 14)):
with ib.if_scope(ib.likely(kw < 2)):
out[ow] = tvm.max(out[ow], data[ow + kw - 1])
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_cce_loop_1():
ib = tvm.ir_builder.create()
dtype = 'float16'
n = 514
m = 514
_A = tvm.placeholder((n*m,), name = 'A')
Ab = tvm.decl_buffer((n*m,), dtype, name="A")
A = ib.buffer_ptr(Ab)
_B = tvm.placeholder((n*m,), name = 'B')
Bb = tvm.decl_buffer((n*m,), dtype, name="B")
B = ib.buffer_ptr(Bb)
#for i in 0 to n-1:
with ib.for_range(0, 11, name="i") as i:
with ib.for_range(0, 160, name="j") as j:
with ib.if_scope(ib.likely(((i*160) + j) < 1600)):
A[(i+1)*m+j+1] = B[(i)*m+j+1] + B[(i+1)*m+j+1] + B[(i+2)*m+j+1]
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_cce_loop_2():
ib = tvm.ir_builder.create()
len = 112
tile = 32
loop = (len + tile - 1) // tile
with ib.for_range(0, loop, 'i') as i:
head = i * tile
with ib.if_scope(ib.likely(head + tile > len)):
tail = len
ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail))
with ib.else_scope():
tail = head + tile
ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_cce_loop_3():
ib = tvm.ir_builder.create()
loop1 = 4
loop2 = 9998
tile = 39991
with ib.for_range(0,loop2,'i') as i:
with ib.for_range(0,loop1,'j') as j:
head1 = i
head2 = j
with ib.if_scope(ib.likely(head1*loop1 + head2 < tile)):
ib.emit(tvm.call_extern('float16',"cce_intrisic",head1))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt,True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_conv_tiling():
HSTR = WSTR = 1
in_channel = 128
kernel_height = kernel_width = 3
out_channel = 64
batch_size = 1
in_height = in_width = 64
out_height = out_width = in_height - kernel_height + 1
data = tvm.placeholder((batch_size, in_channel, in_height, in_width), name='data')
kernel = tvm.placeholder((kernel_height, kernel_width, in_channel,
out_channel), name='kernel')
ic = tvm.reduce_axis((0, in_channel), name='ic')
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
conv = tvm.compute((batch_size, out_channel, out_height, out_width),
lambda n, oc, oh, ow: tvm.sum(data[n, ic, oh*HSTR + kh, ow*WSTR + kw] *
kernel[kh, kw, ic, oc],
axis=[ic, kh, kw]),
name="conv2d")
s = tvm.create_schedule(conv.op)
n, oc, oh, ow = conv.op.axis
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
if __name__ == "__main__": if __name__ == "__main__":
test_basic() test_basic()
test_const_loop() test_const_loop()
...@@ -187,3 +338,10 @@ if __name__ == "__main__": ...@@ -187,3 +338,10 @@ if __name__ == "__main__":
test_select() test_select()
test_thread_axis2() test_thread_axis2()
test_everything_during_deduction() test_everything_during_deduction()
test_single_likely()
test_multi_likely()
test_oneD_pool()
test_cce_loop_1()
test_cce_loop_2()
test_cce_loop_3()
test_conv_tiling()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment