/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2019 by Contributors
 * \file stmt_simplify.cc
 * \brief Statement simplifier based on analyzer
 */
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <tvm/arithmetic.h>

namespace tvm {
namespace arith {

using namespace ir;

class StmtSimplifier : public IRMutator {
 public:
  using IRMutator::Mutate;

  Expr Mutate(Expr expr) final {
    return analyzer_.Simplify(expr);
  }

  Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
    for (auto kv : vrange) {
      analyzer_.Bind(kv.first, kv.second);
    }
    return Mutate(stmt);
  }

  Stmt Mutate_(const For* op, const Stmt& s) final {
    analyzer_.Bind(op->loop_var,
                   Range::make_by_min_extent(op->min, op->extent));
    return IRMutator::Mutate_(op, s);
  }

  Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
    Expr value = this->Mutate(op->value);
    if (!ir::HasSideEffect(value)) {
      analyzer_.Bind(op->var, value);
      return this->Mutate(op->body);
    }
    Stmt body = this->Mutate(op->body);
    if (value.same_as(op->value) &&
        body.same_as(op->body)) {
      return s;
    } else {
      return LetStmt::make(op->var, value, body);
    }
  }

  // IfThenElse
  Stmt Mutate_(const IfThenElse* op, const Stmt& s) {
    Expr condition = this->Mutate(op->condition);
    Stmt then_case, else_case;
    {
      With<ConstraintContext> ctx(&analyzer_, condition);
      then_case = this->Mutate(op->then_case);
    }
    if (op->else_case.defined()) {
      With<ConstraintContext> ctx(&analyzer_, Mutate(Not::make(condition)));
      else_case = this->Mutate(op->else_case);
    }
    if (is_one(condition)) return then_case;
    if (is_zero(condition)) {
      if (else_case.defined()) {
        return else_case;
      }
      return Evaluate::make(0);
    }

    if (condition.same_as(op->condition) &&
        then_case.same_as(op->then_case) &&
        else_case.same_as(op->else_case)) {
      return s;
    } else {
      return IfThenElse::make(condition, then_case, else_case);
    }
  }

  // AttrStmt
  Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
    if (op->attr_key == attr::thread_extent ||
        op->attr_key == attr::virtual_thread) {
      IterVar iv(op->node.node_);
      CHECK_NE(iv->thread_tag.length(), 0U);
      if (!var_dom_.count(iv->var.get())) {
        Range dom = Range::make_by_min_extent(0, op->value);
        var_dom_[iv->var.get()] = dom;
        analyzer_.Bind(iv->var, dom);
      }
      Stmt stmt = IRMutator::Mutate_(op, s);
      return stmt;
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }

  // AssertStmt
  Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
    Expr condition = this->Mutate(op->condition);
    Expr message = this->Mutate(op->message);
    With<ConstraintContext> ctx(&analyzer_, condition);
    Stmt body = this->Mutate(op->body);

    if (condition.same_as(op->condition) &&
        message.same_as(op->message) &&
        body.same_as(op->body)) {
      return s;
    } else {
      return AssertStmt::make(condition, message, body);
    }
  }

  // eliminate useless stores
  Stmt Mutate_(const Store* op, const Stmt& s) {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Store>();
    if (const Load* load = op->value.as<Load>()) {
      if (load->buffer_var.same_as(op->buffer_var) &&
          Equal(load->index, op->index)) {
        return Evaluate::make(0);
      }
    }
    return stmt;
  }

 protected:
  Analyzer analyzer_;
  // variable domain
  std::unordered_map<const Variable*, Range> var_dom_;
};

}  // namespace arith

namespace ir {

Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
  return arith::StmtSimplifier().Simplify(
      stmt, vrange);
}

Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
  arith::Analyzer analyzer;
  for (auto kv : vrange) {
    analyzer.Bind(kv.first, kv.second);
  }
  return analyzer.canonical_simplify(expr);
}

Expr Simplify(Expr expr, Map<Var, Range> vrange) {
  arith::Analyzer analyzer;
  for (auto kv : vrange) {
    analyzer.Bind(kv.first, kv.second);
  }
  expr = analyzer.Simplify(expr);
  return expr;
}

Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
  return arith::StmtSimplifier().Simplify(
      stmt, vrange);
}
}  // namespace ir
}  // namespace tvm