/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *
 * \file dead_code.cc
 *
 * \brief Remove code that does not effect the program result.
 *
 * The algorithm is implemented by two visitor:
 * CalcDep turn an expr into a dependency graph of expr,
 * GenLet turn the dependency graph into a let list, taking only the used value.
 */
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include "let_list.h"

namespace tvm {
namespace relay {

template<typename X>
using VarMap = std::unordered_map<Var, X, ObjectHash, ObjectEqual>;
using VarSet = std::unordered_set<Var, ObjectHash, ObjectEqual>;

class CalcDep;
class FindDef : private ExprVisitor {
 private:
  VarMap<Expr> expr_map_;

  void VisitExpr_(const LetNode* l) final {
    CHECK_EQ(expr_map_.count(l->var), 0);
    expr_map_[l->var] = l->value;
    VisitExpr(l->value);
    VisitExpr(l->body);
  }

  friend CalcDep;
};

class Eliminator : private ExprMutator {
 private:
  VarMap<Expr> expr_map_;
  VarMap<size_t> use_map_;
  bool inline_once_;
  explicit Eliminator(const VarMap<Expr>& expr_map,
                      const VarMap<size_t>& use_map,
                      bool inline_once) :
    expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) { }
  friend CalcDep;

  bool HasLet(const Var& v) {
    switch (use_map_[v]) {
    case 0:
      return false;
    case 1:
      return !inline_once_;
    default:
      return true;
    }
  }

  Expr VisitExpr_(const VarNode* op) final {
    Var v = GetRef<Var>(op);
    return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]);
  }

  Expr VisitExpr_(const LetNode* op) final {
    Var v = op->var;
    if (HasLet(v)) {
      return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
    } else {
      return VisitExpr(op->body);
    }
  }
};

// calculate the dependency graph from expression
class CalcDep : private ExprVisitor {
 public:
  static Expr Eliminate(const Expr& e, bool inline_once) {
    FindDef fd;
    fd(e);
    CalcDep cd(fd.expr_map_);
    cd(e);
    Eliminator el(fd.expr_map_, cd.use_map_, inline_once);
    return el(e);
  }

 private:
  explicit CalcDep(const VarMap<Expr>& expr_map) : expr_map_(expr_map) { }
  VarMap<Expr> expr_map_;
  VarMap<size_t> use_map_;

  void VisitExpr(const Expr& e) final {
    visit_counter_[e.get()]++;
    // The dce code seprate variable into three parts:
    // used 0 times (remove)
    // used 1 times (inline)
    // used 2 times (dont do anything).
    if (visit_counter_[e.get()] <= 2) {
      using TParent = ExprFunctor<void(const Expr&)>;
      TParent::VisitExpr(e);
    }
  }

  void VisitExpr_(const LetNode* l) final {
    VisitExpr(l->body);
  }

  void VisitExpr_(const VarNode* v) final {
    Var var = GetRef<Var>(v);
    ++use_map_[var];
    if (use_map_[var] == 1 && expr_map_.count(var) > 0) {
      VisitExpr(expr_map_[var]);
    }
  }
};

Expr DeadCodeElimination(const Expr& e, bool inline_once) {
  return CalcDep::Eliminate(e, inline_once);
}

namespace transform {

Pass DeadCodeElimination(bool inline_once) {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
    return Downcast<Function>(DeadCodeElimination(f, inline_once));
  };
  return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
}

TVM_REGISTER_GLOBAL("relay._transform.DeadCodeElimination")
.set_body_typed(DeadCodeElimination);

}  // namespace transform

}  // namespace relay
}  // namespace tvm