/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/relay/analysis/call_graph.cc * \brief Implementation of APIs to handle the call graph of a Relay module. */ #include "call_graph.h" #include <tvm/relay/expr_functor.h> #include <tvm/runtime/object.h> #include <algorithm> #include <memory> #include <sstream> #include <unordered_set> #include <vector> namespace tvm { namespace relay { CallGraph::CallGraph(IRModule module) { auto n = make_object<CallGraphNode>(); n->module = std::move(module); auto gvar_funcs = n->module->functions; for (const auto& it : gvar_funcs) { if (const auto* fn = it.second.as<FunctionNode>()) { auto func = GetRef<Function>(fn); // Add the global function to gradually build up the call graph. n->AddToCallGraph(it.first, func); } } data_ = std::move(n); } void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) { CHECK(func.defined() && gv.defined()); // Add the current global function as an entry to the call grpah. CallGraphEntry* cg_node = LookupGlobalVar(gv); // Only GlobalVar nodes need to be handled in a function. It indicates that // the global function of a callee is called by the function that is being // processed. An edge will be added from the current global function, cg_node, // to the node that contains the found callee GlobalVarNode. // // This is the major overhead for constructing a call graph because the // post-order visitor will visit each AST node of the current function to // figure out the dependencies between functions. PostOrderVisit(func, [&](const Expr& expr) { if (const GlobalVarNode* gvn = expr.as<GlobalVarNode>()) { auto callee = GetRef<GlobalVar>(gvn); cg_node->AddCalledGlobal(LookupGlobalVar(callee)); } }); } const CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) const { const_iterator cit = call_graph_.find(gv); CHECK(cit != call_graph_.end()) << "GlobalVar " << gv->name_hint << " not found in the call graph!"; return cit->second.get(); } CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) { const_iterator cit = call_graph_.find(gv); CHECK(cit != call_graph_.end()) << "GlobalVar " << gv->name_hint << " not found in the call graph!"; return cit->second.get(); } BaseFunc CallGraphNode::GetGlobalFunction(const GlobalVar& var) const { CHECK(module->ContainGlobalVar(var->name_hint)) << "GlobalVar " << var->name_hint << " not found in the current ir module"; return module->Lookup(var); } // Query the existence of a GlobalVar in the call graph. It creates an entry if // there is no such node available. CallGraphEntry* CallGraphNode::LookupGlobalVar(const GlobalVar& gv) { CHECK(gv.defined()); // This inserts an element to the call graph if it is not there yet. auto& call_graph_node = call_graph_[gv]; if (call_graph_node) return call_graph_node.get(); CHECK(module->ContainGlobalVar(gv->name_hint)) << "GlobalVar " << gv->name_hint << " not found in the current ir module"; // Create the node for the inserted entry. call_graph_node = std::unique_ptr<CallGraphEntry>(new CallGraphEntry(gv)); return call_graph_node.get(); } void CallGraphNode::Print(std::ostream& os) const { // Print the call graph in the topological order. std::vector<CallGraphEntry*> nodes = TopologicalOrder(); for (const auto* cgn : nodes) { cgn->Print(os); } } GlobalVar CallGraphNode::RemoveGlobalVarFromModule(CallGraphEntry* cg_node, bool update_call_graph) { CHECK(cg_node->empty() || (cg_node->IsRecursive() && cg_node->size() == 1)) << "Cannot remove global var " << cg_node->GetNameHint() << " from call graph, because it still calls " << cg_node->size() << " other global functions"; if (update_call_graph) { // Update the call graph by removing all edges that point to the node // `cg_node`. for (auto& it : *this) { it.second->RemoveAllCallTo(cg_node); } } GlobalVar gv = cg_node->GetGlobalVar(); call_graph_.erase(gv); // Update the IR module. module->Remove(gv); return gv; } std::vector<CallGraphEntry*> CallGraphNode::GetEntryGlobals() const { std::vector<CallGraphEntry*> ret; // An entry function in Relay is a function that never called by other // functions or only called by itself. for (const auto& it : *this) { if (it.second->GetRefCount() == 0 || it.second->IsRecursiveEntry()) { ret.push_back(it.second.get()); } } return ret; } std::vector<CallGraphEntry*> CallGraphNode::TopologicalOrder() const { std::vector<CallGraphEntry*> ret; // Collect all entry nodes. std::vector<CallGraphEntry*> entries = GetEntryGlobals(); CallGraphEntry::CallGraphEntrySet visited; for (const auto& it : entries) { // Keep tracking the nodes that have been visited. auto topo = it->TopologicalOrder(&visited); // Prepend the collected items. The intermediate nodes that are shared by // multiple entries are guaranteed to be collected when visiting the // previous entries. Therefore, topological order remains. ret.insert(ret.begin(), topo.begin(), topo.end()); } // Find out the missing global functions if there are any to help debugging. if (ret.size() != module->functions.size()) { for (auto it : module->functions) { if (visited.find((*this)[it.first]) == visited.end()) { LOG(WARNING) << "Missing global:" << it.first->name_hint << " with # refs = " << (*this)[it.first]->GetRefCount(); } } LOG(FATAL) << "Expected " << module->functions.size() << " globals, but received " << ret.size(); } return ret; } // BSF traversal is used to collect the nodes in a CallGraphEntry. The nodes // that are visited by previous CallGraphEntry entries can be memoized. This // helps us to make sure no entry will be visited multiple times when collecting // the nodes for an entire call graph. std::vector<CallGraphEntry*> CallGraphEntry::TopologicalOrder( CallGraphEntrySet* visited) const { std::vector<CallGraphEntry*> ret; std::vector<CallGraphEntry*> current_nodes; if (visited->find(this) == visited->end()) { visited->emplace(this); current_nodes.emplace_back(const_cast<CallGraphEntry*>(this)); } std::vector<CallGraphEntry*> next_nodes; while (!current_nodes.empty()) { for (const auto& node : current_nodes) { ret.push_back(node); // Iterate through the called entries. for (auto git = node->begin(); git != node->end(); ++git) { if (visited->find(git->second) == visited->end()) { next_nodes.push_back(git->second); visited->emplace(git->second); } } } // Update the current level and clean the next level. current_nodes = next_nodes; next_nodes.clear(); } return ret; } void CallGraphEntry::CleanCallGraphEntries() { while (!called_globals_.empty()) { // Decrement the reference counter called_globals_.back().second->DecRef(); called_globals_.pop_back(); } } inline void CallGraphEntry::AddCalledGlobal(CallGraphEntry* cg_node) { called_globals_.emplace_back(global_, cg_node); // Increment the reference to indicate that another call site is found for // the callee in `cg_node`. cg_node->IncRef(); // Mark the global function as recursive if it calls itself. if (global_ == cg_node->GetGlobalVar()) { cg_node->is_recursive_ = true; } } // Remove an edge from the current global function to the callee. void CallGraphEntry::RemoveCallTo(const GlobalVar& callee) { for (auto it = begin();; ++it) { CHECK(it != end()) << "Cannot find global function " << callee->name_hint << " to remove!"; if (it->second->GetGlobalVar() == callee) { // Only remove one occurrence of the call site. it->second->DecRef(); *it = called_globals_.back(); called_globals_.pop_back(); return; } } } // Remove all edges from the current global function to the callee. void CallGraphEntry::RemoveAllCallTo(CallGraphEntry* callee) { for (uint32_t i = 0, e = size(); i != e;) { if (called_globals_[i].second == callee) { callee->DecRef(); called_globals_[i] = called_globals_.back(); called_globals_.pop_back(); --e; } else { ++i; } } // Make sure all references to the callee are removed. CHECK_EQ(callee->GetRefCount(), 0U) << "All references to " << callee->GetNameHint() << " should have been removed"; } void CallGraphEntry::Print(std::ostream& os) const { if (!global_.defined()) { os << "GlobalVar is not defined\n"; return; } os << "Call graph node: " << global_->name_hint; os << " at: " << this << ", #refs = " << GetRefCount() << "\n"; for (const auto& it : *this) { os << " call site: <" << it.first->name_hint << "> calls "; os << it.second->GetNameHint() << "\n"; } os << "\n"; } std::ostream& operator<<(std::ostream& os, const CallGraph& cg) { cg->Print(os); return os; } std::ostream& operator<<(std::ostream& os, const CallGraphEntry& cgn) { cgn.Print(os); return os; } TVM_REGISTER_NODE_TYPE(CallGraphNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<CallGraphNode>([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast<const CallGraphNode*>(ref.get()); CHECK(node); p->stream << "CallGraph: \n" << GetRef<CallGraph>(node); }); TVM_REGISTER_GLOBAL("relay.analysis.CallGraph") .set_body_typed([](IRModule module) { return CallGraph(module); }); TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph") .set_body_typed([](CallGraph call_graph) { std::stringstream ss; ss << call_graph; return ss.str(); }); TVM_REGISTER_GLOBAL("relay.analysis.GetModule") .set_body_typed([](CallGraph call_graph) { return call_graph->module; }); TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraphGlobalVar") .set_body_typed([](CallGraph call_graph, GlobalVar var) { const auto* entry_node = call_graph[var]; std::stringstream ss; ss << *entry_node; return ss.str(); }); TVM_REGISTER_GLOBAL("relay.analysis.GetRefCountGlobalVar") .set_body_typed([](CallGraph call_graph, GlobalVar var) { const auto* entry_node = call_graph[var]; return static_cast<int>(entry_node->GetRefCount()); }); TVM_REGISTER_GLOBAL("relay.analysis.GetGlobalVarCallCount") .set_body_typed([](CallGraph call_graph, GlobalVar var) { const auto* entry_node = call_graph[var]; return static_cast<int>(entry_node->size()); }); TVM_REGISTER_GLOBAL("relay.analysis.IsRecursive") .set_body_typed([](CallGraph call_graph, GlobalVar var) { const auto* entry_node = call_graph[var]; return entry_node->IsRecursive(); }); } // namespace relay } // namespace tvm