/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file src/relay/analysis/call_graph.h
 * \brief Define data structures for the call graph of a IRModule. It borrows
 * the idea how LLVM constructs CallGraph.
 *
 * https://llvm.org/doxygen/CallGraph_8h_source.html
 */

#ifndef TVM_RELAY_ANALYSIS_CALL_GRAPH_H_
#define TVM_RELAY_ANALYSIS_CALL_GRAPH_H_

#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/runtime/object.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

namespace tvm {
namespace relay {

class CallGraphEntry;
class CallGraph;

class CallGraphNode : public Object {
  using CallGraphMap =
      std::unordered_map<GlobalVar, std::unique_ptr<CallGraphEntry>, ObjectHash,
                         ObjectEqual>;
  // Create iterator alias for a CallGraphNode object.
  using iterator = CallGraphMap::iterator;
  using const_iterator = CallGraphMap::const_iterator;

 public:
  /*! \brief The IR module for creating a CallGraphNode. */
  IRModule module;

  /*! \brief Default constructor. */
  CallGraphNode() {}

  void VisitAttrs(AttrVisitor* v) {
    v->Visit("module", &module);
  }

  /*!
   * \brief Print the call graph.
   *
   * \param os The stream for printing.
   */
  void Print(std::ostream& os) const;

  /*! \return The begin iterator. */
  iterator begin() {
    return call_graph_.begin();
  }
  /*! \return The end iterator. */
  iterator end() {
    return call_graph_.end();
  }
  /*! \return The begin iterator. */
  const_iterator begin() const {
    return call_graph_.begin();
  }
  /*! \return The end iterator. */
  const_iterator end() const {
    return call_graph_.end();
  }

  /*!
   * \brief Get an element from the CallGraphNode using a GlobalVar.
   *
   * \param gv The GlobalVar used for indexing.
   *
   * \return The fetched element.
   */
  const CallGraphEntry* operator[](const GlobalVar& gv) const;
  /*!
   * \brief Get an element from the CallGraphNode using a GlobalVar.
   *
   * \param gv The GlobalVar used for indexing.
   *
   * \return The fetched element.
   */
  CallGraphEntry* operator[](const GlobalVar& gv);
  /*!
   * \brief Get an element from the CallGraphNode using the global function name.
   *
   * \param gvar_name The global function name used for indexing.
   *
   * \return The fetched element.
   */
  const CallGraphEntry* operator[](const std::string& gvar_name) const {
    return (*this)[module->GetGlobalVar(gvar_name)];
  }
  /*!
   * \brief Get an element from the CallGraphNode using the global function name.
   *
   * \param gvar_name The global function name used for indexing.
   *
   * \return The fetched element.
   */
  CallGraphEntry* operator[](const std::string& gvar_name) {
    return (*this)[module->GetGlobalVar(gvar_name)];
  }

  /*!
   * \brief Get the global function corresponding to the variable.
   *
   * \param var The global variable.
   *
   * \return The found global function.
   */
  BaseFunc GetGlobalFunction(const GlobalVar& var) const;

  /*!
   * \brief Get the entries/root nodes of CallGraphNode.
   *
   *  Entry functions are never referenced by other functions.
   *  Note these functions can be recursive as well.
   *
   * \return The list of CallGraphEntry that represent entry nodes.
   */
  std::vector<CallGraphEntry*> GetEntryGlobals() const;

  /*!
   * \brief Remove a GlobalVar in a given CallGraphEntry from the current
   *        IR module.
   *
   * \param cg_node The CallGraphEntry that contains a global function to be
   *        removed.
   * \param update_call_graph Indicate if we will update the CallGraph as well
   *        since updating is costly. We are only able to remove a leaf function
   *        when update_call_graph is disabled because the edges pointing to
   *        functions being removed are not updated.
   *
   * \return The GlobalVar removed from the current module.
   */
  GlobalVar RemoveGlobalVarFromModule(CallGraphEntry* cg_node,
                                      bool update_call_graph = false);

  /*!
   * \brief Lookup a GlobalVar for the CallGraphNode. It creates an entry for
   *        the GlobalVar if it doesn't exist.
   *
   * \param gv The GlobalVar for query.
   *
   * \return The queried entry.
   */
  CallGraphEntry* LookupGlobalVar(const GlobalVar& gv);

  /*!
   * \brief Get the entries from the CallGraphNode in the topological order.
   *
   *  This is useful for various module-level optimizations/analysis. For example,
   *  inlining requires the correct order of the functions being processed, i.e.
   *  callee should be always handled before callers.
   *
   * \return The list of collected entries that are sorted in the topological order.
   */
  std::vector<CallGraphEntry*> TopologicalOrder() const;

  static constexpr const char* _type_key = "relay.CallGraph";
  TVM_DECLARE_FINAL_OBJECT_INFO(CallGraphNode, Object);

 private:
  /*!
   * \brief Create a CallGraphEntry for a global function and add it to the
   *        CallGraphNode.
   *
   * \param gv The global var.
   * \param func The global function corresponding to `gv`.
   */
  void AddToCallGraph(const GlobalVar& gv, const Function& func);

  /*! \brief A record contains GlobalVar to CallGraphEntry mapping. */
  CallGraphMap call_graph_;

  friend CallGraph;
};

/*!
 * \brief The class that represents the call graph of a Relay IR module. It also
 * provides a variety of utility functions for users to query, view, and update
 * a call graph.
 */
class CallGraph : public ObjectRef {
  using CallGraphMap =
      std::unordered_map<GlobalVar, std::unique_ptr<CallGraphEntry>, ObjectHash,
                         ObjectEqual>;
  // Create iterator alias for a CallGraph object.
  using iterator = CallGraphMap::iterator;
  using const_iterator = CallGraphMap::const_iterator;

 public:
  /*!
   * \brief Construct a CallGraph from a IR module.
   *
   * \param module The IR module
   */
  explicit CallGraph(IRModule module);

  /*!
   * \brief Construct from an object pointer.
   * \param n The object pointer.
   */
  explicit CallGraph(ObjectPtr<Object> n) : ObjectRef(n) {}

  /*! \return The begin iterator. */
  iterator begin() {
    auto* n = operator->();
    CHECK(n);
    return n->begin();
  }
  /*! \return The end iterator. */
  iterator end() {
    auto* n = operator->();
    CHECK(n);
    return n->end();
  }
  /*! \return The begin iterator. */
  const_iterator begin() const {
    const auto* n = operator->();
    CHECK(n);
    return n->begin();
  }
  /*! \return The end iterator. */
  const_iterator end() const {
    const auto* n = operator->();
    CHECK(n);
    return n->end();
  }

  /*!
   * \brief Get an element from the CallGraph using a GlobalVar.
   *
   * \param gv The GlobalVar used for indexing.
   *
   * \return The fetched element.
   */
  const CallGraphEntry* operator[](const GlobalVar& gv) const {
    const auto* n = operator->();
    CHECK(n);
    return (*n)[gv];
  }
  /*!
   * \brief Get an element from the CallGraph using a GlobalVar.
   *
   * \param gv The GlobalVar used for indexing.
   *
   * \return The fetched element.
   */
  CallGraphEntry* operator[](const GlobalVar& gv) {
    auto* n = operator->();
    CHECK(n);
    return (*n)[gv];
  }
  /*!
   * \brief Get an element from the CallGraph using the global function name.
   *
   * \param gvar_name The global function name used for indexing.
   *
   * \return The fetched element.
   */
  const CallGraphEntry* operator[](const std::string& gvar_name) const {
    const auto* n = operator->();
    CHECK(n);
    return (*n)[gvar_name];
  }
  /*!
   * \brief Get an element from the CallGraph using the global function name.
   *
   * \param gvar_name The global function name used for indexing.
   *
   * \return The fetched element.
   */
  CallGraphEntry* operator[](const std::string& gvar_name) {
    auto* n = operator->();
    CHECK(n);
    return (*n)[gvar_name];
  }

  /*! \return mutable pointers to the node. */
  CallGraphNode* operator->() const {
    auto* ptr = get_mutable();
    CHECK(ptr != nullptr);
    return static_cast<CallGraphNode*>(ptr);
  }

 private:
  /*! \brief Overload the << operator to print a call graph. */
  friend std::ostream& operator<<(std::ostream& os, const CallGraph&);
};

/*!
 * \brief A node in the call graph. It maintains the edges from a caller to
 * all callees.
 */
class CallGraphEntry {
 public:
  using CallGraphEntryPair = std::pair<GlobalVar, CallGraphEntry*>;
  using CallGraphEntryVector = std::vector<CallGraphEntryPair>;
  using CallGraphEntrySet = std::unordered_set<const CallGraphEntry*>;
  // Create iterator alias for a CallGraphEntry object.
  using iterator = std::vector<CallGraphEntryPair>::iterator;
  using const_iterator = std::vector<CallGraphEntryPair>::const_iterator;

  /*!
   * \brief Construct from a GlobalVar.
   *
   * \param gv The GlobalVar to create a CallGraphEntry.
   */
  explicit CallGraphEntry(const GlobalVar& gv) : global_(gv) {}
  /*!
   * \brief Delete copy constructor.
   */
  CallGraphEntry(const CallGraphEntry&) = delete;
  /*! \brief Delete assignment. */
  CallGraphEntry& operator=(const CallGraphEntry&) = delete;

  /*! \return The begin iterator */
  iterator begin() {
    return called_globals_.begin();
  }
  /*! \return The end iterator */
  iterator end() {
    return called_globals_.end();
  }
  /*! \return The const begin iterator */
  const_iterator begin() const {
    return called_globals_.begin();
  }
  /*! \return The const end iterator */
  const_iterator end() const {
    return called_globals_.end();
  }

  /*!
   * \brief Return if the list of called nodes is empty.
   *
   * \return true if the list is empty. Otherwise, false.
   */
  bool empty() const {
    return called_globals_.empty();
  }

  /*!
   * \brief Return the size of the list that represents the nodes are called by
   * the current node.
   *
   * \return The number of called nodes.
   */
  uint32_t size() const {
    return static_cast<uint32_t>(called_globals_.size());
  }

  /*!
   * \brief Fetch the i-th CallGraphEntry from the list of nodes that are called
   * by the current function.
   *
   * \param i The index.
   *
   * \return The fetched CallGraphEntry.
   */
  CallGraphEntry* operator[](size_t i) const {
    CHECK_LT(i, called_globals_.size()) << "Invalid Index";
    return called_globals_[i].second;
  }

  /*!
   * \brief Print the call graph that is stemmed from the current CallGraphEntry.
   *
   * \param os The stream for printing.
   */
  void Print(std::ostream& os) const;

  /*!
   * \brief Return the number of times the global function is referenced.
   *
   * \return The count.
   */
  uint32_t GetRefCount() const {
    return ref_cnt_;
  }

  /*!
   * \brief Return the GlobalVar stored in the current CallGraphEntry.
   *
   * \return The GlobalVar.
   */
  GlobalVar GetGlobalVar() const {
    return global_;
  }

  /*!
   * \brief Return the name hint of the GlobalVar stored in the CallGraphEntry.
   *
   * \return The name hint of the global function.
   */
  std::string GetNameHint() const {
    return global_->name_hint;
  }

  /*!
   * \brief Return if the global function corresponding to the current
   * CallGraphEntry is a recursive function.
   *
   * \return true if it is recursive. Otherwise, false.
   */
  bool IsRecursive() const {
    return is_recursive_;
  }

  /*!
   * \brief Return if the global function corresponding to the current
   * CallGraphEntry is both a recursive function and an entry function. This type
   * of function only has one reference which is called by itself.
   *
   * \return true if it is both a recursive function and an entry. Otherwise, false.
   */
  bool IsRecursiveEntry() const {
    return GetRefCount() == 1 && IsRecursive();
  }

  /*!
   * \brief Return the topological order of the CallGraphEntry.
   *
   * \param visited A set of CallGraphEntry objects that have been visited.
   *
   * \return The list of CallGraphEntry that is represented in topological order.
   */
  std::vector<CallGraphEntry*> TopologicalOrder(
      CallGraphEntrySet* visited = new CallGraphEntrySet()) const;

  /*!
   * \brief Remove all edges from the current CallGraphEntry to any global
   * function it calls.
   */
  void CleanCallGraphEntries();

  /*!
   * \brief Add a node to the list of nodes that are being called by the current
   * global function.
   *
   * \param cg_node The CallGraphEntry that will be added to the call list.
   */
  void AddCalledGlobal(CallGraphEntry* cg_node);

  /*!
   * \brief Remove a call edge to the global function from the current
   * function.
   *
   * \param callee The function that is being called.
   */
  void RemoveCallTo(const GlobalVar& callee);

  /*!
   * \brief Remove all the edges that represent that calls to the global function
   * stored in a given CallGraphEntry.
   *
   * \param callee The function that is being called.
   */
  void RemoveAllCallTo(CallGraphEntry* callee);

 private:
  /*! \brief Decrement the reference counter by 1. */
  void DecRef() {
    CHECK_GT(ref_cnt_, 0);
    --ref_cnt_;
  }
  /*! \brief Increment the reference counter by 1. */
  void IncRef() { ++ref_cnt_; }

  /*!
   * \brief Mark if the global function stored in the CallGraphEntry is
   * recursive function.
   */
  bool is_recursive_{false};
  /*! \brief Count the number of times the global function is referenced. */
  uint32_t ref_cnt_{0};
  /*! \brief The GlobalVar stored in the current CallGraphEntry. */
  GlobalVar global_;
  /*! \brief The list of entries called by the current CallGraphEntry. */
  CallGraphEntryVector called_globals_;

  friend class CallGraph;
  /*! \brief Overload the << operator to print a call graph node. */
  friend std::ostream& operator<<(std::ostream& os, const CallGraphEntry&);
};

}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_ANALYSIS_CALL_GRAPH_H_