/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file alter_op_layout.cc
 * \brief Alternate the layouts of operators or replace primitive operators with
          other expressions. This pass can be used for computing convolution in
          custom layouts or other general weight pre-transformation.
 */
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/transform.h>
#include <tvm/te/operation.h>
#include <tuple>
#include <vector>
#include <functional>
#include <string>
#include <utility>
#include <unordered_map>

#include "transform_layout.h"
#include "pattern_util.h"

namespace tvm {
namespace relay {

namespace alter_op_layout {

/*!
 * \brief Container to instantiate a Node for alter op layouts.
 */
class AlterTransformMemorizerNode : public TransformMemorizerNode {
 public:
  static constexpr const char* _type_key = "relay.alter_op_layout.AlterTransformMemorizerNode";
};

/*!
 * \brief Container that provides the transformation function for alter layout..
 */
class AlterTransformMemorizer : public TransformMemorizer {
 public:
  AlterTransformMemorizer() {}
  explicit AlterTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {}

  AlterTransformMemorizerNode* operator->() {
    return static_cast<AlterTransformMemorizerNode*>(get_mutable());
  }

  /*!
   * \brief Defines the call transformation for AlterOpLayout pass. The new layouts are defined by
   * used for different targets using a packed func.
   * \param ref_call The original call.
   * \param new_args The traversed/recursed args to the call.
   * \return The new Call after calling the packed func.
   */
  Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override {
    static auto falter_layout = Op::GetAttr<FTVMAlterOpLayout>("FTVMAlterOpLayout");
    Op op = Downcast<Op>(ref_call->op);

    Expr new_e;
    bool modified = false;
    if (falter_layout.count(op)) {
      tvm::Array<tvm::te::Tensor> tinfos;
      for (auto expr : ref_call->args) {
        auto ttype = expr->type_as<TensorTypeNode>();
        tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype));
      }
      // TODO(@kevinthesun, @icemelon9): This won't work if inputs/outputs are dynamic shapes.
      //   Probably we need to disable the AlterOpLayout when compiling dynamic models.
      Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos,
                                             ref_call->checked_type());
      if (altered_value.defined()) {
        new_e = altered_value;
        modified = true;
      }
    }
    if (!modified) {
      new_e = Call(ref_call->op, new_args, ref_call->attrs);
    }

    const CallNode* new_call = new_e.as<CallNode>();
    CHECK(new_call) << "Can only replace the original operator with another call node";
    return GetRef<Call>(new_call);
  }

  using ContainerType = AlterTransformMemorizerNode;
};

/*!
 * Limitations:
 * 1. The altered op should have the same number of arguments as the previous one.
 * 2. Do not support nested tuple arguments.
 */
Expr AlterOpLayout(const Expr& expr) {
  AlterTransformMemorizer alterMemorizer(make_object<AlterTransformMemorizerNode>());
  auto fcontext = [&](const Call& call) -> ObjectRef { return alterMemorizer; };

  return ForwardRewrite(expr, LayoutRewriter<AlterTransformMemorizer>, fcontext);
}

}  // namespace alter_op_layout

namespace transform {

Pass AlterOpLayout() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
      return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
  };
  return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
                            {tir::StringImmNode::make("InferType")});
}

TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout")
.set_body_typed(AlterOpLayout);

}  // namespace transform

}  // namespace relay
}  // namespace tvm