Unverified Commit eba50ad8 by Zhi Committed by GitHub

[Relay][pass] call graph for relay (#4922)

* call graph for relay

* CallGraphEntryNode->CallGraphEntry, __getitem__->print_var

* fix typos
parent 61bea507
......@@ -19,6 +19,7 @@
import os
from sys import setrecursionlimit
from ..api import register_func
from . import call_graph
from . import base
from . import ty
from . import expr
......@@ -141,3 +142,6 @@ Sequential = transform.Sequential
# Feature
Feature = feature.Feature
# CallGraph
CallGraph = call_graph.CallGraph
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Call graph used in Relay."""
from tvm.ir import IRModule
from .base import Object
from .expr import GlobalVar
from . import _analysis
class CallGraph(Object):
"""Class to represent a call graph."""
def __init__(self, module):
"""Construct a call graph.
Parameters
----------
module : tvm.ir.IRModule
The IR module used to create a call graph
Returns
-------
call_graph: CallGraph
A constructed call graph.
"""
self.__init_handle_by_constructor__(_analysis.CallGraph, module)
@property
def module(self):
"""Return the contained Relay IR module.
Parameters
----------
None
Returns
-------
ret : tvm.ir.IRModule
The contained IRModule
"""
return _analysis.GetModule(self)
def ref_count(self, var):
"""Return the number of references to the global var
Parameters
----------
var : Union[String, tvm.relay.GlobalVar]
Returns
-------
ret : int
The number reference to the global var
"""
var = self._get_global_var(var)
return _analysis.GetRefCountGlobalVar(self, var)
def global_call_count(self, var):
"""Return the number of global function calls from a given global var.
Parameters
----------
var : Union[String, tvm.relay.GlobalVar]
Returns
-------
ret : int
The number of global function calls from the given var.
"""
var = self._get_global_var(var)
return _analysis.GetGlobalVarCallCount(self, var)
def is_recursive(self, var):
"""Return if the function corresponding to a var is a recursive
function.
Parameters
----------
var : Union[String, tvm.relay.GlobalVar]
Returns
-------
ret : Boolean
If the function corresponding to var is recurisve.
"""
var = self._get_global_var(var)
return _analysis.IsRecursive(self, var)
def _get_global_var(self, var):
"""Return the global var using a given name or GlobalVar.
Parameters
----------
var : Union[String, tvm.relay.GlobalVar]
Returns
-------
ret : tvm.relay.GlobalVar
The global var.
"""
if isinstance(var, str):
mod = self.module
var = mod.get_global_var(var)
if isinstance(var, GlobalVar):
return var
else:
raise TypeError("var should be either a string or GlobalVar")
def print_var(self, var):
"""Print a call graph of a global function by name or by variable.
Parameters
----------
var: Union[String, tvm.relay.GlobalVar]
The name or global variable.
Returns
-------
ret : String
The call graph represented in string.
"""
var = self._get_global_var(var)
return _analysis.PrintCallGraphGlobalVar(self, var)
def __str__(self):
"""Print the call graph in the topological order."""
return _analysis.PrintCallGraph(self)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/pass/call_graph.cc
* \brief Implementation of APIs to handle the call graph of a Relay module.
*/
#include "call_graph.h"
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/object.h>
#include <algorithm>
#include <memory>
#include <sstream>
#include <unordered_set>
#include <vector>
namespace tvm {
namespace relay {
CallGraph::CallGraph(IRModule module) {
auto n = make_object<CallGraphNode>();
n->module = std::move(module);
auto gvar_funcs = n->module->functions;
for (const auto& it : gvar_funcs) {
if (const auto* fn = it.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
// Add the global function to gradually build up the call graph.
n->AddToCallGraph(it.first, func);
}
}
data_ = std::move(n);
}
void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) {
CHECK(func.defined() && gv.defined());
// Add the current global function as an entry to the call grpah.
CallGraphEntry* cg_node = LookupGlobalVar(gv);
// Only GlobalVar nodes need to be handled in a function. It indicates that
// the global function of a callee is called by the function that is being
// processed. An edge will be added from the current global function, cg_node,
// to the node that contains the found callee GlobalVarNode.
//
// This is the major overhead for constructing a call graph because the
// post-order visitor will visit each AST node of the current function to
// figure out the dependencies between functions.
PostOrderVisit(func, [&](const Expr& expr) {
if (const GlobalVarNode* gvn = expr.as<GlobalVarNode>()) {
auto callee = GetRef<GlobalVar>(gvn);
cg_node->AddCalledGlobal(LookupGlobalVar(callee));
}
});
}
const CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) const {
const_iterator cit = call_graph_.find(gv);
CHECK(cit != call_graph_.end())
<< "GlobalVar " << gv->name_hint << " not found in the call graph!";
return cit->second.get();
}
CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) {
const_iterator cit = call_graph_.find(gv);
CHECK(cit != call_graph_.end())
<< "GlobalVar " << gv->name_hint << " not found in the call graph!";
return cit->second.get();
}
// Query the existence of a GlobalVar in the call graph. It creates an entry if
// there is no such node available.
CallGraphEntry* CallGraphNode::LookupGlobalVar(const GlobalVar& gv) {
CHECK(gv.defined());
// This inserts an element to the call graph if it is not there yet.
auto& call_graph_node = call_graph_[gv];
if (call_graph_node) return call_graph_node.get();
CHECK(module->ContainGlobalVar(gv->name_hint))
<< "GlobalVar " << gv->name_hint << " not found in the current ir module";
// Create the node for the inserted entry.
call_graph_node = std::unique_ptr<CallGraphEntry>(new CallGraphEntry(gv));
return call_graph_node.get();
}
void CallGraphNode::Print(std::ostream& os) const {
// Print the call graph in the topological order.
std::vector<CallGraphEntry*> nodes = TopologicalOrder();
for (const auto* cgn : nodes) {
cgn->Print(os);
}
}
GlobalVar CallGraphNode::RemoveGlobalVarFromModule(CallGraphEntry* cg_node,
bool update_call_graph) {
CHECK(cg_node->empty() || (cg_node->IsRecursive() && cg_node->size() == 1))
<< "Cannot remove global var " << cg_node->GetNameHint()
<< " from call graph, because it still calls "
<< cg_node->size() << " other global functions";
if (update_call_graph) {
// Update the call graph by removing all edges that point to the node
// `cg_node`.
for (auto& it : *this) {
it.second->RemoveAllCallTo(cg_node);
}
}
GlobalVar gv = cg_node->GetGlobalVar();
call_graph_.erase(gv);
// Update the IR module.
module->Remove(gv);
return gv;
}
std::vector<CallGraphEntry*> CallGraphNode::GetEntryGlobals() const {
std::vector<CallGraphEntry*> ret;
// An entry function in Relay is a function that never called by other
// functions or only called by itself.
for (const auto& it : *this) {
if (it.second->GetRefCount() == 0 || it.second->IsRecursiveEntry()) {
ret.push_back(it.second.get());
}
}
return ret;
}
std::vector<CallGraphEntry*> CallGraphNode::TopologicalOrder() const {
std::vector<CallGraphEntry*> ret;
// Collect all entry nodes.
std::vector<CallGraphEntry*> entries = GetEntryGlobals();
CallGraphEntry::CallGraphEntrySet visited;
for (const auto& it : entries) {
// Keep tracking the nodes that have been visited.
auto topo = it->TopologicalOrder(&visited);
// Prepend the collected items. The intermediate nodes that are shared by
// multiple entries are guaranteed to be collected when visiting the
// previous entries. Therefore, topological order remains.
ret.insert(ret.begin(), topo.begin(), topo.end());
}
// Find out the missing global functions if there are any to help debugging.
if (ret.size() != module->functions.size()) {
for (auto it : module->functions) {
if (visited.find((*this)[it.first]) == visited.end()) {
LOG(WARNING) << "Missing global:" << it.first->name_hint
<< " with # refs = " << (*this)[it.first]->GetRefCount();
}
}
LOG(FATAL) << "Expected " << module->functions.size()
<< " globals, but received "
<< ret.size();
}
return ret;
}
// BSF traversal is used to collect the nodes in a CallGraphEntry. The nodes
// that are visited by previous CallGraphEntry entries can be memoized. This
// helps us to make sure no entry will be visited multiple times when collecting
// the nodes for an entire call graph.
std::vector<CallGraphEntry*> CallGraphEntry::TopologicalOrder(
CallGraphEntrySet* visited) const {
std::vector<CallGraphEntry*> ret;
std::vector<CallGraphEntry*> current_nodes;
if (visited->find(this) == visited->end()) {
visited->emplace(this);
current_nodes.emplace_back(const_cast<CallGraphEntry*>(this));
}
std::vector<CallGraphEntry*> next_nodes;
while (!current_nodes.empty()) {
for (const auto& node : current_nodes) {
ret.push_back(node);
// Iterate through the called entries.
for (auto git = node->begin(); git != node->end(); ++git) {
if (visited->find(git->second) == visited->end()) {
next_nodes.push_back(git->second);
visited->emplace(git->second);
}
}
}
// Update the current level and clean the next level.
current_nodes = next_nodes;
next_nodes.clear();
}
return ret;
}
void CallGraphEntry::CleanCallGraphEntries() {
while (!called_globals_.empty()) {
// Decrement the reference counter
called_globals_.back().second->DecRef();
called_globals_.pop_back();
}
}
inline void CallGraphEntry::AddCalledGlobal(CallGraphEntry* cg_node) {
called_globals_.emplace_back(global_, cg_node);
// Increment the reference to indicate that another call site is found for
// the callee in `cg_node`.
cg_node->IncRef();
// Mark the global function as recursive if it calls itself.
if (global_ == cg_node->GetGlobalVar()) {
cg_node->is_recursive_ = true;
}
}
// Remove an edge from the current global function to the callee.
void CallGraphEntry::RemoveCallTo(const GlobalVar& callee) {
for (auto it = begin();; ++it) {
CHECK(it != end()) << "Cannot find global function "
<< callee->name_hint << " to remove!";
if (it->second->GetGlobalVar() == callee) {
// Only remove one occurrence of the call site.
it->second->DecRef();
*it = called_globals_.back();
called_globals_.pop_back();
return;
}
}
}
// Remove all edges from the current global function to the callee.
void CallGraphEntry::RemoveAllCallTo(CallGraphEntry* callee) {
for (uint32_t i = 0, e = size(); i != e;) {
if (called_globals_[i].second == callee) {
callee->DecRef();
called_globals_[i] = called_globals_.back();
called_globals_.pop_back();
--e;
} else {
++i;
}
}
// Make sure all references to the callee are removed.
CHECK_EQ(callee->GetRefCount(), 0U)
<< "All references to " << callee->GetNameHint()
<< " should have been removed";
}
void CallGraphEntry::Print(std::ostream& os) const {
if (!global_.defined()) {
os << "GlobalVar is not defined\n";
return;
}
os << "Call graph node: " << global_->name_hint;
os << " at: " << this << ", #refs = " << GetRefCount() << "\n";
for (const auto& it : *this) {
os << " call site: <" << it.first->name_hint << "> calls ";
os << it.second->GetNameHint() << "\n";
}
os << "\n";
}
std::ostream& operator<<(std::ostream& os, const CallGraph& cg) {
cg->Print(os);
return os;
}
std::ostream& operator<<(std::ostream& os, const CallGraphEntry& cgn) {
cgn.Print(os);
return os;
}
TVM_REGISTER_NODE_TYPE(CallGraphNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CallGraphNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const CallGraphNode*>(ref.get());
CHECK(node);
p->stream << "CallGraph: \n" << GetRef<CallGraph>(node);
});
TVM_REGISTER_GLOBAL("relay._analysis.CallGraph")
.set_body_typed([](IRModule module) {
return CallGraph(module);
});
TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraph")
.set_body_typed([](CallGraph call_graph) {
std::stringstream ss;
ss << call_graph;
return ss.str();
});
TVM_REGISTER_GLOBAL("relay._analysis.GetModule")
.set_body_typed([](CallGraph call_graph) {
return call_graph->GetModule();
});
TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraphGlobalVar")
.set_body_typed([](CallGraph call_graph, GlobalVar var) {
const auto* entry_node = call_graph[var];
std::stringstream ss;
ss << *entry_node;
return ss.str();
});
TVM_REGISTER_GLOBAL("relay._analysis.GetRefCountGlobalVar")
.set_body_typed([](CallGraph call_graph, GlobalVar var) {
const auto* entry_node = call_graph[var];
return static_cast<int>(entry_node->GetRefCount());
});
TVM_REGISTER_GLOBAL("relay._analysis.GetGlobalVarCallCount")
.set_body_typed([](CallGraph call_graph, GlobalVar var) {
const auto* entry_node = call_graph[var];
return static_cast<int>(entry_node->size());
});
TVM_REGISTER_GLOBAL("relay._analysis.IsRecursive")
.set_body_typed([](CallGraph call_graph, GlobalVar var) {
const auto* entry_node = call_graph[var];
return entry_node->IsRecursive();
});
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/pass/call_graph.h
* \brief Define data structures for the call graph of a IRModule. It borrows
* the idea how LLVM constructs CallGraph.
*
* https://llvm.org/doxygen/CallGraph_8h_source.html
*/
#ifndef TVM_RELAY_PASS_CALL_GRAPH_H_
#define TVM_RELAY_PASS_CALL_GRAPH_H_
#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace tvm {
namespace relay {
class CallGraphEntry;
class CallGraph;
class CallGraphNode : public Object {
using CallGraphMap =
std::unordered_map<GlobalVar, std::unique_ptr<CallGraphEntry>, ObjectHash,
ObjectEqual>;
// Create iterator alias for a CallGraphNode object.
using iterator = CallGraphMap::iterator;
using const_iterator = CallGraphMap::const_iterator;
public:
/*! \brief The IR module for creating a CallGraphNode. */
IRModule module;
/*! \brief Default constructor. */
CallGraphNode() {}
void VisitAttrs(AttrVisitor* v) {
v->Visit("module", &module);
}
/*!
* \brief Print the call graph.
*
* \param os The stream for printing.
*/
void Print(std::ostream& os) const;
/*! \return The begin iterator. */
iterator begin() {
return call_graph_.begin();
}
/*! \return The end iterator. */
iterator end() {
return call_graph_.end();
}
/*! \return The begin iterator. */
const_iterator begin() const {
return call_graph_.begin();
}
/*! \return The end iterator. */
const_iterator end() const {
return call_graph_.end();
}
/*!
* \brief Get an element from the CallGraphNode using a GlobalVar.
*
* \param gv The GlobalVar used for indexing.
*
* \return The fetched element.
*/
const CallGraphEntry* operator[](const GlobalVar& gv) const;
/*!
* \brief Get an element from the CallGraphNode using a GlobalVar.
*
* \param gv The GlobalVar used for indexing.
*
* \return The fetched element.
*/
CallGraphEntry* operator[](const GlobalVar& gv);
/*!
* \brief Get an element from the CallGraphNode using the global function name.
*
* \param gvar_name The global function name used for indexing.
*
* \return The fetched element.
*/
const CallGraphEntry* operator[](const std::string& gvar_name) const {
return (*this)[module->GetGlobalVar(gvar_name)];
}
/*!
* \brief Get an element from the CallGraphNode using the global function name.
*
* \param gvar_name The global function name used for indexing.
*
* \return The fetched element.
*/
CallGraphEntry* operator[](const std::string& gvar_name) {
return (*this)[module->GetGlobalVar(gvar_name)];
}
/*! \brief Return the IR module. */
IRModule GetModule() const {
return module;
}
/*!
* \brief Get the entries/root nodes of CallGraphNode.
*
* Entry functions are never referenced by other functions.
* Note these functions can be recursive as well.
*
* \return The list of CallGraphEntry that represent entry nodes.
*/
std::vector<CallGraphEntry*> GetEntryGlobals() const;
/*!
* \brief Remove a GlobalVar in a given CallGraphEntry from the current
* IR module.
*
* \param cg_node The CallGraphEntry that contains a global function to be
* removed.
* \param update_call_graph Indicate if we will update the CallGraph as well
* since updating is costly. We are only able to remove a leaf function
* when update_call_graph is disabled because the edges pointing to
* functions being removed are not updated.
*
* \return The GlobalVar removed from the current module.
*/
GlobalVar RemoveGlobalVarFromModule(CallGraphEntry* cg_node,
bool update_call_graph = false);
/*!
* \brief Lookup a GlobalVar for the CallGraphNode. It creates an entry for
* the GlobalVar if it doesn't exist.
*
* \param gv The GlobalVar for query.
*
* \return The queried entry.
*/
CallGraphEntry* LookupGlobalVar(const GlobalVar& gv);
/*!
* \brief Get the entries from the CallGraphNode in the topological order.
*
* This is useful for various module-level optimizations/analysis. For example,
* inlining requires the correct order of the functions being processed, i.e.
* callee should be always handled before callers.
*
* \return The list of collected entries that are sorted in the topological order.
*/
std::vector<CallGraphEntry*> TopologicalOrder() const;
static constexpr const char* _type_key = "relay.CallGraph";
TVM_DECLARE_FINAL_OBJECT_INFO(CallGraphNode, Object);
private:
/*!
* \brief Create a CallGraphEntry for a global function and add it to the
* CallGraphNode.
*
* \param gv The global var.
* \param func The global function corresponding to `gv`.
*/
void AddToCallGraph(const GlobalVar& gv, const Function& func);
/*! \brief A record contains GlobalVar to CallGraphEntry mapping. */
CallGraphMap call_graph_;
friend CallGraph;
};
/*!
* \brief The class that represents the call graph of a Relay IR module. It also
* provides a variety of utility functions for users to query, view, and update
* a call graph.
*/
class CallGraph : public ObjectRef {
using CallGraphMap =
std::unordered_map<GlobalVar, std::unique_ptr<CallGraphEntry>, ObjectHash,
ObjectEqual>;
// Create iterator alias for a CallGraph object.
using iterator = CallGraphMap::iterator;
using const_iterator = CallGraphMap::const_iterator;
public:
/*!
* \brief Construct a CallGraph from a IR module.
*
* \param module The IR module
*/
explicit CallGraph(IRModule module);
/*!
* \brief Construct from an object pointer.
* \param n The object pointer.
*/
explicit CallGraph(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return The begin iterator. */
iterator begin() {
auto* n = operator->();
CHECK(n);
return n->begin();
}
/*! \return The end iterator. */
iterator end() {
auto* n = operator->();
CHECK(n);
return n->end();
}
/*! \return The begin iterator. */
const_iterator begin() const {
const auto* n = operator->();
CHECK(n);
return n->begin();
}
/*! \return The end iterator. */
const_iterator end() const {
const auto* n = operator->();
CHECK(n);
return n->end();
}
/*!
* \brief Get an element from the CallGraph using a GlobalVar.
*
* \param gv The GlobalVar used for indexing.
*
* \return The fetched element.
*/
const CallGraphEntry* operator[](const GlobalVar& gv) const {
const auto* n = operator->();
CHECK(n);
return (*n)[gv];
}
/*!
* \brief Get an element from the CallGraph using a GlobalVar.
*
* \param gv The GlobalVar used for indexing.
*
* \return The fetched element.
*/
CallGraphEntry* operator[](const GlobalVar& gv) {
auto* n = operator->();
CHECK(n);
return (*n)[gv];
}
/*!
* \brief Get an element from the CallGraph using the global function name.
*
* \param gvar_name The global function name used for indexing.
*
* \return The fetched element.
*/
const CallGraphEntry* operator[](const std::string& gvar_name) const {
const auto* n = operator->();
CHECK(n);
return (*n)[gvar_name];
}
/*!
* \brief Get an element from the CallGraph using the global function name.
*
* \param gvar_name The global function name used for indexing.
*
* \return The fetched element.
*/
CallGraphEntry* operator[](const std::string& gvar_name) {
auto* n = operator->();
CHECK(n);
return (*n)[gvar_name];
}
/*! \return mutable pointers to the node. */
CallGraphNode* operator->() const {
auto* ptr = get_mutable();
CHECK(ptr != nullptr);
return static_cast<CallGraphNode*>(ptr);
}
private:
/*! \brief Overload the << operator to print a call graph. */
friend std::ostream& operator<<(std::ostream& os, const CallGraph&);
};
/*!
* \brief A node in the call graph. It maintains the edges from a caller to
* all callees.
*/
class CallGraphEntry {
public:
using CallGraphEntryPair = std::pair<GlobalVar, CallGraphEntry*>;
using CallGraphEntryVector = std::vector<CallGraphEntryPair>;
using CallGraphEntrySet = std::unordered_set<const CallGraphEntry*>;
// Create iterator alias for a CallGraphEntry object.
using iterator = std::vector<CallGraphEntryPair>::iterator;
using const_iterator = std::vector<CallGraphEntryPair>::const_iterator;
/*!
* \brief Construct from a GlobalVar.
*
* \param gv The GlobalVar to create a CallGraphEntry.
*/
explicit CallGraphEntry(const GlobalVar& gv) : global_(gv) {}
/*!
* \brief Delete copy constructor.
*/
CallGraphEntry(const CallGraphEntry&) = delete;
/*! \brief Delete assignment. */
CallGraphEntry& operator=(const CallGraphEntry&) = delete;
/*! \return The begin iterator */
iterator begin() {
return called_globals_.begin();
}
/*! \return The end iterator */
iterator end() {
return called_globals_.end();
}
/*! \return The const begin iterator */
const_iterator begin() const {
return called_globals_.begin();
}
/*! \return The const end iterator */
const_iterator end() const {
return called_globals_.end();
}
/*!
* \brief Return if the list of called nodes is empty.
*
* \return true if the list is empty. Otherwise, false.
*/
bool empty() const {
return called_globals_.empty();
}
/*!
* \brief Return the size of the list that represents the nodes are called by
* the current node.
*
* \return The number of called nodes.
*/
uint32_t size() const {
return static_cast<uint32_t>(called_globals_.size());
}
/*!
* \brief Fetch the i-th CallGraphEntry from the list of nodes that are called
* by the current function.
*
* \param i The index.
*
* \return The fetched CallGraphEntry.
*/
CallGraphEntry* operator[](size_t i) const {
CHECK_LT(i, called_globals_.size()) << "Invalid Index";
return called_globals_[i].second;
}
/*!
* \brief Print the call graph that is stemmed from the current CallGraphEntry.
*
* \param os The stream for printing.
*/
void Print(std::ostream& os) const;
/*!
* \brief Return the number of times the global function is referenced.
*
* \return The count.
*/
uint32_t GetRefCount() const {
return ref_cnt_;
}
/*!
* \brief Return the GlobalVar stored in the current CallGraphEntry.
*
* \return The GlobalVar.
*/
GlobalVar GetGlobalVar() const {
return global_;
}
/*!
* \brief Return the name hint of the GlobalVar stored in the CallGraphEntry.
*
* \return The name hint of the global function.
*/
std::string GetNameHint() const {
return global_->name_hint;
}
/*!
* \brief Return if the global function corresponding to the current
* CallGraphEntry is a recursive function.
*
* \return true if it is recursive. Otherwise, false.
*/
bool IsRecursive() const {
return is_recursive_;
}
/*!
* \brief Return if the global function corresponding to the current
* CallGraphEntry is both a recursive function and an entry function. This type
* of function only has one reference which is called by itself.
*
* \return true if it is both a recursive function and an entry. Otherwise, false.
*/
bool IsRecursiveEntry() const {
return GetRefCount() == 1 && IsRecursive();
}
/*!
* \brief Return the topological order of the CallGraphEntry.
*
* \param visited A set of CallGraphEntry objects that have been visited.
*
* \return The list of CallGraphEntry that is represented in topological order.
*/
std::vector<CallGraphEntry*> TopologicalOrder(
CallGraphEntrySet* visited = new CallGraphEntrySet()) const;
/*!
* \brief Remove all edges from the current CallGraphEntry to any global
* function it calls.
*/
void CleanCallGraphEntries();
/*!
* \brief Add a node to the list of nodes that are being called by the current
* global function.
*
* \param cg_node The CallGraphEntry that will be added to the call list.
*/
void AddCalledGlobal(CallGraphEntry* cg_node);
/*!
* \brief Remove a call edge to the global function from the current
* function.
*
* \param callee The function that is being called.
*/
void RemoveCallTo(const GlobalVar& callee);
/*!
* \brief Remove all the edges that represent that calls to the global function
* stored in a given CallGraphEntry.
*
* \param callee The function that is being called.
*/
void RemoveAllCallTo(CallGraphEntry* callee);
private:
/*! \brief Decrement the reference counter by 1. */
void DecRef() {
CHECK_GT(ref_cnt_, 0);
--ref_cnt_;
}
/*! \brief Increment the reference counter by 1. */
void IncRef() { ++ref_cnt_; }
/*!
* \brief Mark if the global function stored in the CallGraphEntry is
* recursive function.
*/
bool is_recursive_{false};
/*! \brief Count the number of times the global function is referenced. */
uint32_t ref_cnt_{0};
/*! \brief The GlobalVar stored in the current CallGraphEntry. */
GlobalVar global_;
/*! \brief The list of entries called by the current CallGraphEntry. */
CallGraphEntryVector called_globals_;
friend class CallGraph;
/*! \brief Overload the << operator to print a call graph node. */
friend std::ostream& operator<<(std::ostream& os, const CallGraphEntry&);
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_CALL_GRAPH_H_
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
import pytest
import tvm
from tvm import relay
def test_callgraph_construct():
mod = tvm.IRModule({})
x = relay.var("x", shape=(2, 3))
y = relay.var("y", shape=(2, 3))
mod["g1"] = relay.Function([x, y], x + y)
call_graph = relay.CallGraph(mod)
assert "g1" in str(call_graph)
assert relay.alpha_equal(mod, call_graph.module)
def test_print_element():
mod = tvm.IRModule({})
x0 = relay.var("x0", shape=(2, 3))
y0 = relay.var("y0", shape=(2, 3))
mod["g0"] = relay.Function([x0, y0], x0 + y0)
x1 = relay.var("x1", shape=(2, 3))
y1 = relay.var("y1", shape=(2, 3))
mod["g1"] = relay.Function([x1, y1], x1 - y1)
call_graph = relay.CallGraph(mod)
assert "#refs = 0" in str(call_graph.print_var("g0"))
assert "#refs = 0" in str(call_graph.print_var("g1"))
def test_global_call_count():
mod = tvm.IRModule({})
x0 = relay.var("x0", shape=(2, 3))
y0 = relay.var("y0", shape=(2, 3))
g0 = relay.GlobalVar("g0")
mod[g0] = relay.Function([x0, y0], x0 + y0)
x1 = relay.var("x1", shape=(2, 3))
y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], g0(x1, y1))
call_graph = relay.CallGraph(mod)
p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func
call_graph = relay.CallGraph(mod)
assert call_graph.global_call_count(g0) == 0
assert call_graph.global_call_count(g1) == 1
assert call_graph.global_call_count("main") == 2
def test_ref_count():
mod = tvm.IRModule({})
x0 = relay.var("x0", shape=(2, 3))
y0 = relay.var("y0", shape=(2, 3))
g0 = relay.GlobalVar("g0")
mod[g0] = relay.Function([x0, y0], x0 + y0)
x1 = relay.var("x1", shape=(2, 3))
y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], x1 - y1)
call_graph = relay.CallGraph(mod)
p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func
call_graph = relay.CallGraph(mod)
assert call_graph.ref_count(g0) == 1
assert call_graph.ref_count(g1) == 1
assert call_graph.ref_count("main") == 0
def test_nested_ref():
mod = tvm.IRModule({})
x0 = relay.var("x0", shape=(2, 3))
y0 = relay.var("y0", shape=(2, 3))
g0 = relay.GlobalVar("g0")
mod[g0] = relay.Function([x0, y0], x0 + y0)
x1 = relay.var("x1", shape=(2, 3))
y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], g0(x1, y1))
call_graph = relay.CallGraph(mod)
p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func
call_graph = relay.CallGraph(mod)
assert call_graph.ref_count(g0) == 2
assert call_graph.ref_count(g1) == 1
assert call_graph.ref_count("main") == 0
def test_recursive_func():
mod = tvm.IRModule({})
x = relay.var('x', shape=[], dtype='int32')
fn0 = relay.Function([x], x)
gx = relay.GlobalVar("gx")
mod[gx] = fn0
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
sb = relay.ScopeBuilder()
with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
sb.ret(i)
with sb.else_scope():
one_less = relay.subtract(i, relay.const(1, dtype='int32'))
global_call = gx(i)
rec_call = relay.Call(sum_up, [one_less]) + global_call
sb.ret(relay.add(rec_call, i))
func = relay.Function([i],
sb.get(),
ret_type=relay.TensorType([], 'int32'))
func = func.set_attribute("Compiler", tvm.tir.StringImm("a"))
mod[sum_up] = func
iarg = relay.var('i', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg))
call_graph = relay.CallGraph(mod)
assert call_graph.is_recursive(sum_up)
assert call_graph.ref_count(sum_up) == 2
assert call_graph.ref_count(gx) == 1
assert call_graph.ref_count("main") == 0
if __name__ == "__main__":
pytest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment