/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2019 by Contributors
 * \file let_list.h
 * \brief LetList record let binding and insert let expression implicitly.
 *  using it, one can treat AST as value instead of expression,
 *  and pass them around freely without fear of AST explosion (or effect duplication).
 *  for example, if one write 'b = a + a; c = b + b; d = c + c', the AST will contain 8 'a'.
 *  if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);',
 *  the AST will contain 2 'a', as b and c are now variables.
 */
#ifndef TVM_RELAY_PASS_LET_LIST_H_
#define TVM_RELAY_PASS_LET_LIST_H_

#include <tvm/relay/expr.h>
#include <tvm/relay/analysis.h>
#include <utility>
#include <vector>
#include <tuple>
#include <string>
#include "tvm/relay/type.h"

namespace tvm {
namespace relay {

/*!
 * \brief LetList allow you to transform expression into variables, so you can copy them around.
 *  one can insert into the LetList by calling Push, and wrap an expression with bindings with Get.
 *  additionally, there is the 'With' function, which automatically call Get.
 */
class LetList {
 public:
  ~LetList() {
    if (lets_.size() > 0 && !used_) {
      LOG(WARNING) << "letlist not used";
    }
  }
  /*!
   * \brief insert a binding.
   *
   * \param pv the var of the binding.
   *
   * \param expr the value of the binding.
   *
   * \return a Var that hold the inserted expr.
   */
  Var Push(Var pv, Expr expr) {
    CHECK(!used_);
    CHECK(WellFormed(expr));
    lets_.emplace_back(std::make_pair(pv, expr));
    return pv;
  }

  /*!
   * \brief insert a binding.
   *
   * \param expr the value of the binding.
   *
   * \param ty the type of the binding.
   *
   * \return a Var that hold the inserted expr.
   */
  Var Push(Expr expr, Type ty) {
    return Push(VarNode::make("x", ty), expr);
  }

  /*!
   * \brief insert a binding.
   *
   *  \param expr the value of the binding.
   *
   *  \return a Var that hold the inserted expr.
   */
  Var Push(Expr expr) {
    return Push(expr, Type());
  }

  /*!
   * \brief wrap an expr around the LetList.
   *
   *  \param body the Expression to be wrapped around.
   *
   *  \return the wrapped expr.
   */
  Expr Get(const Expr& body) {
    CHECK(!used_);
    Expr ret = body;
    for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
      ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
    }
    used_ = true;
    return ret;
  }

  /*! \brief generate an LetList and wrap the result automatically.
   *
   *  \param f a function that generate the unwrapped Expr.
   *
   *  \code
   *  // Example code that generate `16 * a` using 4 plus instead of 15 plus.
   *  Expr mult_sixteen(const Var& a) {
   *    Op plus = Op::Get("plus");
   *    // Automatically call Get with LetList::With
   *    return LetList::With([&](LetList* ll) {
   *      // Turn a call to plus into a variable to avoid duplication of code
   *      Var b = ll->Push(CallNode::make(plus, {a, a}));
   *      Var c = ll->Push(CallNode::make(plus, {b, b}));
   *      Var d = ll->Push(CallNode::make(plus, {c, c}));
   *      return CallNode::make(plus, {d, d});
   *    });
   *  }
   *  \endcode
   *
   *  \return the wrapped Expr.
   */
  template<typename F>
  static Expr With(F&& f) {
    LetList ll;
    return ll.Get(f(&ll));
  }

  static Expr Let(const Expr& e, const std::function<Expr(const Var&)>& f) {
    return With([&](LetList* ll) {
      return f(ll->Push(e));
    });
  }

 private:
  std::vector<std::pair<Var, Expr> > lets_;
  bool used_ = false;
};

}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_PASS_LET_LIST_H_