/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * Copyright (c) 2019 by Contributors
 * \file src/relay/pass/pass_manager.cc
 * \brief Relay pass manager implementation.
 */
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pass.h>

namespace tvm {
namespace relay {
namespace pass {

using tvm::IRPrinter;

class ModulePass;

/*!
 * \brief Module-level passes are designed to implement global
 * analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes
 * at this level have the full control of a given Relay program including
 * addition and deletion of functions.
 */
class ModulePassNode : public PassNode {
 public:
  /* \brief The pass meta data.*/
  PassInfo pass_info;

  /*! \brief The pass function sketches the real optimization. For example,
   * we may need to perform dead code elimination on the module level. We could
   * implement the algorithm in the `pass_func` and let it run on a module. It
   * will then remove the dead code including the unused functions in the module.
   */
  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;

  ModulePassNode() = default;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("pass_info", &pass_info);
  }

  /*!
   * \brief Run a module pass on a certain module.
   *
   * \param mod The module that an optimization pass runs on.
   *
   * \return Return the updated module.
   */
  Module operator()(const Module& mod) const final;

  /*!
   * \brief Get the pass information/meta data.
   */
  PassInfo Info() const { return pass_info; }

  /*!
   * \brief Set the context information for a module pass.
   *
   * \param pass_ctx The context information for a module pass.
   */
  void SetContext(const PassContext& pass_ctx) final;

  TVM_DLL static ModulePass make(
      runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func,
      PassInfo pass_info);

  static constexpr const char* _type_key = "relay.ModulePass";
  TVM_DECLARE_NODE_TYPE_INFO(ModulePassNode, PassNode);

 private:
  /*!
   * \brief The context information that is used to help perform a module pass.
   */
  PassContext pass_ctx_;
};

RELAY_DEFINE_NODE_REF(ModulePass, ModulePassNode, Pass);

class FunctionPass;

/*!
 * \brief Function-level passes are used to implement various global
 * optimizations for a given Relay module. It fetches one function at a time
 * from the function list in the module for optimization.
 *
 * Note that the scope of passes at this level is a Relay function. Therefore,
 * we cannot add or delete a function through these passes as they are not aware
 * of the global information.
 */
class FunctionPassNode : public PassNode {
 public:
  /* \brief The pass meta data.*/
  PassInfo pass_info;

  /*! \brief The packed pass function sketches the real optimization. For
   * instance, we can implement a pass that works on a Relay function as a
   * `pass_func` and let it run on a given module. The same `pass_func` will
   * then be applied on each function in the module.
   */
  runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func;

  FunctionPassNode() = default;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("pass_info", &pass_info);
  }

  /*!
   * \brief Run a function pass on a certain module.
   *
   * \param mod The module that an optimization pass runs on.
   *
   * \return Return the updated module.
   */
  Module operator()(const Module& mod) const final;

  /*!
   * \brief Get the pass information/meta data.
   */
  PassInfo Info() const { return pass_info; }

  /*!
   * \brief Set the context information for a function-level pass.
   *
   * \param pass_ctx The context information for a function-level pass.
   */
  void SetContext(const PassContext& pass_ctx) final;

  TVM_DLL static FunctionPass make(
      runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func,
      PassInfo pass_info);

  static constexpr const char* _type_key = "relay.FunctionPass";
  TVM_DECLARE_NODE_TYPE_INFO(FunctionPassNode, PassNode);

 private:
  /*
   * \brief Check if a function should be skipped for optimization.
   *
   * \param func The target function to be checked.
   *
   * \return Return true if the function will be skipped, otherwise false.
   */
  bool SkipFunction(const Function& func) const;

  /*!
   * \brief The context information that is used to help perform a module pass.
   */
  PassContext pass_ctx_;
};

RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass);

class SequentialPass;

/*!
 * \brief The SequentialPassNode contains a set of passes that transform Relay
 * programs from one AST to another semantically equivalent one.
 *
 * One example of this level of pass is that the pass manager needs to correctly
 * perform a host of optimizations with a given optimization level and disabled
 * passes.
 */
class SequentialPassNode : public PassNode {
 public:
  /* \brief The pass meta data.*/
  PassInfo pass_info;

  /*! \brief A list of passes that used to compose a sequential pass. */
  tvm::Array<Pass> passes;
  /*!
   * \brief A list of disabled passes that should be excluded when executing the
   * sequential pass.
   */
  tvm::Array<tvm::Expr> disabled;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("pass_info", &pass_info);
    v->Visit("passes", &passes);
    v->Visit("disabled", &disabled);
  }

  /*!
   * \brief Get the pass information/meta data.
   */
  PassInfo Info() const { return pass_info; }

  /*!
   * \brief Add a pass to the pass list.
   *
   * \param pass The candidate pass to be added.
   */
  void AddPass(const Pass& pass) {
    passes.push_back(pass);
  }

  TVM_DLL static SequentialPass make(tvm::Array<Pass> passes,
                                     PassInfo pass_info,
                                     tvm::Array<tvm::Expr> disabled);

  /*!
   * \brief Resolve the pass dependency. It globs all required passes by
   *        a given pass and executes them.
   *
   * \param mod The module that an optimization pass runs on.
   *
   * \return The updated module after resolving pass dependencies.
   *
   * TODO(zhiics) Build a dependency graph among the passes using provided
   * metadata, i.e. required_passes. Likely, we can have a data structure, i.e.
   * PassInfo, to store the relevant information including the parent passes.
   */
  void ResolveDependency(const Module& mod);

  TVM_DLL std::vector<std::string> DisabledPasses() const;

  /*!
   * \brief Perform optimizations on a series of passes. The aforementioned
   *        typical pass manager jobs could be done by it. This function could
   *        be overloaded to focus on different metrics, i.e. performance,
   *        memory footprint, etc.
   *
   * \param mod The module that an optimization pass runs on.
   *
   * \return Return the updated module.
   */
  Module operator()(const Module& mod) const final;

  /*!
   * \brief Set the context information for a sequential pass.
   *
   * \param pass_ctx The context information for a sequential pass.
   */
  void SetContext(const PassContext& pass_ctx) final;

  static constexpr const char* _type_key = "relay.SequentialPass";
  TVM_DECLARE_NODE_TYPE_INFO(SequentialPassNode, PassNode);

 private:
  /*!
   * \brief The context information that is used to help perform a module pass.
   */
  PassContext pass_ctx_;
};

RELAY_DEFINE_NODE_REF(SequentialPass, SequentialPassNode, Pass);

PassInfo PassInfoNode::make(int opt_level, std::string name,
                            tvm::Array<tvm::Expr> required) {
  auto pass_info = make_node<PassInfoNode>();
  pass_info->opt_level = opt_level;
  pass_info->name = std::move(name);
  pass_info->required = std::move(required);
  return PassInfo(pass_info);
}

PassContext PassContextNode::make() {
  auto ctx = make_node<PassContextNode>();
  return PassContext(ctx);
}

ModulePass ModulePassNode::make(
    runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func,
    PassInfo pass_info) {
  auto n = make_node<ModulePassNode>();
  n->pass_func = std::move(pass_func);
  n->pass_info = std::move(pass_info);
  return ModulePass(n);
}

// Module -> Module optimizations.
// TODO(zhiics) 1. Check and handle the required passes.
//              2. Probably use CoW for all places that use module instead of
//              returning the updated one.
Module ModulePassNode::operator()(const Module& mod) const {
  PassInfo pass_info = Info();
  LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name
            << " with opt level: " << pass_info.operator->()->opt_level << "\n";
  CHECK(mod.defined());
  auto updated_mod = pass_func(mod, pass_ctx_);
  CHECK(updated_mod.defined());
  return updated_mod;
}

void ModulePassNode::SetContext(const PassContext& pass_ctx) {
  pass_ctx_ = pass_ctx;
}

FunctionPass FunctionPassNode::make(
    runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func,
    PassInfo pass_info) {
  auto n = make_node<FunctionPassNode>();
  n->pass_func = std::move(pass_func);
  n->pass_info = std::move(pass_info);
  return FunctionPass(n);
}

// Perform Module -> Module optimizations at the Function level.
// TODO(zhiics) Check and handle the required passes.
Module FunctionPassNode::operator()(const Module& mod) const {
  PassInfo pass_info = Info();
  LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name
            << " with opt level: " << pass_info.operator->()->opt_level << "\n";
  CHECK(mod.defined());
  std::vector<std::pair<GlobalVar, Function>> updated_funcs;
  ModuleNode* mod_node = mod.operator->();
  for (const auto& it : mod_node->functions) {
    if (!SkipFunction(it.second)) {
      auto updated_func = pass_func(it.second, pass_ctx_);
      CHECK(updated_func.defined());
      updated_funcs.push_back({std::move(it.first), std::move(updated_func)});
    }
  }

  // Update the optimized functions.
  for (const auto& it : updated_funcs) {
    mod_node->Update(it.first, it.second);
  }

  return GetRef<Module>(mod_node);
}

void FunctionPassNode::SetContext(const PassContext& pass_ctx) {
  pass_ctx_ = pass_ctx;
}

// TODO(zhiics) Create an enum attribute for FunctionNode
// enum Attribute {kPrimitive, kSkipOptimization}
bool FunctionPassNode::SkipFunction(const Function& func) const {
  NodeRef res = FunctionGetAttr(func, "SkipOptimization");
  const ir::IntImm* pval = res.as<ir::IntImm>();
  return pval && pval->value != 0;
}

SequentialPass SequentialPassNode::make(tvm::Array<Pass> passes,
                                        PassInfo pass_info,
                                        tvm::Array<tvm::Expr> disabled) {
  auto n = make_node<SequentialPassNode>();
  n->passes = std::move(passes);
  n->pass_info = std::move(pass_info);
  n->disabled = std::move(disabled);
  return SequentialPass(n);
}

// TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in
// a SequentialPass without the consideration of their orders. The phase
// ordering problem needed to be handled in the future.
Module SequentialPassNode::operator()(const Module& module) const {
  Module mod = module;
  for (const Pass& pass : passes) {
    CHECK(pass.defined()) << "Found undefined pass for optimization.";
    const auto* pn = pass.operator->();
    mod = (*pn)(mod);
  }
  return mod;
}

void SequentialPassNode::ResolveDependency(const Module& mod) {
  // TODO(zhiics) Implement it.
  // 1. Consider the required passes for each pass.
  // 2. Only resolve the enabled passes.
  // 3. Build a dependency graph. Probably we need to update the pass list.
  LOG(FATAL) << "Pass dependency has not been resolved yet."
             << "\n";
}

std::vector<std::string> SequentialPassNode::DisabledPasses() const {
  std::vector<std::string> ret;
  for (const auto& it : disabled) {
    const auto* str = it.as<tvm::ir::StringImm>();
    CHECK(str) << "disabled passes must be string.";
    ret.push_back(str->value);
  }
  return ret;
}

void SequentialPassNode::SetContext(const PassContext& pass_ctx) {
  pass_ctx_ = pass_ctx;
}

Pass CreateModulePass(
    const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
    int opt_level,
    const std::string& name,
    const tvm::Array<tvm::Expr>& required) {
  PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
  return ModulePassNode::make(pass_func, pass_info);
}

Pass CreateFunctionPass(
    const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
    int opt_level,
    const std::string& name,
    const tvm::Array<tvm::Expr>& required) {
  PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
  return FunctionPassNode::make(pass_func, pass_info);
}

Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
                          int opt_level,
                          const std::string& name,
                          const tvm::Array<tvm::Expr>& required,
                          const tvm::Array<tvm::Expr>& disabled) {
  PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
  return SequentialPassNode::make(passes, pass_info, disabled);
}

TVM_REGISTER_NODE_TYPE(PassInfoNode);

TVM_REGISTER_API("relay._ir_pass.PassInfo")
.set_body_typed(PassInfoNode::make);

TVM_REGISTER_API("relay._ir_pass.Info")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  Pass pass = args[0];
  *ret = pass->Info();
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PassInfoNode>([](const PassInfoNode* node,
                                tvm::IRPrinter* p) {
  p->stream << "The meta data of the pass: ";
  p->stream << "pass name: " << node->name;
  p->stream << "opt_level: " << node->opt_level;
  p->stream << "required passes: [" << "\n";
  for (const auto& it : node->required) {
    const auto* str = it.as<tvm::ir::StringImm>();
    p->stream << str->value << ", ";
  }
  p->stream << "]\n";
});

TVM_REGISTER_NODE_TYPE(ModulePassNode);

TVM_REGISTER_API("relay._ir_pass.CreateModulePass")
.set_body_typed(CreateModulePass);

TVM_REGISTER_API("relay._ir_pass.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  Pass pass = args[0];
  Module mod = args[1];
  CHECK(pass.defined())
      << "Running an undefined pass is not allowed."
      << "\n";

  const auto* pn = pass.operator->();
  *ret = (*pn)(mod);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ModulePassNode>([](const ModulePassNode* node,
                                 tvm::IRPrinter* p) {
  const PassInfoNode* pn = node->Info().operator->();
  p->stream << "Run Module pass: " << pn->name
            << " at the optimization level " << pn->opt_level;
});

TVM_REGISTER_NODE_TYPE(FunctionPassNode);

TVM_REGISTER_API("relay._ir_pass.CreateFunctionPass")
.set_body_typed(CreateFunctionPass);

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
                                   tvm::IRPrinter* p) {
  const PassInfoNode* pn = node->Info().operator->();
  p->stream << "Run Function pass: " << pn->name
            << " at the optimization level " << pn->opt_level;
});

TVM_REGISTER_NODE_TYPE(SequentialPassNode);

TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  tvm::Array<Pass> passes = args[0];
  int opt_level = args[1];
  std::string name = args[2];
  tvm::Array<tvm::Expr> required = args[3];
  tvm::Array<tvm::Expr> disabled = args[4];
  PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
  *ret = SequentialPassNode::make(passes, pass_info, disabled);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SequentialPassNode>([](const SequentialPassNode* node,
                                     tvm::IRPrinter* p) {
  const PassInfoNode* seq_pn = node->Info().operator->();
  p->stream << "Run SequentialPass pass: " << seq_pn->name
            << " at the optimization level. " << seq_pn->opt_level;
  p->stream << "The passes will be executed are: [";
  for (const auto& it : node->passes) {
    const PassNode* pn = it.operator->();
    const PassInfoNode* pass_info_node = pn->Info().operator->();
    p->stream << pass_info_node->name << " ";
  }
  p->stream << "]";
});

TVM_REGISTER_API("relay._ir_pass.SetContext")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  Pass pass = args[0];
  PassContext pass_ctx = args[1];
  pass->SetContext(pass_ctx);
});

TVM_REGISTER_NODE_TYPE(PassContextNode);

TVM_REGISTER_API("relay._ir_pass.PassContext")
.set_body_typed(PassContextNode::make);

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PassContextNode>([](const PassContextNode* node,
                                tvm::IRPrinter* p) {
    p->stream << "TODO(zhiics): printing context";
    LOG(FATAL) << "PassContext printer has not been implemented yet."
               << "\n";
});

}  // namespace pass
}  // namespace relay
}  // namespace tvm