Commit 8e51af2f by Tianqi Chen Committed by GitHub

[PASS] CombineContextCall (#255)

parent 36f20b54
......@@ -302,6 +302,18 @@ constexpr const char* tvm_stack_make_array = "tvm_stack_make_array";
*/
constexpr const char* tvm_call_packed = "tvm_call_packed";
/*!
* \brief See pesudo code
* Mark the content as thread local context, can get optimized
* by only call the call once at thread start.
*
* Do not allow nesting(getting a thread context from another).
*
* Handle tvm_thread_context(Expr call) {
* return call;
* }
*/
constexpr const char* tvm_thread_context = "tvm_thread_context";
/*!
* \brief Lowered version of call packed, the space of value and
* type codes are explicitly allocated.
*
......
......@@ -61,6 +61,19 @@ bool Equal(const Expr& lhs, const Expr& rhs);
bool Equal(const Stmt& lhs, const Stmt& rhs);
/*!
* \brief Deep compare lhs and rhs.
*
* If you only want equality comparison, use Equal
* which will also tie definitions. The compare mode
* will give order of expression in total order.
*
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
int Compare(const Expr& lhs, const Expr& rhs);
/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
*
......@@ -315,6 +328,13 @@ LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
LoweredFunc LowerPackedCall(LoweredFunc f);
/*!
* \brief Combine context function calls.
* \param f The host function to be lowered.
* \return Transformed function.
*/
LoweredFunc CombineContextCall(LoweredFunc f);
/*!
* \brief Lower intrinsic function calls.
* \param f The device function to be lowered.
* \param target The target device.
......
......@@ -321,6 +321,7 @@ def build(sch,
device_type = ndarray.context(device, 0).device_type
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerPackedCall(x) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
if fdevice:
if not target_host:
......
......@@ -89,7 +89,7 @@ def call_pure_extern(dtype, func_name, *args):
The data type of the result.
func_name: str
The intrinsic function name.
The extern function name.
args : list
Positional arguments.
......@@ -102,6 +102,30 @@ def call_pure_extern(dtype, func_name, *args):
return _make.Call(
dtype, func_name, convert(args), _Call.PureExtern, None, 0)
def call_extern(dtype, func_name, *args):
"""Build expression by calling a extern function.
Parameters
----------
dtype : str
The data type of the result.
func_name: str
The extern function name.
args : list
Positional arguments.
Returns
-------
call : Expr
The call expression.
"""
return _make.Call(
dtype, func_name, convert(args), _Call.Extern, None, 0)
def exp(x):
"""Take exponetial of input x.
......
......@@ -102,5 +102,6 @@ REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerPackedCall);
REGISTER_PASS1(CombineContextCall);
} // namespace ir
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* Combine calls into context related function into one.
*
* \file combine_context_call.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <map>
namespace tvm {
namespace ir {
// Calculate the statistics of packed function.
// These information are needed during codegen.
class ContextCallCombiner final : public IRMutator {
public:
struct CompareExpr {
bool operator()(const Expr& lhs, const Expr& rhs) const {
return Compare(lhs, rhs) < 0;
}
};
Expr Mutate_(const Call* op, const Expr& e) final {
if (op->is_intrinsic(intrinsic::tvm_thread_context)) {
CHECK_EQ(op->args.size(), 1U);
Expr ctx = op->args[0];
auto it = ctx_map_.find(ctx);
if (it != ctx_map_.end()) {
return it->second;
} else {
CHECK(ctx.type().is_handle());
std::string name;
if (const Call* call = ctx.as<Call>()) {
name = call->name + "_cache";
} else {
name = "ctx_cache_";
}
Var ctx_var(name, ctx.type());
ctx_map_[ctx] = ctx_var;
return ctx_var;
}
} else {
return IRMutator::Mutate_(op, e);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::thread_extent) {
// Map of comparison expression to variable
std::map<Expr, Var, CompareExpr> temp;
std::swap(temp, ctx_map_);
Stmt stmt = IRMutator::Mutate_(op, s);
std::swap(temp, ctx_map_);
return BuildContext(temp, stmt);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
if (op->for_type == ForType::Parallel) {
// Map of comparison expression to variable
std::map<Expr, Var, CompareExpr> temp;
std::swap(temp, ctx_map_);
Stmt stmt = IRMutator::Mutate_(op, s);
std::swap(temp, ctx_map_);
return BuildContext(temp, stmt);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Combine(Stmt stmt) {
return BuildContext(ctx_map_, this->Mutate(stmt));
}
private:
static Stmt BuildContext(const std::map<Expr, Var, CompareExpr>& cmap,
Stmt body) {
for (const auto& kv : cmap) {
body = LetStmt::make(kv.second, kv.first, body);
}
return body;
}
// Map of comparison expression to variable
std::map<Expr, Var, CompareExpr> ctx_map_;
};
LoweredFunc CombineContextCall(LoweredFunc f) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = ContextCallCombiner().Combine(n->body);
return LoweredFunc(n);
}
} // namespace ir
} // namespace tvm
......@@ -35,8 +35,15 @@ class IRDeepCompare :
return order_ == 0;
}
int Compare(const Expr& lhs, const Expr& rhs) {
tie_def_ = false;
VisitExpr(lhs, rhs);
return order_;
}
void VisitExpr(const Expr& n, const Expr& other) override {
if (order_ != 0) return;
if (n.same_as(other)) return;
if (CompareValue(n->type_index(), other->type_index()) != 0) return;
if (CompareType(n.type(), other.type()) != 0) return;
ExprComparator::VisitExpr(n, other);
......@@ -44,6 +51,7 @@ class IRDeepCompare :
void VisitStmt(const Stmt& n, const Stmt& other) override {
if (order_ != 0) return;
if (n.same_as(other)) return;
if (CompareValue(n->type_index(), other->type_index()) != 0) return;
StmtComparator::VisitStmt(n, other);
}
......@@ -413,5 +421,9 @@ bool Equal(const Expr& lhs, const Expr& rhs) {
return IRDeepCompare().Equal(lhs, rhs);
}
int Compare(const Expr& lhs, const Expr& rhs) {
return IRDeepCompare().Compare(lhs, rhs);
}
} // namespace ir
} // namespace tvm
import tvm
def test_for():
dev_type = tvm.var("dev_type")
def device_context(dev_id):
ctx = tvm.call_extern("handle", "device_context", dev_type, dev_id)
return tvm.make.Call(
"handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0)
ib = tvm.ir_builder.create()
n = tvm.var("n")
A = ib.allocate("float32", n, name="A", scope="global")
with ib.for_range(0, n, name="i") as i:
ib.emit(tvm.call_extern
("int32", "fadd", device_context(0), A))
with ib.for_range(0, 10, name="j") as j:
ib.emit(tvm.call_extern
("int32", "fadd", device_context(1), A))
ib.emit(tvm.call_extern
("int32", "fadd", device_context(0), A))
body = ib.get()
f = tvm.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True)
f = tvm.ir_pass.CombineContextCall(f)
assert f.body.value.dtype == "handle"
assert f.body.body.value.dtype == "handle"
if __name__ == "__main__":
test_for()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment