/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *
 * \file de_duplicate.cc
 * \brief Use a fresh Id for every Var to make the result well-formed.
 */
#include <tvm/ir/type_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/pattern_functor.h>

namespace tvm {
namespace relay {

Expr DeDup(const Expr& e) {
  class DeDupMutator : public TypeMutator,
                       public ExprMutator,
                       public PatternMutator {
   public:
    TypeVar Fresh(const TypeVar& tv) {
      TypeVar ret = TypeVar(tv->name_hint, tv->kind);
      type_rename_[tv] = ret;
      return ret;
    }

    Var Fresh(const Var& v) {
      CHECK_EQ(rename_.count(v), 0);
      CHECK_EQ(memo_.count(v), 0) << v.as<VarNode>();
      Var ret = Var(v->name_hint(), VisitType(v->type_annotation));
      rename_[v] = ret;
      return ret;
    }

    Expr VisitExpr(const Expr& e) final {
      auto ret = ExprMutator::VisitExpr(e);
      ret->checked_type_ = e->checked_type_;
      return ret;
    }

    Expr VisitExpr_(const VarNode* op) final {
      Var v = GetRef<Var>(op);
      return rename_.count(v) != 0 ? rename_.at(v) : v;
    }

    Expr VisitExpr_(const LetNode* op) final {
      Var v = Fresh(op->var);
      return Let(v, VisitExpr(op->value), VisitExpr(op->body));
    }

    Type VisitType(const Type& t) final {
      return t.defined() ? TypeMutator::VisitType(t) : t;
    }

    Expr VisitExpr_(const FunctionNode* op) final {
      tvm::Array<TypeVar> type_params;
      for (const TypeVar& type_param : op->type_params) {
        type_params.push_back(Fresh(type_param));
      }
      tvm::Array<Var> params;
      for (const Var& param : op->params) {
        params.push_back(Fresh(param));
      }
      return Function(params,
                                VisitExpr(op->body),
                                VisitType(op->ret_type),
                                type_params,
                                op->attrs);
    }

    Pattern VisitPattern(const Pattern& p) final {
      return PatternFunctor::VisitPattern(p);
    }

    Pattern VisitPattern_(const PatternVarNode* op) final {
      return PatternVar(Fresh(op->var));
    }

    Type VisitType_(const TypeVarNode* op) final {
      TypeVar v = GetRef<TypeVar>(op);
      return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
    }

    Var VisitVar(const Var& v) final {
      return Fresh(v);
    }

   private:
    std::unordered_map<Var, Var, ObjectHash, ObjectEqual> rename_;
    std::unordered_map<TypeVar, TypeVar, ObjectHash, ObjectEqual> type_rename_;
  };
  CHECK(WellFormed(e)) << AsText(e, false);
  Expr ret = DeDupMutator().VisitExpr(e);
  CHECK(WellFormed(ret));
  CHECK_EQ(FreeVars(e).size(), FreeVars(ret).size());
  return ret;
}

TVM_REGISTER_GLOBAL("relay._transform.dedup")
.set_body_typed(DeDup);

}  // namespace relay
}  // namespace tvm