Unverified Commit 9c1e74ce by Zhi Committed by GitHub

[REFACTOR][BOYC] Non recursive partitioning (#5493)

* non recursive partitioning

* refactor maps

* rebase upstream

* refactor shared output

* address comments

Co-authored-by: Cody Yu <comaniac0422@gmail.com>
parent 967d7318
......@@ -54,39 +54,30 @@ namespace partitioning {
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
/*!
* \brief The checker that verifies if a Relay program is annotated correctly
* for partitioning.
/*! \brief This struct maintains the required metadata for a region to generate a corresponding
* global function and function call. Global function will be passed to the target specific codegen
* and function call will be used in the transform Relay graph to invoke the function in runtime.
*/
class AnnotationChecker : public ExprVisitor {
public:
bool Check() {
if (!found_start_ && !found_end_) {
LOG(WARNING) << "No compiler annotation found";
} else if (!found_start_) {
LOG(ERROR) << "compiler_begin annotation is missing";
return false;
} else if (!found_end_) {
LOG(ERROR) << "compiler_end annotation is missing";
return false;
}
return true;
}
struct RegionFuncMetadata {
/*! \brief The call node of the generated global function for this region. */
Call func_call;
void VisitExpr_(const CallNode* call) final {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return;
} else if (call->op == compiler_begin_op) {
found_start_ = true;
} else if (call->op == compiler_end_op) {
found_end_ = true;
}
}
/*! \brief A list of argument pairs. Each pair includes (var, expr). var is used
* as a function node argument; input expression is used as a function call parameter.
*/
std::vector<std::pair<Var, Expr>> args;
private:
bool found_start_{false};
bool found_end_{false};
/*! \brief Map from each region output expr (compiler end) node to
* the corresponding function output expr.
*/
std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> region_func_out;
/*! \brief Map from each region input expression (compiler begin) to
* the corresponding function input variable. This cache is used to make sure
* a region function will not have duplicated inputs even if it refers to
* the same expr multiple times.
*/
std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> region_func_in;
};
/*! \brief This class partitions the expr labeled with begin and end annotations
......@@ -124,37 +115,35 @@ class AnnotationChecker : public ExprVisitor {
* the compiler name.
*/
class Partitioner : public ExprMutator {
class Partitioner : public MixedModeMutator {
public:
explicit Partitioner(const IRModule& module) : module_(module) {
for (auto f : module->functions) {
GlobalVar f_var = f.first;
BaseFunc f_func = f.second;
// Creating regionset per function in the module
// Creating regionset per function in the module.
auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op,
partitioning::compiler_end_op);
regions_sets_[region_set] = f_func;
}
}
Expr VisitExpr_(const CallNode* call) final {
Expr Rewrite_(const CallNode* call, const Expr& post) final {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return ExprMutator::VisitExpr_(call);
return post;
} else if (call->op == compiler_begin_op) {
// The annotation node is inserted on edge so it must have only one
// argument.
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);
// Traverse the rest graph.
Expr parent = call->args[0];
auto input_expr = VisitExpr(parent);
auto input_expr = Downcast<Call>(post)->args[0];
// Backtrace the parent to find the first ancestor node that is not a begin or end op
while (const auto* parent_call = parent.as<CallNode>()) {
if (parent_call->op == compiler_begin_op ||
parent_call->op == compiler_end_op) {
if (parent_call->op == compiler_begin_op || parent_call->op == compiler_end_op) {
parent = parent_call->args[0];
} else {
break;
......@@ -165,8 +154,8 @@ class Partitioner : public ExprMutator {
int index = GetArgIdx(sg, GetRef<Call>(call));
CHECK_NE(index, -1);
if (shared_output_.count(parent) && shared_output_[parent].count(sg)) {
return shared_output_[parent][sg];
if (region_func_meta_[sg].region_func_in.count(parent)) {
return region_func_meta_[sg].region_func_in[parent];
} else {
// The type of the created variable is the same as the compiler_begin
// node.
......@@ -177,11 +166,11 @@ class Partitioner : public ExprMutator {
std::pair<Var, Expr> cand = std::make_pair(var, input_expr);
if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
region_args[sg].end()) {
region_args[sg].push_back(cand);
if (std::find(region_func_meta_[sg].args.begin(), region_func_meta_[sg].args.end(), cand) ==
region_func_meta_[sg].args.end()) {
region_func_meta_[sg].args.push_back(cand);
}
shared_output_[parent][sg] = var;
region_func_meta_[sg].region_func_in[parent] = var;
return std::move(var);
}
} else {
......@@ -197,114 +186,21 @@ class Partitioner : public ExprMutator {
BaseFunc f = GetFunc(GetRef<Call>(call));
// Traverse subgraph inputs.
auto input = VisitExpr(call->args[0]);
auto input = Downcast<Call>(post)->args[0];
CHECK(region.defined()) << "Region not defined for " << GetRef<Call>(call);
// functions are created for each annotated regions,
// when their first output is encountered.
// If multiple outputs are there, a tuple node is inserted at the end.
// region_function_calls is map that maintains
// (each annotated regions) --> created function
if (region_function_calls.find(region) == region_function_calls.end()) {
// First time this region is encountered in the traversal.
// Creating the function.
if (!region_func_meta_[region].func_call.defined()) {
// First time this region is encountered in the traversal. Creating the function.
CreateFunction(region, call);
}
// Retrieve this particular output of function.
return GetFunctionOutput(region, GetRef<Call>(call));
}
}
Expr VisitExpr_(const TupleNode* op) final {
auto region = GetRegion(GetRef<Tuple>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Array<Expr> fields;
for (auto field : op->fields) {
fields.push_back(VisitExpr(field));
}
return Tuple(fields);
}
}
Expr VisitExpr_(const TupleGetItemNode* g) final {
auto region = GetRegion(GetRef<TupleGetItem>(g));
if (!region.defined()) {
return ExprMutator::VisitExpr_(g);
} else {
auto t = VisitExpr(g->tuple);
return TupleGetItem(t, g->index);
}
}
Expr VisitExpr_(const FunctionNode* op) final {
auto region = GetRegion(GetRef<Function>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Array<Var> params;
for (auto param : op->params) {
Var new_param = Downcast<Var>(VisitExpr(param));
params.push_back(new_param);
}
auto body = VisitExpr(op->body);
return Function(params, body, op->ret_type, op->type_params, op->attrs);
}
}
Expr VisitExpr_(const LetNode* op) final {
auto region = GetRegion(GetRef<Let>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Var var = Downcast<Var>(VisitExpr(op->var));
auto value = VisitExpr(op->value);
auto body = VisitExpr(op->body);
return Let(var, value, body);
}
}
Expr VisitExpr_(const IfNode* op) final {
auto region = GetRegion(GetRef<If>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
auto guard = VisitExpr(op->cond);
auto true_b = VisitExpr(op->true_branch);
auto false_b = VisitExpr(op->false_branch);
return If(guard, true_b, false_b);
}
}
Expr VisitExpr_(const RefCreateNode* op) final {
auto region = GetRegion(GetRef<RefCreate>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Expr value = VisitExpr(op->value);
return RefCreate(value);
}
}
Expr VisitExpr_(const RefReadNode* op) final {
auto region = GetRegion(GetRef<RefRead>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Expr ref = VisitExpr(op->ref);
return RefRead(ref);
}
}
Expr VisitExpr_(const RefWriteNode* op) final {
auto region = GetRegion(GetRef<RefWrite>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Expr ref = VisitExpr(op->ref);
Expr value = VisitExpr(op->value);
return RefWrite(ref, value);
// Retrieve this particular output of function.
Expr region_out_expr = Downcast<Call>(GetRef<Call>(call))->args[0];
CHECK(region_func_meta_[region].region_func_out.count(region_out_expr));
return region_func_meta_[region].region_func_out[region_out_expr];
}
}
......@@ -370,24 +266,22 @@ class Partitioner : public ExprMutator {
}
/*!
* \brief This function is called first time that we encounter a compiler_end
* node to create the function for the subgraph.
* \brief Create a function and its function call for the given region. If the function has
* multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes
* will be created to serve output consumers.
*/
void CreateFunction(AnnotatedRegion region, const CallNode* call) {
// Create fields which is a unique list of outputs. Also populate
// region_return_indices_ map which maps parent of compiler_end node to
// corresponding index in fields.
void CreateFunction(AnnotatedRegion region, const CallNode* end_node) {
// Create fields which is a unique list of outputs.
Array<Expr> fields;
int i = 0;
for (auto ret : region->GetOutputs()) {
auto ret_node = Downcast<Call>(ret)->args[0];
std::unordered_map<Expr, int, ObjectHash, ObjectEqual> out_expr_to_idx;
int out_idx = 0;
for (auto region_end_node : region->GetOutputs()) {
auto ret_node = Downcast<Call>(region_end_node)->args[0];
// Don't duplicate outputs.
if (!region_return_indices_.count(region) ||
!region_return_indices_[region].count(ret_node)) {
auto ret_expr = VisitExpr(ret_node);
if (!out_expr_to_idx.count(ret_node)) {
auto ret_expr = MixedModeMutator::VisitExpr(ret_node);
fields.push_back(ret_expr);
region_return_indices_[region][ret_node] = i;
i++;
out_expr_to_idx[ret_node] = out_idx++;
}
}
......@@ -396,20 +290,14 @@ class Partitioner : public ExprMutator {
Map<Var, Expr> params_bind;
auto IsConstant = [](const Expr& expr) {
if (expr->IsInstance<ConstantNode>())
return true;
if (expr->IsInstance<TupleNode>()) {
auto tuple = expr.as<TupleNode>();
for (const auto& field : tuple->fields) {
if (!field->IsInstance<ConstantNode>())
return false;
}
return true;
}
return false;
if (expr->IsInstance<ConstantNode>()) return true;
if (!expr->IsInstance<TupleNode>()) return false;
const auto* tn = expr.as<TupleNode>();
return std::all_of(tn->fields.begin(), tn->fields.end(),
[](const Expr& e) { return e->IsInstance<ConstantNode>(); });
};
for (auto pair : region_args[region]) {
for (auto pair : region_func_meta_[region].args) {
params.push_back(pair.first);
if (IsConstant(pair.second)) {
params_bind.Set(pair.first, pair.second);
......@@ -422,23 +310,21 @@ class Partitioner : public ExprMutator {
if (fields.size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func =
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
Function(params, fields[0], end_node->args[0]->checked_type_, {}, DictAttrs());
} else {
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
}
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string target = end_node->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());
global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
runtime::String(name));
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::runtime::String(target));
WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, runtime::String(name));
global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
WithAttr(std::move(global_region_func), attr::kCompiler, tvm::runtime::String(target));
global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
// Constant propagation
if (!params_bind.empty()) {
......@@ -446,8 +332,7 @@ class Partitioner : public ExprMutator {
}
std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname))
<< "Global function " << fname << " already exists";
CHECK(!module_->ContainGlobalVar(fname)) << "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent
......@@ -456,128 +341,80 @@ class Partitioner : public ExprMutator {
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);
// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = Call(glob_func, param_expr);
region_function_calls[region] = ret;
}
// Create a call node for the function.
auto call = Call(glob_func, param_expr);
region_func_meta_[region].func_call = call;
/*!
* \brief Get the return(output) of the function for compiler end node "end_arg".
* This will return either a Call (for a function with a single output) or a
* TupleGetItem (for a function with multiple outputs).
*/
Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) {
Expr arg = Downcast<Call>(end_arg)->args[0];
// Function has one output.
if (region_return_indices_[region].size() == 1) {
return region_function_calls[region];
}
// Function has multiple outputs.
// Use already made TupleGetItem.
if (region_return_tuplegetitem_.count(region) &&
region_return_tuplegetitem_[region].count(arg)) {
return region_return_tuplegetitem_[region][arg];
}
// Create new TupleGetItem.
CHECK(region_return_indices_.count(region) &&
region_return_indices_[region].count(arg));
int index = region_return_indices_[region][arg];
auto func_call = region_function_calls[region];
auto tuple_get_item_ = TupleGetItem(func_call, index);
tuple_get_item_->checked_type_ = arg->checked_type_;
region_return_tuplegetitem_[region][arg] = tuple_get_item_;
return std::move(tuple_get_item_);
// Create output expr(s) for the function call.
if (out_expr_to_idx.size() == 1) {
// Single output direcly uses the call node as the output expr.
region_func_meta_[region].region_func_out[out_expr_to_idx.begin()->first] = call;
} else {
// Multiple outptus need to create TupleGetItem nodes as output exprs.
for (auto pair : out_expr_to_idx) {
Expr region_out_expr = pair.first; // The arg of a compiler end node of this region.
int idx = pair.second; // Corresponding function output tuple index.
auto tuple_get_item = TupleGetItem(call, idx);
tuple_get_item->checked_type_ = region_out_expr->checked_type_;
region_func_meta_[region].region_func_out[region_out_expr] = tuple_get_item;
}
}
}
/*!
* \brief This map maintains the already created function calls.
* This is required in the multi-output scenario, to link rest of the outputs
* to call
*/
std::unordered_map<AnnotatedRegion, Call, ObjectHash, ObjectEqual> region_function_calls;
/*!
* \brief This map maintains arguments (of region) visits through visitor
* patterns. Those arguement var and expression will be used to when creating
* the function.
*/
std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, ObjectHash, ObjectEqual>
region_args;
/*!
* \brief This map maintains the index of an output in the subgraph function
* for a given region. If there are multiple entries for a region, then the
* function has a tuple of multiple outputs for its return.
*/
using RegionRetIndexMap = std::unordered_map<Expr, int, ObjectHash, ObjectEqual>;
std::unordered_map<AnnotatedRegion, RegionRetIndexMap, ObjectHash, ObjectEqual>
region_return_indices_;
/*!
* \brief This map holds already created TupleGetItem nodes for accessing
* outputs of a function.
*/
using RegionRetTupleGetItemMap = std::unordered_map<Expr, TupleGetItem, ObjectHash, ObjectEqual>;
std::unordered_map<AnnotatedRegion, RegionRetTupleGetItemMap, ObjectHash, ObjectEqual>
region_return_tuplegetitem_;
/*! \brief Map from each region to its metadata of the generated function. */
std::unordered_map<AnnotatedRegion, RegionFuncMetadata, ObjectHash, ObjectEqual>
region_func_meta_;
/*!
* \brief Each region set is associated with a function in the module.
/*! \brief Each region set is associated with a function in the module.
* This map maintains the mapping between regionsets and the function it
* belongs to
*/
std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;
/*!\brief Cache the output that is shared by different nodes. */
using RegionOutputMap = std::unordered_map<AnnotatedRegion, Var, ObjectHash, ObjectEqual>;
std::unordered_map<Expr, RegionOutputMap, ObjectHash, ObjectEqual> shared_output_;
/*!\brief The IRModule used for partitioning. */
IRModule module_;
};
class DefaultRemover : public ExprMutator {
IRModule RemoveDefaultAnnotations(IRModule module) {
class DefaultRemover : public ExprRewriter {
public:
explicit DefaultRemover(const IRModule& module) : module_(module) {}
IRModule Remove() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
}
return module_;
}
DefaultRemover() = default;
Expr VisitExpr_(const CallNode* call) final {
Expr Rewrite_(const CallNode* call, const Expr& post) final {
auto attrs = call->attrs.as<CompilerAttrs>();
if (attrs != nullptr && attrs->compiler == "default") {
return VisitExpr(call->args[0]);
return Downcast<Call>(post)->args[0];
}
return ExprMutator::VisitExpr_(call);
return post;
}
};
private:
IRModule module_;
};
auto glob_funcs = module->functions;
// module is mutable, hence, we make a copy of it.
module.CopyOnWrite();
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
DefaultRemover remover;
auto removed = PostOrderRewrite(func->body, &remover);
func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs);
module->Update(pair.first, func);
}
}
return module;
}
} // namespace partitioning
namespace transform {
Pass PartitionGraph() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m, PassContext pc) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = [=](IRModule m,
PassContext pc) {
// TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute
// by treating them as un-annotated, but we don't have it yet. This workaround pass removes
// all "default" annotations and should be deleted in the future.
auto new_m = partitioning::DefaultRemover(m).Remove();
auto new_m = partitioning::RemoveDefaultAnnotations(m);
return partitioning::Partitioner(new_m).Partition();
};
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment