/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *
 * \file forward_rewrite.cc
 * \brief Apply rewriting rules in a forward fashion.
 */
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include "pass_util.h"

namespace tvm {
namespace relay {

// Realizer class that realizes the expression
// Note that we can take benefit of its internal memo
// so that calling realize repeatively won't hurt perf.
class TempRealizer : private MixedModeMutator {
 public:
  Expr Realize(Expr expr) {
    return Mutate(expr);
  }

 private:
  Expr DispatchVisitExpr(const Expr& expr) final {
    Expr res;
    if (const auto* temp = expr.as<TempExprNode>()) {
      res = temp->Realize();
    } else {
      res = MixedModeMutator::DispatchVisitExpr(expr);
    }
    return res;
  }
};

class ForwardRewriter : private MixedModeMutator {
 public:
  ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
                  std::function<ObjectRef(const Call&)> fcontext,
                  std::function<Expr(const Expr&)> fmulti_ref_trigger)
      : rewrite_map_(rewrite_map),
        fcontext_(fcontext),
        fmulti_ref_trigger_(fmulti_ref_trigger) {}

  ForwardRewriter(const FForwardRewrite* rewrite_func,
                  std::function<ObjectRef(const Call&)> fcontext,
                  std::function<Expr(const Expr&)> fmulti_ref_trigger)
      : rewrite_func_(rewrite_func),
        fcontext_(fcontext),
        fmulti_ref_trigger_(fmulti_ref_trigger) {}


  // Transform expression.
  Expr Rewrite(const Expr& expr) {
    if (fmulti_ref_trigger_ != nullptr) {
      ref_counter_ = GetExprRefCount(expr);
    }
    return realizer_.Realize(this->VisitExpr(expr));
  }

 private:
  // The rewrite rule.
  const OpMap<FForwardRewrite>* rewrite_map_{nullptr};
  const FForwardRewrite* rewrite_func_{nullptr};
  // The context.const
  std::function<ObjectRef(const Call&)> fcontext_{nullptr};
  // The multiple reference trigger
  std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr};
  // Internal ref counter
  std::unordered_map<const Object*, size_t> ref_counter_;
  // internal realizer
  TempRealizer realizer_;

  // Visit and allow non-realized version.
  Expr GetTempExpr(const Expr& expr, const Expr& post)  {
    if (fmulti_ref_trigger_ != nullptr) {
      Expr ret = post;
      auto it = ref_counter_.find(expr.get());
      CHECK(it != ref_counter_.end());
      if (it->second > 1) {
        ret = fmulti_ref_trigger_(ret);
      }
      return ret;
    } else {
      return post;
    }
  }

  // Automatic fold TupleGetItem.
  Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
    Expr tuple = this->GetTempExpr(op->tuple, post.as<TupleGetItemNode>()->tuple);
    if (const auto* ptuple = tuple.as<TupleNode>()) {
      return ptuple->fields[op->index];
    } else {
      if (tuple.same_as(op->tuple)) {
        return GetRef<Expr>(op);
      } else {
        return TupleGetItem(tuple, op->index);
      }
    }
  }

  Expr Rewrite_(const TupleNode* op, const Expr& post) final {
    tvm::Array<Expr> fields;
    bool all_fields_unchanged = true;
    const auto* post_node = post.as<TupleNode>();
    for (size_t i = 0; i < op->fields.size(); ++i) {
      auto new_field = this->GetTempExpr(op->fields[i], post_node->fields[i]);
      fields.push_back(new_field);
      all_fields_unchanged &= new_field.same_as(op->fields[i]);
    }

    if (all_fields_unchanged) {
      return GetRef<Expr>(op);
    } else {
      return Tuple(fields);
    }
  }

  Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
    const Call& ref_call = GetRef<Call>(call_node);
    PackedFunc frewrite;
    if (rewrite_func_) {
      frewrite = *rewrite_func_;
    } else {
      CHECK(rewrite_map_);
      frewrite = rewrite_map_->get(call_node->op, nullptr);
    }
    const auto* post_node = post.as<CallNode>();
    auto new_op = post_node->op;
    bool unchanged = call_node->op.same_as(new_op);

    Array<Expr> call_args;
    for (size_t i = 0; i < call_node->args.size(); ++i) {
      Expr new_arg = this->GetTempExpr(call_node->args[i], post_node->args[i]);
      if (frewrite == nullptr) {
        new_arg = realizer_.Realize(new_arg);
      }
      unchanged &= new_arg.same_as(call_node->args[i]);
      call_args.push_back(new_arg);
    }
    // try to rewrite.
    if (frewrite != nullptr) {
      Expr res = frewrite(
          ref_call, call_args,
          fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr));
      if (res.defined()) return res;
      // abort, use old rule
      for (size_t i = 0; i < call_args.size(); ++i) {
        Expr arg = call_args[i];
        Expr new_arg = realizer_.Realize(arg);
        if (!arg.same_as(new_arg)) {
          call_args.Set(i, new_arg);
          unchanged = false;
        }
      }
    }
    if (unchanged) return ref_call;
    return Call(
        new_op, call_args, call_node->attrs, call_node->type_args);
  }
};

Expr ForwardRewrite(const Expr& expr,
                    const std::string& rewrite_map_name,
                    std::function<ObjectRef(const Call&)> fcontext,
                    std::function<Expr(const Expr&)> fmulti_ref_trigger) {
  auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
  return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr);
}

Expr ForwardRewrite(const Expr& expr,
                    const FForwardRewrite& rewrite_func,
                    std::function<ObjectRef(const Call&)> fcontext,
                    std::function<Expr(const Expr&)> fmulti_ref_trigger) {
  return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);
}

}  // namespace relay
}  // namespace tvm