Commit e9debc9b by ziheng Committed by Tianqi Chen

[PASS] Use likely tag & enable LoopPartition by default (#132)

* [PASS] Use likely tag & enable LoopPartition by default

* [PASS] Support thread_axis partition

* Take IfThenElse branch method

* [PASS] Insert branch at the innermost thread scope

* [PASS] Select candidates before trying to partition & add test for select

* [PASS] Clean code

* Fix

* Remove print & assert vectorize happens
parent c42e0f1e
...@@ -67,6 +67,7 @@ def lower(sch, ...@@ -67,6 +67,7 @@ def lower(sch,
sch = sch.normalize() sch = sch.normalize()
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.LoopPartition(stmt)
stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt) stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.VectorizeLoop(stmt)
......
...@@ -9,6 +9,7 @@ from . import ir_pass as _pass ...@@ -9,6 +9,7 @@ from . import ir_pass as _pass
from . import collections as _collections from . import collections as _collections
from ._ffi.base import string_types from ._ffi.base import string_types
from ._ffi.node import NodeGeneric from ._ffi.node import NodeGeneric
from .expr import Call as _Call
class WithScope(object): class WithScope(object):
"""Auxiliary scope with""" """Auxiliary scope with"""
...@@ -308,6 +309,19 @@ class IRBuilder(object): ...@@ -308,6 +309,19 @@ class IRBuilder(object):
""" """
return BufferVar(self, buf.data, buf.dtype) return BufferVar(self, buf.data, buf.dtype)
def likely(self, expr):
"""Add likely tag for expression.
Parameters
----------
expr : Expr
The expression. Usually a condition expression.
Returns
-------
expr : Expr
The expression will likely tag.
"""
return _make.Call(expr.dtype, "likely", [expr], _Call.PureIntrinsic, None, 0)
def get(self): def get(self):
"""Return the builded IR. """Return the builded IR.
......
...@@ -311,9 +311,10 @@ Stmt ComputeOpNode::BuildProvide( ...@@ -311,9 +311,10 @@ Stmt ComputeOpNode::BuildProvide(
std::unordered_map<IterVar, Expr> value_map; std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest( auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map); stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
nest.push_back(op::MakeIfNest(op::MakeBoundCheck( auto preds = op::MakeBoundCheck(stage, dom_map, false,
stage, dom_map, false, std::unordered_set<IterVar>(), value_map);
std::unordered_set<IterVar>(), value_map))); for (auto& e : preds) e = likely(e);
nest.push_back(op::MakeIfNest(preds));
if (stage->store_predicate.defined()) { if (stage->store_predicate.defined()) {
nest.emplace_back(op::MakeIfNest({stage->store_predicate})); nest.emplace_back(op::MakeIfNest({stage->store_predicate}));
} }
...@@ -352,9 +353,9 @@ Stmt ComputeOpNode::BuildProvide( ...@@ -352,9 +353,9 @@ Stmt ComputeOpNode::BuildProvide(
auto init_nest = op::MakeLoopNest( auto init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true, stage, dom_map, begin_loop, true,
skip_iter, &init_value_map); skip_iter, &init_value_map);
init_nest.push_back( auto preds = op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map);
op::MakeIfNest( for (auto& e : preds) e = likely(e);
op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map))); init_nest.push_back(op::MakeIfNest(preds));
init = Substitute(init, init_value_map); init = Substitute(init, init_value_map);
init = MergeNest(init_nest, init); init = MergeNest(init_nest, init);
// common nest // common nest
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "../arithmetic/int_set_internal.h" #include "../arithmetic/int_set_internal.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -37,12 +38,84 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) { ...@@ -37,12 +38,84 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
return success; return success;
} }
// Select potential candidate IRs that can be partitioned.
// Rule:
// - the range should not be const
// - there exist a condition expression in the scope that use the var
class CandidateSelector : public IRVisitor {
public:
using VarIsUsed = bool;
CandidateSelector() {}
void Visit_(const For* op) {
if (!is_const(op->min) || !is_const(op->extent)) {
const Variable* var = op->loop_var.get();
record_.insert({var, false});
IRVisitor::Visit_(op);
if (record_.at(var)) {
candidates.insert(op);
}
record_.erase(var);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent) {
const IterVarNode *iv = op->node.as<IterVarNode>();
CHECK(iv);
Var var = iv->var;
runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
if ((scope.rank == 0) && !is_const(op->value)) {
record_.insert({var.get(), false});
IRVisitor::Visit_(op);
if (record_.at(var.get())) {
candidates.insert(op);
}
record_.erase(var.get());
return;
}
}
IRVisitor::Visit_(op);
}
void Visit_(const Call* op) {
if (op->is_intrinsic(Call::likely)) {
in_likely_ = true;
IRVisitor::Visit_(op);
in_likely_ = false;
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Variable* op) {
if (in_likely_ && record_.count(op)) {
record_.at(op) = true;
}
}
std::unordered_set<const Node*> candidates;
private:
bool in_likely_;
std::unordered_map<const Variable*, VarIsUsed> record_;
};
// Find valid partition for specific variable
class PartitionFinder : public IRVisitor { class PartitionFinder : public IRVisitor {
public: public:
explicit PartitionFinder(VarExpr current_var, explicit PartitionFinder(VarExpr current_var,
const std::unordered_map<const Variable*, IntSet>& dom_map) const std::unordered_map<const Variable*, IntSet>& hint_map,
: current_var_(current_var), out_vars_(dom_map.size()), hint_map_(dom_map) { const std::unordered_map<const Variable*, IntSet>& relax_map)
for (const auto& kv : dom_map) out_vars_.insert(kv.first); : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) {
for (const auto& kv : hint_map) {
out_vars_.insert(kv.first);
}
for (const auto& kv : relax_map) {
out_vars_.insert(kv.first);
}
} }
void Visit_(const For* op) { void Visit_(const For* op) {
...@@ -73,10 +146,15 @@ class PartitionFinder : public IRVisitor { ...@@ -73,10 +146,15 @@ class PartitionFinder : public IRVisitor {
} }
} }
void Visit_(const IfThenElse* op) { void Visit_(const Call* op) {
if (ExprUseVars(op->condition, std::unordered_set<const Variable*>({current_var_.get()}))) { if (op->is_intrinsic(Call::likely)) {
IntSet interval = DeduceBound(current_var_, op->condition, hint_map_, relax_map_); Expr cond = op->args[0];
partitions[op->condition.get()] = Partition{op->condition, interval}; if (ExprUseVars(cond,
std::unordered_set<const Variable*>({current_var_.get()}))) {
IntSet interval =
DeduceBound(current_var_, cond, hint_map_, relax_map_);
partitions[cond.get()] = Partition{cond, interval};
}
} else { } else {
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
} }
...@@ -91,54 +169,124 @@ class PartitionFinder : public IRVisitor { ...@@ -91,54 +169,124 @@ class PartitionFinder : public IRVisitor {
std::unordered_map<const Variable*, IntSet> relax_map_; std::unordered_map<const Variable*, IntSet> relax_map_;
}; };
class PartitionReplacer : public IRMutator { // Eliminate the condition expressions by partitions
class ConditionEliminator : public IRMutator {
public: public:
explicit PartitionReplacer(const std::unordered_map<const Node*, Partition>& ps) explicit ConditionEliminator(const std::unordered_map<const Node*, Partition>& ps)
: ps_(ps) {} : ps_(ps) {}
Expr Mutate(Expr e) override { using IRMutator::Mutate;
if (ps_.count(e.get())) { Expr Mutate(Expr e) final {
return Mutate(const_true()); if (ps_.count(e.get())) return Mutate(const_true());
}
return IRMutator::Mutate(e); return IRMutator::Mutate(e);
} }
using IRMutator::Mutate;
private: private:
const std::unordered_map<const Node*, Partition>& ps_; const std::unordered_map<const Node*, Partition>& ps_;
}; };
// Insert the partition branch at the innermost thread scope
class ThreadPartitionInserter : public IRMutator {
public:
explicit ThreadPartitionInserter(const std::unordered_map<const Node*, Partition>& ps,
Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::thread_extent) {
innermost_thread_scope_ = true;
Stmt stmt = IRMutator::Mutate_(op, s);
// add branch code inside the innermost thread scope
if (innermost_thread_scope_) {
Stmt simplified_body = ConditionEliminator(ps_).Mutate(op->body);
Stmt body = IfThenElse::make(cond_, simplified_body, op->body);
Expr value = this->Mutate(op->value);
stmt = AttrStmt::make(op->node, op->attr_key, value, body);
}
innermost_thread_scope_ = false;
return stmt;
} else {
return IRMutator::Mutate_(op, s);
}
}
private:
const std::unordered_map<const Node*, Partition>& ps_;
Expr cond_;
bool innermost_thread_scope_;
};
// Try to do partition at the candidate IRs
class LoopPartitioner : public IRMutator { class LoopPartitioner : public IRMutator {
public: public:
LoopPartitioner() {} explicit LoopPartitioner(std::unordered_set<const Node*> candidates)
: candidates_(candidates) {}
Stmt Mutate_(const For* op, const Stmt& stmt) { Stmt Mutate_(const For* op, const Stmt& stmt) {
if (!is_const(op->min) || !is_const(op->extent)) { if (candidates_.count(op)) {
Stmt s = DoPartition(op, stmt); Stmt s = TryPartition(op, stmt, op->loop_var,
op->min, op->min + op->extent - 1, op->body, false);
if (s.defined()) return s; if (s.defined()) return s;
} }
dom_map_.insert({op->loop_var.get(),
// normal path when loop parittion fails
// normal loop variable can be put into hint map.
hint_map_.insert({op->loop_var.get(),
IntSet::interval(op->min, op->min + op->extent - 1)}); IntSet::interval(op->min, op->min + op->extent - 1)});
Stmt res = IRMutator::Mutate_(op, stmt); Stmt res = IRMutator::Mutate_(op, stmt);
dom_map_.erase(op->loop_var.get()); hint_map_.erase(op->loop_var.get());
return res;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) {
if (op->attr_key != attr::thread_extent) {
return IRMutator::Mutate_(op, stmt);
}
const IterVarNode *iv = op->node.as<IterVarNode>();
CHECK(iv);
Var var = iv->var;
if (candidates_.count(op)) {
Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true);
if (s.defined()) return s;
}
// normal path when loop parittion fails.
runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
Stmt res;
if (scope.rank == 1) {
// threadIdx should be put into relax map, in case of divergence.
relax_map_.insert({var.get(),
IntSet::interval(make_zero(var.type()), op->value - 1)});
res = IRMutator::Mutate_(op, stmt);
relax_map_.erase(var.get());
} else {
hint_map_.insert({var.get(),
IntSet::interval(make_zero(var.type()), op->value - 1)});
res = IRMutator::Mutate_(op, stmt);
hint_map_.erase(var.get());
}
return res; return res;
} }
private: private:
Stmt DoPartition(const For* op, const Stmt& stmt); Stmt TryPartition(const Node* op, const Stmt& stmt, VarExpr var,
Expr min, Expr max, Stmt body, bool partition_thread_scope);
inline Stmt MakeFor(const Node* op, Expr extent, Stmt body);
std::unordered_map<const Variable*, IntSet> dom_map_; /* Candidate IRs that may be partitioned potentially */
std::unordered_set<const Node*> candidates_;
std::unordered_map<const Variable*, IntSet> hint_map_;
std::unordered_map<const Variable*, IntSet> relax_map_;
}; };
Stmt LoopPartitioner::DoPartition(const For* op, const Stmt& stmt) { Stmt LoopPartitioner::TryPartition(const Node* node, const Stmt& stmt,
PartitionFinder finder(op->loop_var, dom_map_); VarExpr var, Expr min, Expr max, Stmt body, bool partition_thread_scope) {
finder.Visit(op->body); PartitionFinder finder(var, hint_map_, relax_map_);
finder.Visit(body);
const auto& partitions = finder.partitions; const auto& partitions = finder.partitions;
if (partitions.empty()) return Stmt(); if (partitions.empty()) return Stmt();
Expr min = op->min;
Expr max = op->min + op->extent - 1;
Array<IntSet> sets; Array<IntSet> sets;
// merge partitions (take their intersect) // merge partitions (take their intersect)
for (const auto& kv : partitions) { for (const auto& kv : partitions) {
...@@ -146,64 +294,92 @@ Stmt LoopPartitioner::DoPartition(const For* op, const Stmt& stmt) { ...@@ -146,64 +294,92 @@ Stmt LoopPartitioner::DoPartition(const For* op, const Stmt& stmt) {
} }
IntSet true_itrv = Intersect(sets); IntSet true_itrv = Intersect(sets);
Stmt pre_stmt;
Expr body_begin; Expr body_begin;
Stmt pre_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) { if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) {
body_begin = true_itrv.min(); body_begin = true_itrv.min();
if (!can_prove(body_begin == min)) { if (!can_prove(body_begin == min)) {
if (!can_prove(body_begin - min >= 0)) { Expr cond = (body_begin - min >= 0);
LOG(WARNING) << "cannot prove: " << (body_begin - min >= 0) if (!can_prove(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop"; << ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min); body_begin = Max::make(body_begin, min);
} }
// [min, body_begin) // [min, body_begin)
Stmt body = Substitute(op->body, if (!partition_thread_scope) {
{{Var{op->loop_var}, op->loop_var + min}}); Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = For::make(op->loop_var, 0, pre_stmt = MakeFor(node, body_begin - min, pre_body);
body_begin - min, op->for_type, op->device_api, body); }
} }
} else { } else {
body_begin = min; body_begin = min;
} }
Stmt post_stmt;
Expr post_doubt_begin; Expr post_doubt_begin;
Stmt post_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) { if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) {
post_doubt_begin = true_itrv.max() + 1; post_doubt_begin = true_itrv.max() + 1;
if (!can_prove(true_itrv.max() == max)) { if (!can_prove(true_itrv.max() == max)) {
if (!can_prove(max - post_doubt_begin >= 0)) { Expr cond = (max - post_doubt_begin >= 0);
LOG(WARNING) << "Cannot prove: " << (max - post_doubt_begin >= 0) if (!can_prove(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop"; << ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max); post_doubt_begin = Min::make(post_doubt_begin, max);
} }
// [post_doubt_begin, max] // [post_doubt_begin, max]
Stmt body = Substitute(op->body, if (!partition_thread_scope) {
{{Var{op->loop_var}, op->loop_var + post_doubt_begin}}); Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = For::make(op->loop_var, 0, post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
max - post_doubt_begin + 1, op->for_type, op->device_api, body); }
} }
} else { } else {
post_doubt_begin = max + 1; post_doubt_begin = max + 1;
} }
Stmt s;
if (!partition_thread_scope) {
// [body_begin, post_doubt_begin) // [body_begin, post_doubt_begin)
Stmt simplified_body = PartitionReplacer(partitions).Mutate(op->body); Stmt simplified_body = ConditionEliminator(partitions).Mutate(body);
Stmt body = Substitute(simplified_body, {{Var{op->loop_var}, op->loop_var + body_begin}}); Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
Stmt simplified_stmt = For::make(op->loop_var, 0, s = MakeFor(node, post_doubt_begin - body_begin, new_body);
post_doubt_begin - body_begin, op->for_type, op->device_api, body); if (pre_stmt.defined()) s = Block::make(pre_stmt, s);
Stmt s = simplified_stmt; if (post_stmt.defined()) s = Block::make(s, post_stmt);
if (pre_stmt.defined()) { } else {
s = Block::make(pre_stmt, s); Expr cond = const_true();
} if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin);
if (post_stmt.defined()) { if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
s = Block::make(s, post_stmt); s = ThreadPartitionInserter(partitions, cond).Mutate(stmt);
} }
s = ConvertSSA(s);
return s;
}
return Simplify(ConvertSSA(s)); inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) {
const For *for_node = static_cast<const For*>(node);
CHECK(for_node);
return For::make(for_node->loop_var, 0, extent,
for_node->for_type, for_node->device_api, body);
} }
class RemoveLikelyTags : public IRMutator {
public:
using IRMutator::Mutate;
Expr Mutate_(const Call *op, const Expr& e) {
if (op->is_intrinsic(Call::likely)) {
CHECK_EQ(op->args.size(), 1);
return IRMutator::Mutate(op->args[0]);
} else {
return IRMutator::Mutate_(op, e);
}
}
};
Stmt LoopPartition(Stmt stmt) { Stmt LoopPartition(Stmt stmt) {
stmt = LoopPartitioner().Mutate(stmt); CandidateSelector selector;
selector.Visit(stmt);
stmt = LoopPartitioner(selector.candidates).Mutate(stmt);
stmt = RemoveLikelyTags().Mutate(stmt);
return stmt; return stmt;
} }
......
...@@ -64,12 +64,12 @@ from tvm.contrib import nvcc_compiler ...@@ -64,12 +64,12 @@ from tvm.contrib import nvcc_compiler
@tvm.register_func @tvm.register_func
def tvm_callback_cuda_compile(code): def tvm_callback_cuda_compile(code):
print(code) print(code)
ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_52"]) ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_35"])
return ptx return ptx
def test_add(): def test_add():
# graph # graph
n = tvm.convert(1024) n = tvm.var('n')
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
bias = tvm.var("bias", dtype="float32") bias = tvm.var("bias", dtype="float32")
......
...@@ -22,6 +22,7 @@ def test_add_pipeline(): ...@@ -22,6 +22,7 @@ def test_add_pipeline():
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
Cb = tvm.decl_buffer(C.shape, C.dtype, name='C') Cb = tvm.decl_buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0)
......
...@@ -17,8 +17,8 @@ def test_basic(): ...@@ -17,8 +17,8 @@ def test_basic():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first)) assert('if' not in str(stmt.body.body.body.first))
print(stmt)
def test_multi_loop(): def test_multi_loop():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
...@@ -27,41 +27,40 @@ def test_multi_loop(): ...@@ -27,41 +27,40 @@ def test_multi_loop():
with ib.for_range(0, 4, "i") as i: with ib.for_range(0, 4, "i") as i:
with ib.for_range(0, n, "j") as j: with ib.for_range(0, n, "j") as j:
with ib.for_range(0, m, "k") as k: with ib.for_range(0, m, "k") as k:
with ib.if_scope(i*m+j+k < n): with ib.if_scope(ib.likely(i*m+j+k < n)):
ib.emit(tvm.make.Evaluate(m)) ib.emit(tvm.make.Evaluate(m))
with ib.else_scope(): with ib.else_scope():
ib.emit(tvm.make.Evaluate(n)) ib.emit(tvm.make.Evaluate(n))
stmt = ib.get() stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt)
assert(not any(collect_visit(stmt.body.first, stmt = tvm.ir_pass.Simplify(stmt)
lambda x: isinstance(x, tvm.stmt.IfThenElse)))) assert(not any(collect_visit(stmt.body.first, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_multi_if(): def test_multi_if():
i = tvm.var('i') ib = tvm.ir_builder.create()
j = tvm.var('j')
k = tvm.var('k')
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
stmt = tvm.make.For( with ib.for_range(0, 4, 'i') as i:
i, 0, 4, 0, 0, with ib.for_range(0, n, 'j') as j:
tvm.make.For( with ib.for_range(0, m, 'k') as k:
j, 0, n, 0, 0, with ib.if_scope(ib.likely(i*m+j+k < n)):
tvm.make.For( ib.emit(tvm.make.Evaluate(m))
k, 0, m, 0, 0, with ib.else_scope():
tvm.make.Block( ib.emit(tvm.make.Evaluate(n))
tvm.make.IfThenElse((i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)), with ib.if_scope(ib.likely(i*m+j-k < n)):
tvm.make.IfThenElse((i*m+j-k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)) ib.emit(tvm.make.Evaluate(m))
)))) with ib.else_scope():
ib.emit(tvm.make.Evaluate(n))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.first)) assert('if' not in str(stmt.body.first))
print(stmt)
def test_thread_axis(): def test_thread_axis():
m = tvm.var('m') m = tvm.var('m')
l = tvm.var('l') l = tvm.var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B') B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
s[B].set_scope("shared") s[B].set_scope("shared")
...@@ -72,12 +71,67 @@ def test_thread_axis(): ...@@ -72,12 +71,67 @@ def test_thread_axis():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt_ = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt)
assert('if' not in str(stmt_.body.body.body.first)) stmt = tvm.ir_pass.Simplify(stmt)
print(stmt_) assert('if' not in str(stmt.body.body.body.first))
def test_vectorize():
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
bias = tvm.var("bias", dtype="float32")
scale = tvm.var("scale", dtype="float32")
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name='C')
# schedule
s = tvm.create_schedule(C.op)
# create iter var and assign them tags.
num_thread = 32
bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
tx, x = s[C].split(x, nparts=num_thread)
_, x = s[C].split(x, factor=4)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
s[C].vectorize(x)
stmt = tvm.lower(s, [A, B], name='ewise_add', with_api_wrapper=False)
body = stmt.body.body.body.body.body
assert(x.var.name not in str(body.condition))
assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.expr.Ramp))))
def test_select():
ib = tvm.ir_builder.create()
m = tvm.var('m')
n = tvm.var('n')
with ib.for_range(0, ((n+3)/4), 'i') as i:
with ib.for_range(0, 4, 'j') as j:
ib.emit(tvm.make.Evaluate(
tvm.make.Select(ib.likely(i*4+j<n), m, n)))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
def test_thread_axis2():
n = tvm.convert(4096)
m = tvm.var('m')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
s = tvm.create_schedule(C.op)
num_thread = 32
bx, x = s[C].split(C.op.axis[0], factor=32)
tx, x = s[C].split(x, nparts=num_thread)
_, x = s[C].split(x, factor=m)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
stmt = tvm.lower(s, [A, B], name='ewise_add', with_api_wrapper=False)
for_body = stmt.body.body.body.body.body.first
assert('threadIdx' not in str(for_body.extent))
if __name__ == "__main__": if __name__ == "__main__":
test_multi_loop()
test_basic() test_basic()
test_multi_loop()
test_multi_if() test_multi_if()
test_thread_axis() test_thread_axis()
test_vectorize()
test_select()
test_thread_axis2()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment