Unverified Commit 14ae3a6e by manupa-arm Committed by GitHub

[RELAY] Re-wrote the Graph Partitioner to support multiple outputs (#5143)

* [RELAY] Re-wrote the Graph Partitioner to support multiple outputs

Input : A Relay module that have functions with disjoint annotated regions
        using compiler_begin and compiler_end. There could be multiple outputs.

Output : A Relay module with global functions for such disjoint annotated regions
         with calls inserted at the respective location

Dependencies : AnnotatedRegionSet Utility class.

Methodology :
      1) The AnnotatedRegionSet utility class is able to construct a collection of
         nodes that are bound by a give annotation -- here we use compiler_begin
         and compiler_end
      2) Initially, for each function in the module AnnotatedRegionSets are populated.
      3) Then, Vistor pass is traversed until a compiler_end node is encountered
         that belongs to a "region".
      4) When the first compiler_end of a given annotated region is found, a function is
         formed and inserted.
         a) if the region has multiple outputs, a Tuple node (capturing all outputs)
            is returned.
      5) Thereafter, if we encounter an another output of the same annotated region,
         it is important to note that the function is already formed. Therefore, it will
         lookup the function and add a TupleGetItemNode is inserted.
          a) We will use the location index of "rets" of each "Region" of AnnotatedRegionSet
             as TupleGetItemNode index.
      6) Therefore, functions will be created for all annotated regions. The name for each
         global function is created using "Region" id and the compiler name.

Change-Id: I1372f02a845b6d3da03b561763e03a378dca263c

* [RELAY] Re-wrote the Graph Partitioner to support multiple outputs

    *removed the expected use-case as we are taking broken-down PR approach
    *code style fixes
    *some trivial one liners

* [RELAY] Re-wrote the Graph Partitioner to support multiple outputs

    *fixed an implicit copy to a move

* [RELAY] Re-wrote the Graph Partitioner to support multiple outputs

    *code style changes for comments
    *renamed test case multiple outputs --> mixed single multiple outputs
        Since the existing test case checks for both single and multiple
        output scenarios
    *added a new test case with conv2d + batch_norm
    *some var name changes in the test

* [RELAY] Re-wrote the Graph Partitioner to support multiple outputs

	*rebased
parent 9cb9a51f
......@@ -25,7 +25,7 @@
* These nodes are used as boundaries to partition the Relay function into
* multiple regions that can be offloaded to different accelerators/backends.
*
* Each of these paritioned functions, a.k.a subgraphs, will be viewed as
* Each of these paritioned functions, a.k.a regions, will be viewed as
* external functions, and they will use the provided compiler for codegen.
*/
......@@ -36,13 +36,14 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <string>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "../backend/utils.h"
#include "../analysis/annotated_region_set.h"
namespace tvm {
namespace relay {
......@@ -54,20 +55,6 @@ static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
/*!
* \brief The subgraph properties for partitioning.
*/
struct Subgraph {
/*! \brief The subgraph ID. */
int id;
/*! \brief The input arguments of this subgraph. */
std::vector<std::pair<Var, Expr>> args;
/*! \brief Nodes in this subgraph. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
};
/*!
* \brief The checker that verifies if a Relay program is annotated correctly
* for partitioning.
*/
......@@ -86,7 +73,7 @@ class AnnotationChecker : public ExprVisitor {
return true;
}
void VisitExpr_(const CallNode* call) final {
void VisitExpr_(const CallNode *call) final {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return;
......@@ -102,61 +89,56 @@ class AnnotationChecker : public ExprVisitor {
bool found_end_{false};
};
/*! \brief This class partitions the expr labeled with begin and end annoations
/*! \brief This class partitions the expr labeled with begin and end annotations
* into function containing multiple regions. Each region is labeled with
* a compiler attribute so that it will be handled by any compilers that are not
* in the TVM stack.
*
* TODO(@zhiics) This following algorithm is not adequate to handle all cases,
* i.e. multiple `compiler_end` nodes.
* Input : A Relay module that have functions with disjoint annotated regions
* using compiler_begin and compiler_end. There could be multiple outputs.
*
* Output : A Relay module with global functions for such disjoint annotated regions
* with calls inserted at the respective location
*
* Dependencies : RegionSet Utility class.
*
* Methodology :
* 1) The RegionSet utility class is able to construct a collection of
* nodes that are bound by a given annotation -- here we use compiler_begin
* and compiler_end
* 2) Initially, for each function in the module RegionSets are populated.
* 3) Then, Vistor pass is traversed until a compiler_end node is encountered
* that belongs to a "region".
* 4) When the first compiler_end of a given annotated region is found, a function is
* formed and inserted.
* a) if the region has multiple outputs, a Tuple node (capturing all outputs)
* is returned.
* 5) Thereafter, if we encounter an another output of the same annotated region,
* it is important to note that the function is already formed. Therefore, it will
* lookup the function and add a TupleGetItemNode.
* a) We will use the location index of "rets" of each "Region" of RegionSet
* as TupleGetItemNode index.
* 6) Therefore, functions will be created for all annotated regions. The name for each
* global function is created using "Region" id and the compiler name.
*/
class Partitioner : public ExprMutator {
public:
explicit Partitioner(const IRModule& module) : module_(module) {}
std::shared_ptr<Subgraph> GetSubgraph(const Expr node) {
for (auto candidate : this->subgraphs_) {
if (candidate->nodes.find(node) != candidate->nodes.end()) {
return candidate;
}
}
return nullptr;
}
void MergeSubgraph(std::shared_ptr<Subgraph> subgraph1,
std::shared_ptr<Subgraph> subgraph2) {
if (subgraph1 == subgraph2) {
return;
}
// Merge subgraph 2 to subgraph 1 and erase subgraph 2.
subgraph1->nodes.insert(subgraph2->nodes.begin(), subgraph2->nodes.end());
for (auto arg : subgraph2->args) {
subgraph1->args.push_back(arg);
}
this->subgraphs_.erase(subgraph2);
}
void AddToSubgraph(std::shared_ptr<Subgraph> subgraph, const Expr expr) {
auto subgraph2 = GetSubgraph(expr);
if (subgraph2) {
MergeSubgraph(subgraph, subgraph2);
} else {
subgraph->nodes.insert(expr);
explicit Partitioner(const IRModule& module) : module_(module) {
for (auto f : module->functions) {
GlobalVar f_var = f.first;
BaseFunc f_func = f.second;
// Creating regionset per function in the module
auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op,
partitioning::compiler_end_op);
regions_sets_[region_set] = f_func;
}
}
Expr VisitExpr_(const CallNode* call) final {
Expr VisitExpr_(const CallNode *call) final {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
// Propogate subgraph to arguments
auto subgraph = GetSubgraph(GetRef<Call>(call));
if (subgraph) {
for (auto arg : call->args) {
AddToSubgraph(subgraph, arg);
}
}
return ExprMutator::VisitExpr_(call);
} else if (call->op == compiler_begin_op) {
// The annotation node is inserted on edge so it must have only one argument.
......@@ -165,101 +147,142 @@ class Partitioner : public ExprMutator {
// Traverse the rest graph.
auto input_expr = VisitExpr(call->args[0]);
// Replace the begin annotation with an external call input variable.
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
AnnotatedRegion sg = GetRegion(GetRef<Call>(call));
int index = GetArgIdx(sg, GetRef<Call>(call));
CHECK_NE(index, -1);
// The type of the created variable is the same as the compiler_begin
// node.
auto var = Var(compiler_attrs->compiler + "_input" + std::to_string(var_id_++),
call->checked_type_);
// Find the corresponding subgraph and add the argument.
auto subgraph = GetSubgraph(GetRef<Call>(call));
if (!subgraph) {
throw Error(ErrorBuilder()
<< "Cannot find the corresponding subgraph for start annotation:\n"
<< AsText(GetRef<Call>(call), false));
}
subgraph->args.push_back({var, input_expr});
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string varname = target + "_" + std::to_string(sg->GetID())
+ "_i" + std::to_string(index);
auto var = Var(varname, GetRef<Call>(call)->checked_type_);
auto cand = std::make_pair(var, input_expr);
if (std::find(region_args[sg].begin(),
region_args[sg].end(), cand) == region_args[sg].end()) {
region_args[sg].push_back(cand);
}
return std::move(var);
} else {
CHECK_EQ(call->op, compiler_end_op);
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
AnnotatedRegion region = GetRegion(GetRef<Call>(call));
// Check if the argument already belongs to an existing subgraph
auto subgraph = GetSubgraph(call->args[0]);
if (!subgraph) {
auto ret = this->subgraphs_.emplace(std::make_shared<Subgraph>());
subgraph = *ret.first;
subgraph->nodes.insert(call->args[0]);
subgraph->id = this->subgraph_id_++;
}
subgraph->nodes.insert(GetRef<Call>(call));
// TODO(@manupa-arm) : need to use the parent function (to which region
// belongs to) name/key for the funtions that are created
BaseFunc f = GetFunc(GetRef<Call>(call));
// Traverse subgraph inputs.
auto input = VisitExpr(call->args[0]);
Array<Var> params;
Array<Expr> args;
std::unordered_map<std::string, runtime::NDArray> params_bind;
// The subgraph may be merged so we need to update it again.
subgraph = GetSubgraph(GetRef<Call>(call));
CHECK(subgraph);
// Record the constants for propagation.
for (auto pair : subgraph->args) {
params.push_back(pair.first);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
CHECK(region.defined()) << "Region not defined for " << GetRef<Call>(call);
// functions are created for each annotated regions,
// when their first output is encountered.
// If multiple outputs are there, a tuple node is inserted at the end.
// region_function_calls is map that maintains
// (each annotated regions) --> created function
if (region_function_calls.find(region) != region_function_calls.end()) {
// This section is executed only if there are multiple outputs in the region
// Thus, the function is always created and at the end there would be a tuple node
// Therefore, we insert a tuple get item node.
// Use the already created tuple node
auto sg_call = region_function_calls[region];
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);
auto tuple_get_item_ = TupleGetItem(sg_call, index);
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
return std::move(tuple_get_item_);
} else {
// First time this region is encountered in the traversal
// Creating the function
Array<Expr> fields;
for (auto ret : region->GetOutputs()) {
auto ret_expr = VisitExpr(Downcast<Call>(ret)->args[0]);
fields.push_back(ret_expr);
}
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);
Array<Var> params;
Array<Expr> param_expr;
std::unordered_map<std::string, runtime::NDArray> params_bind;
for (auto pair : region_args[region]) {
params.push_back(pair.first);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
} else {
param_expr.push_back(pair.second);
}
}
Function global_region_func;
if (region->GetOutputs().size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func = Function(params, fields[0],
call->args[0]->checked_type_, {}, DictAttrs());
} else {
args.push_back(pair.second);
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
}
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());
global_region_func = WithAttr(std::move(global_region_func), attr::kExternalSymbol,
tir::StringImmNode::make(name));
global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive,
tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::tir::StringImmNode::make(target));
global_region_func = WithAttr(std::move(global_region_func), attr::kInline,
tvm::Integer(1));
// Constant propagation
if (!params_bind.empty()) {
global_region_func = backend::BindParamsByName(global_region_func, params_bind);
}
}
auto subgraph_func =
Function(params, input, call->checked_type_, {});
std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id);
subgraph_func =
WithAttr(std::move(subgraph_func), attr::kExternalSymbol, tir::StringImmNode::make(name));
subgraph_func =
WithAttr(std::move(subgraph_func), attr::kPrimitive, tvm::Integer(1));
subgraph_func =
WithAttr(std::move(subgraph_func), attr::kCompiler,
tvm::tir::StringImmNode::make(compiler_attrs->compiler));
subgraph_func =
WithAttr(std::move(subgraph_func), attr::kInline, tvm::Integer(1));
// Constant propagation
if (!params_bind.empty()) {
subgraph_func = backend::BindParamsByName(subgraph_func, params_bind);
std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname))
<< "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent relay
// function level passes (i.e. simplify inference and fusion) optimizing it.
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);
// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = Call(glob_func, param_expr);
region_function_calls[region] = ret;
if (region->GetOutputs().size() == 1) {
// If there is only a single output; no need to add a tuplegetitem node
return std::move(ret);
} else {
// Add a tuplegetitem node to select this output out of many
auto tuple_get_item_ = TupleGetItem(ret, index);
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
return std::move(tuple_get_item_);
}
}
CHECK(!module_->ContainGlobalVar(name))
<< "Global function " << name << " already exists";
// Create a global function and add it to the IRModule for the subgraph.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent relay
// function level passes (i.e. simplify inference and fusion) optimizing it.
GlobalVar glob_func(name);
module_->Add(glob_func, subgraph_func);
// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = Call(glob_func, args);
ret->checked_type_ = call->checked_type_;
return std::move(ret);
}
}
Expr VisitExpr_(const TupleNode* op) final {
auto subgraph = GetSubgraph(GetRef<Tuple>(op));
if (!subgraph) {
Expr VisitExpr_(const TupleNode *op) final {
auto region = GetRegion(GetRef<Tuple>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
for (auto field : op->fields) {
AddToSubgraph(subgraph, field);
}
Array<Expr> fields;
for (auto field : op->fields) {
fields.push_back(VisitExpr(field));
......@@ -268,27 +291,23 @@ class Partitioner : public ExprMutator {
}
}
Expr VisitExpr_(const TupleGetItemNode* g) final {
auto subgraph = GetSubgraph(GetRef<TupleGetItem>(g));
if (!subgraph) {
Expr VisitExpr_(const TupleGetItemNode *g) final {
auto region = GetRegion(GetRef<TupleGetItem>(g));
if (!region.defined()) {
return ExprMutator::VisitExpr_(g);
} else {
AddToSubgraph(subgraph, g->tuple);
auto t = VisitExpr(g->tuple);
return TupleGetItem(t, g->index);
}
}
Expr VisitExpr_(const FunctionNode* op) final {
auto subgraph = GetSubgraph(GetRef<Function>(op));
if (!subgraph) {
Expr VisitExpr_(const FunctionNode *op) final {
auto region = GetRegion(GetRef<Function>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Array<Var> params;
for (auto param : op->params) {
AddToSubgraph(subgraph, param);
}
for (auto param : op->params) {
Var new_param = Downcast<Var>(VisitExpr(param));
params.push_back(new_param);
}
......@@ -297,30 +316,23 @@ class Partitioner : public ExprMutator {
}
}
Expr VisitExpr_(const LetNode* op) final {
auto subgraph = GetSubgraph(GetRef<Let>(op));
if (!subgraph) {
Expr VisitExpr_(const LetNode *op) final {
auto region = GetRegion(GetRef<Let>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
AddToSubgraph(subgraph, op->var);
AddToSubgraph(subgraph, op->value);
AddToSubgraph(subgraph, op->body);
Var var = Downcast<Var>(VisitExpr(op->var));
auto value = VisitExpr(op->value);
auto body = VisitExpr(op->body);
return Let(var, value, body);
}
}
Expr VisitExpr_(const IfNode* op) final {
auto subgraph = GetSubgraph(GetRef<If>(op));
if (!subgraph) {
Expr VisitExpr_(const IfNode *op) final {
auto region = GetRegion(GetRef<If>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
AddToSubgraph(subgraph, op->cond);
AddToSubgraph(subgraph, op->true_branch);
AddToSubgraph(subgraph, op->false_branch);
auto guard = VisitExpr(op->cond);
auto true_b = VisitExpr(op->true_branch);
auto false_b = VisitExpr(op->false_branch);
......@@ -328,34 +340,31 @@ class Partitioner : public ExprMutator {
}
}
Expr VisitExpr_(const RefCreateNode* op) final {
auto subgraph = GetSubgraph(GetRef<RefCreate>(op));
if (!subgraph) {
Expr VisitExpr_(const RefCreateNode *op) final {
auto region = GetRegion(GetRef<RefCreate>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
AddToSubgraph(subgraph, op->value);
Expr value = VisitExpr(op->value);
return RefCreate(value);
}
}
Expr VisitExpr_(const RefReadNode* op) final {
auto subgraph = GetSubgraph(GetRef<RefRead>(op));
if (!subgraph) {
Expr VisitExpr_(const RefReadNode *op) final {
auto region = GetRegion(GetRef<RefRead>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
AddToSubgraph(subgraph, op->ref);
Expr ref = VisitExpr(op->ref);
return RefRead(ref);
}
}
Expr VisitExpr_(const RefWriteNode* op) final {
auto subgraph = GetSubgraph(GetRef<RefWrite>(op));
if (!subgraph) {
Expr VisitExpr_(const RefWriteNode *op) final {
auto region = GetRegion(GetRef<RefWrite>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
AddToSubgraph(subgraph, op->ref);
Expr ref = VisitExpr(op->ref);
Expr value = VisitExpr(op->value);
return RefWrite(ref, value);
......@@ -365,13 +374,13 @@ class Partitioner : public ExprMutator {
IRModule Partition() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
if (auto *fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = Function(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
}
......@@ -379,21 +388,99 @@ class Partitioner : public ExprMutator {
}
private:
int var_id_{0};
int subgraph_id_{0};
std::unordered_set<std::shared_ptr<Subgraph>> subgraphs_;
/*!
* \brief Get the region an expression belongs to
* if its in a region.
*/
AnnotatedRegion GetRegion(const Expr& e) {
for (auto sg_set_it : regions_sets_) {
auto sg_set = sg_set_it.first;
AnnotatedRegion sg = sg_set->GetRegion(e);
if (sg.defined()) {
return sg;
}
}
return AnnotatedRegion(nullptr);
}
/*!
* \brief Get the function an expression belongs to
* if its in a region.
*/
BaseFunc GetFunc(const Expr& e) {
for (auto sg_set_it : regions_sets_) {
auto sg_set = sg_set_it.first;
auto func = sg_set_it.second;
AnnotatedRegion sg = sg_set->GetRegion(e);
if (sg.defined()) {
return func;
}
}
return BaseFunc(nullptr);
}
/*!
* \brief Get the index of the argument;
* this is to be used as tuplegetitem idx
*/
int GetArgIdx(AnnotatedRegion sg, const Expr& arg) {
int idx = 0;
for (auto arg_ : sg->GetInputs()) {
if (arg == arg_) {
return idx;
}
idx++;
}
return -1;
}
/*!
* \brief Get the index of the return(output);
* this is to be used as tuplegetitem idx
*/
int GetRetIdx(AnnotatedRegion sg, const Expr& arg) {
int idx = 0;
for (auto arg_ : sg->GetOutputs()) {
if (arg == arg_) {
return idx;
}
idx++;
}
return -1;
}
/*!
* \brief This map maintains the already created function calls.
* This is required in the multi-output scenario, to link rest of the outputs to call
*/
std::unordered_map<AnnotatedRegion, Call, ObjectHash, ObjectEqual> region_function_calls;
/*!
* \brief This map maintains arguments (of region) visits through visitor patterns.
* Those arguement var and expression will be used to when creating the function.
*/
std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>,
ObjectHash, ObjectEqual> region_args;
/*!
* \brief Each region set is associated with a function in the module.
* This map maintains the mapping between regionsets and the function it belongs to
*/
std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;
IRModule module_;
};
} // namespace partitioning
namespace transform {
Pass PartitionGraph() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m, PassContext pc) {
return partitioning::Partitioner(m).Partition();
};
[=](IRModule m, PassContext pc) {
return partitioning::Partitioner(m).Partition();
};
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()});
}
......
......@@ -678,6 +678,187 @@ def test_constant_propagation():
check_result(mod, {"y": y_data}, (8, 8), np.log(np_add))
def test_multiple_outputs():
def create_graph():
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
bn_mean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32"))
data_cb = compiler_begin(data, 'test_target')
weight_cb = compiler_begin(weight, 'test_target')
bn_gamma_cb = compiler_begin(bn_gamma, 'test_target')
bn_beta_cb = compiler_begin(bn_beta, 'test_target')
bn_mean_cb = compiler_begin(bn_mean, 'test_target')
bn_var_cb = compiler_begin(bn_var, 'test_target')
conv_o = relay.nn.conv2d(
data=data_cb,
weight=weight_cb,
kernel_size=(3, 3),
channels=16,
padding=(1, 1))
bn_o = relay.nn.batch_norm(conv_o, bn_gamma_cb, bn_beta_cb, bn_mean_cb,
bn_var_cb)
relu_o = relay.nn.relu(bn_o[0])
relu_o_ce = compiler_end(relu_o, 'test_target')
bn_omean = bn_o[1]
rebn_omean_ce = compiler_end(bn_omean, 'test_target')
bn_ovar = bn_o[2]
bn_ovar_ce = compiler_end(bn_ovar, 'test_target')
dummy_mean_abs = relay.abs(rebn_omean_ce)
dummy_ovar_abs = relay.abs(bn_ovar_ce)
dummy_tuple = relay.Tuple((relu_o_ce, dummy_mean_abs,dummy_ovar_abs))
func = relay.Function([data, weight, bn_gamma, bn_beta,
bn_mean, bn_var], dummy_tuple)
return func
def expected():
mod = tvm.IRModule()
# function 0
data = relay.var("test_target_0_i0", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("test_target_0_i1", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("test_target_0_i3", relay.TensorType((16, ), "float32"))
bn_mean = relay.var("test_target_0_i4", relay.TensorType((16, ), "float32"))
bn_var = relay.var("test_target_0_i5", relay.TensorType((16, ), "float32"))
conv_o = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=(3, 3),
channels=16,
padding=(1, 1))
bn_o = relay.nn.batch_norm(conv_o, bn_gamma, bn_beta, bn_mean,
bn_var)
relu_o = relay.nn.relu(bn_o[0])
tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2]))
func0 = relay.Function([data, weight, bn_gamma, bn_beta,
bn_mean, bn_var], tuple_o)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_0"))
gv0 = relay.GlobalVar("test_target_0")
mod[gv0] = func0
# body
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
bn_mean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32"))
f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var)
f0_relu_o = relay.TupleGetItem(f0_o, 0)
f0_mean_o = relay.TupleGetItem(f0_o, 1)
f0_var_o = relay.TupleGetItem(f0_o, 2)
f0_mean_abs = relay.abs(f0_mean_o)
f0_var_abs = relay.abs(f0_var_o)
main_tuple = relay.Tuple((f0_relu_o, f0_mean_abs, f0_var_abs))
func = relay.Function([data, weight, bn_gamma,
bn_beta, bn_mean, bn_var], main_tuple)
mod["main"] = func
return mod
mod = tvm.IRModule()
mod["main"] = create_graph()
ref_mod = expected()
partitioned = transform.PartitionGraph()(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_mixed_single_multiple_outputs():
def create_graph():
data = relay.var('data', shape=(10, 10))
cb_1 = compiler_begin(data, 'test_target')
O_1 = relay.abs(cb_1)
ce_2 = compiler_end(O_1, 'test_target')
O_2 = relay.nn.relu(O_1)
ce_3 = compiler_end(O_2, 'test_target')
X = relay.tanh(ce_2)
cb_3 = compiler_begin(ce_3, 'test_target')
cb_4 = compiler_begin(X, 'test_target')
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, 'test_target')
func = relay.Function([data], ce_4)
return func
def expected():
mod = tvm.IRModule()
# function 1
f1_cb1 = relay.var('test_target_1_i0', shape=(10, 10))
f1_O_1 = relay.abs(f1_cb1)
f1_O_2 = relay.nn.relu(f1_O_1)
f1_out = relay.Tuple((f1_O_2, f1_O_1))
func1 = relay.Function([f1_cb1], f1_out)
func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func1 = func1.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_1"))
gv1 = relay.GlobalVar("test_target_1")
mod[gv1] = func1
# function 0
f2_cb3 = relay.var('test_target_0_i0', shape=(10, 10))
f2_cb4 = relay.var('test_target_0_i1', shape=(10, 10))
f2_O_3 = relay.add(f2_cb3, f2_cb4)
func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_0"))
gv0 = relay.GlobalVar("test_target_0")
mod[gv0] = func0
# body
data = relay.var('data', shape=(10, 10))
tuple_out = gv1(data)
ce_2 = relay.TupleGetItem(tuple_out, 1)
ce_3 = relay.TupleGetItem(tuple_out, 0)
X = relay.tanh(ce_2)
ce_4 = gv0(ce_3, X)
func = relay.Function([data], ce_4)
mod["main"] = func
return mod
mod = tvm.IRModule()
mod["main"] = create_graph()
ref_mod = expected()
partitioned = transform.PartitionGraph()(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
if __name__ == "__main__":
test_multi_node_compiler()
test_extern_ccompiler_single_op()
......@@ -688,3 +869,5 @@ if __name__ == "__main__":
test_function_lifting()
test_function_lifting_inline()
test_constant_propagation()
test_multiple_outputs()
test_mixed_single_multiple_outputs()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment