Commit e9debc9b by ziheng Committed by Tianqi Chen

[PASS] Use likely tag & enable LoopPartition by default (#132)

* [PASS] Use likely tag & enable LoopPartition by default

* [PASS] Support thread_axis partition

* Take IfThenElse branch method

* [PASS] Insert branch at the innermost thread scope

* [PASS] Select candidates before trying to partition & add test for select

* [PASS] Clean code

* Fix

* Remove print & assert vectorize happens
parent c42e0f1e
......@@ -67,6 +67,7 @@ def lower(sch,
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.LoopPartition(stmt)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
......
......@@ -9,6 +9,7 @@ from . import ir_pass as _pass
from . import collections as _collections
from ._ffi.base import string_types
from ._ffi.node import NodeGeneric
from .expr import Call as _Call
class WithScope(object):
"""Auxiliary scope with"""
......@@ -308,6 +309,19 @@ class IRBuilder(object):
"""
return BufferVar(self, buf.data, buf.dtype)
def likely(self, expr):
"""Add likely tag for expression.
Parameters
----------
expr : Expr
The expression. Usually a condition expression.
Returns
-------
expr : Expr
The expression will likely tag.
"""
return _make.Call(expr.dtype, "likely", [expr], _Call.PureIntrinsic, None, 0)
def get(self):
"""Return the builded IR.
......
......@@ -311,9 +311,10 @@ Stmt ComputeOpNode::BuildProvide(
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
nest.push_back(op::MakeIfNest(op::MakeBoundCheck(
stage, dom_map, false,
std::unordered_set<IterVar>(), value_map)));
auto preds = op::MakeBoundCheck(stage, dom_map, false,
std::unordered_set<IterVar>(), value_map);
for (auto& e : preds) e = likely(e);
nest.push_back(op::MakeIfNest(preds));
if (stage->store_predicate.defined()) {
nest.emplace_back(op::MakeIfNest({stage->store_predicate}));
}
......@@ -352,9 +353,9 @@ Stmt ComputeOpNode::BuildProvide(
auto init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true,
skip_iter, &init_value_map);
init_nest.push_back(
op::MakeIfNest(
op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map)));
auto preds = op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map);
for (auto& e : preds) e = likely(e);
init_nest.push_back(op::MakeIfNest(preds));
init = Substitute(init, init_value_map);
init = MergeNest(init_nest, init);
// common nest
......
......@@ -10,6 +10,7 @@
#include <unordered_map>
#include <unordered_set>
#include "../arithmetic/int_set_internal.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
......@@ -37,12 +38,84 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
return success;
}
// Select potential candidate IRs that can be partitioned.
// Rule:
// - the range should not be const
// - there exist a condition expression in the scope that use the var
class CandidateSelector : public IRVisitor {
public:
using VarIsUsed = bool;
CandidateSelector() {}
void Visit_(const For* op) {
if (!is_const(op->min) || !is_const(op->extent)) {
const Variable* var = op->loop_var.get();
record_.insert({var, false});
IRVisitor::Visit_(op);
if (record_.at(var)) {
candidates.insert(op);
}
record_.erase(var);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent) {
const IterVarNode *iv = op->node.as<IterVarNode>();
CHECK(iv);
Var var = iv->var;
runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
if ((scope.rank == 0) && !is_const(op->value)) {
record_.insert({var.get(), false});
IRVisitor::Visit_(op);
if (record_.at(var.get())) {
candidates.insert(op);
}
record_.erase(var.get());
return;
}
}
IRVisitor::Visit_(op);
}
void Visit_(const Call* op) {
if (op->is_intrinsic(Call::likely)) {
in_likely_ = true;
IRVisitor::Visit_(op);
in_likely_ = false;
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Variable* op) {
if (in_likely_ && record_.count(op)) {
record_.at(op) = true;
}
}
std::unordered_set<const Node*> candidates;
private:
bool in_likely_;
std::unordered_map<const Variable*, VarIsUsed> record_;
};
// Find valid partition for specific variable
class PartitionFinder : public IRVisitor {
public:
explicit PartitionFinder(VarExpr current_var,
const std::unordered_map<const Variable*, IntSet>& dom_map)
: current_var_(current_var), out_vars_(dom_map.size()), hint_map_(dom_map) {
for (const auto& kv : dom_map) out_vars_.insert(kv.first);
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map)
: current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) {
for (const auto& kv : hint_map) {
out_vars_.insert(kv.first);
}
for (const auto& kv : relax_map) {
out_vars_.insert(kv.first);
}
}
void Visit_(const For* op) {
......@@ -73,10 +146,15 @@ class PartitionFinder : public IRVisitor {
}
}
void Visit_(const IfThenElse* op) {
if (ExprUseVars(op->condition, std::unordered_set<const Variable*>({current_var_.get()}))) {
IntSet interval = DeduceBound(current_var_, op->condition, hint_map_, relax_map_);
partitions[op->condition.get()] = Partition{op->condition, interval};
void Visit_(const Call* op) {
if (op->is_intrinsic(Call::likely)) {
Expr cond = op->args[0];
if (ExprUseVars(cond,
std::unordered_set<const Variable*>({current_var_.get()}))) {
IntSet interval =
DeduceBound(current_var_, cond, hint_map_, relax_map_);
partitions[cond.get()] = Partition{cond, interval};
}
} else {
IRVisitor::Visit_(op);
}
......@@ -91,54 +169,124 @@ class PartitionFinder : public IRVisitor {
std::unordered_map<const Variable*, IntSet> relax_map_;
};
class PartitionReplacer : public IRMutator {
// Eliminate the condition expressions by partitions
class ConditionEliminator : public IRMutator {
public:
explicit PartitionReplacer(const std::unordered_map<const Node*, Partition>& ps)
explicit ConditionEliminator(const std::unordered_map<const Node*, Partition>& ps)
: ps_(ps) {}
Expr Mutate(Expr e) override {
if (ps_.count(e.get())) {
return Mutate(const_true());
}
using IRMutator::Mutate;
Expr Mutate(Expr e) final {
if (ps_.count(e.get())) return Mutate(const_true());
return IRMutator::Mutate(e);
}
using IRMutator::Mutate;
private:
const std::unordered_map<const Node*, Partition>& ps_;
};
// Insert the partition branch at the innermost thread scope
class ThreadPartitionInserter : public IRMutator {
public:
explicit ThreadPartitionInserter(const std::unordered_map<const Node*, Partition>& ps,
Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::thread_extent) {
innermost_thread_scope_ = true;
Stmt stmt = IRMutator::Mutate_(op, s);
// add branch code inside the innermost thread scope
if (innermost_thread_scope_) {
Stmt simplified_body = ConditionEliminator(ps_).Mutate(op->body);
Stmt body = IfThenElse::make(cond_, simplified_body, op->body);
Expr value = this->Mutate(op->value);
stmt = AttrStmt::make(op->node, op->attr_key, value, body);
}
innermost_thread_scope_ = false;
return stmt;
} else {
return IRMutator::Mutate_(op, s);
}
}
private:
const std::unordered_map<const Node*, Partition>& ps_;
Expr cond_;
bool innermost_thread_scope_;
};
// Try to do partition at the candidate IRs
class LoopPartitioner : public IRMutator {
public:
LoopPartitioner() {}
explicit LoopPartitioner(std::unordered_set<const Node*> candidates)
: candidates_(candidates) {}
Stmt Mutate_(const For* op, const Stmt& stmt) {
if (!is_const(op->min) || !is_const(op->extent)) {
Stmt s = DoPartition(op, stmt);
if (candidates_.count(op)) {
Stmt s = TryPartition(op, stmt, op->loop_var,
op->min, op->min + op->extent - 1, op->body, false);
if (s.defined()) return s;
}
dom_map_.insert({op->loop_var.get(),
// normal path when loop parittion fails
// normal loop variable can be put into hint map.
hint_map_.insert({op->loop_var.get(),
IntSet::interval(op->min, op->min + op->extent - 1)});
Stmt res = IRMutator::Mutate_(op, stmt);
dom_map_.erase(op->loop_var.get());
hint_map_.erase(op->loop_var.get());
return res;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) {
if (op->attr_key != attr::thread_extent) {
return IRMutator::Mutate_(op, stmt);
}
const IterVarNode *iv = op->node.as<IterVarNode>();
CHECK(iv);
Var var = iv->var;
if (candidates_.count(op)) {
Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true);
if (s.defined()) return s;
}
// normal path when loop parittion fails.
runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
Stmt res;
if (scope.rank == 1) {
// threadIdx should be put into relax map, in case of divergence.
relax_map_.insert({var.get(),
IntSet::interval(make_zero(var.type()), op->value - 1)});
res = IRMutator::Mutate_(op, stmt);
relax_map_.erase(var.get());
} else {
hint_map_.insert({var.get(),
IntSet::interval(make_zero(var.type()), op->value - 1)});
res = IRMutator::Mutate_(op, stmt);
hint_map_.erase(var.get());
}
return res;
}
private:
Stmt DoPartition(const For* op, const Stmt& stmt);
Stmt TryPartition(const Node* op, const Stmt& stmt, VarExpr var,
Expr min, Expr max, Stmt body, bool partition_thread_scope);
inline Stmt MakeFor(const Node* op, Expr extent, Stmt body);
std::unordered_map<const Variable*, IntSet> dom_map_;
/* Candidate IRs that may be partitioned potentially */
std::unordered_set<const Node*> candidates_;
std::unordered_map<const Variable*, IntSet> hint_map_;
std::unordered_map<const Variable*, IntSet> relax_map_;
};
Stmt LoopPartitioner::DoPartition(const For* op, const Stmt& stmt) {
PartitionFinder finder(op->loop_var, dom_map_);
finder.Visit(op->body);
Stmt LoopPartitioner::TryPartition(const Node* node, const Stmt& stmt,
VarExpr var, Expr min, Expr max, Stmt body, bool partition_thread_scope) {
PartitionFinder finder(var, hint_map_, relax_map_);
finder.Visit(body);
const auto& partitions = finder.partitions;
if (partitions.empty()) return Stmt();
Expr min = op->min;
Expr max = op->min + op->extent - 1;
Array<IntSet> sets;
// merge partitions (take their intersect)
for (const auto& kv : partitions) {
......@@ -146,64 +294,92 @@ Stmt LoopPartitioner::DoPartition(const For* op, const Stmt& stmt) {
}
IntSet true_itrv = Intersect(sets);
Stmt pre_stmt;
Expr body_begin;
Stmt pre_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) {
body_begin = true_itrv.min();
if (!can_prove(body_begin == min)) {
if (!can_prove(body_begin - min >= 0)) {
LOG(WARNING) << "cannot prove: " << (body_begin - min >= 0)
Expr cond = (body_begin - min >= 0);
if (!can_prove(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min);
}
// [min, body_begin)
Stmt body = Substitute(op->body,
{{Var{op->loop_var}, op->loop_var + min}});
pre_stmt = For::make(op->loop_var, 0,
body_begin - min, op->for_type, op->device_api, body);
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(node, body_begin - min, pre_body);
}
}
} else {
body_begin = min;
}
Stmt post_stmt;
Expr post_doubt_begin;
Stmt post_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) {
post_doubt_begin = true_itrv.max() + 1;
if (!can_prove(true_itrv.max() == max)) {
if (!can_prove(max - post_doubt_begin >= 0)) {
LOG(WARNING) << "Cannot prove: " << (max - post_doubt_begin >= 0)
Expr cond = (max - post_doubt_begin >= 0);
if (!can_prove(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max);
}
// [post_doubt_begin, max]
Stmt body = Substitute(op->body,
{{Var{op->loop_var}, op->loop_var + post_doubt_begin}});
post_stmt = For::make(op->loop_var, 0,
max - post_doubt_begin + 1, op->for_type, op->device_api, body);
if (!partition_thread_scope) {
Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
}
}
} else {
post_doubt_begin = max + 1;
}
// [body_begin, post_doubt_begin)
Stmt simplified_body = PartitionReplacer(partitions).Mutate(op->body);
Stmt body = Substitute(simplified_body, {{Var{op->loop_var}, op->loop_var + body_begin}});
Stmt simplified_stmt = For::make(op->loop_var, 0,
post_doubt_begin - body_begin, op->for_type, op->device_api, body);
Stmt s = simplified_stmt;
if (pre_stmt.defined()) {
s = Block::make(pre_stmt, s);
}
if (post_stmt.defined()) {
s = Block::make(s, post_stmt);
Stmt s;
if (!partition_thread_scope) {
// [body_begin, post_doubt_begin)
Stmt simplified_body = ConditionEliminator(partitions).Mutate(body);
Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
s = MakeFor(node, post_doubt_begin - body_begin, new_body);
if (pre_stmt.defined()) s = Block::make(pre_stmt, s);
if (post_stmt.defined()) s = Block::make(s, post_stmt);
} else {
Expr cond = const_true();
if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin);
if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
s = ThreadPartitionInserter(partitions, cond).Mutate(stmt);
}
s = ConvertSSA(s);
return s;
}
return Simplify(ConvertSSA(s));
inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) {
const For *for_node = static_cast<const For*>(node);
CHECK(for_node);
return For::make(for_node->loop_var, 0, extent,
for_node->for_type, for_node->device_api, body);
}
class RemoveLikelyTags : public IRMutator {
public:
using IRMutator::Mutate;
Expr Mutate_(const Call *op, const Expr& e) {
if (op->is_intrinsic(Call::likely)) {
CHECK_EQ(op->args.size(), 1);
return IRMutator::Mutate(op->args[0]);
} else {
return IRMutator::Mutate_(op, e);
}
}
};
Stmt LoopPartition(Stmt stmt) {
stmt = LoopPartitioner().Mutate(stmt);
CandidateSelector selector;
selector.Visit(stmt);
stmt = LoopPartitioner(selector.candidates).Mutate(stmt);
stmt = RemoveLikelyTags().Mutate(stmt);
return stmt;
}
......
......@@ -64,12 +64,12 @@ from tvm.contrib import nvcc_compiler
@tvm.register_func
def tvm_callback_cuda_compile(code):
print(code)
ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_52"])
ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_35"])
return ptx
def test_add():
# graph
n = tvm.convert(1024)
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
bias = tvm.var("bias", dtype="float32")
......
......@@ -22,6 +22,7 @@ def test_add_pipeline():
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
Cb = tvm.decl_buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0)
......
......@@ -17,8 +17,8 @@ def test_basic():
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first))
print(stmt)
def test_multi_loop():
ib = tvm.ir_builder.create()
......@@ -27,41 +27,40 @@ def test_multi_loop():
with ib.for_range(0, 4, "i") as i:
with ib.for_range(0, n, "j") as j:
with ib.for_range(0, m, "k") as k:
with ib.if_scope(i*m+j+k < n):
with ib.if_scope(ib.likely(i*m+j+k < n)):
ib.emit(tvm.make.Evaluate(m))
with ib.else_scope():
ib.emit(tvm.make.Evaluate(n))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
assert(not any(collect_visit(stmt.body.first,
lambda x: isinstance(x, tvm.stmt.IfThenElse))))
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.body.first, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_multi_if():
i = tvm.var('i')
j = tvm.var('j')
k = tvm.var('k')
ib = tvm.ir_builder.create()
m = tvm.var('m')
n = tvm.var('n')
stmt = tvm.make.For(
i, 0, 4, 0, 0,
tvm.make.For(
j, 0, n, 0, 0,
tvm.make.For(
k, 0, m, 0, 0,
tvm.make.Block(
tvm.make.IfThenElse((i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)),
tvm.make.IfThenElse((i*m+j-k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n))
))))
with ib.for_range(0, 4, 'i') as i:
with ib.for_range(0, n, 'j') as j:
with ib.for_range(0, m, 'k') as k:
with ib.if_scope(ib.likely(i*m+j+k < n)):
ib.emit(tvm.make.Evaluate(m))
with ib.else_scope():
ib.emit(tvm.make.Evaluate(n))
with ib.if_scope(ib.likely(i*m+j-k < n)):
ib.emit(tvm.make.Evaluate(m))
with ib.else_scope():
ib.emit(tvm.make.Evaluate(n))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.first))
print(stmt)
def test_thread_axis():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B')
s = tvm.create_schedule(B.op)
s[B].set_scope("shared")
......@@ -72,12 +71,67 @@ def test_thread_axis():
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt_ = tvm.ir_pass.LoopPartition(stmt)
assert('if' not in str(stmt_.body.body.body.first))
print(stmt_)
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first))
def test_vectorize():
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
bias = tvm.var("bias", dtype="float32")
scale = tvm.var("scale", dtype="float32")
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name='C')
# schedule
s = tvm.create_schedule(C.op)
# create iter var and assign them tags.
num_thread = 32
bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
tx, x = s[C].split(x, nparts=num_thread)
_, x = s[C].split(x, factor=4)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
s[C].vectorize(x)
stmt = tvm.lower(s, [A, B], name='ewise_add', with_api_wrapper=False)
body = stmt.body.body.body.body.body
assert(x.var.name not in str(body.condition))
assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.expr.Ramp))))
def test_select():
ib = tvm.ir_builder.create()
m = tvm.var('m')
n = tvm.var('n')
with ib.for_range(0, ((n+3)/4), 'i') as i:
with ib.for_range(0, 4, 'j') as j:
ib.emit(tvm.make.Evaluate(
tvm.make.Select(ib.likely(i*4+j<n), m, n)))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
def test_thread_axis2():
n = tvm.convert(4096)
m = tvm.var('m')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
s = tvm.create_schedule(C.op)
num_thread = 32
bx, x = s[C].split(C.op.axis[0], factor=32)
tx, x = s[C].split(x, nparts=num_thread)
_, x = s[C].split(x, factor=m)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
stmt = tvm.lower(s, [A, B], name='ewise_add', with_api_wrapper=False)
for_body = stmt.body.body.body.body.body.first
assert('threadIdx' not in str(for_body.extent))
if __name__ == "__main__":
test_multi_loop()
test_basic()
test_multi_loop()
test_multi_if()
test_thread_axis()
test_vectorize()
test_select()
test_thread_axis2()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment