let_list.h 3.24 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
/*!
 *  Copyright (c) 2018 by Contributors
 * \file let_list.h
 * \brief LetList record let binding and insert let expression implicitly.
 *  using it, one can treat AST as value instead of expression,
 *  and pass them around freely without fear of AST explosion (or effect duplication).
 *  for example, if one write 'b = a + a; c = b + b; d = c + c', the AST will contain 8 'a'.
 *  if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);',
 *  the AST will contain 2 'a', as b and c are now variables.
 */
#ifndef TVM_RELAY_PASS_LET_LIST_H_
#define TVM_RELAY_PASS_LET_LIST_H_

#include <tvm/relay/expr.h>
#include <utility>
#include <vector>
#include <tuple>
#include "tvm/relay/type.h"

namespace tvm {
namespace relay {

/*! \brief LetList allow you to transform expression into variables, so you can copy them around.
 *  one can insert into the LetList by calling Push, and wrap an expression with bindings with Get.
 *  additionally, there is the 'With' function, which automatically call Get.
 */
class LetList {
 public:
29 30
  /*!
   * \brief insert a binding.
31
   *
32
   * \param pv the var of the binding.
33
   *
34
   * \param expr the value of the binding.
35
   *
36
   * \return a Var that hold the inserted expr.
37
   */
38
  Var Push(Var pv, Expr expr) {
39
    CHECK(!used_);
40
    lets_.emplace_back(std::make_pair(pv, expr));
41 42 43
    return pv;
  }

44 45
  /*!
   * \brief insert a binding.
46
   *
47
   * \param ty the type of the binding.
48
   *
49
   * \param expr the value of the binding.
50
   *
51
   * \return a Var that hold the inserted expr.
52
   */
53 54
  Var Push(Type ty, Expr expr) {
    return Push(VarNode::make("x", ty), expr);
55 56
  }

57 58
  /*!
   * \brief insert a binding.
59 60 61 62 63
   *
   *  \param expr the value of the binding.
   *
   *  \return a Var that hold the inserted expr.
   */
64
  Var Push(Expr expr) {
65
    return Push(IncompleteTypeNode::make(Kind::kType), expr);
66 67
  }

68 69
  /*!
   * \brief wrap an expr around the LetList.
70 71 72 73 74
   *
   *  \param body the Expression to be wrapped around.
   *
   *  \return the wrapped expr.
   */
75 76
  Expr Get(const Expr& body) {
    CHECK(!used_);
77 78
    Expr ret = body;
    for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
79
      ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
80
    }
81
    used_ = true;
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
    return ret;
  }

  /*! \brief generate an LetList and wrap the result automatically.
   *
   *  \param f a function that generate the unwrapped Expr.
   *
   *  \code
   *  // Example code that generate `16 * a` using 4 plus instead of 15 plus.
   *  Expr mult_sixteen(const Var& a) {
   *    Op plus = Op::Get("plus");
   *    // Automatically call Get with LetList::With
   *    return LetList::With([&](LetList* ll) {
   *      // Turn a call to plus into a variable to avoid duplication of code
   *      Var b = ll->Push(CallNode::make(plus, {a, a}));
   *      Var c = ll->Push(CallNode::make(plus, {b, b}));
   *      Var d = ll->Push(CallNode::make(plus, {c, c}));
   *      return CallNode::make(plus, {d, d});
   *    });
   *  }
   *  \endcode
   *
   *  \return the wrapped Expr.
   */
  template<typename F>
  static Expr With(F&& f) {
    LetList ll;
    return ll.Get(f(&ll));
  }

 private:
113
  std::vector<std::pair<Var, Expr> > lets_;
114
  bool used_ = false;
115 116 117 118 119
};

}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_PASS_LET_LIST_H_