let_list.h 4.26 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28
/*!
 * \file let_list.h
 * \brief LetList record let binding and insert let expression implicitly.
 *  using it, one can treat AST as value instead of expression,
 *  and pass them around freely without fear of AST explosion (or effect duplication).
 *  for example, if one write 'b = a + a; c = b + b; d = c + c', the AST will contain 8 'a'.
 *  if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);',
 *  the AST will contain 2 'a', as b and c are now variables.
 */
29 30
#ifndef TVM_RELAY_TRANSFORMS_LET_LIST_H_
#define TVM_RELAY_TRANSFORMS_LET_LIST_H_
31 32

#include <tvm/relay/expr.h>
33
#include <tvm/relay/analysis.h>
34 35 36
#include <utility>
#include <vector>
#include <tuple>
37
#include <string>
38 39 40 41 42
#include "tvm/relay/type.h"

namespace tvm {
namespace relay {

43 44
/*!
 * \brief LetList allow you to transform expression into variables, so you can copy them around.
45 46 47 48 49
 *  one can insert into the LetList by calling Push, and wrap an expression with bindings with Get.
 *  additionally, there is the 'With' function, which automatically call Get.
 */
class LetList {
 public:
50 51
  ~LetList() {
    if (lets_.size() > 0 && !used_) {
52
      LOG(WARNING) << "letlist not used";
53 54
    }
  }
55 56
  /*!
   * \brief insert a binding.
57
   *
58
   * \param pv the var of the binding.
59
   *
60
   * \param expr the value of the binding.
61
   *
62
   * \return a Var that hold the inserted expr.
63
   */
64
  Var Push(Var pv, Expr expr) {
65
    CHECK(!used_);
66
    CHECK(WellFormed(expr));
67
    lets_.emplace_back(std::make_pair(pv, expr));
68 69 70
    return pv;
  }

71 72
  /*!
   * \brief insert a binding.
73
   *
74
   * \param expr the value of the binding.
75
   *
76 77
   * \param ty the type of the binding.
   *
78
   * \return a Var that hold the inserted expr.
79
   */
80
  Var Push(Expr expr, Type ty) {
81
    return Push(Var("x", ty), expr);
82 83
  }

84 85
  /*!
   * \brief insert a binding.
86 87 88 89 90
   *
   *  \param expr the value of the binding.
   *
   *  \return a Var that hold the inserted expr.
   */
91
  Var Push(Expr expr) {
92
    return Push(expr, Type());
93 94
  }

95 96
  /*!
   * \brief wrap an expr around the LetList.
97 98 99 100 101
   *
   *  \param body the Expression to be wrapped around.
   *
   *  \return the wrapped expr.
   */
102 103
  Expr Get(const Expr& body) {
    CHECK(!used_);
104 105
    Expr ret = body;
    for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
106
      ret = Let(std::get<0>(*rit), std::get<1>(*rit), ret);
107
    }
108
    used_ = true;
109 110 111 112 113 114 115 116 117 118 119 120 121 122
    return ret;
  }

  /*! \brief generate an LetList and wrap the result automatically.
   *
   *  \param f a function that generate the unwrapped Expr.
   *
   *  \code
   *  // Example code that generate `16 * a` using 4 plus instead of 15 plus.
   *  Expr mult_sixteen(const Var& a) {
   *    Op plus = Op::Get("plus");
   *    // Automatically call Get with LetList::With
   *    return LetList::With([&](LetList* ll) {
   *      // Turn a call to plus into a variable to avoid duplication of code
123 124 125 126
   *      Var b = ll->Push(Call(plus, {a, a}));
   *      Var c = ll->Push(Call(plus, {b, b}));
   *      Var d = ll->Push(Callplus, {c, c}));
   *      return Call(plus, {d, d});
127 128 129 130 131 132 133 134 135 136 137 138
   *    });
   *  }
   *  \endcode
   *
   *  \return the wrapped Expr.
   */
  template<typename F>
  static Expr With(F&& f) {
    LetList ll;
    return ll.Get(f(&ll));
  }

139
  static Expr LetBind(const Expr& e, const std::function<Expr(const Var&)>& f) {
140 141 142 143 144
    return With([&](LetList* ll) {
      return f(ll->Push(e));
    });
  }

145
 private:
146
  std::vector<std::pair<Var, Expr> > lets_;
147
  bool used_ = false;
148 149 150 151
};

}  // namespace relay
}  // namespace tvm
152
#endif  // TVM_RELAY_TRANSFORMS_LET_LIST_H_