/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file simple_passes.cc
 * \brief Implementation of simple passes
 */
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>

namespace tvm {
namespace tir {

class IRSideEffect : public ExprVisitor {
 public:
  void VisitExpr(const PrimExpr& e) final {
    if (has_side_effect_) return;
    ExprVisitor::VisitExpr(e);
  }

  void VisitExpr_(const CallNode* op) final {
    if (!op->is_pure()) {
      has_side_effect_ = true; return;
    } else {
      ExprVisitor::VisitExpr_(op);
    }
  }

  bool has_side_effect_{false};
};

bool HasSideEffect(const PrimExpr& e) {
  IRSideEffect v;
  v(e);
  return v.has_side_effect_;
}

class IRSubstitue : public StmtExprMutator {
 public:
  explicit IRSubstitue(
      const std::unordered_map<const VarNode*, PrimExpr>& smap)
      : smap_(smap) {
  }

  PrimExpr VisitExpr_(const VarNode* op) final {
    auto it = smap_.find(op);
    if (it != smap_.end()) {
      return it->second;
    } else {
      return GetRef<PrimExpr>(op);
    }
  }

  PrimExpr VisitExpr_(const LoadNode* op) final {
    // NOTE: we do not explicit recursivly mutate op->buffer_var
    PrimExpr ret = StmtExprMutator::VisitExpr_(op);
    op = ret.as<LoadNode>();
    auto it = smap_.find(op->buffer_var.get());
    if (it != smap_.end()) {
      return LoadNode::make(
          op->dtype, Downcast<Var>(it->second), op->index, op->predicate);
    } else {
      return ret;
    }
  }

  Stmt VisitStmt_(const StoreNode* op) final {
    // NOTE: we do not explicit recursivly mutate op->buffer_var
    Stmt ret = StmtExprMutator::VisitStmt_(op);
    op = ret.as<StoreNode>();
    auto it = smap_.find(op->buffer_var.get());
    if (it != smap_.end()) {
      return StoreNode::make(
          Downcast<Var>(it->second), op->value, op->index, op->predicate);
    } else {
      return ret;
    }
  }

 private:
  const std::unordered_map<const VarNode*, PrimExpr>& smap_;
};

Stmt Substitute(Stmt stmt,
                const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
  if (value_map.size() == 0) return stmt;
  return IRSubstitue(value_map)(std::move(stmt));
}

PrimExpr Substitute(PrimExpr expr,
                const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
  if (value_map.size() == 0) return expr;
  return IRSubstitue(value_map)(std::move(expr));
}

Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map) {
  std::unordered_map<const VarNode*, PrimExpr> vmap;
  for (const auto& kv : value_map) {
    vmap[kv.first.get()] = kv.second;
  }
  return Substitute(stmt, vmap);
}

PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map) {
  std::unordered_map<const VarNode*, PrimExpr> vmap;
  for (const auto& kv : value_map) {
    vmap[kv.first.get()] = kv.second;
  }
  return Substitute(expr, vmap);
}

class VarTouchVisitor : public ExprVisitor {
 public:
  void VisitExpr(const PrimExpr& e) final {
    if (use_var_) return;
    ExprVisitor::VisitExpr(e);
  }

  void VisitExpr_(const VarNode* op) final {
    Handle(op);
  }

  void VisitExpr_(const LoadNode* op) final {
    Handle(op->buffer_var.get());
    ExprVisitor::VisitExpr_(op);
  }

  virtual void Handle(const VarNode* var) = 0;

  bool use_var_{false};
};

class ExprUseVarVisitor : public VarTouchVisitor {
 public:
  explicit ExprUseVarVisitor(const VarNode* var)
      : var_(var) {}

  void Handle(const VarNode* var) final {
    if (var == var_) use_var_ = true;
  }
 private:
  const VarNode* var_;
};

class ExprUseVSetVisitor : public VarTouchVisitor {
 public:
  explicit ExprUseVSetVisitor(
      const std::unordered_set<const VarNode*>& vset)
      : vset_(vset) {}

  void Handle(const VarNode* var) final {
    if (vset_.count(var)) use_var_ = true;
  }
 private:
  const std::unordered_set<const VarNode*>& vset_;
};

bool ExprUseVar(const PrimExpr& e, const Var& v) {
  ExprUseVarVisitor visitor(v.get());
  visitor(e);
  return visitor.use_var_;
}

bool ExprUseVar(const PrimExpr& e,
                const std::unordered_set<const VarNode*>& vset) {
  ExprUseVSetVisitor visitor(vset);
  visitor(e);
  return visitor.use_var_;
}

}  // namespace tir
}  // namespace tvm
