/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * Copyright (c) 2018 by Contributors
 *
 * \file dead_code.cc
 *
 * \brief Remove code that does not effect the program result.
 *
 * The algorithm is implemented by two visitor:
 * CalcDep turn an expr into a dependency graph of expr,
 * GenLet turn the dependency graph into a let list, taking only the used value.
 */
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include "let_list.h"

namespace tvm {
namespace relay {

// calculate the dependency graph from expression
class CalcDep : private ExprVisitor {
 public:
  static Expr Eliminate(const Expr& e, bool inline_once) {
    CalcDep cd;
    cd.Calculate(e);
    Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
    return el(e);
  }

 private:
  template<typename X>
  using VarMap = std::unordered_map<Var, X, NodeHash, NodeEqual>;
  using VarSet = std::unordered_set<Var, NodeHash, NodeEqual>;
  VarMap<Expr> expr_map_;
  VarMap<size_t> use_map_;
  VarSet letrec_set_;
  bool count_ = true;
  VarSet dead_worklist_;
  VarSet current_letrec_;

  void LetRec(const std::function<void()>& func, const Var& v) {
    current_letrec_.insert(v);
    func();
    current_letrec_.erase(v);
  }

  void VisitExpr_(const LetNode* l) final {
    if (count_) {
      CHECK_EQ(expr_map_.count(l->var), 0);
      CHECK_EQ(use_map_.count(l->var), 0);
      expr_map_[l->var] = l->value;
      use_map_[l->var] = 0;
      dead_worklist_.insert(l->var);
      LetRec([&]() { VisitExpr(l->value); }, l->var);
    }
    VisitExpr(l->body);
  }

  void VisitExpr(const Expr& e) final {
    ExprFunctor<void(const Expr&)>::VisitExpr(e);
  }

  void VisitExpr_(const VarNode* v) final {
    Var var = GetRef<Var>(v);
    if (expr_map_.count(var) == 0) {
      return;
    }
    if (current_letrec_.count(var) == 0) {
      if (count_) {
        use_map_[var] += 1;
        dead_worklist_.erase(var);
      } else {
        CHECK_GT(use_map_[var], 0) << var;
        use_map_[var] -= 1;
        if (use_map_[var] == 0) {
          dead_worklist_.insert(var);
        }
      }
    } else {
      letrec_set_.insert(var);
    }
  }

  void Calculate(const Expr& v) {
    VisitExpr(v);
    count_ = false;
    while (!dead_worklist_.empty()) {
      Var dead = *(dead_worklist_.begin());
      dead_worklist_.erase(dead);
      CHECK_EQ(use_map_[dead], 0);
      if (expr_map_.count(dead) > 0) {
        LetRec([&]() { VisitExpr(expr_map_[dead]); }, dead);
      }
    }
  }

  class Eliminator : private ExprMutator {
   private:
    VarMap<Expr> expr_map_;
    VarMap<size_t> use_map_;
    VarSet letrec_set_;
    bool inline_once_;
    explicit Eliminator(const VarMap<Expr>& expr_map,
                        const VarMap<size_t>& use_map,
                        const VarSet& letrec_set,
                        bool inline_once) :
      expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
    friend CalcDep;

    bool HasLet(const Var& v) {
      switch (use_map_[v]) {
      case 0:
        return false;
      case 1:
        return letrec_set_.count(v) > 0 || !inline_once_;
      default:
        return true;
      }
    }

    Expr VisitExpr_(const VarNode* op) final {
      Var v = GetRef<Var>(op);
      return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]);
    }

    Expr VisitExpr_(const LetNode* op) final {
      Var v = op->var;
      if (HasLet(v)) {
        return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
      } else {
        return VisitExpr(op->body);
      }
    }
  };
};

Expr DeadCodeElimination(const Expr& e, bool inline_once) {
  return CalcDep::Eliminate(e, inline_once);
}

TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
.set_body_typed(DeadCodeElimination);

namespace transform {

Pass DeadCodeElimination(bool inline_once) {
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
    [=](Function f, Module m, PassContext pc) {
    return Downcast<Function>(DeadCodeElimination(f, inline_once));
  };
  return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
}

TVM_REGISTER_API("relay._transform.DeadCodeElimination")
.set_body_typed(DeadCodeElimination);

}  // namespace transform

}  // namespace relay
}  // namespace tvm