/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*
 * \file src/relay/transforms/partition_graph.cc
 *
 * \brief Partition an input function into multiple functions according based
 * on the inserted annotation nodes (i.e. compiler_begin and compiler_end).
 * These nodes are used as boundaries to partition the Relay function into
 * multiple regions that can be offloaded to different accelerators/backends.
 *
 * Each of these paritioned functions, a.k.a regions, will be viewed as
 * external functions, and they will use the provided compiler for codegen.
 */

#include <tvm/ir/error.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/container.h>

#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "../analysis/annotated_region_set.h"
#include "../backend/utils.h"

namespace tvm {
namespace relay {
namespace partitioning {

// Cache compiler_begin and compiler_end annotation ops for equivalence check to
// reduce registry lookup overhead.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");

/*!
 * \brief The checker that verifies if a Relay program is annotated correctly
 * for partitioning.
 */
class AnnotationChecker : public ExprVisitor {
 public:
  bool Check() {
    if (!found_start_ && !found_end_) {
      LOG(WARNING) << "No compiler annotation found";
    } else if (!found_start_) {
      LOG(ERROR) << "compiler_begin annotation is missing";
      return false;
    } else if (!found_end_) {
      LOG(ERROR) << "compiler_end annotation is missing";
      return false;
    }
    return true;
  }

  void VisitExpr_(const CallNode* call) final {
    auto op_node = call->op.as<OpNode>();
    if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
      return;
    } else if (call->op == compiler_begin_op) {
      found_start_ = true;
    } else if (call->op == compiler_end_op) {
      found_end_ = true;
    }
  }

 private:
  bool found_start_{false};
  bool found_end_{false};
};

/*! \brief This class partitions the expr labeled with begin and end annotations
 * into function containing multiple regions. Each region is labeled with
 * a compiler attribute so that it will be handled by any compilers that are not
 * in the TVM stack.
 *
 * Input : A Relay module that have functions with disjoint annotated regions
 *         using compiler_begin and compiler_end. There could be multiple
 * outputs.
 *
 * Output : A Relay module with global functions for such disjoint annotated
 * regions with calls inserted at the respective location
 *
 * Dependencies : AnnotatedRegionSet Utility class.
 *
 * Methodology :
 *      1) The AnnotatedRegionSet utility class is able to construct a collection
 *      of nodes that are bound by a given annotation -- here we use
 *      compiler_begin and compiler_end
 *      2) Initially, for each function in the module RegionSets are populated.
 *      3) Then, Vistor pass is traversed until a compiler_end node is encountered
 *         that belongs to a "region".
 *      4) When the first compiler_end of a given annotated region is found,
 *         a function is formed and inserted.
 *         a) if the region has multiple outputs, a Tuple node (capturing
 *            all outputs) is returned.
 *      5) Thereafter, if we encounter an another output of the same annotated
 *         region, it is important to note that the function is already formed.
 *         Therefore, it will lookup the function and add a TupleGetItemNode.
 *         a) We will use the location index of "rets" of each Region" of
 *         AnnotatedRegionSet as TupleGetItemNode index.
 *      6) Therefore, functions will be created for all annotated regions.
 *         The name for each global function is created using "Region" id and
 *         the compiler name.
 */

class Partitioner : public ExprMutator {
 public:
  explicit Partitioner(const IRModule& module) : module_(module) {
    for (auto f : module->functions) {
      GlobalVar f_var = f.first;
      BaseFunc f_func = f.second;

      // Creating regionset per function in the module
      auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op,
                                                   partitioning::compiler_end_op);
      regions_sets_[region_set] = f_func;
    }
  }

  Expr VisitExpr_(const CallNode* call) final {
    auto op_node = call->op.as<OpNode>();
    if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
      return ExprMutator::VisitExpr_(call);
    } else if (call->op == compiler_begin_op) {
      // The annotation node is inserted on edge so it must have only one
      // argument.
      CHECK_EQ(call->args.size(), 1U);

      // Traverse the rest graph.
      Expr parent = call->args[0];
      auto input_expr = VisitExpr(parent);

      // Backtrace the parent to find the first ancestor node that is not a begin or end op
      while (const auto* parent_call = parent.as<CallNode>()) {
        if (parent_call->op == compiler_begin_op ||
            parent_call->op == compiler_end_op) {
          parent = parent_call->args[0];
        } else {
          break;
        }
      }

      AnnotatedRegion sg = GetRegion(GetRef<Call>(call));
      int index = GetArgIdx(sg, GetRef<Call>(call));
      CHECK_NE(index, -1);

      if (shared_output_.count(parent) && shared_output_[parent].count(sg)) {
        return shared_output_[parent][sg];
      } else {
        // The type of the created variable is the same as the compiler_begin
        // node.
        std::string target = call->attrs.as<CompilerAttrs>()->compiler;
        std::string varname =
            target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index);
        auto var = Var(varname, GetRef<Call>(call)->checked_type_);

        std::pair<Var, Expr> cand = std::make_pair(var, input_expr);

        if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
            region_args[sg].end()) {
          region_args[sg].push_back(cand);
        }
        shared_output_[parent][sg] = var;
        return std::move(var);
      }
    } else {
      CHECK_EQ(call->op, compiler_end_op);
      // The annotation node is inserted on edge so it must have only one
      // argument.
      CHECK_EQ(call->args.size(), 1U);

      AnnotatedRegion region = GetRegion(GetRef<Call>(call));

      // TODO(@manupa-arm) : need to use the parent function (to which region
      // belongs to) name/key for the funtions that are created
      BaseFunc f = GetFunc(GetRef<Call>(call));

      // Traverse subgraph inputs.
      auto input = VisitExpr(call->args[0]);
      CHECK(region.defined()) << "Region not defined for " << GetRef<Call>(call);
      // functions are created for each annotated regions,
      // when their first output is encountered.
      // If multiple outputs are there, a tuple node is inserted at the end.
      // region_function_calls is map that maintains
      // (each annotated regions) --> created function

      if (region_function_calls.find(region) == region_function_calls.end()) {
        // First time this region is encountered in the traversal.
        // Creating the function.
        CreateFunction(region, call);
      }
      // Retrieve this particular output of function.
      return GetFunctionOutput(region, GetRef<Call>(call));
    }
  }

  Expr VisitExpr_(const TupleNode* op) final {
    auto region = GetRegion(GetRef<Tuple>(op));
    if (!region.defined()) {
      return ExprMutator::VisitExpr_(op);
    } else {
      Array<Expr> fields;
      for (auto field : op->fields) {
        fields.push_back(VisitExpr(field));
      }
      return Tuple(fields);
    }
  }

  Expr VisitExpr_(const TupleGetItemNode* g) final {
    auto region = GetRegion(GetRef<TupleGetItem>(g));
    if (!region.defined()) {
      return ExprMutator::VisitExpr_(g);
    } else {
      auto t = VisitExpr(g->tuple);
      return TupleGetItem(t, g->index);
    }
  }

  Expr VisitExpr_(const FunctionNode* op) final {
    auto region = GetRegion(GetRef<Function>(op));
    if (!region.defined()) {
      return ExprMutator::VisitExpr_(op);
    } else {
      Array<Var> params;
      for (auto param : op->params) {
        Var new_param = Downcast<Var>(VisitExpr(param));
        params.push_back(new_param);
      }
      auto body = VisitExpr(op->body);
      return Function(params, body, op->ret_type, op->type_params, op->attrs);
    }
  }

  Expr VisitExpr_(const LetNode* op) final {
    auto region = GetRegion(GetRef<Let>(op));
    if (!region.defined()) {
      return ExprMutator::VisitExpr_(op);
    } else {
      Var var = Downcast<Var>(VisitExpr(op->var));
      auto value = VisitExpr(op->value);
      auto body = VisitExpr(op->body);
      return Let(var, value, body);
    }
  }

  Expr VisitExpr_(const IfNode* op) final {
    auto region = GetRegion(GetRef<If>(op));
    if (!region.defined()) {
      return ExprMutator::VisitExpr_(op);
    } else {
      auto guard = VisitExpr(op->cond);
      auto true_b = VisitExpr(op->true_branch);
      auto false_b = VisitExpr(op->false_branch);
      return If(guard, true_b, false_b);
    }
  }

  Expr VisitExpr_(const RefCreateNode* op) final {
    auto region = GetRegion(GetRef<RefCreate>(op));
    if (!region.defined()) {
      return ExprMutator::VisitExpr_(op);
    } else {
      Expr value = VisitExpr(op->value);
      return RefCreate(value);
    }
  }

  Expr VisitExpr_(const RefReadNode* op) final {
    auto region = GetRegion(GetRef<RefRead>(op));
    if (!region.defined()) {
      return ExprMutator::VisitExpr_(op);
    } else {
      Expr ref = VisitExpr(op->ref);
      return RefRead(ref);
    }
  }

  Expr VisitExpr_(const RefWriteNode* op) final {
    auto region = GetRegion(GetRef<RefWrite>(op));
    if (!region.defined()) {
      return ExprMutator::VisitExpr_(op);
    } else {
      Expr ref = VisitExpr(op->ref);
      Expr value = VisitExpr(op->value);
      return RefWrite(ref, value);
    }
  }

  IRModule Partition() {
    auto glob_funcs = module_->functions;
    for (const auto& pair : glob_funcs) {
      if (auto* fn = pair.second.as<FunctionNode>()) {
        auto func = GetRef<Function>(fn);
        func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
                        func->attrs);
        module_->Update(pair.first, func);
      }
    }
    return module_;
  }

 private:
  /*!
   * \brief Get the region an expression belongs to
   * if its in a region.
   */
  AnnotatedRegion GetRegion(const Expr& e) {
    for (auto sg_set_it : regions_sets_) {
      auto sg_set = sg_set_it.first;
      AnnotatedRegion sg = sg_set->GetRegion(e);
      if (sg.defined()) {
        return sg;
      }
    }
    return AnnotatedRegion(nullptr);
  }

  /*!
   * \brief Get the function an expression belongs to
   * if its in a region.
   */
  BaseFunc GetFunc(const Expr& e) {
    for (auto sg_set_it : regions_sets_) {
      auto sg_set = sg_set_it.first;
      auto func = sg_set_it.second;

      AnnotatedRegion sg = sg_set->GetRegion(e);
      if (sg.defined()) {
        return func;
      }
    }
    return BaseFunc(nullptr);
  }

  /*!
   * \brief Get the index of the argument;
   * this is to be used as tuplegetitem idx
   */
  int GetArgIdx(AnnotatedRegion sg, const Expr& arg) {
    int idx = 0;
    for (auto arg_ : sg->GetInputs()) {
      if (arg == arg_) {
        return idx;
      }
      idx++;
    }
    return -1;
  }

  /*!
   * \brief This function is called first time that we encounter a compiler_end
   * node to create the function for the subgraph.
   */
  void CreateFunction(AnnotatedRegion region, const CallNode* call) {
    // Create fields which is a unique list of outputs. Also populate
    // region_return_indices_ map which maps parent of compiler_end node to
    // corresponding index in fields.
    Array<Expr> fields;
    int i = 0;
    for (auto ret : region->GetOutputs()) {
      auto ret_node = Downcast<Call>(ret)->args[0];
      // Don't duplicate outputs.
      if (!region_return_indices_.count(region) ||
          !region_return_indices_[region].count(ret_node)) {
        auto ret_expr = VisitExpr(ret_node);
        fields.push_back(ret_expr);
        region_return_indices_[region][ret_node] = i;
        i++;
      }
    }

    Array<Var> params;
    Array<Expr> param_expr;
    std::unordered_map<std::string, runtime::NDArray> params_bind;

    for (auto pair : region_args[region]) {
      params.push_back(pair.first);
      if (const auto* cn = pair.second.as<ConstantNode>()) {
        params_bind[pair.first->name_hint()] = cn->data;
      } else {
        param_expr.push_back(pair.second);
      }
    }

    Function global_region_func;
    if (fields.size() == 1) {
      // If there are only a single output; no need to add a tuple
      global_region_func =
          Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
    } else {
      auto tuple = Tuple(fields);
      global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
    }

    std::string target = call->attrs.as<CompilerAttrs>()->compiler;
    std::string name = target + "_" + std::to_string(region->GetID());

    global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
                                  runtime::String(name));
    global_region_func =
        WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
    global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
                                  tvm::runtime::String(target));
    global_region_func =
        WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));

    // Constant propagation
    if (!params_bind.empty()) {
      global_region_func = backend::BindParamsByName(global_region_func, params_bind);
    }

    std::string fname = name;
    CHECK(!module_->ContainGlobalVar(fname))
        << "Global function " << fname << " already exists";
    // Create a global function and add it to the IRModule for the region.
    // This way we lift the functions that should be handled by external
    // codegen to the module scope and rely on the pass manager to prevent
    // relay function level passes (i.e. simplify inference and fusion)
    // optimizing it.
    GlobalVar glob_func(fname);
    module_->Add(glob_func, global_region_func);

    // The return type of callnode is the same as the type of the
    // compiler_end node.
    auto ret = Call(glob_func, param_expr);
    region_function_calls[region] = ret;
  }

  /*!
   * \brief Get the return(output) of the function for compiler end node "end_arg".
   * This will return either a Call (for a function with a single output) or a
   * TupleGetItem (for a function with multiple outputs).
   */
  Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) {
    Expr arg = Downcast<Call>(end_arg)->args[0];
    // Function has one output.
    if (region_return_indices_[region].size() == 1) {
      return region_function_calls[region];
    }
    // Function has multiple outputs.
    // Use already made TupleGetItem.
    if (region_return_tuplegetitem_.count(region) &&
        region_return_tuplegetitem_[region].count(arg)) {
      return region_return_tuplegetitem_[region][arg];
    }
    // Create new TupleGetItem.
    CHECK(region_return_indices_.count(region) &&
          region_return_indices_[region].count(arg));
    int index = region_return_indices_[region][arg];

    auto func_call = region_function_calls[region];
    auto tuple_get_item_ = TupleGetItem(func_call, index);
    tuple_get_item_->checked_type_ = arg->checked_type_;
    region_return_tuplegetitem_[region][arg] = tuple_get_item_;
    return std::move(tuple_get_item_);
  }

  /*!
   * \brief This map maintains the already created function calls.
   * This is required in the multi-output scenario, to link rest of the outputs
   * to call
   */
  std::unordered_map<AnnotatedRegion, Call, ObjectHash, ObjectEqual> region_function_calls;

  /*!
   * \brief This map maintains arguments (of region) visits through visitor
   * patterns. Those arguement var and expression will be used to when creating
   * the function.
   */
  std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, ObjectHash, ObjectEqual>
      region_args;

  /*!
   * \brief This map maintains the index of an output in the subgraph function
   * for a given region. If there are multiple entries for a region, then the
   * function has a tuple of multiple outputs for its return.
   */
  using RegionRetIndexMap = std::unordered_map<Expr, int, ObjectHash, ObjectEqual>;
  std::unordered_map<AnnotatedRegion, RegionRetIndexMap, ObjectHash, ObjectEqual>
      region_return_indices_;

  /*!
   * \brief This map holds already created TupleGetItem nodes for accessing
   * outputs of a function.
   */
  using RegionRetTupleGetItemMap = std::unordered_map<Expr, TupleGetItem, ObjectHash, ObjectEqual>;
  std::unordered_map<AnnotatedRegion, RegionRetTupleGetItemMap, ObjectHash, ObjectEqual>
      region_return_tuplegetitem_;

  /*!
   * \brief Each region set is associated with a function in the module.
   * This map maintains the mapping between regionsets and the function it
   * belongs to
   */
  std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;

  /*!\brief Cache the output that is shared by different nodes. */
  using RegionOutputMap = std::unordered_map<AnnotatedRegion, Var, ObjectHash, ObjectEqual>;
  std::unordered_map<Expr, RegionOutputMap, ObjectHash, ObjectEqual> shared_output_;

  /*!\brief The IRModule used for partitioning. */
  IRModule module_;
};

class DefaultRemover : public ExprMutator {
 public:
  explicit DefaultRemover(const IRModule& module) : module_(module) {}

  IRModule Remove() {
    auto glob_funcs = module_->functions;
    for (const auto& pair : glob_funcs) {
      if (auto* fn = pair.second.as<FunctionNode>()) {
        auto func = GetRef<Function>(fn);
        func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
                        func->attrs);
        module_->Update(pair.first, func);
      }
    }
    return module_;
  }

  Expr VisitExpr_(const CallNode* call) final {
    auto attrs = call->attrs.as<CompilerAttrs>();
    if (attrs != nullptr && attrs->compiler == "default") {
      return VisitExpr(call->args[0]);
    }
    return ExprMutator::VisitExpr_(call);
  }

 private:
  IRModule module_;
};

}  // namespace partitioning

namespace transform {

Pass PartitionGraph() {
  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
      [=](IRModule m, PassContext pc) {
        // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute
        // by treating them as un-annotated, but we don't have it yet. This workaround pass removes
        // all "default" annotations and should be deleted in the future.
        auto new_m = partitioning::DefaultRemover(m).Remove();
        return partitioning::Partitioner(new_m).Partition();
  };
  auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
  return Sequential({partitioned, InferType()});
}

TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph").set_body_typed(transform::PartitionGraph);

}  // namespace transform

}  // namespace relay
}  // namespace tvm