Unverified Commit 9c1e74ce by Zhi Committed by GitHub

[REFACTOR][BOYC] Non recursive partitioning (#5493)

* non recursive partitioning

* refactor maps

* rebase upstream

* refactor shared output

* address comments

Co-authored-by: Cody Yu <comaniac0422@gmail.com>
parent 967d7318
...@@ -54,39 +54,30 @@ namespace partitioning { ...@@ -54,39 +54,30 @@ namespace partitioning {
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
/*! /*! \brief This struct maintains the required metadata for a region to generate a corresponding
* \brief The checker that verifies if a Relay program is annotated correctly * global function and function call. Global function will be passed to the target specific codegen
* for partitioning. * and function call will be used in the transform Relay graph to invoke the function in runtime.
*/ */
class AnnotationChecker : public ExprVisitor { struct RegionFuncMetadata {
public: /*! \brief The call node of the generated global function for this region. */
bool Check() { Call func_call;
if (!found_start_ && !found_end_) {
LOG(WARNING) << "No compiler annotation found";
} else if (!found_start_) {
LOG(ERROR) << "compiler_begin annotation is missing";
return false;
} else if (!found_end_) {
LOG(ERROR) << "compiler_end annotation is missing";
return false;
}
return true;
}
void VisitExpr_(const CallNode* call) final { /*! \brief A list of argument pairs. Each pair includes (var, expr). var is used
auto op_node = call->op.as<OpNode>(); * as a function node argument; input expression is used as a function call parameter.
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) { */
return; std::vector<std::pair<Var, Expr>> args;
} else if (call->op == compiler_begin_op) {
found_start_ = true;
} else if (call->op == compiler_end_op) {
found_end_ = true;
}
}
private: /*! \brief Map from each region output expr (compiler end) node to
bool found_start_{false}; * the corresponding function output expr.
bool found_end_{false}; */
std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> region_func_out;
/*! \brief Map from each region input expression (compiler begin) to
* the corresponding function input variable. This cache is used to make sure
* a region function will not have duplicated inputs even if it refers to
* the same expr multiple times.
*/
std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> region_func_in;
}; };
/*! \brief This class partitions the expr labeled with begin and end annotations /*! \brief This class partitions the expr labeled with begin and end annotations
...@@ -124,37 +115,35 @@ class AnnotationChecker : public ExprVisitor { ...@@ -124,37 +115,35 @@ class AnnotationChecker : public ExprVisitor {
* the compiler name. * the compiler name.
*/ */
class Partitioner : public ExprMutator { class Partitioner : public MixedModeMutator {
public: public:
explicit Partitioner(const IRModule& module) : module_(module) { explicit Partitioner(const IRModule& module) : module_(module) {
for (auto f : module->functions) { for (auto f : module->functions) {
GlobalVar f_var = f.first; GlobalVar f_var = f.first;
BaseFunc f_func = f.second; BaseFunc f_func = f.second;
// Creating regionset per function in the module // Creating regionset per function in the module.
auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op, auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op,
partitioning::compiler_end_op); partitioning::compiler_end_op);
regions_sets_[region_set] = f_func; regions_sets_[region_set] = f_func;
} }
} }
Expr VisitExpr_(const CallNode* call) final { Expr Rewrite_(const CallNode* call, const Expr& post) final {
auto op_node = call->op.as<OpNode>(); auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) { if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return ExprMutator::VisitExpr_(call); return post;
} else if (call->op == compiler_begin_op) { } else if (call->op == compiler_begin_op) {
// The annotation node is inserted on edge so it must have only one // The annotation node is inserted on edge so it must have only one argument.
// argument.
CHECK_EQ(call->args.size(), 1U); CHECK_EQ(call->args.size(), 1U);
// Traverse the rest graph. // Traverse the rest graph.
Expr parent = call->args[0]; Expr parent = call->args[0];
auto input_expr = VisitExpr(parent); auto input_expr = Downcast<Call>(post)->args[0];
// Backtrace the parent to find the first ancestor node that is not a begin or end op // Backtrace the parent to find the first ancestor node that is not a begin or end op
while (const auto* parent_call = parent.as<CallNode>()) { while (const auto* parent_call = parent.as<CallNode>()) {
if (parent_call->op == compiler_begin_op || if (parent_call->op == compiler_begin_op || parent_call->op == compiler_end_op) {
parent_call->op == compiler_end_op) {
parent = parent_call->args[0]; parent = parent_call->args[0];
} else { } else {
break; break;
...@@ -165,8 +154,8 @@ class Partitioner : public ExprMutator { ...@@ -165,8 +154,8 @@ class Partitioner : public ExprMutator {
int index = GetArgIdx(sg, GetRef<Call>(call)); int index = GetArgIdx(sg, GetRef<Call>(call));
CHECK_NE(index, -1); CHECK_NE(index, -1);
if (shared_output_.count(parent) && shared_output_[parent].count(sg)) { if (region_func_meta_[sg].region_func_in.count(parent)) {
return shared_output_[parent][sg]; return region_func_meta_[sg].region_func_in[parent];
} else { } else {
// The type of the created variable is the same as the compiler_begin // The type of the created variable is the same as the compiler_begin
// node. // node.
...@@ -177,11 +166,11 @@ class Partitioner : public ExprMutator { ...@@ -177,11 +166,11 @@ class Partitioner : public ExprMutator {
std::pair<Var, Expr> cand = std::make_pair(var, input_expr); std::pair<Var, Expr> cand = std::make_pair(var, input_expr);
if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) == if (std::find(region_func_meta_[sg].args.begin(), region_func_meta_[sg].args.end(), cand) ==
region_args[sg].end()) { region_func_meta_[sg].args.end()) {
region_args[sg].push_back(cand); region_func_meta_[sg].args.push_back(cand);
} }
shared_output_[parent][sg] = var; region_func_meta_[sg].region_func_in[parent] = var;
return std::move(var); return std::move(var);
} }
} else { } else {
...@@ -197,114 +186,21 @@ class Partitioner : public ExprMutator { ...@@ -197,114 +186,21 @@ class Partitioner : public ExprMutator {
BaseFunc f = GetFunc(GetRef<Call>(call)); BaseFunc f = GetFunc(GetRef<Call>(call));
// Traverse subgraph inputs. // Traverse subgraph inputs.
auto input = VisitExpr(call->args[0]); auto input = Downcast<Call>(post)->args[0];
CHECK(region.defined()) << "Region not defined for " << GetRef<Call>(call); CHECK(region.defined()) << "Region not defined for " << GetRef<Call>(call);
// functions are created for each annotated regions, // functions are created for each annotated regions,
// when their first output is encountered. // when their first output is encountered.
// If multiple outputs are there, a tuple node is inserted at the end. // If multiple outputs are there, a tuple node is inserted at the end.
// region_function_calls is map that maintains
// (each annotated regions) --> created function
if (region_function_calls.find(region) == region_function_calls.end()) { if (!region_func_meta_[region].func_call.defined()) {
// First time this region is encountered in the traversal. // First time this region is encountered in the traversal. Creating the function.
// Creating the function.
CreateFunction(region, call); CreateFunction(region, call);
} }
// Retrieve this particular output of function.
return GetFunctionOutput(region, GetRef<Call>(call));
}
}
Expr VisitExpr_(const TupleNode* op) final {
auto region = GetRegion(GetRef<Tuple>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Array<Expr> fields;
for (auto field : op->fields) {
fields.push_back(VisitExpr(field));
}
return Tuple(fields);
}
}
Expr VisitExpr_(const TupleGetItemNode* g) final {
auto region = GetRegion(GetRef<TupleGetItem>(g));
if (!region.defined()) {
return ExprMutator::VisitExpr_(g);
} else {
auto t = VisitExpr(g->tuple);
return TupleGetItem(t, g->index);
}
}
Expr VisitExpr_(const FunctionNode* op) final {
auto region = GetRegion(GetRef<Function>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Array<Var> params;
for (auto param : op->params) {
Var new_param = Downcast<Var>(VisitExpr(param));
params.push_back(new_param);
}
auto body = VisitExpr(op->body);
return Function(params, body, op->ret_type, op->type_params, op->attrs);
}
}
Expr VisitExpr_(const LetNode* op) final {
auto region = GetRegion(GetRef<Let>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Var var = Downcast<Var>(VisitExpr(op->var));
auto value = VisitExpr(op->value);
auto body = VisitExpr(op->body);
return Let(var, value, body);
}
}
Expr VisitExpr_(const IfNode* op) final {
auto region = GetRegion(GetRef<If>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
auto guard = VisitExpr(op->cond);
auto true_b = VisitExpr(op->true_branch);
auto false_b = VisitExpr(op->false_branch);
return If(guard, true_b, false_b);
}
}
Expr VisitExpr_(const RefCreateNode* op) final {
auto region = GetRegion(GetRef<RefCreate>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Expr value = VisitExpr(op->value);
return RefCreate(value);
}
}
Expr VisitExpr_(const RefReadNode* op) final {
auto region = GetRegion(GetRef<RefRead>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Expr ref = VisitExpr(op->ref);
return RefRead(ref);
}
}
Expr VisitExpr_(const RefWriteNode* op) final { // Retrieve this particular output of function.
auto region = GetRegion(GetRef<RefWrite>(op)); Expr region_out_expr = Downcast<Call>(GetRef<Call>(call))->args[0];
if (!region.defined()) { CHECK(region_func_meta_[region].region_func_out.count(region_out_expr));
return ExprMutator::VisitExpr_(op); return region_func_meta_[region].region_func_out[region_out_expr];
} else {
Expr ref = VisitExpr(op->ref);
Expr value = VisitExpr(op->value);
return RefWrite(ref, value);
} }
} }
...@@ -370,24 +266,22 @@ class Partitioner : public ExprMutator { ...@@ -370,24 +266,22 @@ class Partitioner : public ExprMutator {
} }
/*! /*!
* \brief This function is called first time that we encounter a compiler_end * \brief Create a function and its function call for the given region. If the function has
* node to create the function for the subgraph. * multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes
* will be created to serve output consumers.
*/ */
void CreateFunction(AnnotatedRegion region, const CallNode* call) { void CreateFunction(AnnotatedRegion region, const CallNode* end_node) {
// Create fields which is a unique list of outputs. Also populate // Create fields which is a unique list of outputs.
// region_return_indices_ map which maps parent of compiler_end node to
// corresponding index in fields.
Array<Expr> fields; Array<Expr> fields;
int i = 0; std::unordered_map<Expr, int, ObjectHash, ObjectEqual> out_expr_to_idx;
for (auto ret : region->GetOutputs()) { int out_idx = 0;
auto ret_node = Downcast<Call>(ret)->args[0]; for (auto region_end_node : region->GetOutputs()) {
auto ret_node = Downcast<Call>(region_end_node)->args[0];
// Don't duplicate outputs. // Don't duplicate outputs.
if (!region_return_indices_.count(region) || if (!out_expr_to_idx.count(ret_node)) {
!region_return_indices_[region].count(ret_node)) { auto ret_expr = MixedModeMutator::VisitExpr(ret_node);
auto ret_expr = VisitExpr(ret_node);
fields.push_back(ret_expr); fields.push_back(ret_expr);
region_return_indices_[region][ret_node] = i; out_expr_to_idx[ret_node] = out_idx++;
i++;
} }
} }
...@@ -396,20 +290,14 @@ class Partitioner : public ExprMutator { ...@@ -396,20 +290,14 @@ class Partitioner : public ExprMutator {
Map<Var, Expr> params_bind; Map<Var, Expr> params_bind;
auto IsConstant = [](const Expr& expr) { auto IsConstant = [](const Expr& expr) {
if (expr->IsInstance<ConstantNode>()) if (expr->IsInstance<ConstantNode>()) return true;
return true; if (!expr->IsInstance<TupleNode>()) return false;
if (expr->IsInstance<TupleNode>()) { const auto* tn = expr.as<TupleNode>();
auto tuple = expr.as<TupleNode>(); return std::all_of(tn->fields.begin(), tn->fields.end(),
for (const auto& field : tuple->fields) { [](const Expr& e) { return e->IsInstance<ConstantNode>(); });
if (!field->IsInstance<ConstantNode>())
return false;
}
return true;
}
return false;
}; };
for (auto pair : region_args[region]) { for (auto pair : region_func_meta_[region].args) {
params.push_back(pair.first); params.push_back(pair.first);
if (IsConstant(pair.second)) { if (IsConstant(pair.second)) {
params_bind.Set(pair.first, pair.second); params_bind.Set(pair.first, pair.second);
...@@ -422,23 +310,21 @@ class Partitioner : public ExprMutator { ...@@ -422,23 +310,21 @@ class Partitioner : public ExprMutator {
if (fields.size() == 1) { if (fields.size() == 1) {
// If there are only a single output; no need to add a tuple // If there are only a single output; no need to add a tuple
global_region_func = global_region_func =
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs()); Function(params, fields[0], end_node->args[0]->checked_type_, {}, DictAttrs());
} else { } else {
auto tuple = Tuple(fields); auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
} }
std::string target = call->attrs.as<CompilerAttrs>()->compiler; std::string target = end_node->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID()); std::string name = target + "_" + std::to_string(region->GetID());
global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
runtime::String(name));
global_region_func = global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, runtime::String(name));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
tvm::runtime::String(target));
global_region_func = global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); WithAttr(std::move(global_region_func), attr::kCompiler, tvm::runtime::String(target));
global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
// Constant propagation // Constant propagation
if (!params_bind.empty()) { if (!params_bind.empty()) {
...@@ -446,8 +332,7 @@ class Partitioner : public ExprMutator { ...@@ -446,8 +332,7 @@ class Partitioner : public ExprMutator {
} }
std::string fname = name; std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname)) CHECK(!module_->ContainGlobalVar(fname)) << "Global function " << fname << " already exists";
<< "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region. // Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external // This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent // codegen to the module scope and rely on the pass manager to prevent
...@@ -456,128 +341,80 @@ class Partitioner : public ExprMutator { ...@@ -456,128 +341,80 @@ class Partitioner : public ExprMutator {
GlobalVar glob_func(fname); GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func); module_->Add(glob_func, global_region_func);
// The return type of callnode is the same as the type of the // Create a call node for the function.
// compiler_end node. auto call = Call(glob_func, param_expr);
auto ret = Call(glob_func, param_expr); region_func_meta_[region].func_call = call;
region_function_calls[region] = ret;
}
/*! // Create output expr(s) for the function call.
* \brief Get the return(output) of the function for compiler end node "end_arg". if (out_expr_to_idx.size() == 1) {
* This will return either a Call (for a function with a single output) or a // Single output direcly uses the call node as the output expr.
* TupleGetItem (for a function with multiple outputs). region_func_meta_[region].region_func_out[out_expr_to_idx.begin()->first] = call;
*/ } else {
Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) { // Multiple outptus need to create TupleGetItem nodes as output exprs.
Expr arg = Downcast<Call>(end_arg)->args[0]; for (auto pair : out_expr_to_idx) {
// Function has one output. Expr region_out_expr = pair.first; // The arg of a compiler end node of this region.
if (region_return_indices_[region].size() == 1) { int idx = pair.second; // Corresponding function output tuple index.
return region_function_calls[region]; auto tuple_get_item = TupleGetItem(call, idx);
} tuple_get_item->checked_type_ = region_out_expr->checked_type_;
// Function has multiple outputs. region_func_meta_[region].region_func_out[region_out_expr] = tuple_get_item;
// Use already made TupleGetItem. }
if (region_return_tuplegetitem_.count(region) && }
region_return_tuplegetitem_[region].count(arg)) {
return region_return_tuplegetitem_[region][arg];
}
// Create new TupleGetItem.
CHECK(region_return_indices_.count(region) &&
region_return_indices_[region].count(arg));
int index = region_return_indices_[region][arg];
auto func_call = region_function_calls[region];
auto tuple_get_item_ = TupleGetItem(func_call, index);
tuple_get_item_->checked_type_ = arg->checked_type_;
region_return_tuplegetitem_[region][arg] = tuple_get_item_;
return std::move(tuple_get_item_);
} }
/*! /*! \brief Map from each region to its metadata of the generated function. */
* \brief This map maintains the already created function calls. std::unordered_map<AnnotatedRegion, RegionFuncMetadata, ObjectHash, ObjectEqual>
* This is required in the multi-output scenario, to link rest of the outputs region_func_meta_;
* to call
*/
std::unordered_map<AnnotatedRegion, Call, ObjectHash, ObjectEqual> region_function_calls;
/*!
* \brief This map maintains arguments (of region) visits through visitor
* patterns. Those arguement var and expression will be used to when creating
* the function.
*/
std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, ObjectHash, ObjectEqual>
region_args;
/*!
* \brief This map maintains the index of an output in the subgraph function
* for a given region. If there are multiple entries for a region, then the
* function has a tuple of multiple outputs for its return.
*/
using RegionRetIndexMap = std::unordered_map<Expr, int, ObjectHash, ObjectEqual>;
std::unordered_map<AnnotatedRegion, RegionRetIndexMap, ObjectHash, ObjectEqual>
region_return_indices_;
/*!
* \brief This map holds already created TupleGetItem nodes for accessing
* outputs of a function.
*/
using RegionRetTupleGetItemMap = std::unordered_map<Expr, TupleGetItem, ObjectHash, ObjectEqual>;
std::unordered_map<AnnotatedRegion, RegionRetTupleGetItemMap, ObjectHash, ObjectEqual>
region_return_tuplegetitem_;
/*! /*! \brief Each region set is associated with a function in the module.
* \brief Each region set is associated with a function in the module.
* This map maintains the mapping between regionsets and the function it * This map maintains the mapping between regionsets and the function it
* belongs to * belongs to
*/ */
std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_; std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;
/*!\brief Cache the output that is shared by different nodes. */
using RegionOutputMap = std::unordered_map<AnnotatedRegion, Var, ObjectHash, ObjectEqual>;
std::unordered_map<Expr, RegionOutputMap, ObjectHash, ObjectEqual> shared_output_;
/*!\brief The IRModule used for partitioning. */ /*!\brief The IRModule used for partitioning. */
IRModule module_; IRModule module_;
}; };
class DefaultRemover : public ExprMutator { IRModule RemoveDefaultAnnotations(IRModule module) {
class DefaultRemover : public ExprRewriter {
public: public:
explicit DefaultRemover(const IRModule& module) : module_(module) {} DefaultRemover() = default;
IRModule Remove() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
}
return module_;
}
Expr VisitExpr_(const CallNode* call) final { Expr Rewrite_(const CallNode* call, const Expr& post) final {
auto attrs = call->attrs.as<CompilerAttrs>(); auto attrs = call->attrs.as<CompilerAttrs>();
if (attrs != nullptr && attrs->compiler == "default") { if (attrs != nullptr && attrs->compiler == "default") {
return VisitExpr(call->args[0]); return Downcast<Call>(post)->args[0];
} }
return ExprMutator::VisitExpr_(call); return post;
} }
};
private: auto glob_funcs = module->functions;
IRModule module_; // module is mutable, hence, we make a copy of it.
}; module.CopyOnWrite();
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
DefaultRemover remover;
auto removed = PostOrderRewrite(func->body, &remover);
func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs);
module->Update(pair.first, func);
}
}
return module;
}
} // namespace partitioning } // namespace partitioning
namespace transform { namespace transform {
Pass PartitionGraph() { Pass PartitionGraph() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = [=](IRModule m,
[=](IRModule m, PassContext pc) { PassContext pc) {
// TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute
// by treating them as un-annotated, but we don't have it yet. This workaround pass removes // by treating them as un-annotated, but we don't have it yet. This workaround pass removes
// all "default" annotations and should be deleted in the future. // all "default" annotations and should be deleted in the future.
auto new_m = partitioning::DefaultRemover(m).Remove(); auto new_m = partitioning::RemoveDefaultAnnotations(m);
return partitioning::Partitioner(new_m).Partition(); return partitioning::Partitioner(new_m).Partition();
}; };
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment