Unverified Commit 02121383 by mbaret Committed by GitHub

[RELAY] Add MergeCompilerRegions pass (#5134)

* [RELAY] Add MergeCompilerRegions pass

This pass is part of the flow to support creating compiler
regions with multiple outputs. It should be called after
AnnotateTarget and will merge together regions that share
the same target to create larger compiler regions that can
be off-loaded to external codegens.

This pass implements an algorithm to ensure that during the
merging, no data dependency issues are created. See the tests
for an example of this case.

Co-authored-by: Ramana Radhakrishnan  <ramana.radhakrishnan@arm.com>
Co-authored-by: Manupa Karunaratne    <manupa.karunaratne@arm.com>

Change-Id: Ibd99083564608d888482f57c5080109f3eefec88

* [RELAY] Annotate compiler_ends on each edge

This alters the behaviour of the AnnotateTarget
pass to enforce the property that all compiler
annotations exist along a single data flow edge.
Specifically, this means they should have exactly
one parent and one child.

Change-Id: I0e74803a77767f4f377d17755a13a74a30909797

* Fix comment

* Rebase *Node::make

* Moved block outside for loop

* Code style

* Update make API

* Remove comment

* Remove redundant 'else's

* Make one line

* Fix comment

* RefWrite

* Fix merge ordering

* Add the RFC example as a test

* [FIX] Fixed merging behaviour in AnnotateRegionSet

Deleting items from a list while iterating it seems to
result in undefined behaviour which sometimes segfaults.
This makes sure all the item deletion happens separately.

* Added checks

* Move comment

* Update comments
parent 33260318
...@@ -397,6 +397,17 @@ def MergeComposite(pattern_table): ...@@ -397,6 +397,17 @@ def MergeComposite(pattern_table):
return _ffi_api.MergeComposite(pattern_names, patterns) return _ffi_api.MergeComposite(pattern_names, patterns)
def MergeCompilerRegions():
"""Merge together compiler regions.
Returns
-------
ret : tvm.relay.Pass
The registered pass that merges compiler regions.
"""
return _ffi_api.MergeCompilerRegions()
def RewriteAnnotatedOps(fallback_device): def RewriteAnnotatedOps(fallback_device):
"""Rewrite the annotated program where annotation operators, e.g. """Rewrite the annotated program where annotation operators, e.g.
`on_deivce`, mark which device an expression should be scheduled to. `on_deivce`, mark which device an expression should be scheduled to.
......
...@@ -55,14 +55,18 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, ...@@ -55,14 +55,18 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
} }
// if any of the outputs of src are inputs of dest, they become internal nodes // if any of the outputs of src are inputs of dest, they become internal nodes
// so remove them from outs // so remove them from outs
std::vector<Expr> ins_to_remove;
for (const auto& input : dest->ins) { for (const auto& input : dest->ins) {
auto call = Downcast<Call>(input); auto call = Downcast<Call>(input);
auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]); auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]);
if (it != src->outs.end()) { if (it != src->outs.end()) {
dest->outs.remove(*it); dest->outs.remove(*it);
dest->ins.remove(input); ins_to_remove.push_back(input);
} }
} }
for (const auto& input : ins_to_remove) {
dest->ins.remove(input);
}
regions_.erase(src); regions_.erase(src);
} }
......
...@@ -38,38 +38,136 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -38,38 +38,136 @@ class AnnotateTargetWrapper : public ExprMutator {
public: public:
explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {} explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {}
Expr Annotate(const Expr& expr) {
return InsertEnd(Mutate(expr));
}
bool IsSupported(const Expr& expr) {
if (expr->IsInstance<CallNode>()) {
Call call = Downcast<Call>(expr);
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op);
CHECK(op.defined());
if (fannotate.count(op)) {
return fannotate[op](call->attrs, call->args);
}
}
return false;
}
Expr InsertEnd(const Expr& arg) {
if (IsSupported(arg)) {
const auto *end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op);
Expr end = (*end_op)(arg, target_);
return end;
}
return arg;
}
Expr VisitExpr_(const CallNode* cn) { Expr VisitExpr_(const CallNode* cn) {
// TODO(@zhiics, @comaniac) Handle composite functions. // TODO(@zhiics, @comaniac) Handle composite functions.
auto new_e = ExprMutator::VisitExpr_(cn); auto new_e = ExprMutator::VisitExpr_(cn);
Call call = Downcast<Call>(new_e); Call call = Downcast<Call>(new_e);
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op); // add end annotations if the args are supported
CHECK(op.defined()); Array<Expr> compiler_ends;
for (const auto& it : call->args) {
if (fannotate.count(op)) { compiler_ends.push_back(InsertEnd(it));
bool external = fannotate[op](call->attrs, call->args); }
if (external) { call = Call(call->op, compiler_ends, call->attrs);
tvm::Array<tvm::relay::Expr> compiler_begins;
for (const auto& it : call->args) { // add begin annotations if the call node is supported
const auto* begin_op = if (IsSupported(call)) {
runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); tvm::Array<tvm::relay::Expr> compiler_begins;
CHECK(begin_op); const auto* begin_op =
Expr begin = (*begin_op)(it, target_); runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
compiler_begins.push_back(begin); for (const auto& it : call->args) {
} CHECK(begin_op);
Expr update_call = Call(call->op, compiler_begins, call->attrs); Expr begin = (*begin_op)(it, target_);
const auto* end_op = compiler_begins.push_back(begin);
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op);
Expr end = (*end_op)(update_call, target_);
return end;
} }
} else { call = Call(call->op, compiler_begins, call->attrs);
LOG(WARNING) << op->name << " in " << target_
<< " is not registered. It will be executed on CPU.";
} }
return new_e;
return std::move(call);
}
Expr VisitExpr_(const TupleNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto tup = Downcast<Tuple>(new_e);
Array<Expr> new_fields;
for (auto field : tup->fields) {
new_fields.push_back(InsertEnd(field));
}
return Tuple(new_fields);
}
Expr VisitExpr_(const TupleGetItemNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto get = Downcast<TupleGetItem>(new_e);
return TupleGetItem(
InsertEnd(get->tuple),
get->index);
}
Expr VisitExpr_(const FunctionNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto func = Downcast<Function>(new_e);
return Function(
func->params,
InsertEnd(func->body),
func->ret_type,
func->type_params,
func->attrs);
}
Expr VisitExpr_(const LetNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto let = Downcast<Let>(new_e);
return Let(
let->var,
InsertEnd(let->value),
InsertEnd(let->body));
}
Expr VisitExpr_(const IfNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto iff = Downcast<If>(new_e);
return If(
InsertEnd(iff->cond),
InsertEnd(iff->true_branch),
InsertEnd(iff->false_branch));
}
Expr VisitExpr_(const RefCreateNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto create = Downcast<RefCreate>(new_e);
return RefCreate(InsertEnd(create->value));
}
Expr VisitExpr_(const RefReadNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto read = Downcast<RefRead>(new_e);
return RefRead(InsertEnd(read->ref));
}
Expr VisitExpr_(const RefWriteNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto write = Downcast<RefWrite>(new_e);
return RefWrite(
InsertEnd(write->ref),
InsertEnd(write->value));
} }
private: private:
...@@ -77,7 +175,7 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -77,7 +175,7 @@ class AnnotateTargetWrapper : public ExprMutator {
}; };
Expr AnnotateTarget(const Expr& expr, const std::string& target) { Expr AnnotateTarget(const Expr& expr, const std::string& target) {
return AnnotateTargetWrapper(target).Mutate(expr); return AnnotateTargetWrapper(target).Annotate(expr);
} }
} // namespace annotate_target } // namespace annotate_target
......
...@@ -22,6 +22,7 @@ import pytest ...@@ -22,6 +22,7 @@ import pytest
import tvm import tvm
import tvm.relay.testing import tvm.relay.testing
import tvm.relay.op as reg
import tvm.relay.transform as transform import tvm.relay.transform as transform
from tvm import relay from tvm import relay
from tvm import runtime from tvm import runtime
...@@ -183,6 +184,41 @@ def test_extern_dnnl_mobilenet(): ...@@ -183,6 +184,41 @@ def test_extern_dnnl_mobilenet():
(1, 1000), ref_res.asnumpy(), tol=1e-5, params=params) (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
@reg.register("nn.relu", "target.test")
def relu(attrs, args):
return True
def test_multiple_ends():
def before():
x = relay.var("x", shape=(10, 10))
r = relay.nn.relu(x)
a_1 = relay.abs(r)
a_2 = relay.abs(r)
out = relay.add(a_1, a_2)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod
def after():
x = relay.var("x", shape=(10, 10))
cb_1 = relay.annotation.compiler_begin(x, "test")
r = relay.nn.relu(cb_1)
ce_1 = relay.annotation.compiler_end(r, "test")
ce_2 = relay.annotation.compiler_end(r, "test")
a_1 = relay.abs(ce_1)
a_2 = relay.abs(ce_2)
out = relay.add(a_1, a_2)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod
result = transform.AnnotateTarget("test")(before())
expected = transform.InferType()(after())
assert relay.analysis.alpha_equal(expected, result)
if __name__ == "__main__": if __name__ == "__main__":
test_multiple_ends()
test_extern_dnnl() test_extern_dnnl()
test_extern_dnnl_mobilenet() test_extern_dnnl_mobilenet()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for merge compiler regions."""
import tvm
from tvm import relay
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.testing import run_opt_pass
def test_diamond_graph_fanouts():
"""
This tests that the data dependencies present in a diamond-shaped
graph are correctly resolved by the merging pass.
O = supported by target
X = not supported by target
O O
/ \ / \
O X --> O + + X
\ / \ /
O O
Note that we can't just merge the three supported operators together,
otherwise both subgraphs would depend on the other.
"""
def diamond_graph_fanouts():
data = relay.var('data', shape=(10, 10))
cb_1 = compiler_begin(data, "test")
O_1 = relay.abs(cb_1)
ce_1 = compiler_end(O_1, "test")
ce_2 = compiler_end(O_1, "test")
cb_2 = compiler_begin(ce_1, "test")
O_2 = relay.nn.relu(cb_2)
ce_3 = compiler_end(O_2, "test")
X = relay.tanh(ce_2)
cb_3 = compiler_begin(ce_3, "test")
cb_4 = compiler_begin(X, "test")
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, "test")
diamond = relay.Function([data], ce_4)
return diamond
def expected():
data = relay.var('data', shape=(10, 10))
cb_1 = compiler_begin(data, "test")
O_1 = relay.abs(cb_1)
ce_2 = compiler_end(O_1, "test")
O_2 = relay.nn.relu(O_1)
ce_3 = compiler_end(O_2, "test")
cb_x = compiler_begin(ce_2, "default")
X = relay.tanh(cb_x)
ce_x1 = compiler_end(X, "default")
ce_x2 = compiler_end(X, "default")
cb_3 = compiler_begin(ce_3, "test")
cb_4 = compiler_begin(ce_x1, "test")
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, "test")
func = relay.Function([data], ce_4)
return func
result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions())
golden = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, golden)
def test_example_graph():
"""This tests the merging algorithm on the example used in the RFC.
See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830
Blue nodes are adds, red nodes are subtracts.
"""
def annotated():
in_1 = relay.var('in_1', shape=(10, 10), dtype='float32')
in_2 = relay.var('in_2', shape=(10, 10), dtype='float32')
in_3 = relay.var('in_3', shape=(10, 10), dtype='float32')
in_4 = relay.var('in_4', shape=(10, 10), dtype='float32')
in_5 = relay.var('in_5', shape=(10, 10), dtype='float32')
in_6 = relay.var('in_6', shape=(10, 10), dtype='float32')
in_7 = relay.var('in_7', shape=(10, 10), dtype='float32')
in_8 = relay.var('in_8', shape=(10, 10), dtype='float32')
in_9 = relay.var('in_9', shape=(10, 10), dtype='float32')
in_10 = relay.var('in_10', shape=(10, 10), dtype='float32')
begin0 = compiler_begin(in_1, "test")
begin1 = compiler_begin(in_2, "test")
begin2 = compiler_begin(in_3, "test")
begin3 = compiler_begin(in_4, "test")
node0 = relay.add(begin0, begin1)
node1 = relay.add(begin2, begin3)
end0 = compiler_end(node0, "test")
end1 = compiler_end(node1, "test")
begin4 = compiler_begin(end0, "test")
begin5 = compiler_begin(end1, "test")
node2 = relay.add(begin4, begin5)
end2 = compiler_end(node2, "test")
node3 = relay.subtract(in_5, in_6)
node4 = relay.subtract(in_7, node3)
begin6 = compiler_begin(end2, "test")
begin7 = compiler_begin(node4, "test")
node5 = relay.add(begin6, begin7)
end3 = compiler_end(node5, "test")
end4 = compiler_end(node5, "test")
node6 = relay.subtract(in_8, end3)
begin8 = compiler_begin(in_9, "test")
begin9 = compiler_begin(end4, "test")
node7 = relay.add(begin8, begin9)
end5 = compiler_end(node7, "test")
begin10 = compiler_begin(node6, "test")
begin11 = compiler_begin(end5, "test")
node8 = relay.add(begin10, begin11)
end6 = compiler_end(node8, "test")
begin12 = compiler_begin(in_10, "test")
begin13 = compiler_begin(end6, "test")
node9 = relay.add(begin12, begin13)
end7 = compiler_end(node9, "test")
f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end7)
mod = tvm.IRModule.from_expr(f)
return mod
def expected():
in_1 = relay.var('in_1', shape=(10, 10), dtype='float32')
in_2 = relay.var('in_2', shape=(10, 10), dtype='float32')
in_3 = relay.var('in_3', shape=(10, 10), dtype='float32')
in_4 = relay.var('in_4', shape=(10, 10), dtype='float32')
in_5 = relay.var('in_5', shape=(10, 10), dtype='float32')
in_6 = relay.var('in_6', shape=(10, 10), dtype='float32')
in_7 = relay.var('in_7', shape=(10, 10), dtype='float32')
in_8 = relay.var('in_8', shape=(10, 10), dtype='float32')
in_9 = relay.var('in_9', shape=(10, 10), dtype='float32')
in_10 = relay.var('in_10', shape=(10, 10), dtype='float32')
begin0 = compiler_begin(in_1, "test")
begin1 = compiler_begin(in_2, "test")
begin2 = compiler_begin(in_3, "test")
begin3 = compiler_begin(in_4, "test")
node0 = relay.add(begin0, begin1)
node1 = relay.add(begin2, begin3)
node2 = relay.add(node0, node1)
begin4 = compiler_begin(in_5, "default")
begin5 = compiler_begin(in_6, "default")
begin6 = compiler_begin(in_7, "default")
node3 = relay.subtract(begin4, begin5)
node4 = relay.subtract(begin6, node3)
end0 = compiler_end(node4, "default")
begin7 = compiler_begin(end0, "test")
begin8 = compiler_begin(in_9, "test")
node5 = relay.add(node2, begin7)
end1 = compiler_end(node5, "test")
begin9 = compiler_begin(end1, "default")
begin10 = compiler_begin(in_8, "default")
node6 = relay.subtract(begin10, begin9)
end2 = compiler_end(node6, "default")
node7 = relay.add(begin8, node5)
end3 = compiler_end(node7, "test")
begin11 = compiler_begin(end3, "test")
begin12 = compiler_begin(end2, "test")
node8 = relay.add(begin12, begin11)
begin13 = compiler_begin(in_10, "test")
node9 = relay.add(begin13, node8)
end4 = compiler_end(node9, "test")
f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end4)
mod = tvm.IRModule.from_expr(f)
return mod
mod = annotated()
mod = relay.transform.MergeCompilerRegions()(mod)
ref_mod = expected()
assert relay.analysis.alpha_equal(mod, ref_mod)
if __name__ == "__main__":
test_diamond_graph_fanouts()
test_example_graph()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment