Commit 54593ca1 by Tianqi Chen Committed by GitHub

[LANG/GPU] Cross Thread Reduction (#79)

* [LANG/GPU] Cross Thread Reduction.

* Fix doxygen error

* Upgrade verilog testcase to new one
parent 6d798778
......@@ -299,7 +299,7 @@ inline const char* IterVarType2String(IterVarType t) {
switch (t) {
case kDataPar: return "DataPar";
case kThreadIndex: return "ThreadIndex";
case kCommReduce: return "CommRedude";
case kCommReduce: return "CommReduce";
case kOrdered: return "Ordered";
case kOpaque: return "Opaque";
case kUnrolled: return "Unrolled";
......
......@@ -42,6 +42,21 @@ struct Reduce : public ExprNode<Reduce> {
static Expr make(std::string op, Expr src,
Array<IterVar> rdom,
Expr condition = const_true());
/*!
* \brief Get initial value for reduction.
* \param op The operator
* \param type The data type.
* \return The initial value that can be assigned to reduction.
*/
static Expr InitValue(const std::string& op, Type type);
/*!
* \brief Combine two values with given reduction.
* \param op The operator
* \param a The left operand.
* \param b The left operand.
* \return The combined reduction result.
*/
static Expr Combine(const std::string& op, Expr a, Expr b);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
......@@ -87,6 +102,10 @@ constexpr const char* thread_extent = "thread_extent";
*/
constexpr const char* virtual_thread = "virtual_thread";
/*!
* \brief Mark the scope as volatile access for certain handle.
*/
constexpr const char* volatile_scope = "volatile_scope";
/*!
* \brief Mark storage scope of buffers
*/
constexpr const char* storage_scope = "storage_scope";
......@@ -164,6 +183,17 @@ constexpr const char* tvm_call_packed = "tvm_call_packed";
* }
*/
constexpr const char* tvm_storage_sync = "tvm_storage_sync";
/*!
* \brief See pesudo code
*
* Expr tvm_thread_allreduce(std::string op, Expr value, Expr cond,
* Var thread_idx1, thread_idx2...) {
* // constraint by the other thread_idx remain the same.
* return reduce(op, value, cond,
* over [thread_idx1, thread_idx2] passed by any caller)
* }
*/
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {
......
......@@ -234,6 +234,13 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
*/
LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);
/*!
* \brief Lower cross thread alleduce in the stmt.
* \param f The device function to be lowered.
* \param warp_size the size of warp where no sync is needed.
* \return Transformed function.
*/
LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
} // namespace ir
} // namespace tvm
......
......@@ -74,6 +74,14 @@ class Stage : public NodeRef {
*/
Stage& compute_root(); // NOLINT(*)
/*!
* \brief Rebase the parent iter var as rebased variable.
*
* \param parent The parent iteration domain.
* \param rebased The variable to be used in rebase.
* \return reference to self.
*/
Stage& rebase(IterVar parent, IterVar rebased);
/*!
* \brief Split the parent by factor, generate
* \param parent The parent iteration domain.
* \param p_outer The result outer domain
......
......@@ -71,7 +71,6 @@ def lower(sch,
return fapi
def build(sch,
args=None,
target="llvm",
......@@ -128,6 +127,8 @@ def build(sch,
fsplits = [x for x in fsplits]
for i in range(1, len(fsplits)):
fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared")
warp_size = 32 if target == "cuda" else 1
fsplits[i] = ir_pass.LowerThreadAllreduce(fsplits[i], warp_size)
if len(fsplits) > 1:
mhost = codegen.build(fsplits[0], target_host)
......
......@@ -112,6 +112,24 @@ class Schedule(NodeBase):
@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
def rebase(self, parent, rebased):
"""Rebase parent by an existing thread axis.
Parameters
----------
parent : IterVar
The parent iter var.
rebased : IterVar
The rebased iter var.
Returns
-------
rebased : IterVar
The rebased itervar.
"""
_api_internal._StageRebase(self, parent, rebased)
return rebased
def split(self, parent, factor=None, outer=None):
"""Split the stage either by factor providing outer scope, or both
......
......@@ -219,6 +219,13 @@ TVM_REGISTER_API(_StageSetScope)
.set_scope(args[1]);
});
TVM_REGISTER_API(_StageRebase)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner;
args[0].operator Stage()
.rebase(args[1], args[2]);
});
TVM_REGISTER_API(_StageSplitByFactor)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner;
......
......@@ -73,6 +73,7 @@ REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(LoopPartition);
REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS2(SplitPipeline);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS1(NarrowChannelAccess);
} // namespace ir
} // namespace tvm
......@@ -88,14 +88,26 @@ void CodeGenC::PrintSSAAssign(
}
// Print a reference expression to a buffer.
void CodeGenC::PrintBufferRef(
std::string CodeGenC::GetBufferRef(
const Variable* buffer,
Type t, Expr index,
std::ostream& os) { // NOLINT(*)
Type t, Expr index) {
std::ostringstream os;
std::string vid = GetVarID(buffer);
std::string scope;
if (alloc_storage_scope_.count(buffer)) {
scope = alloc_storage_scope_.at(buffer);
}
bool is_vol = volatile_buf_.count(buffer);
if (t.lanes() == 1) {
if (!HandleTypeMatch(buffer, t)) {
if (!HandleTypeMatch(buffer, t) || is_vol) {
os << "((";
if (is_vol) {
os << "volatile ";
}
if (scope.length() != 0) {
PrintStorageScope(scope, os);
}
os << ' ';
PrintType(t, os);
os << "*)" << vid << ')';
} else {
......@@ -107,17 +119,24 @@ void CodeGenC::PrintBufferRef(
} else {
// Buffer declared as vector type.
// optimize for case where it is in register,
if (HandleTypeMatch(buffer, t)) {
if (HandleTypeMatch(buffer, t) && !is_vol) {
// optimize for constant access
int offset;
if (arith::GetConstInt(index, &offset)) {
CHECK_EQ(offset % t.lanes(), 0)
<< "Find unaligned vector load to a vector type";
os << vid << '[' << (offset / t.lanes()) << ']';
return;
return os.str();
}
}
os << "((";
if (is_vol) {
os << "volatile ";
}
if (scope.length() != 0) {
PrintStorageScope(scope, os);
}
os << ' ';
PrintType(t, os);
os << "*)(";
if (!HandleTypeMatch(buffer, t.element_of())) {
......@@ -129,6 +148,7 @@ void CodeGenC::PrintBufferRef(
PrintExpr(index, os);
os << "))[0]";
}
return os.str();
}
......@@ -162,18 +182,17 @@ void CodeGenC::PrintVecElemStore(const std::string& vec,
<< " = " << value << ";\n";
}
void CodeGenC::PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os) {
PrintBufferRef(buffer, t, base, os);
std::string CodeGenC::GetVecLoad(const Variable* buffer,
Type t, Expr base) {
return GetBufferRef(buffer, t, base);
}
void CodeGenC::PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value) {
std::string ref = GetBufferRef(buffer, t, base);
this->PrintIndent();
PrintBufferRef(buffer, t, base, stream);
stream << " = " << value << ";\n";
stream << ref << " = " << value << ";\n";
}
void CodeGenC::PrintThreadIndexExpr(
......@@ -483,24 +502,21 @@ inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) {
void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
int lanes = op->type.lanes();
std::string svalue = GetUniqueName("_");
// delcare type.
this->PrintIndent();
this->PrintType(op->type, stream);
stream << ' ' << svalue;
if (op->type.lanes() == 1) {
stream << " = ";
this->PrintBufferRef(op->buffer_var.get(), op->type, op->index, stream);
stream << ";\n";
std::string ref = GetBufferRef(op->buffer_var.get(), op->type, op->index);
os << ref;
} else {
Expr base;
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
stream << " = ";
this->PrintVecLoad(op->buffer_var.get(), op->type, base, stream);
stream << ";\n";
std::string ref = GetVecLoad(op->buffer_var.get(), op->type, base);
os << ref;
} else {
// Load elements seperately
stream << ";\n";
// load seperately.
std::string svalue = GetUniqueName("_");
this->PrintIndent();
this->PrintType(op->type, stream);
stream << ' ' << svalue << ";\n";
std::string sindex = SSAGetID(PrintExpr(op->index), op->index.type());
std::string vid = GetVarID(op->buffer_var.get());
Type elem_type = op->type.element_of();
......@@ -518,18 +534,18 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
value_temp << ']';
PrintVecElemStore(svalue, op->type, i, value_temp.str());
}
os << svalue;
}
}
os << svalue;
}
void CodeGenC::VisitStmt_(const Store* op) {
Type t = op->value.type();
if (t.lanes() == 1) {
std::string value = this->PrintExpr(op->value);
std::string ref = this->GetBufferRef(op->buffer_var.get(), t, op->index);
this->PrintIndent();
this->PrintBufferRef(op->buffer_var.get(), t, op->index, stream);
stream << " = " << value << ";\n";
stream << ref << " = " << value << ";\n";
} else {
Expr base;
if (TryGetRamp1Base(op->index, t.lanes(), &base)) {
......@@ -577,7 +593,13 @@ void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*
}
void CodeGenC::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Select: not supported ";
os << "(";
PrintExpr(op->condition, os);
os << " ? ";
PrintExpr(op->true_value, os);
os << " : ";
PrintExpr(op->false_value, os);
os << ")";
}
void CodeGenC::VisitStmt_(const LetStmt* op) {
......@@ -649,6 +671,10 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
} else if (op->type_key == ir::attr::volatile_scope) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
volatile_buf_.insert(v);
}
this->PrintStmt(op->body);
}
......
......@@ -13,6 +13,7 @@
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include "./codegen_source_base.h"
namespace tvm {
......@@ -132,9 +133,8 @@ class CodeGenC :
const std::string&op, Type op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
// print vector load
virtual void PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os); // NOLINT(*)
virtual std::string GetVecLoad(const Variable* buffer,
Type t, Expr base);
// print vector store
virtual void PrintVecStore(const Variable* buffer,
Type t, Expr base,
......@@ -149,9 +149,8 @@ class CodeGenC :
protected:
// print reference to a buffer as type t in index.
void PrintBufferRef(const Variable* buffer,
Type t, Expr index,
std::ostream& os); // NOLINT(*)
std::string GetBufferRef(const Variable* buffer,
Type t, Expr index);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
......@@ -172,9 +171,11 @@ class CodeGenC :
private:
/*! \brief whether to print in SSA form */
bool print_ssa_form_{true};
bool print_ssa_form_{false};
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief set of volatile buf access */
std::unordered_set<const Variable*> volatile_buf_;
};
} // namespace codegen
......
......@@ -95,12 +95,13 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t,
os << GetVarID(buffer) << " + ";
PrintExpr(base, os);
}
void CodeGenOpenCL::PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os) {
std::string CodeGenOpenCL::GetVecLoad(const Variable* buffer,
Type t, Expr base) {
std::ostringstream os;
os << "vload" << t.lanes() << "(0, ";
PrintVecAddr(buffer, t, base, os);
os << ")";
return os.str();
}
void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
......@@ -121,7 +122,8 @@ void CodeGenOpenCL::PrintStorageSync(const std::string& sync) {
}
}
void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
void CodeGenOpenCL::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*)
if (scope == "global") {
os << "__global";
} else if (scope == "shared") {
......
......@@ -24,9 +24,8 @@ class CodeGenOpenCL : public CodeGenC {
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const std::string& scope) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
void PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os) final; // NOLINT(*)
std::string GetVecLoad(const Variable* buffer,
Type t, Expr base) final;
void PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value) final; // NOLINT(*)
......
......@@ -35,6 +35,7 @@ std::string CodeGenSourceBase::GetUniqueName(std::string prefix) {
}
std::string CodeGenSourceBase::SSAGetID(std::string src, Type t) {
LOG(INFO) << "ssa get id";
if (name_alloc_map_.count(src)) return src;
auto it = ssa_assign_map_.find(src);
if (it != ssa_assign_map_.end()) {
......
......@@ -61,6 +61,32 @@ Expr Reduce::make(std::string op, Expr source,
return Expr(n);
}
Expr Reduce::InitValue(const std::string& op, Type type) {
if (op == "Add") {
return make_zero(type);
} else if (op == "Max") {
return type.min();
} else if (op == "Min") {
return type.max();
} else {
LOG(FATAL) << "Unsupported reduction " << op;
return Expr();
}
}
Expr Reduce::Combine(const std::string& op, Expr a, Expr b) {
if (op == "Add") {
return Add::make(a, b);
} else if (op == "Max") {
return Max::make(a, b);
} else if (op == "Min") {
return Min::make(a, b);
} else {
LOG(FATAL) << "Unsupported reduction " << op;
return Expr();
}
}
TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(AttrStmt);
......
......@@ -20,7 +20,6 @@ Expr Tensor::operator()(Array<Expr> indices) const {
return n;
}
Tensor TensorNode::make(Array<Expr> shape,
Type dtype,
Operation op,
......
......@@ -174,19 +174,8 @@ void MakeReduction(const ComputeOpNode* op,
}
const Reduce* reduce = op->body.as<Reduce>();
CHECK(reduce);
Expr init_value, update_value;
if (reduce->op == "Add") {
init_value = make_zero(reduce->type);
update_value = Add::make(t(args), reduce->source);
} else if (reduce->op == "Max") {
init_value = reduce->type.min();
update_value = Max::make(t(args), reduce->source);
} else if (reduce->op == "Min") {
init_value = reduce->type.max();
update_value = Min::make(t(args), reduce->source);
} else {
LOG(FATAL) << "Unsupported reduction " << reduce->op;
}
Expr init_value = Reduce::InitValue(reduce->op, reduce->type);
Expr update_value = Reduce::Combine(reduce->op, t(args), reduce->source);
*init = Provide::make(t->op, t->value_index, init_value, args);
*provide = Provide::make(t->op, t->value_index, update_value, args);
if (!is_one(reduce->condition)) {
......@@ -194,15 +183,6 @@ void MakeReduction(const ComputeOpNode* op,
}
}
Stmt MakeProvide(const ComputeOpNode* op,
const Tensor& t) {
Array<Expr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
return Provide::make(t->op, t->value_index, op->body, args);
}
Stmt Substitute(Stmt s,
const std::unordered_map<IterVar, Expr>& value_map) {
Map<Var, Expr> temp;
......@@ -212,11 +192,107 @@ Stmt Substitute(Stmt s,
return ir::Substitute(s, temp);
}
// Cross Thread reduction marker.
bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage) {
std::unordered_set<IterVar> rebase_thread;
for (IterVarRelation rel : stage->relations) {
if (const RebaseNode* s = rel.as<RebaseNode>()) {
if (s->parent->iter_type == kCommReduce &&
s->rebased->iter_type == kThreadIndex) {
rebase_thread.insert(s->rebased);
}
}
}
if (rebase_thread.size() == 0) return false;
// Verify correctness of leaf nest.
bool reduce_start = false;
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
LOG(FATAL) << "Cannot mix cross thread reduce with normal reduce";
} else if (rebase_thread.count(iv)) {
reduce_start = true;
} else {
CHECK(!reduce_start)
<< "Cross thread reduce cannot swap with normal data axis";
}
}
return true;
}
Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) {
Array<Expr> args;
for (IterVar iv : self->axis) {
args.push_back(iv->var);
}
const Reduce* reduce = self->body.as<Reduce>();
CHECK(reduce);
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
auto conds = op::MakeBoundCheck(
stage, dom_map, false,
std::unordered_set<IterVar>(), value_map);
Expr cond = reduce->condition;
for (Expr v : conds) {
cond = cond && v;
}
Var res_handle("reduce_temp", Handle());
Array<Expr> freduce_args;
freduce_args.push_back(StringImm::make(reduce->op));
freduce_args.push_back(reduce->source);
freduce_args.push_back(cond);
std::vector<Expr> thread_head_check;
for (IterVarRelation rel : stage->relations) {
if (const RebaseNode* s = rel.as<RebaseNode>()) {
if (s->parent->iter_type == kCommReduce &&
s->rebased->iter_type == kThreadIndex) {
freduce_args.push_back(s->rebased->var);
thread_head_check.push_back(s->rebased->var == 0);
}
}
}
Stmt reduce_body = Store::make(
res_handle, Call::make(
reduce->type,
ir::intrinsic::tvm_thread_allreduce,
freduce_args, Call::Intrinsic),
0);
Stmt assign_body = Provide::make(
stage->op, 0, Load::make(reduce->type, res_handle, 0), args);
assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
Stmt body = Allocate::make(
res_handle, reduce->type, {1}, const_true(),
Block::make(reduce_body, assign_body));
body = AttrStmt::make(
res_handle, attr::storage_scope, StringImm::make("local"), body);
body = Substitute(body, value_map);
return MergeNest(nest, body);
}
Stmt MakeProvide(const ComputeOpNode* op,
const Tensor& t) {
Array<Expr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
return Provide::make(t->op, t->value_index, op->body, args);
}
Stmt ComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
if (IsCrossThreadReduction(this, stage)) {
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map);
}
Stmt init, provide;
if (this->reduce_axis.size() == 0) {
provide = MakeProvide(this, stage->op.output(0));
......@@ -227,9 +303,9 @@ Stmt ComputeOpNode::BuildProvide(
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
nest.push_back(op::MakeBoundCheck(
nest.push_back(op::MakeIfNest(op::MakeBoundCheck(
stage, dom_map, false,
std::unordered_set<IterVar>(), value_map));
std::unordered_set<IterVar>(), value_map)));
provide = Substitute(provide, value_map);
if (init.defined()) {
......@@ -266,7 +342,8 @@ Stmt ComputeOpNode::BuildProvide(
stage, dom_map, begin_loop, true,
skip_iter, &init_value_map);
init_nest.push_back(
op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map));
op::MakeIfNest(
op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map)));
init = Substitute(init, init_value_map);
init = MergeNest(init_nest, init);
// common nest
......
......@@ -160,37 +160,45 @@ void PassUpBoundCheck(const Stage& s,
}
}
std::vector<Stmt> MakeBoundCheck(
std::vector<Expr> MakeBoundCheck(
const Stage& stage,
const Map<IterVar, Range>& dom_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter,
const std::unordered_map<IterVar, Expr>& value_map) {
Stmt no_op = Evaluate::make(0);
std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : stage->leaf_iter_vars) {
bound_state[iv] = false;
}
PassUpBoundCheck(stage, dom_map, &bound_state);
// insert conditions
std::vector<Stmt> nest;
std::vector<Expr> preds;
for (IterVar iv : stage->op->root_iter_vars()) {
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
Range dom = dom_map.at(iv);
if (bound_state.at(iv)) {
Expr condition = ComputeExpr<Sub>(value_map.at(iv), dom->min) < dom->extent;
nest.emplace_back(IfThenElse::make(condition, no_op));
preds.emplace_back(
ComputeExpr<Sub>(value_map.at(iv), dom->min) < dom->extent);
}
CHECK(iv->dom.defined());
if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
Expr condition = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min) < iv->dom->extent;
nest.emplace_back(IfThenElse::make(condition, no_op));
preds.emplace_back(
ComputeExpr<Sub>(value_map.at(iv), iv->dom->min) < iv->dom->extent);
}
}
return preds;
}
std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
Stmt no_op = Evaluate::make(0);
std::vector<Stmt> nest;
for (const Expr& cond : predicates) {
nest.emplace_back(IfThenElse::make(cond, no_op));
}
return nest;
}
// replacer to replace tensors
class TensorReplacer : public ir::IRMutator {
public:
......
......@@ -43,13 +43,21 @@ MakeLoopNest(const Stage& stage,
* \param skip_ivar_domain Whether we can skip check for IterVar's original domain.
* \param skip_iter Whether skip certain iteration.
* \param value_map The result value of each IterVar.
* \return List of predicates that we need to check.
*/
std::vector<Stmt>
std::vector<Expr>
MakeBoundCheck(const Stage& stage,
const Map<IterVar, Range>& dom_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter,
const std::unordered_map<IterVar, Expr>& value_map);
/*!
* \brief Create a nest of if checking the predicates.
*
* \param predicates The predicates to be checked.
* \return List of If nest that checks the predicates.
*/
std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates);
/*!
* \brief Replace the tensor reference in stmt by the replace map.
......
......@@ -269,7 +269,8 @@ Stmt ScanOpNode::BuildProvide(
stage, dom_map, 0, false, empty, &vmap);
nest[begin_scan].push_back(init);
nest.push_back(
op::MakeBoundCheck(stage, dom_map, false, empty, vmap));
op::MakeIfNest(
op::MakeBoundCheck(stage, dom_map, false, empty, vmap)));
return MergeNest(nest, provide);
}
} // namespace tvm
......@@ -70,6 +70,21 @@ inline Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) {
return body;
}
/*!
* \brief combine sequence of operations.
* \param seq The sequence.
* \return The combined Stmt
*/
inline Stmt MergeSeq(const std::vector<Stmt>& seq) {
if (seq.size() == 0) return Evaluate::make(0);
Stmt body = seq[0];
for (size_t i = 1; i < seq.size(); ++i) {
body = Block::make(body, seq[i]);
}
return body;
}
} // namespace ir
} // namespace tvm
#endif // TVM_PASS_IR_UTIL_H_
/*!
* Copyright (c) 2017 by Contributors
* Lower allreduce to device implementable ir.
* \file lower_thread_allreduce.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
class ThreadAllreduceBuilder : public IRMutator {
public:
explicit ThreadAllreduceBuilder(int warp_size)
: warp_size_(warp_size) {}
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->type_key == attr::thread_extent) {
thread_extents_.push_back(op);
Stmt ret = IRMutator::Mutate_(op, s);
thread_extents_.pop_back();
return ret;
} else if (op->type_key == attr::storage_scope) {
Stmt ret = IRMutator::Mutate_(op, s);
op = ret.as<AttrStmt>();
const Variable* v = op->node.as<Variable>();
if (alloc_remap_.count(v)) {
return op->body;
} else {
return ret;
}
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
return MakeAllreduce(op, call);
} else {
return stmt;
}
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
const Allocate* repl = it->second.as<Allocate>();
// use volatile access to shared buffer.
stmt = AttrStmt::make(
repl->buffer_var, attr::volatile_scope, 1, op->body);
stmt = Allocate::make(
repl->buffer_var, repl->type,
repl->extents, repl->condition, stmt);
stmt = AttrStmt::make(
repl->buffer_var, attr::storage_scope,
StringImm::make("shared"), stmt);
return stmt;
} else {
return stmt;
}
}
Expr Mutate_(const Load* op, const Expr& e) final {
auto it = load_remap_.find(op->buffer_var.get());
if (it != load_remap_.end()) {
CHECK(is_zero(op->index));
return it->second;
} else {
return IRMutator::Mutate_(op, e);
}
}
private:
// Thread entry
struct ThreadEntry {
runtime::ThreadScope scope;
IterVar iv;
int extent;
// comparator
bool operator<(const ThreadEntry& other) const {
return scope.dim_index < other.scope.dim_index;
}
};
// make allreduce.
Stmt MakeAllreduce(const Store* op, const Call* call) {
const std::string& op_code = call->args[0].as<StringImm>()->value;
Expr value = call->args[1];
Expr cond = call->args[2];
if (!is_one(cond)) {
value = Select::make(
cond, value, Reduce::InitValue(op_code, value.type()));
}
std::unordered_set<const Variable*> reduce_index_;
for (size_t i = 3; i < call->args.size(); ++i) {
const Variable* v = call->args[i].as<Variable>();
CHECK(v);
reduce_index_.insert(v);
}
size_t nmatch = 0;
std::vector<ThreadEntry> vred, vpar;
for (const AttrStmt* attr : thread_extents_) {
ThreadEntry e;
IterVar iv(attr->node.node_);
e.scope = runtime::ThreadScope::make(iv->thread_tag);
e.iv = iv;
CHECK(arith::GetConstInt(attr->value, &(e.extent)))
<< "Need constant extent for thread group";
CHECK_LE(e.scope.rank, 1);
CHECK_GE(e.scope.dim_index, 0)
<< "vthread do not work with cross thread reduction";
if (e.scope.rank == 1) {
if (reduce_index_.count(iv->var.get())) {
vred.push_back(e);
++nmatch;
} else {
vpar.push_back(e);
}
}
}
CHECK_EQ(nmatch, reduce_index_.size())
<< "Not all reduce index are presented in the context";
std::sort(vred.begin(), vred.end());
std::sort(vpar.begin(), vpar.end());
// the size of each index.
int reduce_extent, group_extent;
int threadx_extent = 1;
Expr reduce_index = FlattenThread(vred, &reduce_extent);
Expr group_index = FlattenThread(vpar, &group_extent);
if (reduce_extent == 1) {
// special case, no reduction is needed.
return Store::make(op->buffer_var, value, 0);
}
// Whether the threadIdx.x is involved in reduction.
if (vred[0].scope.dim_index == 0) {
threadx_extent = vred[0].extent;
}
Var shared_buf("red_buf", Handle());
std::vector<Stmt> seq;
seq.emplace_back(Store::make(
shared_buf, value,
BufIndex(reduce_index, group_index, reduce_extent)));
seq.emplace_back(SyncThread());
seq.emplace_back(MakeBufAllreduce(
op_code, value.type(), shared_buf,
reduce_index, group_index, reduce_extent, threadx_extent));
CHECK(!load_remap_.count(op->buffer_var.get()));
load_remap_[op->buffer_var.get()] =
Load::make(
value.type(), shared_buf,
BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent));
alloc_remap_[op->buffer_var.get()] =
Allocate::make(shared_buf, value.type(),
{Expr(group_extent), Expr(reduce_extent)},
const_true(), Evaluate::make(0));
return MergeSeq(seq);
}
// make allreduce.
Stmt MakeBufAllreduce(const std::string& op,
Type type,
Var shared_buf,
Expr reduce_index,
Expr group_index,
int reduce_extent,
int threadx_extent) {
// Get next power of two
int reduce_align = 1;
while (reduce_extent > reduce_align) {
reduce_align = reduce_align << 1;
}
CHECK_GT(reduce_align, 1);
std::vector<Stmt> seq;
Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
// make reduction
auto freduce = [&](int offset) {
Expr b = Load::make(
type, shared_buf,
BufIndex(reduce_index + offset, group_index, reduce_extent));
Expr a = Load::make(type, shared_buf, buf_index);
return Store::make(shared_buf, Reduce::Combine(op, a, b), buf_index);
};
// Step one, check for
if (reduce_align > reduce_extent) {
// reduction with the boundary condition
reduce_align = reduce_align >> 1;
Expr cond = reduce_index < (reduce_extent - reduce_align);
seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread());
}
CHECK(threadx_extent >= 1 && warp_size_ >= 1);
// normal synchronization
while (reduce_align > threadx_extent ||
reduce_align > warp_size_) {
reduce_align = reduce_align >> 1;
Expr cond = reduce_index < reduce_align;
seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread());
}
// in warp synchronization.
std::vector<Stmt> in_warp_seq;
Expr in_warp_cond = reduce_index < (reduce_align >> 1);
while (reduce_align > 1) {
reduce_align = reduce_align >> 1;
in_warp_seq.emplace_back(freduce(reduce_align));
}
if (in_warp_seq.size() != 0) {
Stmt warp_body = MergeSeq(in_warp_seq);
seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body));
}
return MergeSeq(seq);
}
// Flatten the thread index.
// Also return a warp number,
Expr FlattenThread(const std::vector<ThreadEntry>& tvec,
int* out_total_extent) {
int& total_extent = *out_total_extent;
total_extent = 1;
if (tvec.size() == 0) {
return make_zero(Int(32));
}
Expr ret;
for (const ThreadEntry& e : tvec) {
if (ret.defined()) {
ret = ret + e.iv->var * total_extent;
} else {
CHECK_EQ(total_extent, 1);
ret = e.iv->var;
}
total_extent *= e.extent;
}
return ret;
}
// sync thread op.
static Stmt SyncThread() {
return Evaluate::make(
Call::make(Int(32), intrinsic::tvm_storage_sync,
{StringImm::make("shared")},
Call::Intrinsic));
}
// The local buffer index.
static Expr BufIndex(Expr reduce_index, Expr group_index, int reduce_extent) {
if (!is_zero(group_index)) {
return ir::Simplify(group_index * reduce_extent + reduce_index);
} else {
return reduce_index;
}
}
// The warp size of the device.
int warp_size_{1};
// surrounding scope of thread extent.
std::vector<const AttrStmt*> thread_extents_;
// The load remap
std::unordered_map<const Variable *, Expr> load_remap_;
// Allocate remap
std::unordered_map<const Variable *, Stmt> alloc_remap_;
};
LoweredFunc
LowerThreadAllreduce(LoweredFunc f, int warp_size) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body);
return LoweredFunc(n);
}
} // namespace ir
} // namespace tvm
......@@ -112,18 +112,10 @@ class ThreadAxisConfig {
arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
filled[ts.rank * 3 + ts.dim_index] = true;
}
work_dim_ = 3;
work_dim_ = 1;
for (int i = 0; i < 3; ++i) {
if (!filled[i]) {
for (int j = i; j < 3; ++j) {
CHECK(!filled[j] && !filled[j + 3])
<< "Invalid thread group configuration";
}
work_dim_ = i;
break;
} else {
CHECK(filled[i])
<< "Must have both threadIdx and blockIdx";
if (filled[i] || filled[i + 3]) {
work_dim_ = i + 1;
}
}
}
......
......@@ -75,8 +75,16 @@ void PassDownDomain(const Stage& stage,
CHECK(allow_missing);
continue;
}
state[r->rebased] = Range::make_with_min_extent(
0, state.at(r->parent)->extent);
Range res = Range::make_with_min_extent(
0, state.at(r->parent)->extent);
if (r->rebased->dom.defined()) {
Range rebase_rng = r->rebased->dom;
bool match = is_zero(rebase_rng->min);
if (!prove_equal(rebase_rng->extent, res->extent)) match = false;
CHECK(match) << r->rebased
<< " does not match parent scope's range";
}
state[r->rebased] = res;
} else {
LOG(FATAL) << "unknown relation type";
}
......
......@@ -305,8 +305,10 @@ Tensor Schedule::rfactor(const Tensor& tensor,
}
}
// predicate generation, copy not touched axis.
const Reduce* reduce = compute_op->body.as<Reduce>();
CHECK(reduce) << "Can only rfactor non-inline reductions";
Expr predicate = reduce->condition;
std::unordered_map<const Variable*, Expr> vsub;
Expr predicate;
for (IterVar iv : compute_op->reduce_axis) {
if (!touch_map.count(iv)) {
n->reduce_axis.push_back(iv);
......@@ -316,10 +318,10 @@ Tensor Schedule::rfactor(const Tensor& tensor,
vsub[iv->var.get()] = index;
if (!index.same_as(iv->var)) {
Expr cond = (index < dom_map.at(iv)->extent);
if (predicate.defined()) {
predicate = predicate && cond;
} else {
if (is_one(predicate)) {
predicate = cond;
} else {
predicate = predicate && cond;
}
}
}
......@@ -333,8 +335,6 @@ Tensor Schedule::rfactor(const Tensor& tensor,
n->reduce_axis.push_back(IterVar(ncpy));
}
}
const Reduce* reduce = compute_op->body.as<Reduce>();
CHECK(reduce) << "Can only rfactor non-inline reductions";
n->body = Reduce::make(reduce->op,
VarReplacer(vsub).Mutate(reduce->source),
n->reduce_axis,
......
......@@ -136,6 +136,25 @@ Stage& Stage::compute_root() { // NOLINT(*)
return *this;
}
Stage& Stage::rebase(IterVar parent, IterVar rebased) { // NOLINT(*)
CHECK(parent->iter_type == kDataPar ||
parent->iter_type == kCommReduce)
<< "Cannot rebase " << IterVarType2String(parent->iter_type);
CHECK(rebased->iter_type == kThreadIndex)
<< "Cannot rebase by " << IterVarType2String(rebased->iter_type)
<< ", only thread axis is allowed so far";
ArrayNode* all_vars = (*this)->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = (*this)->leaf_iter_vars.CopyOnWrite();
size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
(*this)->relations.push_back(RebaseNode::make(parent, rebased));
// add vars to all vars
all_vars->data.push_back(rebased.node_);
// replace the position.
leaf_vars->data.erase(leaf_vars->data.begin() + pos);
leaf_vars->data.insert(leaf_vars->data.begin() + pos, rebased.node_);
return *this;
}
Stage& Stage::split(
IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
CheckSplit(operator->(), parent, IterVar());
......
......@@ -51,7 +51,7 @@ def test_rfactor():
n = tvm.convert(1027)
A = tvm.placeholder((n,), name='A')
k = tvm.reduce_axis((0, n))
B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B')
B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B')
kf = tvm.reduce_axis((0, 4))
# schedule
s = tvm.Schedule(B.op)
......@@ -78,6 +78,56 @@ def test_rfactor():
check_target()
def test_rfactor_threads():
nn = 1027
mm = 10
n = tvm.convert(nn)
m = tvm.convert(mm)
A = tvm.placeholder((m, n), name='A')
k = tvm.reduce_axis((0, n))
nthread = 16
B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B')
tx = tvm.thread_axis((0, nthread), "threadIdx.x")
ty = tvm.thread_axis((0, nthread), "threadIdx.y")
bx = tvm.thread_axis(None, "blockIdx.x")
# schedule
s = tvm.Schedule(B.op)
ko, kf = s[B].split(k, factor=nthread)
BF = s.rfactor(B, kf)
xo, xi = s[B].split(s[B].op.axis[0], factor=nthread, outer=bx)
s[B].rebase(xi, ty)
s[B].rebase(s[B].op.reduce_axis[0], tx)
s[BF].compute_at(s[B], tx)
# one line to build the function.
def check_target(device, host="stackvm"):
if not tvm.codegen.enabled(device):
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
fapi = tvm.lower(s, args=[A, B])
fapi2 = tvm.ir_pass.LowerThreadAllreduce(fapi, 32)
fsum = tvm.build(fapi,
target=device,
name="mysum")
print(fsum.imported_modules[0].get_source())
# launch the kernel.
n = nn
m = mm
a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
fsum(a, b)
res = np.sum(a.asnumpy(), axis=1)
res[:2] = 0
np.testing.assert_allclose(
b.asnumpy(), res, rtol=1e-4)
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_target("cuda")
check_target("opencl")
if __name__ == "__main__":
test_rfactor_threads()
test_rfactor()
test_sum()
......@@ -35,11 +35,11 @@ def test_buffer_doublebuff():
write_data.put_int(0)
# De-assert reset
sess.yield_until_posedge()
sess.yield_until_next_cycle()
rst.put_int(0)
# Leave the following signals set to true
sess.yield_until_posedge()
sess.yield_until_next_cycle()
write_valid.put_int(1)
# Main simulation loop
......@@ -50,15 +50,15 @@ def test_buffer_doublebuff():
if (write_idx < len(test_data)):
write_advance.put_int(0)
if (write_ready.get_int()):
write_data.put_int(test_data[write_idx])
write_addr.put_int(write_idx%window_width)
write_data.put_int(int(test_data[write_idx]))
write_addr.put_int(write_idx % window_width)
if (write_idx%window_width==window_width-1):
write_advance.put_int(1)
write_idx += 1
else:
write_advance.put_int(0)
write_valid.put_int(0)
# correctness checks
if (read_data_valid.get_int()):
assert(read_data.get_int()==test_data[read_idx])
......@@ -66,7 +66,7 @@ def test_buffer_doublebuff():
read_idx += 1
# step
sess.yield_until_posedge()
sess.yield_until_next_cycle()
if __name__ == "__main__":
......
......@@ -27,7 +27,7 @@ def test_buffer_fifo():
write_data.put_int(0)
# De-assert reset
sess.yield_until_posedge()
sess.yield_until_next_cycle()
rst.put_int(0)
# Main simulation loop
......@@ -46,7 +46,7 @@ def test_buffer_fifo():
assert(read_data.get_int()==test_data[read_idx])
read_idx += 1
# step
sess.yield_until_posedge()
sess.yield_until_next_cycle()
if __name__ == "__main__":
......
......@@ -33,11 +33,11 @@ def test_buffer_linebuff():
write_data.put_int(0)
# De-assert reset
sess.yield_until_posedge()
sess.yield_until_next_cycle()
rst.put_int(0)
# Leave the following signals set to true
sess.yield_until_posedge()
sess.yield_until_next_cycle()
write_advance.put_int(1)
write_valid.put_int(1)
......@@ -48,12 +48,12 @@ def test_buffer_linebuff():
# write logic
if (write_idx < len(test_data)):
if (write_ready.get_int()):
write_data.put_int(test_data[write_idx])
write_data.put_int(int(test_data[write_idx]))
write_idx += 1
else:
write_advance.put_int(0)
write_valid.put_int(0)
# correctness checks
if (read_data_valid.get_int()):
# Derive convolution window indices
......@@ -67,7 +67,7 @@ def test_buffer_linebuff():
read_idx += 1
# step
sess.yield_until_posedge()
sess.yield_until_next_cycle()
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment