/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * Copyright (c) 2018 by Contributors
 *
 * \file partial_eval.cc
 *
 * \brief Perform known computation in compile time.
 *
 * The partial evaluator try to do computation at compile time,
 * so it can generate code that do less work.
 * Additionally, it might open more chance for further optimization,
 * since the high level, structural part of the code (closure, reference, control flow)
 * might get partially evaluated away, and the subsequent optimization (for example, kernel fusion)
 * can reason across those structural code as it got removed.
 * In the extreme case, partial evaluation can even turn the whole program
 * into pure first order computation with no control flow.
 * In such a case, we can compile the whole computation onto SIMD Instruction/GPU/FPGA,
 * and get huge speedup.
 *
 * It works by making the following modifications to the standard relay interpreter:
 *
 * 0: The values become partially static value.
 * Since we cannot know the value of every term at compile time,
 * Term might get partially evaluated to 'Unknown Value'.
 * Every partially static value is, hence,
 * a static fragment that might not be there (partially static),
 * and a dynamic fragment that is semantically equivalent to the original term,
 * so the unknown part will be computed at runtime, using the dynamic fragment.
 *
 * 1: The interpreter holds a LetList, which preserves A Normal Form for the generated code.
 * More specifically, we require that all dynamic is an atom.
 * This avoids code duplication (which is both inefficient and incorrect), as atom has constant size
 * and allow us to not handle capture-avoidance substitution (as atom has no binder).
 *
 * 2: The map of References to partially static values is reified, as described below.
 * Instead of Reference having mutable field, Reference only has an unique identifier.
 * There will be a mutable mapping of id to partially static value, called the store.
 * This allow us to rollback the store:
 * when a path may or may not be executed (as in a conditional), we copy the store,
 * recurse with the copy, and reinstate the original when the call returns
 * so that the effects of the computation are not preserved.
 * We do this in if else, pattern matching, and in function,
 * as, when we see a function, we partially evaluate it with all the argument as dynamic,
 * to generate efficient dynamic for that function.
 *
 * 3: The generated code reuses bindings (although they are not shadowed),
 * so we have to deduplicate them.
 *
 * 4: In the generated code, multiple VarNode might have same Id.
 * While it is permitted, most pass use NodeHash for Var,
 * and having multiple VarNode for same Id break them.
 * Thus we remap them to a single Id for now.
 *
 * Also, It will also generate lots of dead code,
 * so it is a good idea to feed it through the dead code eliminator after partial evaluation.
 *
 * The partial evaluator makes several assumptions, so there is room for improvement:
 *
 * 0: The partial evaluator treats global variables as opaque.
 * Doing PartialEval on a module level will solve this.
 *
 * 1: The partial evaluator assume all functions as terminating.
 * We need to has a max_expand parameter that shrink on every compile time evaluation,
 * to make sure PE does not infinite loop.
 * Additionally, we might add a termination analysis pass that lift this requirement
 * for function that analysis found terminating.
 *
 * 2: Every time an unknown effect happened, we clear the whole store.
 * It is too conservative: if a local reference is created (and do not get passed outside),
 * An unknown global function call/global reference write can not modify it.
 * We can pair PE with escape analysis/alias analysis.
 *
 * 3: We assume all unknown code has effect. Doing effect analysis can make the store more precise.
 *
 * 4: When doing pattern matching, we can simplify the match even for dynamic case.
 * Right now it is all or nothing: either a complete match, or the original dynamic code.
 * Instead, we can get a match tree, pair it with the data and evaluate it to a normal form.
 * We then can reify the result.
 *
 * 5: Every time a function is called, it's code will get expanded and partially evaluated.
 * We can do a binding time analysis to cache the result and avoid re-partial evaluation.
 *
 * These assumptions do not affect the correctness of the algorithm, however.
 */
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h>
#include "pass_util.h"
#include "let_list.h"

namespace tvm {
namespace relay {

using namespace runtime;

/*! \brief Hash Var by it's id.
 * Different VarNode might has same vid, and they are considered to be the same var in such case.
 * Use VarHash to hash Var by id.
 */
struct VarHash {
  size_t operator()(const Var& v) const {
    return v->vid.hash();
  }
};

/*! \brief Compare Var by it's id.
 * Different VarNode might has same vid, and they are considered to be the same var in such case.
 * Use VarEqual to compare Var by id.
 */
struct VarEqual {
  bool operator()(const Var& l, const Var& r) const {
    return l->vid.get() == r->vid.get();
  }
};

/*! \brief The base container type of Relay values. */
class StaticNode : public RelayNode {
 public:
  static constexpr const char* _type_key = "relay.Value";
  TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode);
};

class Static : public NodeRef {
 public:
  Static() {}
  explicit Static(NodePtr<Node> n) : NodeRef(n) {}
  const ValueNode* operator->() const {
    return static_cast<const ValueNode*>(node_.get());
  }

  using ContainerType = StaticNode;
};

struct PStaticNode : Node {
  Static pstatic;  // may be null
  Expr dynamic;
  PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { }
  explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { }
  TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node);
};

RELAY_DEFINE_NODE_REF(PStatic, PStaticNode, NodeRef);

struct STupleNode : StaticNode {
  std::vector<PStatic> fields;
  explicit STupleNode(const std::vector<PStatic>& fields) : fields(fields) { }
  TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode);
};

RELAY_DEFINE_NODE_REF(STuple, STupleNode, Value);

Static MkSTuple(const std::vector<PStatic>& fields) {
  return Static(make_node<STupleNode>(fields));
}

struct STensorNode : StaticNode {
  runtime::NDArray data;
  explicit STensorNode(const NDArray& data) : data(data) { }
  TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode);
};

RELAY_DEFINE_NODE_REF(STensor, STensorNode, Value);

Static MkSTensor(const NDArray& data) {
  return Static(make_node<STensorNode>(data));
}

struct SConstructorNode : StaticNode {
  Constructor constructor;
  std::vector<PStatic> fields;
  SConstructorNode(const Constructor& constructor, const std::vector<PStatic>& fields) :
    constructor(constructor), fields(fields) { }
  TVM_DECLARE_NODE_TYPE_INFO(SConstructorNode, StaticNode);
};

RELAY_DEFINE_NODE_REF(SConstructor, SConstructorNode, Value);

Static MkSConstructor(const Constructor& constructor, const std::vector<PStatic>& fields) {
  return Static(make_node<SConstructorNode>(constructor, fields));
}

struct SRefNode : StaticNode {
  // we will use the address as the guid for hashing
  TVM_DECLARE_NODE_TYPE_INFO(SRefNode, StaticNode);
};

RELAY_DEFINE_NODE_REF(SRef, SRefNode, Value);

Static MkSRef() {
  return Static(make_node<SRefNode>());
}

using Func = std::function<PStatic(const std::vector<PStatic>&,
                                      const Attrs&,
                                      const Array<Type>&,
                                      LetList*)>;

struct SFuncNode : StaticNode {
  Func func;
  explicit SFuncNode(const Func& func) : func(func) { }
  TVM_DECLARE_NODE_TYPE_INFO(SFuncNode, StaticNode);
};

RELAY_DEFINE_NODE_REF(SFunc, SFuncNode, Value);

Static MkSFunc(const Func& func) {
  return Static(make_node<SFuncNode>(func));
}

/*!
 * \brief A stack frame in the Relay interpreter.
 *
 * Contains a mapping from relay::Var to relay::Value.
 */
struct Frame {
  /*! \brief The set of local variables and arguments for the frame. */
  std::unordered_map<Var, PStatic, VarHash, VarEqual> locals;
  Frame() = default;
};

class Environment {
 public:
  Environment() : env_({Frame()}) { }
  Environment(const Environment&) = delete;

  template<typename T>
  T Extend(const std::function<T()>& body) {
    FrameContext fc(this);
    return body();
  }

  void Insert(const Var& v, const PStatic& ps) {
    CHECK(ps.defined());
    env_.back().locals[v] = ps;
  }

  PStatic Lookup(const Var& v) {
    auto rit = env_.rbegin();
    while (rit != env_.rend()) {
      if (rit->locals.find(v) != rit->locals.end()) {
        return rit->locals.find(v)->second;
      }
      ++rit;
    }
    LOG(FATAL) << "Unknown Variable: " << v;
    throw;
  }

 private:
  std::list<Frame> env_;

  struct FrameContext {
    Environment* env_;
    explicit FrameContext(Environment* env) : env_(env) {
      env_->env_.push_back(Frame());
    }
    ~FrameContext() {
      env_->env_.pop_back();
    }
  };
};

/*!
 * \brief As our store require rollback, we implement it as a frame.
 * every time we need to copy the store, a new frame is insert.
 * every time we roll back, a frame is popped.
 */
struct StoreFrame {
  std::unordered_map<const SRefNode*, PStatic> store;
  /*! \brief on unknown effect, history_valid is set to true to signal above frame is outdated */
  bool history_valid = true;
  explicit StoreFrame(const std::unordered_map<const SRefNode*, PStatic>& store) : store(store) { }
  StoreFrame() = default;
};

class Store {
 public:
  Store() : store_({StoreFrame()}) { }
  Store(const Store&) = delete;

  template<typename T>
  T Extend(const std::function<T()>& body) {
    StoreFrameContext sfc(this);
    return body();
  }

  void Insert(const SRefNode* r, const PStatic& ps) {
    store_.back().store[r] = ps;
  }

  // return null if not found
  PStatic Lookup(const SRefNode* r) {
    auto rit = store_.rbegin();
    while (rit != store_.rend()) {
      if (!rit->history_valid) {
        return PStatic();
      }
      if (rit->store.find(r) != rit->store.end()) {
        return rit->store.find(r)->second;
      }
      ++rit;
    }
    return PStatic();
  }

  void Invalidate() {
    store_.back().history_valid = false;
  }

 private:
  std::list<StoreFrame> store_;

  struct StoreFrameContext {
    Store* store_;
    explicit StoreFrameContext(Store* store) : store_(store) {
      store_->store_.push_back(StoreFrame());
    }
    ~StoreFrameContext() {
      store_->store_.pop_back();
    }
  };
};

PStatic HasStatic(const Static& stat, const Expr& dynamic) {
  return PStatic(make_node<PStaticNode>(stat, dynamic));
}

PStatic NoStatic(const Expr& dynamic) {
  return PStatic(make_node<PStaticNode>(dynamic));
}

enum struct MatchStatus {
  Match, NoMatch, Unknown
};

bool StatefulOp(const Expr& e) {
  static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
  struct StatefulOpVisitor : ExprVisitor {
    bool stateful = false;
    void VisitExpr_(const OpNode* op) {
      stateful = stateful || op_stateful.get(GetRef<Op>(op), false);
    }
  };
  StatefulOpVisitor sov;
  sov(e);
  return sov.stateful;
}

using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>;

DLContext CPUContext() {
  DLContext ctx;
  ctx.device_type = kDLCPU;
  ctx.device_id = 0;
  return ctx;
}

FInterpreter CPUInterpreter() {
  Target target = Target::Create("llvm");
  // use a fresh build context
  // in case we are already in a build context.
  With<BuildConfig> fresh_build_ctx(BuildConfig::Create());

  return CreateInterpreter(Module(nullptr), CPUContext(), target);
}

class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
                         public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
 public:
  PartialEvaluator(const tvm::Array<Var>& free_vars) {
    for (const Var& v : free_vars) {
      env_.Insert(v, NoStatic(v));
    }
  }

  PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final {
    return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef<Expr>(op)));
  }

  PStatic VisitExpr_(const TupleNode* op, LetList* ll) final {
    std::vector<PStatic> value;
    tvm::Array<Expr> expr;
    for (const Expr& e : op->fields) {
      PStatic ps = VisitExpr(e, ll);
      value.push_back(ps);
      expr.push_back(ps->dynamic);
    }
    return HasStatic(MkSTuple(value), ll->Push(TupleNode::make(expr)));
  }

  PStatic VisitExpr_(const TupleGetItemNode* op, LetList* ll) final {
    PStatic ps = VisitExpr(op->tuple, ll);
    if (ps->pstatic.defined()) {
      return Downcast<STuple>(ps->pstatic)->fields[op->index];
    } else {
      return NoStatic(ll->Push(TupleGetItemNode::make(ps->dynamic, op->index)));
    }
  }

  PStatic VisitExpr_(const VarNode* op, LetList* ll) final {
    return env_.Lookup(GetRef<Var>(op));
  }

  PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
    return NoStatic(GetRef<Expr>(op));
  }

  PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
    env_.Insert(op->var, VisitExpr(op->value, ll));
    return VisitExpr(op->body, ll);
  }

  PStatic VisitExpr_(const IfNode* op, LetList* ll) final {
    PStatic c = VisitExpr(op->cond, ll);
    if (c->pstatic.defined()) {
      NDArray cpu_array = Downcast<STensor>(c->pstatic)->data.CopyTo(CPUContext());
      CHECK_EQ(TVMType2Type(cpu_array->dtype), Bool());
      if (reinterpret_cast<uint8_t*>(cpu_array->data)[0]) {
        return VisitExpr(op->true_branch, ll);
      } else {
        return VisitExpr(op->false_branch, ll);
      }
    } else {
      Expr t = store_.Extend<Expr>([&]() {
          return LetList::With([&](LetList* ll) {
              return VisitExpr(op->true_branch, ll)->dynamic;
            });
        });
      Expr f = store_.Extend<Expr>([&]() {
          return LetList::With([&](LetList* ll) {
              return VisitExpr(op->false_branch, ll)->dynamic;
            });
        });
      store_.Invalidate();
      return NoStatic(ll->Push(IfNode::make(c->dynamic, t, f)));
    }
  }

  PStatic VisitExpr_(const RefCreateNode* op, LetList* ll) final {
    PStatic ps = VisitExpr(op->value, ll);
    Static r = MkSRef();
    store_.Insert(r.as<SRefNode>(), ps);
    return HasStatic(r, ll->Push(RefCreateNode::make(ps->dynamic)));
  }

  PStatic VisitExpr_(const RefWriteNode* op, LetList* ll) final {
    PStatic r = VisitExpr(op->ref, ll);
    PStatic v = VisitExpr(op->value, ll);
    if (r->pstatic.defined()) {
      store_.Insert(r->pstatic.as<SRefNode>(), v);
    } else {
      store_.Invalidate();
    }
    return HasStatic(MkSTuple({}), ll->Push(RefWriteNode::make(r->dynamic, v->dynamic)));
  }

  PStatic VisitExpr_(const RefReadNode* op, LetList* ll) final {
    PStatic r = VisitExpr(op->ref, ll);
    if (r->pstatic.defined()) {
      PStatic ret = store_.Lookup(r->pstatic.as<SRefNode>());
      if (ret) {
        return ret;
      }
    }
    return NoStatic(ll->Push(RefReadNode::make(r->dynamic)));
  }

  PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
    PStatic f = VisitExpr(op->op, ll);
    std::vector<PStatic> x;
    tvm::Array<Expr> x_dyn;
    for (const Expr& e : op->args) {
      PStatic ps = VisitExpr(e, ll);
      x.push_back(ps);
      x_dyn.push_back(ps->dynamic);
    }
    if (f->pstatic.defined()) {
      return Downcast<SFunc>(f->pstatic)->func(x, op->attrs, op->type_args, ll);
    } else {
      store_.Invalidate();
      return NoStatic(ll->Push(CallNode::make(f->dynamic, x_dyn, op->attrs, op->type_args)));
    }
  }

  PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
    Function func = GetRef<Function>(op);
    if (func->IsPrimitive()) {
      return HasStatic(MkSFunc(ConstEvaluateFunc(func, ll)), func);
    }
    std::vector<std::pair<Var, PStatic> > free_vars;
    for (const auto& v : FreeVars(GetRef<Expr>(op))) {
      free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v)));
    }
    Func f = [=](const std::vector<PStatic>& pv,
                 const Attrs& attrs,
                 const tvm::Array<Type>& type_args,
                 LetList* ll) {
      return env_.Extend<PStatic>([&]() {
          CHECK_EQ(pv.size(), func->params.size());
          for (size_t i = 0; i < pv.size(); ++i) {
            env_.Insert(func->params[i], pv[i]);
          }
          for (const auto& p : free_vars) {
            env_.Insert(p.first, p.second);
          }
          tvm::Map<TypeVar, Type> subst;
          for (size_t i = 0; i < type_args.size(); ++i) {
            subst.Set(func->type_params[i], type_args[i]);
          }
          for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
            subst.Set(func->type_params[i], Type());
          }
          return VisitExpr(TypeSubst(func->body, subst), ll);
        });
    };
    Expr dyn = store_.Extend<Expr>([&]() {
        store_.Invalidate();
        return FunctionNode::make(func->params, LetList::With([&](LetList* ll) {
              std::vector<PStatic> pv;
              for (const auto& v : func->params) {
                pv.push_back(NoStatic(v));
              }
              tvm::Array<Type> type_args;
              for (const auto& tp : func->type_params) {
                type_args.push_back(tp);
              }
              return f(pv, Attrs(), type_args, ll)->dynamic;
            }), func->ret_type, func->type_params, func->attrs);
      });
    return HasStatic(MkSFunc(f), ll->Push(dyn));
  }

  Expr Reflect(const PStatic& st) {
    if (const STensorNode* op = st->pstatic.as<STensorNode>()) {
      return ConstantNode::make(op->data);
    } else if (const STupleNode* op = st->pstatic.as<STupleNode>()) {
      tvm::Array<Expr> fields;
      for (const PStatic& field : op->fields) {
        fields.push_back(Reflect(field));
      }
      return TupleNode::make(fields);
    } else {
      LOG(FATAL) << "Unknown case";
      throw;
    }
  }

  PStatic Reify(const Value& v, LetList* ll) const {
    if (const TensorValueNode* op = v.as<TensorValueNode>()) {
      return HasStatic(MkSTensor(op->data), ll->Push(ConstantNode::make(op->data)));
    } else if (const TupleValueNode* op = v.as<TupleValueNode>()) {
      std::vector<PStatic> fields;
      tvm::Array<Expr> fields_dyn;
      for (const Value& field : op->fields) {
        PStatic ps = Reify(field, ll);
        fields.push_back(ps);
        fields_dyn.push_back(ps->dynamic);
      }
      return HasStatic(MkSTuple(fields), ll->Push(TupleNode::make(fields_dyn)));
    } else {
      LOG(FATAL) << "Unknown case";
      throw;
    }
  }

  // Constant evaluate a expression.
  PStatic ConstEvaluate(const Expr& expr, LetList* ll) {
    Expr infered = InferType(expr, Module(nullptr));
    Expr fused = FuseOps(infered, 0, Module(nullptr));
    Expr fused_infered = InferType(fused, Module(nullptr));
    return Reify(executor_(fused_infered), ll);
  }

  Func ConstEvaluateFunc(const Expr& expr, LetList* ll) {
    return [=](const std::vector<PStatic>& pv,
               const Attrs& attrs,
               const tvm::Array<Type>& type_args,
               LetList* ll) {
      tvm::Array<Expr> ns_args;
      for (const PStatic& ps : pv) {
        ns_args.push_back(ps->dynamic);
      }
      PStatic ns = NoStatic(CallNode::make(expr, ns_args, attrs, type_args));
      if (StatefulOp(expr)) {
        return ns;
      }
      tvm::Array<Expr> args;
      for (const PStatic& ps : pv) {
        if (ps->pstatic.defined()) {
          args.push_back(Reflect(ps));
        } else {
          return ns;
        }
      }
      return ConstEvaluate(CallNode::make(expr, args, attrs, type_args), ll);
    };
  }

  PStatic VisitExpr_(const OpNode* op, LetList* ll) final {
    return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op), ll)), GetRef<Expr>(op));
  }

  PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final {
    Constructor c = GetRef<Constructor>(op);
    Func f = [=](const std::vector<PStatic>& pv,
                 const Attrs& attrs,
                 const tvm::Array<Type>& type_args,
                 LetList* ll) {
      tvm::Array<Expr> dyn;
      for (const PStatic& ps : pv) {
        dyn.push_back(ps->dynamic);
      }
      return HasStatic(MkSConstructor(c, pv), ll->Push(CallNode::make(c, dyn)));
    };
    return HasStatic(MkSFunc(f), GetRef<Expr>(op));
  }

  PStatic VisitExpr_(const MatchNode* op, LetList* ll) final {
    PStatic ps = VisitExpr(op->data, ll);
    return env_.Extend<PStatic>([&]() {
        for (const Clause& c : op->clauses) {
          switch (VisitPattern(c->lhs, ps)) {
          case MatchStatus::Match:
            return VisitExpr(c->rhs, ll);
          case MatchStatus::NoMatch:
            continue;
          case MatchStatus::Unknown:
            tvm::Array<Clause> clauses;
            for (const Clause& c : op->clauses) {
              Expr expr = store_.Extend<Expr>([&]() {
                  return LetList::With([&](LetList* ll) {
                      for (const Var& v : BoundVars(c->lhs)) {
                        env_.Insert(v, NoStatic(v));
                      }
                      return VisitExpr(c->rhs, ll)->dynamic;
                    });
                });
              clauses.push_back(ClauseNode::make(c->lhs, expr));
            }
            store_.Invalidate();
            return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses)));
          }
        }
        LOG(FATAL) << "No case Match";
        throw;
      });
  }

  MatchStatus VisitPattern_(const PatternWildcardNode* op, const PStatic& ps) final {
    return MatchStatus::Match;
  }

  MatchStatus VisitPattern_(const PatternVarNode* op, const PStatic& ps) final {
    env_.Insert(op->var, ps);
    return MatchStatus::Match;
  }

  MatchStatus VisitPattern_(const PatternConstructorNode* op, const PStatic& ps) final {
    if (ps->pstatic.defined()) {
      SConstructor scn = Downcast<SConstructor>(ps->pstatic);
      CHECK_NE(op->constructor->tag, -1);
      CHECK_NE(scn->constructor->tag, -1);
      if (op->constructor->tag == scn->constructor->tag) {
        // todo(M.K.): should use ptr equality but it is broken
        CHECK_EQ(op->patterns.size(), scn->fields.size());
        MatchStatus current_match_status = MatchStatus::Match;
        for (size_t i = 0; i < op->patterns.size(); ++i) {
          MatchStatus ms = VisitPattern(op->patterns[i], scn->fields[i]);
          switch (ms) {
          case MatchStatus::Match:
            continue;
          case MatchStatus::NoMatch:
            return MatchStatus::NoMatch;
          case MatchStatus::Unknown:
            current_match_status = MatchStatus::Unknown;
          }
        }
        return current_match_status;
      }
      return MatchStatus::NoMatch;
    } else {
      return MatchStatus::Unknown;
    }
  }

 private:
  Environment env_;
  Store store_;
  DLContext context_ = CPUContext();
  FInterpreter executor_ = CPUInterpreter();
};

Var DeDupVar(const Var& v) {
  return VarNode::make(v->name_hint(), v->type_annotation);
}

TypeVar DeDupTypeVar(const TypeVar& tv) {
  return TypeVarNode::make(tv->var->name_hint, tv->kind);
}

/*! \brief Use a fresh Id for every Var to make the result well-formed. */
Expr DeDup(const Expr& e) {
  class DeDupMutator : public ExprMutator, public PatternMutator {
   public:
    Var Fresh(const Var& v) {
      Var ret = DeDupVar(v);
      rename_[v] = ret;
      return ret;
    }

    Expr VisitExpr(const Expr& e) final {
      return ExprMutator::VisitExpr(e);
    }

    Expr VisitExpr_(const VarNode* op) final {
      Var v = GetRef<Var>(op);
      return rename_.count(v) != 0 ? rename_.at(v) : v;
    }

    Expr VisitExpr_(const LetNode* op) final {
      return LetNode::make(Fresh(op->var), VisitExpr(op->value), VisitExpr(op->body));
    }

    Expr VisitExpr_(const FunctionNode* op) final {
      tvm::Array<Var> params;
      for (const Var& param : op->params) {
        params.push_back(Fresh(param));
      }
      return FunctionNode::make(params,
                                VisitExpr(op->body),
                                op->ret_type,
                                op->type_params,
                                op->attrs);
    }

    Pattern VisitPattern(const Pattern& p) final {
      return PatternMutator::VisitPattern(p);
    }

    Var VisitVar(const Var& v) final {
      return Fresh(v);
    }

   private:
    std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
  };
  return DeDupMutator().VisitExpr(e);
}

/*! \brief Remap multiple Var sharing the same Id into the same Var. */
Expr Remap(const Expr& e) {
  class RemapMutator : public ExprMutator, public PatternMutator {
    Expr VisitExpr_(const VarNode* op) final {
      Var v = GetRef<Var>(op);
      if (remap_.count(v) == 0) {
        remap_.insert({v, v});
      }
      return remap_.at(v);
    }

    Var VisitVar(const Var& v) final {
      return Downcast<Var>(VisitExpr(v));
    }

   private:
    std::unordered_map<Var, Var, VarHash, VarEqual> remap_;
  };
  return RemapMutator().VisitExpr(e);
}

Expr PartialEval(const Expr& e) {
  return TransformF([&](const Expr& e) {
      return LetList::With([&](LetList* ll) {
          PartialEvaluator pe(FreeVars(e));
          return Remap(DeDup(pe.VisitExpr(e, ll)->dynamic));
        });
    }, e);
}

TVM_REGISTER_API("relay._ir_pass.partial_evaluate")
.set_body_typed(PartialEval);

namespace transform {

Pass PartialEval() {
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
    [=](Function f, Module m, PassContext pc) {
    return Downcast<Function>(PartialEval(f));
  };
  return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {});
}

TVM_REGISTER_API("relay._transform.PartialEvaluate")
.set_body_typed(PartialEval);

}  // namespace transform

}  // namespace relay
}  // namespace tvm