Unverified Commit 02121383 by mbaret Committed by GitHub

[RELAY] Add MergeCompilerRegions pass (#5134)

* [RELAY] Add MergeCompilerRegions pass

This pass is part of the flow to support creating compiler
regions with multiple outputs. It should be called after
AnnotateTarget and will merge together regions that share
the same target to create larger compiler regions that can
be off-loaded to external codegens.

This pass implements an algorithm to ensure that during the
merging, no data dependency issues are created. See the tests
for an example of this case.

Co-authored-by: Ramana Radhakrishnan  <ramana.radhakrishnan@arm.com>
Co-authored-by: Manupa Karunaratne    <manupa.karunaratne@arm.com>

Change-Id: Ibd99083564608d888482f57c5080109f3eefec88

* [RELAY] Annotate compiler_ends on each edge

This alters the behaviour of the AnnotateTarget
pass to enforce the property that all compiler
annotations exist along a single data flow edge.
Specifically, this means they should have exactly
one parent and one child.

Change-Id: I0e74803a77767f4f377d17755a13a74a30909797

* Fix comment

* Rebase *Node::make

* Moved block outside for loop

* Code style

* Update make API

* Remove comment

* Remove redundant 'else's

* Make one line

* Fix comment

* RefWrite

* Fix merge ordering

* Add the RFC example as a test

* [FIX] Fixed merging behaviour in AnnotateRegionSet

Deleting items from a list while iterating it seems to
result in undefined behaviour which sometimes segfaults.
This makes sure all the item deletion happens separately.

* Added checks

* Move comment

* Update comments
parent 33260318
...@@ -397,6 +397,17 @@ def MergeComposite(pattern_table): ...@@ -397,6 +397,17 @@ def MergeComposite(pattern_table):
return _ffi_api.MergeComposite(pattern_names, patterns) return _ffi_api.MergeComposite(pattern_names, patterns)
def MergeCompilerRegions():
"""Merge together compiler regions.
Returns
-------
ret : tvm.relay.Pass
The registered pass that merges compiler regions.
"""
return _ffi_api.MergeCompilerRegions()
def RewriteAnnotatedOps(fallback_device): def RewriteAnnotatedOps(fallback_device):
"""Rewrite the annotated program where annotation operators, e.g. """Rewrite the annotated program where annotation operators, e.g.
`on_deivce`, mark which device an expression should be scheduled to. `on_deivce`, mark which device an expression should be scheduled to.
......
...@@ -55,13 +55,17 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, ...@@ -55,13 +55,17 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
} }
// if any of the outputs of src are inputs of dest, they become internal nodes // if any of the outputs of src are inputs of dest, they become internal nodes
// so remove them from outs // so remove them from outs
std::vector<Expr> ins_to_remove;
for (const auto& input : dest->ins) { for (const auto& input : dest->ins) {
auto call = Downcast<Call>(input); auto call = Downcast<Call>(input);
auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]); auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]);
if (it != src->outs.end()) { if (it != src->outs.end()) {
dest->outs.remove(*it); dest->outs.remove(*it);
dest->ins.remove(input); ins_to_remove.push_back(input);
}
} }
for (const auto& input : ins_to_remove) {
dest->ins.remove(input);
} }
regions_.erase(src); regions_.erase(src);
} }
......
...@@ -38,38 +38,136 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -38,38 +38,136 @@ class AnnotateTargetWrapper : public ExprMutator {
public: public:
explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {} explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {}
Expr Annotate(const Expr& expr) {
return InsertEnd(Mutate(expr));
}
bool IsSupported(const Expr& expr) {
if (expr->IsInstance<CallNode>()) {
Call call = Downcast<Call>(expr);
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op);
CHECK(op.defined());
if (fannotate.count(op)) {
return fannotate[op](call->attrs, call->args);
}
}
return false;
}
Expr InsertEnd(const Expr& arg) {
if (IsSupported(arg)) {
const auto *end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op);
Expr end = (*end_op)(arg, target_);
return end;
}
return arg;
}
Expr VisitExpr_(const CallNode* cn) { Expr VisitExpr_(const CallNode* cn) {
// TODO(@zhiics, @comaniac) Handle composite functions. // TODO(@zhiics, @comaniac) Handle composite functions.
auto new_e = ExprMutator::VisitExpr_(cn); auto new_e = ExprMutator::VisitExpr_(cn);
Call call = Downcast<Call>(new_e); Call call = Downcast<Call>(new_e);
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op);
CHECK(op.defined());
if (fannotate.count(op)) { // add end annotations if the args are supported
bool external = fannotate[op](call->attrs, call->args); Array<Expr> compiler_ends;
if (external) {
tvm::Array<tvm::relay::Expr> compiler_begins;
for (const auto& it : call->args) { for (const auto& it : call->args) {
compiler_ends.push_back(InsertEnd(it));
}
call = Call(call->op, compiler_ends, call->attrs);
// add begin annotations if the call node is supported
if (IsSupported(call)) {
tvm::Array<tvm::relay::Expr> compiler_begins;
const auto* begin_op = const auto* begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
for (const auto& it : call->args) {
CHECK(begin_op); CHECK(begin_op);
Expr begin = (*begin_op)(it, target_); Expr begin = (*begin_op)(it, target_);
compiler_begins.push_back(begin); compiler_begins.push_back(begin);
} }
Expr update_call = Call(call->op, compiler_begins, call->attrs); call = Call(call->op, compiler_begins, call->attrs);
const auto* end_op = }
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op); return std::move(call);
Expr end = (*end_op)(update_call, target_); }
return end;
Expr VisitExpr_(const TupleNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto tup = Downcast<Tuple>(new_e);
Array<Expr> new_fields;
for (auto field : tup->fields) {
new_fields.push_back(InsertEnd(field));
}
return Tuple(new_fields);
}
Expr VisitExpr_(const TupleGetItemNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto get = Downcast<TupleGetItem>(new_e);
return TupleGetItem(
InsertEnd(get->tuple),
get->index);
}
Expr VisitExpr_(const FunctionNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto func = Downcast<Function>(new_e);
return Function(
func->params,
InsertEnd(func->body),
func->ret_type,
func->type_params,
func->attrs);
} }
} else {
LOG(WARNING) << op->name << " in " << target_ Expr VisitExpr_(const LetNode* op) {
<< " is not registered. It will be executed on CPU."; auto new_e = ExprMutator::VisitExpr_(op);
auto let = Downcast<Let>(new_e);
return Let(
let->var,
InsertEnd(let->value),
InsertEnd(let->body));
} }
return new_e;
Expr VisitExpr_(const IfNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto iff = Downcast<If>(new_e);
return If(
InsertEnd(iff->cond),
InsertEnd(iff->true_branch),
InsertEnd(iff->false_branch));
}
Expr VisitExpr_(const RefCreateNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto create = Downcast<RefCreate>(new_e);
return RefCreate(InsertEnd(create->value));
}
Expr VisitExpr_(const RefReadNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto read = Downcast<RefRead>(new_e);
return RefRead(InsertEnd(read->ref));
}
Expr VisitExpr_(const RefWriteNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto write = Downcast<RefWrite>(new_e);
return RefWrite(
InsertEnd(write->ref),
InsertEnd(write->value));
} }
private: private:
...@@ -77,7 +175,7 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -77,7 +175,7 @@ class AnnotateTargetWrapper : public ExprMutator {
}; };
Expr AnnotateTarget(const Expr& expr, const std::string& target) { Expr AnnotateTarget(const Expr& expr, const std::string& target) {
return AnnotateTargetWrapper(target).Mutate(expr); return AnnotateTargetWrapper(target).Annotate(expr);
} }
} // namespace annotate_target } // namespace annotate_target
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file src/relay/transforms/merge_compiler_regions.cc
*
* \brief After operators have been annotated with the targets that support
* them, this pass creates regions of the operators for each target. It
* is guaranteed that the regions will have a topological ordering so that
* no data dependency issues exist.
*
* This pass only introduces annotations to indicate the regions.
* partition_graph must subsequently be called to lift these regions out
* as external functions.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "../analysis/annotated_region_set.h"
namespace tvm {
namespace relay {
namespace partitioning {
// Cache compiler_begin and compiler_end annotation ops for equivalence check to
// reduce registry lookup overhead.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
/*! \brief This is a pre-requisite pass to merge-supported pass.
* The AnnotateRestDefault pass will put "default" Compiler Annotations to
* nodes that are not annotated already. This is there to ensure that the
* user will not leave un-annotated nodes MergeCompilerRegions pass is run.
* Why? Because, MergeCompilerRegions pass assumes every node to be annotated.
*/
class AnnotateRestDefault : public ExprMutator {
public:
explicit AnnotateRestDefault(const Expr& expr) {
regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
}
Expr Annotate(const Expr& expr) {
// Its a function that is being passed on to annotate
func_ = Downcast<Function>(expr);
// Corner Case CC1 : If the last node does not belong
// to a region nede to add a compiler_end
auto region = regions_->GetRegion(func_->body);
auto mutated_expr = this->VisitExpr(expr);
if (!region.defined()) {
func_ = Downcast<Function>(mutated_expr);
// CC1 : add that compiler end after mutation
auto body = AddCompilerEnd_(func_->body);
func_ = Function(func_->params, body,
body->checked_type_, {}, DictAttrs());
return Downcast<Expr>(func_);
}
return mutated_expr;
}
/*! \brief This function adds compiler ends to nodes that
* have a region AND they should not be arguments of the
* original function
* \param expr The expression to add a compiler end to.
* \return expr The expression with or without a compiler end added.
*/
Expr AddCompilerEnd(const Expr& expr) {
auto region = regions_->GetRegion(expr);
auto visited_expr = VisitExpr(expr);
// The compiler ends are added to nodes that does have a region
// AND they should not be arguments of the original function
if (!region.defined() &&
std::find(func_->params.begin(),
func_->params.end(), visited_expr)
== func_->params.end()) {
return AddCompilerEnd_(visited_expr);
}
return visited_expr;
}
Expr AddCompilerEnd_(const Expr& expr) {
const auto* end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op);
Expr end = (*end_op)(expr, target_);
return end;
}
Expr VisitExpr_(const CallNode* call) final {
auto op_node = call->op.as<OpNode>();
auto ret = GetRef<Call>(call);
Array<Expr> args;
// Add compiler ends if the parent is supported
for (auto arg : call->args) {
args.push_back(AddCompilerEnd(arg));
}
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
// Skip annotatation ops, only add default compiler to actual compute nodes
auto region = regions_->GetRegion(ret);
if (!region.defined()) {
// if the current node does not belong to annotated region
// annotate the all incoming edges (args)
// with "default" compile_begin and compiler_end annotations.
tvm::Array<tvm::relay::Expr> compiler_begins;
for (auto arg : args) {
const auto* begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
CHECK(begin_op);
Expr begin = (*begin_op)(arg, target_);
compiler_begins.push_back(begin);
}
Expr update_call = Call(call->op, compiler_begins, call->attrs);
return update_call;
}
}
return Call(call->op, args, call->attrs);
};
Expr VisitExpr_(const TupleNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto tup = Downcast<Tuple>(new_e);
Array<Expr> new_fields;
for (auto field : tup->fields) {
new_fields.push_back(AddCompilerEnd(field));
}
return Tuple(new_fields);
}
Expr VisitExpr_(const TupleGetItemNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto get = Downcast<TupleGetItem>(new_e);
return TupleGetItem(AddCompilerEnd(get->tuple), get->index);
}
Expr VisitExpr_(const LetNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto let = Downcast<Let>(new_e);
return Let(
let->var,
AddCompilerEnd(let->value),
AddCompilerEnd(let->body));
}
Expr VisitExpr_(const IfNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto iff = Downcast<If>(new_e);
return If(
AddCompilerEnd(iff->cond),
AddCompilerEnd(iff->true_branch),
AddCompilerEnd(iff->false_branch));
}
Expr VisitExpr_(const RefCreateNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto create = Downcast<RefCreate>(new_e);
return RefCreate(AddCompilerEnd(create->value));
}
Expr VisitExpr_(const RefReadNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto read = Downcast<RefRead>(new_e);
return RefRead(AddCompilerEnd(read->ref));
}
Expr VisitExpr_(const RefWriteNode *op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto write = Downcast<RefWrite>(new_e);
return RefWrite(
AddCompilerEnd(write->ref),
AddCompilerEnd(write->value));
}
private:
AnnotatedRegionSet regions_;
const std::string target_ = "default";
Function func_;
};
class MergeAnnotations : public ExprMutator {
public:
explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {}
Expr VisitExpr_(const CallNode* call) final {
if (call->op == compiler_begin_op) {
if (call->args[0]->IsInstance<CallNode>()) {
auto arg = Downcast<Call>(call->args[0]);
if (arg->op == compiler_end_op) {
auto region1 = regions_->GetRegion(GetRef<Call>(call));
auto region2 = regions_->GetRegion(arg);
if (region1 == region2) {
return ExprMutator::VisitExpr(arg->args[0]);
}
}
}
}
return ExprMutator::VisitExpr_(call);
}
private:
AnnotatedRegionSet regions_;
};
class RegionMerger : public ExprVisitor {
public:
explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}
void VisitExpr_(const CallNode* call) final {
if (call->op == compiler_end_op) {
auto region = regions_->GetRegion(GetRef<Call>(call));
// set the region target
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
region_targets_[region->GetID()] = compiler_attrs->compiler;
std::vector<AnnotatedRegion> mergeable_regions;
// first look at the region args to determine the parent regions
for (const auto& arg : region->GetInputs()) {
// all args should be begin annotations
auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op);
// the arguments of the begin annotations will be in the parent regions
auto parent_region = regions_->GetRegion(begin->args[0]);
// if there is no parent region, move on
if (!parent_region.defined()) continue;
// merge the parent region if it hasn't been done already
if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) {
VisitExpr(begin->args[0]);
}
mergeable_regions.push_back(parent_region);
}
auto& region_restrictions = region_restrictions_[region->GetID()];
for (const auto& parent_region : mergeable_regions) {
// add all the parent restrictions to the current region
auto parent_restrictions = region_restrictions_[parent_region->GetID()];
region_restrictions.insert(parent_restrictions.begin(),
parent_restrictions.end());
}
for (const auto& parent_region : mergeable_regions) {
bool merged = false;
// check the parent region has the same target
if (region_targets_[parent_region->GetID()] == compiler_attrs->compiler) {
// check the parent region isn't in the restrictions
if (region_restrictions.find(parent_region->GetID()) == region_restrictions.end()) {
// merge the parent region into the current region
regions_->MergeRegions(parent_region, region);
// update the restrictions of all other regions to reflect the change in id
for (const auto& r : regions_) {
auto& restrictions = region_restrictions_[r->GetID()];
if (restrictions.find(parent_region->GetID()) != restrictions.end()) {
restrictions.erase(parent_region->GetID());
restrictions.insert(region->GetID());
}
}
merged = true;
}
}
// if the parent wasn't merged, add it as a restriction to the current region
if (!merged)
region_restrictions.insert(parent_region->GetID());
}
merged_regions_.insert(region->GetID());
}
ExprVisitor::VisitExpr_(call);
}
private:
AnnotatedRegionSet regions_;
std::unordered_set<int> merged_regions_;
std::map<int, std::unordered_set<int>> region_restrictions_;
std::map<int, std::string> region_targets_;
};
Expr MergeCompilerRegions(const Expr& expr) {
// Annotate all the nodes that aren't annotated as 'default'.
AnnotateRestDefault anno_default(expr);
auto expr_all_annotated = anno_default.Annotate(expr);
// Create regions using the annotations.
AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr_all_annotated,
compiler_begin_op, compiler_end_op);
// By now, all the nodes have some sort of annotation.
// Region merger is an ExprVisitor that will update the
// AnnotatedRegionSet, merging all the regions that can be merged.
RegionMerger merger(regions);
merger.VisitExpr(expr_all_annotated);
// This updates the expression to remove annotations that are now
// 'internal' to a merged region.
MergeAnnotations merge_anno(regions);
return merge_anno.Mutate(expr_all_annotated);
}
} // namespace partitioning
namespace transform {
Pass MergeCompilerRegions() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(partitioning::MergeCompilerRegions(f));
};
auto partitioned = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {});
return Sequential({partitioned, InferType()});
}
TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions")
.set_body_typed(transform::MergeCompilerRegions);
} // namespace transform
} // namespace relay
} // namespace tvm
...@@ -22,6 +22,7 @@ import pytest ...@@ -22,6 +22,7 @@ import pytest
import tvm import tvm
import tvm.relay.testing import tvm.relay.testing
import tvm.relay.op as reg
import tvm.relay.transform as transform import tvm.relay.transform as transform
from tvm import relay from tvm import relay
from tvm import runtime from tvm import runtime
...@@ -183,6 +184,41 @@ def test_extern_dnnl_mobilenet(): ...@@ -183,6 +184,41 @@ def test_extern_dnnl_mobilenet():
(1, 1000), ref_res.asnumpy(), tol=1e-5, params=params) (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
@reg.register("nn.relu", "target.test")
def relu(attrs, args):
return True
def test_multiple_ends():
def before():
x = relay.var("x", shape=(10, 10))
r = relay.nn.relu(x)
a_1 = relay.abs(r)
a_2 = relay.abs(r)
out = relay.add(a_1, a_2)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod
def after():
x = relay.var("x", shape=(10, 10))
cb_1 = relay.annotation.compiler_begin(x, "test")
r = relay.nn.relu(cb_1)
ce_1 = relay.annotation.compiler_end(r, "test")
ce_2 = relay.annotation.compiler_end(r, "test")
a_1 = relay.abs(ce_1)
a_2 = relay.abs(ce_2)
out = relay.add(a_1, a_2)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod
result = transform.AnnotateTarget("test")(before())
expected = transform.InferType()(after())
assert relay.analysis.alpha_equal(expected, result)
if __name__ == "__main__": if __name__ == "__main__":
test_multiple_ends()
test_extern_dnnl() test_extern_dnnl()
test_extern_dnnl_mobilenet() test_extern_dnnl_mobilenet()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for merge compiler regions."""
import tvm
from tvm import relay
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.testing import run_opt_pass
def test_diamond_graph_fanouts():
"""
This tests that the data dependencies present in a diamond-shaped
graph are correctly resolved by the merging pass.
O = supported by target
X = not supported by target
O O
/ \ / \
O X --> O + + X
\ / \ /
O O
Note that we can't just merge the three supported operators together,
otherwise both subgraphs would depend on the other.
"""
def diamond_graph_fanouts():
data = relay.var('data', shape=(10, 10))
cb_1 = compiler_begin(data, "test")
O_1 = relay.abs(cb_1)
ce_1 = compiler_end(O_1, "test")
ce_2 = compiler_end(O_1, "test")
cb_2 = compiler_begin(ce_1, "test")
O_2 = relay.nn.relu(cb_2)
ce_3 = compiler_end(O_2, "test")
X = relay.tanh(ce_2)
cb_3 = compiler_begin(ce_3, "test")
cb_4 = compiler_begin(X, "test")
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, "test")
diamond = relay.Function([data], ce_4)
return diamond
def expected():
data = relay.var('data', shape=(10, 10))
cb_1 = compiler_begin(data, "test")
O_1 = relay.abs(cb_1)
ce_2 = compiler_end(O_1, "test")
O_2 = relay.nn.relu(O_1)
ce_3 = compiler_end(O_2, "test")
cb_x = compiler_begin(ce_2, "default")
X = relay.tanh(cb_x)
ce_x1 = compiler_end(X, "default")
ce_x2 = compiler_end(X, "default")
cb_3 = compiler_begin(ce_3, "test")
cb_4 = compiler_begin(ce_x1, "test")
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, "test")
func = relay.Function([data], ce_4)
return func
result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions())
golden = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, golden)
def test_example_graph():
"""This tests the merging algorithm on the example used in the RFC.
See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830
Blue nodes are adds, red nodes are subtracts.
"""
def annotated():
in_1 = relay.var('in_1', shape=(10, 10), dtype='float32')
in_2 = relay.var('in_2', shape=(10, 10), dtype='float32')
in_3 = relay.var('in_3', shape=(10, 10), dtype='float32')
in_4 = relay.var('in_4', shape=(10, 10), dtype='float32')
in_5 = relay.var('in_5', shape=(10, 10), dtype='float32')
in_6 = relay.var('in_6', shape=(10, 10), dtype='float32')
in_7 = relay.var('in_7', shape=(10, 10), dtype='float32')
in_8 = relay.var('in_8', shape=(10, 10), dtype='float32')
in_9 = relay.var('in_9', shape=(10, 10), dtype='float32')
in_10 = relay.var('in_10', shape=(10, 10), dtype='float32')
begin0 = compiler_begin(in_1, "test")
begin1 = compiler_begin(in_2, "test")
begin2 = compiler_begin(in_3, "test")
begin3 = compiler_begin(in_4, "test")
node0 = relay.add(begin0, begin1)
node1 = relay.add(begin2, begin3)
end0 = compiler_end(node0, "test")
end1 = compiler_end(node1, "test")
begin4 = compiler_begin(end0, "test")
begin5 = compiler_begin(end1, "test")
node2 = relay.add(begin4, begin5)
end2 = compiler_end(node2, "test")
node3 = relay.subtract(in_5, in_6)
node4 = relay.subtract(in_7, node3)
begin6 = compiler_begin(end2, "test")
begin7 = compiler_begin(node4, "test")
node5 = relay.add(begin6, begin7)
end3 = compiler_end(node5, "test")
end4 = compiler_end(node5, "test")
node6 = relay.subtract(in_8, end3)
begin8 = compiler_begin(in_9, "test")
begin9 = compiler_begin(end4, "test")
node7 = relay.add(begin8, begin9)
end5 = compiler_end(node7, "test")
begin10 = compiler_begin(node6, "test")
begin11 = compiler_begin(end5, "test")
node8 = relay.add(begin10, begin11)
end6 = compiler_end(node8, "test")
begin12 = compiler_begin(in_10, "test")
begin13 = compiler_begin(end6, "test")
node9 = relay.add(begin12, begin13)
end7 = compiler_end(node9, "test")
f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end7)
mod = tvm.IRModule.from_expr(f)
return mod
def expected():
in_1 = relay.var('in_1', shape=(10, 10), dtype='float32')
in_2 = relay.var('in_2', shape=(10, 10), dtype='float32')
in_3 = relay.var('in_3', shape=(10, 10), dtype='float32')
in_4 = relay.var('in_4', shape=(10, 10), dtype='float32')
in_5 = relay.var('in_5', shape=(10, 10), dtype='float32')
in_6 = relay.var('in_6', shape=(10, 10), dtype='float32')
in_7 = relay.var('in_7', shape=(10, 10), dtype='float32')
in_8 = relay.var('in_8', shape=(10, 10), dtype='float32')
in_9 = relay.var('in_9', shape=(10, 10), dtype='float32')
in_10 = relay.var('in_10', shape=(10, 10), dtype='float32')
begin0 = compiler_begin(in_1, "test")
begin1 = compiler_begin(in_2, "test")
begin2 = compiler_begin(in_3, "test")
begin3 = compiler_begin(in_4, "test")
node0 = relay.add(begin0, begin1)
node1 = relay.add(begin2, begin3)
node2 = relay.add(node0, node1)
begin4 = compiler_begin(in_5, "default")
begin5 = compiler_begin(in_6, "default")
begin6 = compiler_begin(in_7, "default")
node3 = relay.subtract(begin4, begin5)
node4 = relay.subtract(begin6, node3)
end0 = compiler_end(node4, "default")
begin7 = compiler_begin(end0, "test")
begin8 = compiler_begin(in_9, "test")
node5 = relay.add(node2, begin7)
end1 = compiler_end(node5, "test")
begin9 = compiler_begin(end1, "default")
begin10 = compiler_begin(in_8, "default")
node6 = relay.subtract(begin10, begin9)
end2 = compiler_end(node6, "default")
node7 = relay.add(begin8, node5)
end3 = compiler_end(node7, "test")
begin11 = compiler_begin(end3, "test")
begin12 = compiler_begin(end2, "test")
node8 = relay.add(begin12, begin11)
begin13 = compiler_begin(in_10, "test")
node9 = relay.add(begin13, node8)
end4 = compiler_end(node9, "test")
f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end4)
mod = tvm.IRModule.from_expr(f)
return mod
mod = annotated()
mod = relay.transform.MergeCompilerRegions()(mod)
ref_mod = expected()
assert relay.analysis.alpha_equal(mod, ref_mod)
if __name__ == "__main__":
test_diamond_graph_fanouts()
test_example_graph()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment