Unverified Commit b5ec0711 by mbaret Committed by GitHub

[RELAY] Added a AnnotatedRegion utility class (#5030)

* [RELAY] Added an AnnotatedRegionSet utility class

In many of the passes involved in graph partitioning,
we need to extract and manipulate annotated regions.
This class simplifies the extraction of regions from a relay
expression containing region begin and end annotations
as well as providing utility functions to query these
regions and merge them.

Co-authored-by: Ramana Radhakrishnan  <ramana.radhakrishnan@arm.com>

Change-Id: Ia912fea0b99f64b6a7197aa6da2347e58f469fbb

* Rename fix

* Update MakeRegions

* Fix __init__

* Indentation

* Code style

* Remove 'Region' from docs

* Overload [] to get region

* Use src/dest for MergeRegions

* Simplify merge

* Tidy const loop vars
parent 314f31b0
...@@ -19,6 +19,9 @@ ...@@ -19,6 +19,9 @@
# Analysis passes # Analysis passes
from .analysis import * from .analysis import *
# Annotations
from .annotated_regions import AnnotatedRegionSet
# Call graph # Call graph
from . import call_graph from . import call_graph
from .call_graph import CallGraph from .call_graph import CallGraph
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Regions used in Relay."""
from tvm.runtime import Object
from . import _ffi_api
class AnnotatedRegionSet(Object):
"""Class to represent a relay expression split into regions."""
def __init__(self, expr, region_begin_op, region_end_op):
"""Construct regions from an expression.
Parameters
----------
expr : tvm.relay.Expr
The expression from which to construct the regions.
region_begin_op : tvm.relay.Op
The region begin annotation.
region_end_op : tvm.relay.Op
The region end annotation.
"""
self.__init_handle_by_constructor__(_ffi_api.AnnotatedRegionSet,
expr,
region_begin_op,
region_end_op)
def __len__(self):
return len(self.regions)
def get_region(self, expr):
"""Get the region an expression belongs to.
Parameters
----------
expr : tvm.relay.Expr
The expression.
Returns
-------
region
The region containing the expression.
None if not found.
"""
return _ffi_api.GetRegion(self, expr)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "annotated_region_set.h"
#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
#include <algorithm>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace relay {
AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const {
for (auto candidate : regions_) {
if (candidate->nodes.find(expr) != candidate->nodes.end()) {
return candidate;
}
}
return AnnotatedRegion(nullptr);
}
void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
AnnotatedRegion dest) {
if (dest == src) {
return;
}
// Merge src to dest and erase src.
dest->nodes.insert(src->nodes.begin(), src->nodes.end());
for (const auto& input : src->ins) {
dest->ins.push_back(input);
}
for (const auto& output : src->outs) {
dest->outs.push_back(output);
}
// if any of the outputs of src are inputs of dest, they become internal nodes
// so remove them from outs
for (const auto& input : dest->ins) {
auto call = Downcast<Call>(input);
auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]);
if (it != src->outs.end()) {
dest->outs.remove(*it);
dest->ins.remove(input);
}
}
regions_.erase(src);
}
void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion region, const Expr& expr) {
auto region2 = GetRegion(expr);
if (region2.defined()) {
MergeRegions(region, region2);
} else {
region->nodes.insert(expr);
}
}
AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() {
auto ret = regions_.emplace(AnnotatedRegion());
(*ret.first)->id = region_id_++;
return *ret.first;
}
class AnnotatedRegionSet::Creator : public ExprVisitor {
public:
Creator(const Op& region_begin_op, const Op& region_end_op) :
begin_op_(region_begin_op), end_op_(region_end_op) {}
AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}
void VisitExpr_(const CallNode* call) {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
// Propagate region to arguments
auto region = region_set_->GetRegion(GetRef<Call>(call));
if (region.defined()) {
for (auto arg : call->args) {
region_set_->AddToRegion(region, arg);
}
}
} else if (call->op == begin_op_) {
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);
auto region = region_set_->GetRegion(GetRef<Call>(call));
if (!region.defined()) {
throw Error(ErrorBuilder()
<< "Cannot find the corresponding region for start annotation:\n"
<< AsText(GetRef<Call>(call), false));
}
region->ins.push_back(GetRef<Call>(call));
} else {
CHECK_EQ(call->op, end_op_);
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);
// Check if the argument already belongs to a region
auto region = region_set_->GetRegion(call->args[0]);
if (!region.defined()) {
region = region_set_->MakeRegion();
region->nodes.insert(call->args[0]);
}
region->nodes.insert(GetRef<Call>(call));
region->outs.push_back(GetRef<Call>(call));
}
ExprVisitor::VisitExpr_(call);
}
void VisitExpr_(const TupleNode* op) {
auto region = region_set_->GetRegion(GetRef<Tuple>(op));
if (region.defined()) {
for (auto field : op->fields) {
region_set_->AddToRegion(region, field);
}
}
ExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const TupleGetItemNode* g) {
auto region = region_set_->GetRegion(GetRef<TupleGetItem>(g));
if (region.defined()) {
region_set_->AddToRegion(region, g->tuple);
}
ExprVisitor::VisitExpr_(g);
}
void VisitExpr_(const FunctionNode* op) {
auto region = region_set_->GetRegion(GetRef<Function>(op));
if (region.defined()) {
for (auto param : op->params) {
region_set_->AddToRegion(region, param);
}
}
ExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const LetNode* op) {
auto region = region_set_->GetRegion(GetRef<Let>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->var);
region_set_->AddToRegion(region, op->value);
region_set_->AddToRegion(region, op->body);
}
ExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const IfNode* op) {
auto region = region_set_->GetRegion(GetRef<If>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->cond);
region_set_->AddToRegion(region, op->true_branch);
region_set_->AddToRegion(region, op->false_branch);
}
ExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const RefCreateNode* op) {
auto region = region_set_->GetRegion(GetRef<RefCreate>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->value);
}
ExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const RefReadNode* op) {
auto region = region_set_->GetRegion(GetRef<RefRead>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->ref);
}
ExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const RefWriteNode* op) {
auto region = region_set_->GetRegion(GetRef<RefWrite>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->ref);
}
ExprVisitor::VisitExpr_(op);
}
private:
/*! \brief The region set being constructed.*/
AnnotatedRegionSet region_set_;
/*! \brief Region 'begin' annotation operator. */
const Op begin_op_;
/*! \brief Region 'end' annotation operator. */
const Op end_op_;
};
AnnotatedRegionSet AnnotatedRegionSet::Create(const Expr& expr, const Op& begin, const Op& end) {
return Creator(begin, end).Create(expr);
}
TVM_REGISTER_NODE_TYPE(AnnotatedRegionNode);
TVM_REGISTER_NODE_TYPE(AnnotatedRegionSetNode);
TVM_REGISTER_GLOBAL("relay.analysis.AnnotatedRegionSet")
.set_body_typed([](Expr expr, Op begin, Op end) {
return AnnotatedRegionSet::Create(expr, begin, end);
});
TVM_REGISTER_GLOBAL("relay.analysis.GetRegion")
.set_body_typed([](AnnotatedRegionSet region_set, Expr expr) {
return region_set->GetRegion(expr);
});
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/pass/annotated_region_set.h
* \brief Define data structures to extract and manipulate regions from
* a relay function. Regions are denoted by region_begin and region_end
* annotations that exist on all the input and output edges of the region.
*/
#ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
#define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include <list>
namespace tvm {
namespace relay {
class AnnotatedRegion;
class AnnotatedRegionSet;
class AnnotatedRegionNode : public Object {
public:
void VisitAttrs(AttrVisitor* v) {
v->Visit("id", &id);
Array<Expr> nodes_array(nodes.begin(), nodes.end());
v->Visit("nodes", &nodes_array);
Array<Expr> args_array(ins.begin(), ins.end());
v->Visit("args", &args_array);
Array<Expr> rets_array(outs.begin(), outs.end());
v->Visit("rets", &rets_array);
}
/*! \brief Get the region ID. */
int GetID() const {
return id;
}
/*! \brief Get the region's inputs. */
std::list<Expr> GetInputs() const {
return ins;
}
/*! \brief Get the region's outputs. */
std::list<Expr> GetOutputs() const {
return outs;
}
/*! \brief Get the region's nodes. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const {
return nodes;
}
static constexpr const char* _type_key = "relay.AnnotatedRegion";
TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object);
protected:
/*! \brief The region ID. */
int id{-1};
/*! \brief The inputs to this region. */
std::list<Expr> ins;
/*! \brief The outputs of this region */
std::list<Expr> outs;
/*! \brief Nodes in this region. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
friend class AnnotatedRegionSet;
friend class AnnotatedRegionSetNode;
};
/*!
* \brief An object to hold the properties of a region as used by the
* AnnotatedRegionSet class. This should be considered read-only.
*/
class AnnotatedRegion : public ObjectRef {
public:
AnnotatedRegion() {
auto n = make_object<AnnotatedRegionNode>();
data_ = std::move(n);
}
/*!
* \brief Construct from an object pointer.
* \param n The object pointer.
*/
explicit AnnotatedRegion(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return Mutable pointers to the node. */
AnnotatedRegionNode* operator->() const {
auto* ptr = get_mutable();
CHECK(ptr != nullptr);
return static_cast<AnnotatedRegionNode*>(ptr);
}
};
class AnnotatedRegionSetNode : public Object {
using UnorderedRegionSet =
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual>;
// Create iterator alias for a RegionSet object.
using iterator = UnorderedRegionSet::iterator;
using const_iterator = UnorderedRegionSet::const_iterator;
public:
/*! \brief Default constructor. */
AnnotatedRegionSetNode() = default;
/*! \return The begin iterator */
iterator begin() {
return regions_.begin();
}
/*! \return The end iterator */
iterator end() {
return regions_.end();
}
/*! \return The const begin iterator */
const_iterator begin() const {
return regions_.begin();
}
/*! \return The const end iterator */
const_iterator end() const {
return regions_.end();
}
/*!
* \brief Get the region that an expression belongs to.
*
* \param expr Which expr to get the region for.
*
* \return A pointer to the region, nullptr if the expression
* doesn't belong to a region.
*/
AnnotatedRegion GetRegion(const Expr& expr) const;
/*!
* \brief Merge src region into dest region.
*
* \param src The region to merge - will be erased.
* \param dest The region into which src will be merged.
*/
void MergeRegions(AnnotatedRegion src, AnnotatedRegion dest);
void VisitAttrs(AttrVisitor* v) {
Array<AnnotatedRegion> regions_array(regions_.begin(), regions_.end());
v->Visit("regions", &regions_array);
}
static constexpr const char* _type_key = "relay.AnnotatedRegionSet";
TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionSetNode, Object);
private:
/*!
* \brief Add an expression to a region.
*
* \param region The region to add the expression to.
* \param expr The expression.
*/
void AddToRegion(AnnotatedRegion region, const Expr& expr);
/*!
* \brief Make a new region.
*
* \return The new region.
*/
AnnotatedRegion MakeRegion();
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_;
/*! \brief The next region ID to assign. */
int region_id_{0};
friend class AnnotatedRegionSet;
};
/*!
* \brief A class to hold a set of regions produced from a relay expression
* that contains 'region_begin' and 'region_end' style annotations. The
* regions should be disjoint. The class provides both a method to construct
* the region set of a given relay expression as well as additional methods
* to update and query regions.
*/
class AnnotatedRegionSet : public ObjectRef {
using UnorderedRegionSet =
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual>;
// Create iterator alias for a RegionSet object.
using iterator = UnorderedRegionSet::iterator;
using const_iterator = UnorderedRegionSet::const_iterator;
public:
AnnotatedRegionSet() {
auto n = make_object<AnnotatedRegionSetNode>();
data_ = std::move(n);
}
/*!
* \brief Construct from an object pointer.
*
* \param n The object pointer.
*/
explicit AnnotatedRegionSet(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return The begin iterator. */
iterator begin() {
auto* n = operator->();
CHECK(n);
return n->begin();
}
/*! \return The end iterator. */
iterator end() {
auto* n = operator->();
CHECK(n);
return n->end();
}
/*! \return The begin iterator. */
const_iterator begin() const {
const auto* n = operator->();
CHECK(n);
return n->begin();
}
/*! \return The end iterator. */
const_iterator end() const {
const auto *n = operator->();
CHECK(n);
return n->end();
}
/*! \return mutable pointers to the node. */
AnnotatedRegionSetNode* operator->() const {
auto* ptr = get_mutable();
CHECK(ptr != nullptr);
return static_cast<AnnotatedRegionSetNode*>(ptr);
}
/*! \return The region an expression belongs to. */
AnnotatedRegion operator[](const Expr& expr) {
const auto *n = operator->();
CHECK(n);
return n->GetRegion(expr);
}
/*! \brief Create a RegionSet from a relay expression.
*
* \param expr The relay expr from which to construct the set.
* \param begin Region begin annotation operator.
* \param end Region end annotation operator.
*
* \return The created RegionSet for the expression.
*/
static AnnotatedRegionSet Create(const Expr& expr,
const Op& begin,
const Op& end);
private:
/*! \brief Helper class to construct a RegionSet from an expr.*/
class Creator;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
from tvm import relay
from tvm.relay.op.annotation import compiler_begin, compiler_end
def check_region(region_set, args, nodes, rets):
region = region_set.get_region(args[0])
assert region
assert set(args) == set(region.args)
assert set(nodes) == set(region.nodes)
assert set(rets) == set(region.rets)
def test_region_set_creator_diamond():
data = relay.var('data', shape=(10, 10))
cb_1 = compiler_begin(data, 'test_target')
O_1 = relay.abs(cb_1)
ce_1 = compiler_end(O_1, 'test_target')
ce_2 = compiler_end(O_1, 'test_target')
cb_2 = compiler_begin(ce_1, 'test_target')
O_2 = relay.nn.relu(cb_2)
ce_3 = compiler_end(O_2, 'test_target')
cb_d = compiler_begin(ce_2, "default")
X = relay.tanh(cb_d)
ce_d = compiler_end(X, 'default')
cb_3 = compiler_begin(ce_3, 'test_target')
cb_4 = compiler_begin(ce_d, 'test_target')
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, 'test_target')
diamond = relay.Function([data], ce_4)
region_set = relay.analysis.AnnotatedRegionSet(diamond,
relay.op.get("annotation.compiler_begin"),
relay.op.get("annotation.compiler_end"))
assert len(region_set) == 4
check_region(
region_set,
[cb_1],
[cb_1, O_1, ce_1, ce_2],
[ce_1, ce_2],
)
check_region(
region_set,
[cb_2],
[cb_2, O_2, ce_3],
[ce_3],
)
check_region(
region_set,
[cb_d],
[cb_d, X, ce_d],
[ce_d],
)
check_region(
region_set,
[cb_3, cb_4],
[cb_3, cb_4, O_3, ce_4],
[ce_4],
)
def test_region_set_creator_merged():
data = relay.var('data', shape=(10, 10))
cb_1 = compiler_begin(data, 'test_target')
O_1 = relay.abs(cb_1)
ce_2 = compiler_end(O_1, 'test_target')
O_2 = relay.nn.relu(O_1)
ce_3 = compiler_end(O_2, 'test_target')
cb_d = compiler_begin(ce_2, "default")
X = relay.tanh(cb_d)
ce_d = compiler_end(X, 'default')
cb_3 = compiler_begin(ce_3, 'test_target')
cb_4 = compiler_begin(ce_d, 'test_target')
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, 'test_target')
merged = relay.Function([data], ce_4)
region_set = relay.analysis.AnnotatedRegionSet(merged,
relay.op.get("annotation.compiler_begin"),
relay.op.get("annotation.compiler_end"))
assert len(region_set) == 3
check_region(
region_set,
[cb_1],
[cb_1, O_1, O_2, ce_2, ce_3],
[ce_2, ce_3],
)
check_region(
region_set,
[cb_d],
[cb_d, X, ce_d],
[ce_d],
)
check_region(
region_set,
[cb_3, cb_4],
[cb_3, cb_4, O_3, ce_4],
[ce_4],
)
if __name__ == "__main__":
test_region_set_creator_diamond()
test_region_set_creator_merged()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment