Unverified Commit e46aa333 by manupa-arm Committed by GitHub

[RELAY] Partition graph codestyle fixes (#5202)

* [RELAY] Codestyle fixes for Graph Partitioner
	*ran through clang-format

* *formatting comments

* *further codestyle changes (after clang-format)
parent e722301a
...@@ -29,21 +29,20 @@ ...@@ -29,21 +29,20 @@
* external functions, and they will use the provided compiler for codegen. * external functions, and they will use the provided compiler for codegen.
*/ */
#include <tvm/ir/error.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h> #include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <utility>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "../backend/utils.h"
#include "../analysis/annotated_region_set.h" #include "../analysis/annotated_region_set.h"
#include "../backend/utils.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -73,7 +72,7 @@ class AnnotationChecker : public ExprVisitor { ...@@ -73,7 +72,7 @@ class AnnotationChecker : public ExprVisitor {
return true; return true;
} }
void VisitExpr_(const CallNode *call) final { void VisitExpr_(const CallNode* call) final {
auto op_node = call->op.as<OpNode>(); auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) { if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return; return;
...@@ -95,31 +94,33 @@ class AnnotationChecker : public ExprVisitor { ...@@ -95,31 +94,33 @@ class AnnotationChecker : public ExprVisitor {
* in the TVM stack. * in the TVM stack.
* *
* Input : A Relay module that have functions with disjoint annotated regions * Input : A Relay module that have functions with disjoint annotated regions
* using compiler_begin and compiler_end. There could be multiple outputs. * using compiler_begin and compiler_end. There could be multiple
* outputs.
* *
* Output : A Relay module with global functions for such disjoint annotated regions * Output : A Relay module with global functions for such disjoint annotated
* with calls inserted at the respective location * regions with calls inserted at the respective location
* *
* Dependencies : RegionSet Utility class. * Dependencies : AnnotatedRegionSet Utility class.
* *
* Methodology : * Methodology :
* 1) The RegionSet utility class is able to construct a collection of * 1) The AnnotatedRegionSet utility class is able to construct a collection
* nodes that are bound by a given annotation -- here we use compiler_begin * of nodes that are bound by a given annotation -- here we use
* and compiler_end * compiler_begin and compiler_end
* 2) Initially, for each function in the module RegionSets are populated. * 2) Initially, for each function in the module RegionSets are populated.
* 3) Then, Vistor pass is traversed until a compiler_end node is encountered * 3) Then, Vistor pass is traversed until a compiler_end node is encountered
* that belongs to a "region". * that belongs to a "region".
* 4) When the first compiler_end of a given annotated region is found, a function is * 4) When the first compiler_end of a given annotated region is found,
* formed and inserted. * a function is formed and inserted.
* a) if the region has multiple outputs, a Tuple node (capturing all outputs) * a) if the region has multiple outputs, a Tuple node (capturing
* is returned. * all outputs) is returned.
* 5) Thereafter, if we encounter an another output of the same annotated region, * 5) Thereafter, if we encounter an another output of the same annotated
* it is important to note that the function is already formed. Therefore, it will * region, it is important to note that the function is already formed.
* lookup the function and add a TupleGetItemNode. * Therefore, it will lookup the function and add a TupleGetItemNode.
* a) We will use the location index of "rets" of each "Region" of RegionSet * a) We will use the location index of "rets" of each Region" of
* as TupleGetItemNode index. * AnnotatedRegionSet as TupleGetItemNode index.
* 6) Therefore, functions will be created for all annotated regions. The name for each * 6) Therefore, functions will be created for all annotated regions.
* global function is created using "Region" id and the compiler name. * The name for each global function is created using "Region" id and
* the compiler name.
*/ */
class Partitioner : public ExprMutator { class Partitioner : public ExprMutator {
...@@ -136,12 +137,13 @@ class Partitioner : public ExprMutator { ...@@ -136,12 +137,13 @@ class Partitioner : public ExprMutator {
} }
} }
Expr VisitExpr_(const CallNode *call) final { Expr VisitExpr_(const CallNode* call) final {
auto op_node = call->op.as<OpNode>(); auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) { if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return ExprMutator::VisitExpr_(call); return ExprMutator::VisitExpr_(call);
} else if (call->op == compiler_begin_op) { } else if (call->op == compiler_begin_op) {
// The annotation node is inserted on edge so it must have only one argument. // The annotation node is inserted on edge so it must have only one
// argument.
CHECK_EQ(call->args.size(), 1U); CHECK_EQ(call->args.size(), 1U);
// Traverse the rest graph. // Traverse the rest graph.
...@@ -153,20 +155,21 @@ class Partitioner : public ExprMutator { ...@@ -153,20 +155,21 @@ class Partitioner : public ExprMutator {
// The type of the created variable is the same as the compiler_begin // The type of the created variable is the same as the compiler_begin
// node. // node.
std::string target = call->attrs.as<CompilerAttrs>()->compiler; std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string varname = target + "_" + std::to_string(sg->GetID()) std::string varname =
+ "_i" + std::to_string(index); target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index);
auto var = Var(varname, GetRef<Call>(call)->checked_type_); auto var = Var(varname, GetRef<Call>(call)->checked_type_);
auto cand = std::make_pair(var, input_expr); auto cand = std::make_pair(var, input_expr);
if (std::find(region_args[sg].begin(), if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
region_args[sg].end(), cand) == region_args[sg].end()) { region_args[sg].end()) {
region_args[sg].push_back(cand); region_args[sg].push_back(cand);
} }
return std::move(var); return std::move(var);
} else { } else {
CHECK_EQ(call->op, compiler_end_op); CHECK_EQ(call->op, compiler_end_op);
// The annotation node is inserted on edge so it must have only one argument. // The annotation node is inserted on edge so it must have only one
// argument.
CHECK_EQ(call->args.size(), 1U); CHECK_EQ(call->args.size(), 1U);
AnnotatedRegion region = GetRegion(GetRef<Call>(call)); AnnotatedRegion region = GetRegion(GetRef<Call>(call));
...@@ -185,9 +188,9 @@ class Partitioner : public ExprMutator { ...@@ -185,9 +188,9 @@ class Partitioner : public ExprMutator {
// (each annotated regions) --> created function // (each annotated regions) --> created function
if (region_function_calls.find(region) != region_function_calls.end()) { if (region_function_calls.find(region) != region_function_calls.end()) {
// This section is executed only if there are multiple outputs in the region // This section is executed only if there are multiple outputs in the
// Thus, the function is always created and at the end there would be a tuple node // region Thus, the function is always created and at the end there
// Therefore, we insert a tuple get item node. // would be a tuple node Therefore, we insert a tuple get item node.
// Use the already created tuple node // Use the already created tuple node
auto sg_call = region_function_calls[region]; auto sg_call = region_function_calls[region];
...@@ -226,8 +229,8 @@ class Partitioner : public ExprMutator { ...@@ -226,8 +229,8 @@ class Partitioner : public ExprMutator {
Function global_region_func; Function global_region_func;
if (region->GetOutputs().size() == 1) { if (region->GetOutputs().size() == 1) {
// If there are only a single output; no need to add a tuple // If there are only a single output; no need to add a tuple
global_region_func = Function(params, fields[0], global_region_func =
call->args[0]->checked_type_, {}, DictAttrs()); Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
} else { } else {
auto tuple = Tuple(fields); auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
...@@ -238,12 +241,12 @@ class Partitioner : public ExprMutator { ...@@ -238,12 +241,12 @@ class Partitioner : public ExprMutator {
global_region_func = WithAttr(std::move(global_region_func), attr::kExternalSymbol, global_region_func = WithAttr(std::move(global_region_func), attr::kExternalSymbol,
tir::StringImmNode::make(name)); tir::StringImmNode::make(name));
global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, global_region_func =
tvm::Integer(1)); WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::tir::StringImmNode::make(target)); tvm::tir::StringImmNode::make(target));
global_region_func = WithAttr(std::move(global_region_func), attr::kInline, global_region_func =
tvm::Integer(1)); WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
// Constant propagation // Constant propagation
if (!params_bind.empty()) { if (!params_bind.empty()) {
...@@ -255,8 +258,9 @@ class Partitioner : public ExprMutator { ...@@ -255,8 +258,9 @@ class Partitioner : public ExprMutator {
<< "Global function " << fname << " already exists"; << "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region. // Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external // This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent relay // codegen to the module scope and rely on the pass manager to prevent
// function level passes (i.e. simplify inference and fusion) optimizing it. // relay function level passes (i.e. simplify inference and fusion)
// optimizing it.
GlobalVar glob_func(fname); GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func); module_->Add(glob_func, global_region_func);
...@@ -266,7 +270,8 @@ class Partitioner : public ExprMutator { ...@@ -266,7 +270,8 @@ class Partitioner : public ExprMutator {
region_function_calls[region] = ret; region_function_calls[region] = ret;
if (region->GetOutputs().size() == 1) { if (region->GetOutputs().size() == 1) {
// If there is only a single output; no need to add a tuplegetitem node // If there is only a single output; no need to add a tuplegetitem
// node
return std::move(ret); return std::move(ret);
} else { } else {
// Add a tuplegetitem node to select this output out of many // Add a tuplegetitem node to select this output out of many
...@@ -278,7 +283,7 @@ class Partitioner : public ExprMutator { ...@@ -278,7 +283,7 @@ class Partitioner : public ExprMutator {
} }
} }
Expr VisitExpr_(const TupleNode *op) final { Expr VisitExpr_(const TupleNode* op) final {
auto region = GetRegion(GetRef<Tuple>(op)); auto region = GetRegion(GetRef<Tuple>(op));
if (!region.defined()) { if (!region.defined()) {
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
...@@ -291,7 +296,7 @@ class Partitioner : public ExprMutator { ...@@ -291,7 +296,7 @@ class Partitioner : public ExprMutator {
} }
} }
Expr VisitExpr_(const TupleGetItemNode *g) final { Expr VisitExpr_(const TupleGetItemNode* g) final {
auto region = GetRegion(GetRef<TupleGetItem>(g)); auto region = GetRegion(GetRef<TupleGetItem>(g));
if (!region.defined()) { if (!region.defined()) {
return ExprMutator::VisitExpr_(g); return ExprMutator::VisitExpr_(g);
...@@ -301,7 +306,7 @@ class Partitioner : public ExprMutator { ...@@ -301,7 +306,7 @@ class Partitioner : public ExprMutator {
} }
} }
Expr VisitExpr_(const FunctionNode *op) final { Expr VisitExpr_(const FunctionNode* op) final {
auto region = GetRegion(GetRef<Function>(op)); auto region = GetRegion(GetRef<Function>(op));
if (!region.defined()) { if (!region.defined()) {
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
...@@ -316,7 +321,7 @@ class Partitioner : public ExprMutator { ...@@ -316,7 +321,7 @@ class Partitioner : public ExprMutator {
} }
} }
Expr VisitExpr_(const LetNode *op) final { Expr VisitExpr_(const LetNode* op) final {
auto region = GetRegion(GetRef<Let>(op)); auto region = GetRegion(GetRef<Let>(op));
if (!region.defined()) { if (!region.defined()) {
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
...@@ -328,7 +333,7 @@ class Partitioner : public ExprMutator { ...@@ -328,7 +333,7 @@ class Partitioner : public ExprMutator {
} }
} }
Expr VisitExpr_(const IfNode *op) final { Expr VisitExpr_(const IfNode* op) final {
auto region = GetRegion(GetRef<If>(op)); auto region = GetRegion(GetRef<If>(op));
if (!region.defined()) { if (!region.defined()) {
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
...@@ -340,7 +345,7 @@ class Partitioner : public ExprMutator { ...@@ -340,7 +345,7 @@ class Partitioner : public ExprMutator {
} }
} }
Expr VisitExpr_(const RefCreateNode *op) final { Expr VisitExpr_(const RefCreateNode* op) final {
auto region = GetRegion(GetRef<RefCreate>(op)); auto region = GetRegion(GetRef<RefCreate>(op));
if (!region.defined()) { if (!region.defined()) {
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
...@@ -350,7 +355,7 @@ class Partitioner : public ExprMutator { ...@@ -350,7 +355,7 @@ class Partitioner : public ExprMutator {
} }
} }
Expr VisitExpr_(const RefReadNode *op) final { Expr VisitExpr_(const RefReadNode* op) final {
auto region = GetRegion(GetRef<RefRead>(op)); auto region = GetRegion(GetRef<RefRead>(op));
if (!region.defined()) { if (!region.defined()) {
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
...@@ -360,7 +365,7 @@ class Partitioner : public ExprMutator { ...@@ -360,7 +365,7 @@ class Partitioner : public ExprMutator {
} }
} }
Expr VisitExpr_(const RefWriteNode *op) final { Expr VisitExpr_(const RefWriteNode* op) final {
auto region = GetRegion(GetRef<RefWrite>(op)); auto region = GetRegion(GetRef<RefWrite>(op));
if (!region.defined()) { if (!region.defined()) {
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
...@@ -374,12 +379,9 @@ class Partitioner : public ExprMutator { ...@@ -374,12 +379,9 @@ class Partitioner : public ExprMutator {
IRModule Partition() { IRModule Partition() {
auto glob_funcs = module_->functions; auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) { for (const auto& pair : glob_funcs) {
if (auto *fn = pair.second.as<FunctionNode>()) { if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn); auto func = GetRef<Function>(fn);
func = Function(func->params, func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs); func->attrs);
module_->Update(pair.first, func); module_->Update(pair.first, func);
} }
...@@ -452,41 +454,40 @@ class Partitioner : public ExprMutator { ...@@ -452,41 +454,40 @@ class Partitioner : public ExprMutator {
/*! /*!
* \brief This map maintains the already created function calls. * \brief This map maintains the already created function calls.
* This is required in the multi-output scenario, to link rest of the outputs to call * This is required in the multi-output scenario, to link rest of the outputs
* to call
*/ */
std::unordered_map<AnnotatedRegion, Call, ObjectHash, ObjectEqual> region_function_calls; std::unordered_map<AnnotatedRegion, Call, ObjectHash, ObjectEqual> region_function_calls;
/*! /*!
* \brief This map maintains arguments (of region) visits through visitor patterns. * \brief This map maintains arguments (of region) visits through visitor
* Those arguement var and expression will be used to when creating the function. * patterns. Those arguement var and expression will be used to when creating
* the function.
*/ */
std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, ObjectHash, ObjectEqual>
ObjectHash, ObjectEqual> region_args; region_args;
/*! /*!
* \brief Each region set is associated with a function in the module. * \brief Each region set is associated with a function in the module.
* This map maintains the mapping between regionsets and the function it belongs to * This map maintains the mapping between regionsets and the function it
* belongs to
*/ */
std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_; std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;
IRModule module_; IRModule module_;
}; };
} // namespace partitioning } // namespace partitioning
namespace transform { namespace transform {
Pass PartitionGraph() { Pass PartitionGraph() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m, PassContext pc) { [=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); };
return partitioning::Partitioner(m).Partition();
};
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()}); return Sequential({partitioned, InferType()});
} }
TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph") TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph").set_body_typed(transform::PartitionGraph);
.set_body_typed(transform::PartitionGraph);
} // namespace transform } // namespace transform
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment