/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file remove_no_op.cc
 * \brief Remove no op from the stmt
 */
#include <tvm/tir/stmt.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_map>

namespace tvm {
namespace tir {

// Mark the statment of each stage.
class NoOpRemover : public StmtMutator {
 public:
  Stmt VisitStmt_(const LetStmtNode* op) final {
    Stmt stmt = StmtMutator::VisitStmt_(op);
    op = stmt.as<LetStmtNode>();
    return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
  }
  Stmt VisitStmt_(const AttrStmtNode* op) final {
    if (op->attr_key == "pragma_debug_skip_region") {
      return MakeEvaluate(0);
    }
    Stmt stmt = StmtMutator::VisitStmt_(op);
    op = stmt.as<AttrStmtNode>();
    return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
  }
  Stmt VisitStmt_(const IfThenElseNode* op) final {
    Stmt stmt = StmtMutator::VisitStmt_(op);
    op = stmt.as<IfThenElseNode>();
    if (op->else_case.defined()) {
      if (is_no_op(op->else_case)) {
        if (is_no_op(op->then_case)) {
          return MakeEvaluate(op->condition);
        } else {
          return IfThenElseNode::make(op->condition, op->then_case);
        }
      } else {
        return stmt;
      }
    } else {
      if (is_no_op(op->then_case)) {
        return MakeEvaluate(op->condition);
      } else {
        return stmt;
      }
    }
  }
  Stmt VisitStmt_(const ForNode* op) final {
    Stmt stmt = StmtMutator::VisitStmt_(op);
    op = stmt.as<ForNode>();
    if (is_zero(op->extent)) {
      return EvaluateNode::make(0);
    }
    return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt;
  }
  Stmt VisitStmt_(const AllocateNode* op) final {
    Stmt stmt = StmtMutator::VisitStmt_(op);
    op = stmt.as<AllocateNode>();
    return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt;
  }
  Stmt VisitStmt_(const ProducerConsumerNode* op) final {
    Stmt stmt = StmtMutator::VisitStmt_(op);
    op = stmt.as<ProducerConsumerNode>();
    return is_no_op(op->body) ? op->body : stmt;
  }
  Stmt VisitStmt_(const RealizeNode* op) final {
    Stmt stmt = StmtMutator::VisitStmt_(op);
    op = stmt.as<RealizeNode>();
    return is_no_op(op->body) ? op->body : stmt;
  }
  Stmt VisitStmt_(const EvaluateNode* op) final {
    if (HasSideEffect(op->value)) return GetRef<Stmt>(op);
    return EvaluateNode::make(0);
  }

  Stmt VisitStmt_(const SeqStmtNode* op) final {
    Stmt ret = StmtMutator::VisitSeqStmt_(op, true);
    op = ret.as<SeqStmtNode>();
    CHECK(op != nullptr);
    bool need_compact = false;
    for (size_t i = 0; i < op->size(); ++i) {
      if (is_no_op(op->seq[i])) need_compact = true;
    }
    if (need_compact) {
      auto n = CopyOnWrite(op);
      size_t top = 0;
      for (size_t i = 0; i < n->seq.size(); ++i) {
        if (!is_no_op(n->seq[i]))  {
          n->seq.Set(top++, n->seq[i]);
        }
      }
      if (top == 1) {
        return n->seq[0];
      } else {
        n->seq.resize(top);
        return Stmt(n);
      }
    } else {
      if (op->size() == 1) {
        return op->seq[0];
      } else {
        return ret;
      }
    }
  }

 private:
  Stmt MakeEvaluate(PrimExpr value) {
    if (HasSideEffect(value)) {
      return EvaluateNode::make(value);
    } else {
      return EvaluateNode::make(0);
    }
  }
  Stmt MakeEvaluate(const Array<PrimExpr>& values) {
    Stmt stmt;
    for (PrimExpr e : values) {
      if (HasSideEffect(e)) {
        if (stmt.defined()) {
          stmt = SeqStmt({stmt, EvaluateNode::make(e)});
        } else {
          stmt = EvaluateNode::make(e);
        }
      }
    }
    return stmt.defined() ? stmt : EvaluateNode::make(0);
  }
};

Stmt RemoveNoOp(Stmt stmt) {
  return NoOpRemover()(std::move(stmt));
}
}  // namespace tir
}  // namespace tvm