Unverified Commit 14ae3a6e by manupa-arm Committed by GitHub

[RELAY] Re-wrote the Graph Partitioner to support multiple outputs (#5143)

* [RELAY] Re-wrote the Graph Partitioner to support multiple outputs

Input : A Relay module that have functions with disjoint annotated regions
        using compiler_begin and compiler_end. There could be multiple outputs.

Output : A Relay module with global functions for such disjoint annotated regions
         with calls inserted at the respective location

Dependencies : AnnotatedRegionSet Utility class.

Methodology :
      1) The AnnotatedRegionSet utility class is able to construct a collection of
         nodes that are bound by a give annotation -- here we use compiler_begin
         and compiler_end
      2) Initially, for each function in the module AnnotatedRegionSets are populated.
      3) Then, Vistor pass is traversed until a compiler_end node is encountered
         that belongs to a "region".
      4) When the first compiler_end of a given annotated region is found, a function is
         formed and inserted.
         a) if the region has multiple outputs, a Tuple node (capturing all outputs)
            is returned.
      5) Thereafter, if we encounter an another output of the same annotated region,
         it is important to note that the function is already formed. Therefore, it will
         lookup the function and add a TupleGetItemNode is inserted.
          a) We will use the location index of "rets" of each "Region" of AnnotatedRegionSet
             as TupleGetItemNode index.
      6) Therefore, functions will be created for all annotated regions. The name for each
         global function is created using "Region" id and the compiler name.

Change-Id: I1372f02a845b6d3da03b561763e03a378dca263c

* [RELAY] Re-wrote the Graph Partitioner to support multiple outputs

    *removed the expected use-case as we are taking broken-down PR approach
    *code style fixes
    *some trivial one liners

* [RELAY] Re-wrote the Graph Partitioner to support multiple outputs

    *fixed an implicit copy to a move

* [RELAY] Re-wrote the Graph Partitioner to support multiple outputs

    *code style changes for comments
    *renamed test case multiple outputs --> mixed single multiple outputs
        Since the existing test case checks for both single and multiple
        output scenarios
    *added a new test case with conv2d + batch_norm
    *some var name changes in the test

* [RELAY] Re-wrote the Graph Partitioner to support multiple outputs

	*rebased
parent 9cb9a51f
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
* These nodes are used as boundaries to partition the Relay function into * These nodes are used as boundaries to partition the Relay function into
* multiple regions that can be offloaded to different accelerators/backends. * multiple regions that can be offloaded to different accelerators/backends.
* *
* Each of these paritioned functions, a.k.a subgraphs, will be viewed as * Each of these paritioned functions, a.k.a regions, will be viewed as
* external functions, and they will use the provided compiler for codegen. * external functions, and they will use the provided compiler for codegen.
*/ */
...@@ -36,13 +36,14 @@ ...@@ -36,13 +36,14 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <string> #include <utility>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "../backend/utils.h" #include "../backend/utils.h"
#include "../analysis/annotated_region_set.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -54,20 +55,6 @@ static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); ...@@ -54,20 +55,6 @@ static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
/*! /*!
* \brief The subgraph properties for partitioning.
*/
struct Subgraph {
/*! \brief The subgraph ID. */
int id;
/*! \brief The input arguments of this subgraph. */
std::vector<std::pair<Var, Expr>> args;
/*! \brief Nodes in this subgraph. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
};
/*!
* \brief The checker that verifies if a Relay program is annotated correctly * \brief The checker that verifies if a Relay program is annotated correctly
* for partitioning. * for partitioning.
*/ */
...@@ -86,7 +73,7 @@ class AnnotationChecker : public ExprVisitor { ...@@ -86,7 +73,7 @@ class AnnotationChecker : public ExprVisitor {
return true; return true;
} }
void VisitExpr_(const CallNode* call) final { void VisitExpr_(const CallNode *call) final {
auto op_node = call->op.as<OpNode>(); auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) { if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return; return;
...@@ -102,61 +89,56 @@ class AnnotationChecker : public ExprVisitor { ...@@ -102,61 +89,56 @@ class AnnotationChecker : public ExprVisitor {
bool found_end_{false}; bool found_end_{false};
}; };
/*! \brief This class partitions the expr labeled with begin and end annoations /*! \brief This class partitions the expr labeled with begin and end annotations
* into function containing multiple regions. Each region is labeled with * into function containing multiple regions. Each region is labeled with
* a compiler attribute so that it will be handled by any compilers that are not * a compiler attribute so that it will be handled by any compilers that are not
* in the TVM stack. * in the TVM stack.
* *
* TODO(@zhiics) This following algorithm is not adequate to handle all cases, * Input : A Relay module that have functions with disjoint annotated regions
* i.e. multiple `compiler_end` nodes. * using compiler_begin and compiler_end. There could be multiple outputs.
*
* Output : A Relay module with global functions for such disjoint annotated regions
* with calls inserted at the respective location
*
* Dependencies : RegionSet Utility class.
*
* Methodology :
* 1) The RegionSet utility class is able to construct a collection of
* nodes that are bound by a given annotation -- here we use compiler_begin
* and compiler_end
* 2) Initially, for each function in the module RegionSets are populated.
* 3) Then, Vistor pass is traversed until a compiler_end node is encountered
* that belongs to a "region".
* 4) When the first compiler_end of a given annotated region is found, a function is
* formed and inserted.
* a) if the region has multiple outputs, a Tuple node (capturing all outputs)
* is returned.
* 5) Thereafter, if we encounter an another output of the same annotated region,
* it is important to note that the function is already formed. Therefore, it will
* lookup the function and add a TupleGetItemNode.
* a) We will use the location index of "rets" of each "Region" of RegionSet
* as TupleGetItemNode index.
* 6) Therefore, functions will be created for all annotated regions. The name for each
* global function is created using "Region" id and the compiler name.
*/ */
class Partitioner : public ExprMutator { class Partitioner : public ExprMutator {
public: public:
explicit Partitioner(const IRModule& module) : module_(module) {} explicit Partitioner(const IRModule& module) : module_(module) {
for (auto f : module->functions) {
std::shared_ptr<Subgraph> GetSubgraph(const Expr node) { GlobalVar f_var = f.first;
for (auto candidate : this->subgraphs_) { BaseFunc f_func = f.second;
if (candidate->nodes.find(node) != candidate->nodes.end()) {
return candidate; // Creating regionset per function in the module
} auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op,
} partitioning::compiler_end_op);
return nullptr; regions_sets_[region_set] = f_func;
}
void MergeSubgraph(std::shared_ptr<Subgraph> subgraph1,
std::shared_ptr<Subgraph> subgraph2) {
if (subgraph1 == subgraph2) {
return;
}
// Merge subgraph 2 to subgraph 1 and erase subgraph 2.
subgraph1->nodes.insert(subgraph2->nodes.begin(), subgraph2->nodes.end());
for (auto arg : subgraph2->args) {
subgraph1->args.push_back(arg);
}
this->subgraphs_.erase(subgraph2);
}
void AddToSubgraph(std::shared_ptr<Subgraph> subgraph, const Expr expr) {
auto subgraph2 = GetSubgraph(expr);
if (subgraph2) {
MergeSubgraph(subgraph, subgraph2);
} else {
subgraph->nodes.insert(expr);
} }
} }
Expr VisitExpr_(const CallNode* call) final { Expr VisitExpr_(const CallNode *call) final {
auto op_node = call->op.as<OpNode>(); auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) { if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
// Propogate subgraph to arguments
auto subgraph = GetSubgraph(GetRef<Call>(call));
if (subgraph) {
for (auto arg : call->args) {
AddToSubgraph(subgraph, arg);
}
}
return ExprMutator::VisitExpr_(call); return ExprMutator::VisitExpr_(call);
} else if (call->op == compiler_begin_op) { } else if (call->op == compiler_begin_op) {
// The annotation node is inserted on edge so it must have only one argument. // The annotation node is inserted on edge so it must have only one argument.
...@@ -165,101 +147,142 @@ class Partitioner : public ExprMutator { ...@@ -165,101 +147,142 @@ class Partitioner : public ExprMutator {
// Traverse the rest graph. // Traverse the rest graph.
auto input_expr = VisitExpr(call->args[0]); auto input_expr = VisitExpr(call->args[0]);
// Replace the begin annotation with an external call input variable. AnnotatedRegion sg = GetRegion(GetRef<Call>(call));
auto compiler_attrs = call->attrs.as<CompilerAttrs>(); int index = GetArgIdx(sg, GetRef<Call>(call));
CHECK_NE(index, -1);
// The type of the created variable is the same as the compiler_begin // The type of the created variable is the same as the compiler_begin
// node. // node.
auto var = Var(compiler_attrs->compiler + "_input" + std::to_string(var_id_++), std::string target = call->attrs.as<CompilerAttrs>()->compiler;
call->checked_type_); std::string varname = target + "_" + std::to_string(sg->GetID())
+ "_i" + std::to_string(index);
// Find the corresponding subgraph and add the argument. auto var = Var(varname, GetRef<Call>(call)->checked_type_);
auto subgraph = GetSubgraph(GetRef<Call>(call));
if (!subgraph) { auto cand = std::make_pair(var, input_expr);
throw Error(ErrorBuilder() if (std::find(region_args[sg].begin(),
<< "Cannot find the corresponding subgraph for start annotation:\n" region_args[sg].end(), cand) == region_args[sg].end()) {
<< AsText(GetRef<Call>(call), false)); region_args[sg].push_back(cand);
} }
subgraph->args.push_back({var, input_expr});
return std::move(var); return std::move(var);
} else { } else {
CHECK_EQ(call->op, compiler_end_op); CHECK_EQ(call->op, compiler_end_op);
// The annotation node is inserted on edge so it must have only one argument. // The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U); CHECK_EQ(call->args.size(), 1U);
auto compiler_attrs = call->attrs.as<CompilerAttrs>(); AnnotatedRegion region = GetRegion(GetRef<Call>(call));
// Check if the argument already belongs to an existing subgraph // TODO(@manupa-arm) : need to use the parent function (to which region
auto subgraph = GetSubgraph(call->args[0]); // belongs to) name/key for the funtions that are created
if (!subgraph) { BaseFunc f = GetFunc(GetRef<Call>(call));
auto ret = this->subgraphs_.emplace(std::make_shared<Subgraph>());
subgraph = *ret.first;
subgraph->nodes.insert(call->args[0]);
subgraph->id = this->subgraph_id_++;
}
subgraph->nodes.insert(GetRef<Call>(call));
// Traverse subgraph inputs. // Traverse subgraph inputs.
auto input = VisitExpr(call->args[0]); auto input = VisitExpr(call->args[0]);
Array<Var> params; CHECK(region.defined()) << "Region not defined for " << GetRef<Call>(call);
Array<Expr> args; // functions are created for each annotated regions,
std::unordered_map<std::string, runtime::NDArray> params_bind; // when their first output is encountered.
// If multiple outputs are there, a tuple node is inserted at the end.
// The subgraph may be merged so we need to update it again. // region_function_calls is map that maintains
subgraph = GetSubgraph(GetRef<Call>(call)); // (each annotated regions) --> created function
CHECK(subgraph);
if (region_function_calls.find(region) != region_function_calls.end()) {
// Record the constants for propagation. // This section is executed only if there are multiple outputs in the region
for (auto pair : subgraph->args) { // Thus, the function is always created and at the end there would be a tuple node
params.push_back(pair.first); // Therefore, we insert a tuple get item node.
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data; // Use the already created tuple node
auto sg_call = region_function_calls[region];
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);
auto tuple_get_item_ = TupleGetItem(sg_call, index);
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
return std::move(tuple_get_item_);
} else {
// First time this region is encountered in the traversal
// Creating the function
Array<Expr> fields;
for (auto ret : region->GetOutputs()) {
auto ret_expr = VisitExpr(Downcast<Call>(ret)->args[0]);
fields.push_back(ret_expr);
}
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);
Array<Var> params;
Array<Expr> param_expr;
std::unordered_map<std::string, runtime::NDArray> params_bind;
for (auto pair : region_args[region]) {
params.push_back(pair.first);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
} else {
param_expr.push_back(pair.second);
}
}
Function global_region_func;
if (region->GetOutputs().size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func = Function(params, fields[0],
call->args[0]->checked_type_, {}, DictAttrs());
} else { } else {
args.push_back(pair.second); auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
}
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());
global_region_func = WithAttr(std::move(global_region_func), attr::kExternalSymbol,
tir::StringImmNode::make(name));
global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive,
tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::tir::StringImmNode::make(target));
global_region_func = WithAttr(std::move(global_region_func), attr::kInline,
tvm::Integer(1));
// Constant propagation
if (!params_bind.empty()) {
global_region_func = backend::BindParamsByName(global_region_func, params_bind);
} }
}
auto subgraph_func = std::string fname = name;
Function(params, input, call->checked_type_, {}); CHECK(!module_->ContainGlobalVar(fname))
<< "Global function " << fname << " already exists";
std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id); // Create a global function and add it to the IRModule for the region.
subgraph_func = // This way we lift the functions that should be handled by external
WithAttr(std::move(subgraph_func), attr::kExternalSymbol, tir::StringImmNode::make(name)); // codegen to the module scope and rely on the pass manager to prevent relay
subgraph_func = // function level passes (i.e. simplify inference and fusion) optimizing it.
WithAttr(std::move(subgraph_func), attr::kPrimitive, tvm::Integer(1)); GlobalVar glob_func(fname);
subgraph_func = module_->Add(glob_func, global_region_func);
WithAttr(std::move(subgraph_func), attr::kCompiler,
tvm::tir::StringImmNode::make(compiler_attrs->compiler)); // The return type of callnode is the same as the type of the
subgraph_func = // compiler_end node.
WithAttr(std::move(subgraph_func), attr::kInline, tvm::Integer(1)); auto ret = Call(glob_func, param_expr);
region_function_calls[region] = ret;
// Constant propagation
if (!params_bind.empty()) { if (region->GetOutputs().size() == 1) {
subgraph_func = backend::BindParamsByName(subgraph_func, params_bind); // If there is only a single output; no need to add a tuplegetitem node
return std::move(ret);
} else {
// Add a tuplegetitem node to select this output out of many
auto tuple_get_item_ = TupleGetItem(ret, index);
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
return std::move(tuple_get_item_);
}
} }
CHECK(!module_->ContainGlobalVar(name))
<< "Global function " << name << " already exists";
// Create a global function and add it to the IRModule for the subgraph.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent relay
// function level passes (i.e. simplify inference and fusion) optimizing it.
GlobalVar glob_func(name);
module_->Add(glob_func, subgraph_func);
// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = Call(glob_func, args);
ret->checked_type_ = call->checked_type_;
return std::move(ret);
} }
} }
Expr VisitExpr_(const TupleNode* op) final { Expr VisitExpr_(const TupleNode *op) final {
auto subgraph = GetSubgraph(GetRef<Tuple>(op)); auto region = GetRegion(GetRef<Tuple>(op));
if (!subgraph) { if (!region.defined()) {
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
} else { } else {
for (auto field : op->fields) {
AddToSubgraph(subgraph, field);
}
Array<Expr> fields; Array<Expr> fields;
for (auto field : op->fields) { for (auto field : op->fields) {
fields.push_back(VisitExpr(field)); fields.push_back(VisitExpr(field));
...@@ -268,27 +291,23 @@ class Partitioner : public ExprMutator { ...@@ -268,27 +291,23 @@ class Partitioner : public ExprMutator {
} }
} }
Expr VisitExpr_(const TupleGetItemNode* g) final { Expr VisitExpr_(const TupleGetItemNode *g) final {
auto subgraph = GetSubgraph(GetRef<TupleGetItem>(g)); auto region = GetRegion(GetRef<TupleGetItem>(g));
if (!subgraph) { if (!region.defined()) {
return ExprMutator::VisitExpr_(g); return ExprMutator::VisitExpr_(g);
} else { } else {
AddToSubgraph(subgraph, g->tuple);
auto t = VisitExpr(g->tuple); auto t = VisitExpr(g->tuple);
return TupleGetItem(t, g->index); return TupleGetItem(t, g->index);
} }
} }
Expr VisitExpr_(const FunctionNode* op) final { Expr VisitExpr_(const FunctionNode *op) final {
auto subgraph = GetSubgraph(GetRef<Function>(op)); auto region = GetRegion(GetRef<Function>(op));
if (!subgraph) { if (!region.defined()) {
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
} else { } else {
Array<Var> params; Array<Var> params;
for (auto param : op->params) { for (auto param : op->params) {
AddToSubgraph(subgraph, param);
}
for (auto param : op->params) {
Var new_param = Downcast<Var>(VisitExpr(param)); Var new_param = Downcast<Var>(VisitExpr(param));
params.push_back(new_param); params.push_back(new_param);
} }
...@@ -297,30 +316,23 @@ class Partitioner : public ExprMutator { ...@@ -297,30 +316,23 @@ class Partitioner : public ExprMutator {
} }
} }
Expr VisitExpr_(const LetNode* op) final { Expr VisitExpr_(const LetNode *op) final {
auto subgraph = GetSubgraph(GetRef<Let>(op)); auto region = GetRegion(GetRef<Let>(op));
if (!subgraph) { if (!region.defined()) {
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
} else { } else {
AddToSubgraph(subgraph, op->var);
AddToSubgraph(subgraph, op->value);
AddToSubgraph(subgraph, op->body);
Var var = Downcast<Var>(VisitExpr(op->var)); Var var = Downcast<Var>(VisitExpr(op->var));
auto value = VisitExpr(op->value); auto value = VisitExpr(op->value);
auto body = VisitExpr(op->body); auto body = VisitExpr(op->body);
return Let(var, value, body); return Let(var, value, body);
} }
} }
Expr VisitExpr_(const IfNode* op) final { Expr VisitExpr_(const IfNode *op) final {
auto subgraph = GetSubgraph(GetRef<If>(op)); auto region = GetRegion(GetRef<If>(op));
if (!subgraph) { if (!region.defined()) {
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
} else { } else {
AddToSubgraph(subgraph, op->cond);
AddToSubgraph(subgraph, op->true_branch);
AddToSubgraph(subgraph, op->false_branch);
auto guard = VisitExpr(op->cond); auto guard = VisitExpr(op->cond);
auto true_b = VisitExpr(op->true_branch); auto true_b = VisitExpr(op->true_branch);
auto false_b = VisitExpr(op->false_branch); auto false_b = VisitExpr(op->false_branch);
...@@ -328,34 +340,31 @@ class Partitioner : public ExprMutator { ...@@ -328,34 +340,31 @@ class Partitioner : public ExprMutator {
} }
} }
Expr VisitExpr_(const RefCreateNode* op) final { Expr VisitExpr_(const RefCreateNode *op) final {
auto subgraph = GetSubgraph(GetRef<RefCreate>(op)); auto region = GetRegion(GetRef<RefCreate>(op));
if (!subgraph) { if (!region.defined()) {
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
} else { } else {
AddToSubgraph(subgraph, op->value);
Expr value = VisitExpr(op->value); Expr value = VisitExpr(op->value);
return RefCreate(value); return RefCreate(value);
} }
} }
Expr VisitExpr_(const RefReadNode* op) final { Expr VisitExpr_(const RefReadNode *op) final {
auto subgraph = GetSubgraph(GetRef<RefRead>(op)); auto region = GetRegion(GetRef<RefRead>(op));
if (!subgraph) { if (!region.defined()) {
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
} else { } else {
AddToSubgraph(subgraph, op->ref);
Expr ref = VisitExpr(op->ref); Expr ref = VisitExpr(op->ref);
return RefRead(ref); return RefRead(ref);
} }
} }
Expr VisitExpr_(const RefWriteNode* op) final { Expr VisitExpr_(const RefWriteNode *op) final {
auto subgraph = GetSubgraph(GetRef<RefWrite>(op)); auto region = GetRegion(GetRef<RefWrite>(op));
if (!subgraph) { if (!region.defined()) {
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
} else { } else {
AddToSubgraph(subgraph, op->ref);
Expr ref = VisitExpr(op->ref); Expr ref = VisitExpr(op->ref);
Expr value = VisitExpr(op->value); Expr value = VisitExpr(op->value);
return RefWrite(ref, value); return RefWrite(ref, value);
...@@ -365,13 +374,13 @@ class Partitioner : public ExprMutator { ...@@ -365,13 +374,13 @@ class Partitioner : public ExprMutator {
IRModule Partition() { IRModule Partition() {
auto glob_funcs = module_->functions; auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) { for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) { if (auto *fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn); auto func = GetRef<Function>(fn);
func = Function(func->params, func = Function(func->params,
VisitExpr(func->body), VisitExpr(func->body),
func->ret_type, func->ret_type,
func->type_params, func->type_params,
func->attrs); func->attrs);
module_->Update(pair.first, func); module_->Update(pair.first, func);
} }
} }
...@@ -379,21 +388,99 @@ class Partitioner : public ExprMutator { ...@@ -379,21 +388,99 @@ class Partitioner : public ExprMutator {
} }
private: private:
int var_id_{0}; /*!
int subgraph_id_{0}; * \brief Get the region an expression belongs to
std::unordered_set<std::shared_ptr<Subgraph>> subgraphs_; * if its in a region.
*/
AnnotatedRegion GetRegion(const Expr& e) {
for (auto sg_set_it : regions_sets_) {
auto sg_set = sg_set_it.first;
AnnotatedRegion sg = sg_set->GetRegion(e);
if (sg.defined()) {
return sg;
}
}
return AnnotatedRegion(nullptr);
}
/*!
* \brief Get the function an expression belongs to
* if its in a region.
*/
BaseFunc GetFunc(const Expr& e) {
for (auto sg_set_it : regions_sets_) {
auto sg_set = sg_set_it.first;
auto func = sg_set_it.second;
AnnotatedRegion sg = sg_set->GetRegion(e);
if (sg.defined()) {
return func;
}
}
return BaseFunc(nullptr);
}
/*!
* \brief Get the index of the argument;
* this is to be used as tuplegetitem idx
*/
int GetArgIdx(AnnotatedRegion sg, const Expr& arg) {
int idx = 0;
for (auto arg_ : sg->GetInputs()) {
if (arg == arg_) {
return idx;
}
idx++;
}
return -1;
}
/*!
* \brief Get the index of the return(output);
* this is to be used as tuplegetitem idx
*/
int GetRetIdx(AnnotatedRegion sg, const Expr& arg) {
int idx = 0;
for (auto arg_ : sg->GetOutputs()) {
if (arg == arg_) {
return idx;
}
idx++;
}
return -1;
}
/*!
* \brief This map maintains the already created function calls.
* This is required in the multi-output scenario, to link rest of the outputs to call
*/
std::unordered_map<AnnotatedRegion, Call, ObjectHash, ObjectEqual> region_function_calls;
/*!
* \brief This map maintains arguments (of region) visits through visitor patterns.
* Those arguement var and expression will be used to when creating the function.
*/
std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>,
ObjectHash, ObjectEqual> region_args;
/*!
* \brief Each region set is associated with a function in the module.
* This map maintains the mapping between regionsets and the function it belongs to
*/
std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;
IRModule module_; IRModule module_;
}; };
} // namespace partitioning } // namespace partitioning
namespace transform { namespace transform {
Pass PartitionGraph() { Pass PartitionGraph() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m, PassContext pc) { [=](IRModule m, PassContext pc) {
return partitioning::Partitioner(m).Partition(); return partitioning::Partitioner(m).Partition();
}; };
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()}); return Sequential({partitioned, InferType()});
} }
......
...@@ -678,6 +678,187 @@ def test_constant_propagation(): ...@@ -678,6 +678,187 @@ def test_constant_propagation():
check_result(mod, {"y": y_data}, (8, 8), np.log(np_add)) check_result(mod, {"y": y_data}, (8, 8), np.log(np_add))
def test_multiple_outputs():
def create_graph():
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
bn_mean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32"))
data_cb = compiler_begin(data, 'test_target')
weight_cb = compiler_begin(weight, 'test_target')
bn_gamma_cb = compiler_begin(bn_gamma, 'test_target')
bn_beta_cb = compiler_begin(bn_beta, 'test_target')
bn_mean_cb = compiler_begin(bn_mean, 'test_target')
bn_var_cb = compiler_begin(bn_var, 'test_target')
conv_o = relay.nn.conv2d(
data=data_cb,
weight=weight_cb,
kernel_size=(3, 3),
channels=16,
padding=(1, 1))
bn_o = relay.nn.batch_norm(conv_o, bn_gamma_cb, bn_beta_cb, bn_mean_cb,
bn_var_cb)
relu_o = relay.nn.relu(bn_o[0])
relu_o_ce = compiler_end(relu_o, 'test_target')
bn_omean = bn_o[1]
rebn_omean_ce = compiler_end(bn_omean, 'test_target')
bn_ovar = bn_o[2]
bn_ovar_ce = compiler_end(bn_ovar, 'test_target')
dummy_mean_abs = relay.abs(rebn_omean_ce)
dummy_ovar_abs = relay.abs(bn_ovar_ce)
dummy_tuple = relay.Tuple((relu_o_ce, dummy_mean_abs,dummy_ovar_abs))
func = relay.Function([data, weight, bn_gamma, bn_beta,
bn_mean, bn_var], dummy_tuple)
return func
def expected():
mod = tvm.IRModule()
# function 0
data = relay.var("test_target_0_i0", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("test_target_0_i1", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("test_target_0_i3", relay.TensorType((16, ), "float32"))
bn_mean = relay.var("test_target_0_i4", relay.TensorType((16, ), "float32"))
bn_var = relay.var("test_target_0_i5", relay.TensorType((16, ), "float32"))
conv_o = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=(3, 3),
channels=16,
padding=(1, 1))
bn_o = relay.nn.batch_norm(conv_o, bn_gamma, bn_beta, bn_mean,
bn_var)
relu_o = relay.nn.relu(bn_o[0])
tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2]))
func0 = relay.Function([data, weight, bn_gamma, bn_beta,
bn_mean, bn_var], tuple_o)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_0"))
gv0 = relay.GlobalVar("test_target_0")
mod[gv0] = func0
# body
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
bn_mean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32"))
f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var)
f0_relu_o = relay.TupleGetItem(f0_o, 0)
f0_mean_o = relay.TupleGetItem(f0_o, 1)
f0_var_o = relay.TupleGetItem(f0_o, 2)
f0_mean_abs = relay.abs(f0_mean_o)
f0_var_abs = relay.abs(f0_var_o)
main_tuple = relay.Tuple((f0_relu_o, f0_mean_abs, f0_var_abs))
func = relay.Function([data, weight, bn_gamma,
bn_beta, bn_mean, bn_var], main_tuple)
mod["main"] = func
return mod
mod = tvm.IRModule()
mod["main"] = create_graph()
ref_mod = expected()
partitioned = transform.PartitionGraph()(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_mixed_single_multiple_outputs():
def create_graph():
data = relay.var('data', shape=(10, 10))
cb_1 = compiler_begin(data, 'test_target')
O_1 = relay.abs(cb_1)
ce_2 = compiler_end(O_1, 'test_target')
O_2 = relay.nn.relu(O_1)
ce_3 = compiler_end(O_2, 'test_target')
X = relay.tanh(ce_2)
cb_3 = compiler_begin(ce_3, 'test_target')
cb_4 = compiler_begin(X, 'test_target')
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, 'test_target')
func = relay.Function([data], ce_4)
return func
def expected():
mod = tvm.IRModule()
# function 1
f1_cb1 = relay.var('test_target_1_i0', shape=(10, 10))
f1_O_1 = relay.abs(f1_cb1)
f1_O_2 = relay.nn.relu(f1_O_1)
f1_out = relay.Tuple((f1_O_2, f1_O_1))
func1 = relay.Function([f1_cb1], f1_out)
func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func1 = func1.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_1"))
gv1 = relay.GlobalVar("test_target_1")
mod[gv1] = func1
# function 0
f2_cb3 = relay.var('test_target_0_i0', shape=(10, 10))
f2_cb4 = relay.var('test_target_0_i1', shape=(10, 10))
f2_O_3 = relay.add(f2_cb3, f2_cb4)
func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_0"))
gv0 = relay.GlobalVar("test_target_0")
mod[gv0] = func0
# body
data = relay.var('data', shape=(10, 10))
tuple_out = gv1(data)
ce_2 = relay.TupleGetItem(tuple_out, 1)
ce_3 = relay.TupleGetItem(tuple_out, 0)
X = relay.tanh(ce_2)
ce_4 = gv0(ce_3, X)
func = relay.Function([data], ce_4)
mod["main"] = func
return mod
mod = tvm.IRModule()
mod["main"] = create_graph()
ref_mod = expected()
partitioned = transform.PartitionGraph()(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
if __name__ == "__main__": if __name__ == "__main__":
test_multi_node_compiler() test_multi_node_compiler()
test_extern_ccompiler_single_op() test_extern_ccompiler_single_op()
...@@ -688,3 +869,5 @@ if __name__ == "__main__": ...@@ -688,3 +869,5 @@ if __name__ == "__main__":
test_function_lifting() test_function_lifting()
test_function_lifting_inline() test_function_lifting_inline()
test_constant_propagation() test_constant_propagation()
test_multiple_outputs()
test_mixed_single_multiple_outputs()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment