Unverified Commit e46aa333 by manupa-arm Committed by GitHub

[RELAY] Partition graph codestyle fixes (#5202)

* [RELAY] Codestyle fixes for Graph Partitioner
	*ran through clang-format

* *formatting comments

* *further codestyle changes (after clang-format)
parent e722301a
......@@ -29,21 +29,20 @@
* external functions, and they will use the provided compiler for codegen.
*/
#include <tvm/ir/error.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "../backend/utils.h"
#include "../analysis/annotated_region_set.h"
#include "../backend/utils.h"
namespace tvm {
namespace relay {
......@@ -73,7 +72,7 @@ class AnnotationChecker : public ExprVisitor {
return true;
}
void VisitExpr_(const CallNode *call) final {
void VisitExpr_(const CallNode* call) final {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return;
......@@ -95,31 +94,33 @@ class AnnotationChecker : public ExprVisitor {
* in the TVM stack.
*
* Input : A Relay module that have functions with disjoint annotated regions
* using compiler_begin and compiler_end. There could be multiple outputs.
* using compiler_begin and compiler_end. There could be multiple
* outputs.
*
* Output : A Relay module with global functions for such disjoint annotated regions
* with calls inserted at the respective location
* Output : A Relay module with global functions for such disjoint annotated
* regions with calls inserted at the respective location
*
* Dependencies : RegionSet Utility class.
* Dependencies : AnnotatedRegionSet Utility class.
*
* Methodology :
* 1) The RegionSet utility class is able to construct a collection of
* nodes that are bound by a given annotation -- here we use compiler_begin
* and compiler_end
* 1) The AnnotatedRegionSet utility class is able to construct a collection
* of nodes that are bound by a given annotation -- here we use
* compiler_begin and compiler_end
* 2) Initially, for each function in the module RegionSets are populated.
* 3) Then, Vistor pass is traversed until a compiler_end node is encountered
* that belongs to a "region".
* 4) When the first compiler_end of a given annotated region is found, a function is
* formed and inserted.
* a) if the region has multiple outputs, a Tuple node (capturing all outputs)
* is returned.
* 5) Thereafter, if we encounter an another output of the same annotated region,
* it is important to note that the function is already formed. Therefore, it will
* lookup the function and add a TupleGetItemNode.
* a) We will use the location index of "rets" of each "Region" of RegionSet
* as TupleGetItemNode index.
* 6) Therefore, functions will be created for all annotated regions. The name for each
* global function is created using "Region" id and the compiler name.
* 4) When the first compiler_end of a given annotated region is found,
* a function is formed and inserted.
* a) if the region has multiple outputs, a Tuple node (capturing
* all outputs) is returned.
* 5) Thereafter, if we encounter an another output of the same annotated
* region, it is important to note that the function is already formed.
* Therefore, it will lookup the function and add a TupleGetItemNode.
* a) We will use the location index of "rets" of each Region" of
* AnnotatedRegionSet as TupleGetItemNode index.
* 6) Therefore, functions will be created for all annotated regions.
* The name for each global function is created using "Region" id and
* the compiler name.
*/
class Partitioner : public ExprMutator {
......@@ -136,12 +137,13 @@ class Partitioner : public ExprMutator {
}
}
Expr VisitExpr_(const CallNode *call) final {
Expr VisitExpr_(const CallNode* call) final {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return ExprMutator::VisitExpr_(call);
} else if (call->op == compiler_begin_op) {
// The annotation node is inserted on edge so it must have only one argument.
// The annotation node is inserted on edge so it must have only one
// argument.
CHECK_EQ(call->args.size(), 1U);
// Traverse the rest graph.
......@@ -153,20 +155,21 @@ class Partitioner : public ExprMutator {
// The type of the created variable is the same as the compiler_begin
// node.
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string varname = target + "_" + std::to_string(sg->GetID())
+ "_i" + std::to_string(index);
std::string varname =
target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index);
auto var = Var(varname, GetRef<Call>(call)->checked_type_);
auto cand = std::make_pair(var, input_expr);
if (std::find(region_args[sg].begin(),
region_args[sg].end(), cand) == region_args[sg].end()) {
if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
region_args[sg].end()) {
region_args[sg].push_back(cand);
}
return std::move(var);
} else {
CHECK_EQ(call->op, compiler_end_op);
// The annotation node is inserted on edge so it must have only one argument.
// The annotation node is inserted on edge so it must have only one
// argument.
CHECK_EQ(call->args.size(), 1U);
AnnotatedRegion region = GetRegion(GetRef<Call>(call));
......@@ -185,9 +188,9 @@ class Partitioner : public ExprMutator {
// (each annotated regions) --> created function
if (region_function_calls.find(region) != region_function_calls.end()) {
// This section is executed only if there are multiple outputs in the region
// Thus, the function is always created and at the end there would be a tuple node
// Therefore, we insert a tuple get item node.
// This section is executed only if there are multiple outputs in the
// region Thus, the function is always created and at the end there
// would be a tuple node Therefore, we insert a tuple get item node.
// Use the already created tuple node
auto sg_call = region_function_calls[region];
......@@ -226,8 +229,8 @@ class Partitioner : public ExprMutator {
Function global_region_func;
if (region->GetOutputs().size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func = Function(params, fields[0],
call->args[0]->checked_type_, {}, DictAttrs());
global_region_func =
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
} else {
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
......@@ -238,12 +241,12 @@ class Partitioner : public ExprMutator {
global_region_func = WithAttr(std::move(global_region_func), attr::kExternalSymbol,
tir::StringImmNode::make(name));
global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive,
tvm::Integer(1));
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::tir::StringImmNode::make(target));
global_region_func = WithAttr(std::move(global_region_func), attr::kInline,
tvm::Integer(1));
global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
// Constant propagation
if (!params_bind.empty()) {
......@@ -255,8 +258,9 @@ class Partitioner : public ExprMutator {
<< "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent relay
// function level passes (i.e. simplify inference and fusion) optimizing it.
// codegen to the module scope and rely on the pass manager to prevent
// relay function level passes (i.e. simplify inference and fusion)
// optimizing it.
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);
......@@ -266,7 +270,8 @@ class Partitioner : public ExprMutator {
region_function_calls[region] = ret;
if (region->GetOutputs().size() == 1) {
// If there is only a single output; no need to add a tuplegetitem node
// If there is only a single output; no need to add a tuplegetitem
// node
return std::move(ret);
} else {
// Add a tuplegetitem node to select this output out of many
......@@ -278,7 +283,7 @@ class Partitioner : public ExprMutator {
}
}
Expr VisitExpr_(const TupleNode *op) final {
Expr VisitExpr_(const TupleNode* op) final {
auto region = GetRegion(GetRef<Tuple>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
......@@ -291,7 +296,7 @@ class Partitioner : public ExprMutator {
}
}
Expr VisitExpr_(const TupleGetItemNode *g) final {
Expr VisitExpr_(const TupleGetItemNode* g) final {
auto region = GetRegion(GetRef<TupleGetItem>(g));
if (!region.defined()) {
return ExprMutator::VisitExpr_(g);
......@@ -301,7 +306,7 @@ class Partitioner : public ExprMutator {
}
}
Expr VisitExpr_(const FunctionNode *op) final {
Expr VisitExpr_(const FunctionNode* op) final {
auto region = GetRegion(GetRef<Function>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
......@@ -316,7 +321,7 @@ class Partitioner : public ExprMutator {
}
}
Expr VisitExpr_(const LetNode *op) final {
Expr VisitExpr_(const LetNode* op) final {
auto region = GetRegion(GetRef<Let>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
......@@ -328,7 +333,7 @@ class Partitioner : public ExprMutator {
}
}
Expr VisitExpr_(const IfNode *op) final {
Expr VisitExpr_(const IfNode* op) final {
auto region = GetRegion(GetRef<If>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
......@@ -340,7 +345,7 @@ class Partitioner : public ExprMutator {
}
}
Expr VisitExpr_(const RefCreateNode *op) final {
Expr VisitExpr_(const RefCreateNode* op) final {
auto region = GetRegion(GetRef<RefCreate>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
......@@ -350,7 +355,7 @@ class Partitioner : public ExprMutator {
}
}
Expr VisitExpr_(const RefReadNode *op) final {
Expr VisitExpr_(const RefReadNode* op) final {
auto region = GetRegion(GetRef<RefRead>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
......@@ -360,7 +365,7 @@ class Partitioner : public ExprMutator {
}
}
Expr VisitExpr_(const RefWriteNode *op) final {
Expr VisitExpr_(const RefWriteNode* op) final {
auto region = GetRegion(GetRef<RefWrite>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
......@@ -374,12 +379,9 @@ class Partitioner : public ExprMutator {
IRModule Partition() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
if (auto *fn = pair.second.as<FunctionNode>()) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = Function(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
......@@ -452,41 +454,40 @@ class Partitioner : public ExprMutator {
/*!
* \brief This map maintains the already created function calls.
* This is required in the multi-output scenario, to link rest of the outputs to call
* This is required in the multi-output scenario, to link rest of the outputs
* to call
*/
std::unordered_map<AnnotatedRegion, Call, ObjectHash, ObjectEqual> region_function_calls;
/*!
* \brief This map maintains arguments (of region) visits through visitor patterns.
* Those arguement var and expression will be used to when creating the function.
* \brief This map maintains arguments (of region) visits through visitor
* patterns. Those arguement var and expression will be used to when creating
* the function.
*/
std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>,
ObjectHash, ObjectEqual> region_args;
std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, ObjectHash, ObjectEqual>
region_args;
/*!
* \brief Each region set is associated with a function in the module.
* This map maintains the mapping between regionsets and the function it belongs to
* This map maintains the mapping between regionsets and the function it
* belongs to
*/
std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;
IRModule module_;
};
} // namespace partitioning
namespace transform {
Pass PartitionGraph() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m, PassContext pc) {
return partitioning::Partitioner(m).Partition();
};
[=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); };
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()});
}
TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph")
.set_body_typed(transform::PartitionGraph);
TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph").set_body_typed(transform::PartitionGraph);
} // namespace transform
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment