ir_functor.h 8.86 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
/*!
 * \file tvm/node/ir_functor.h
 * \brief Defines the IRFunctor data structures.
 */
#ifndef TVM_NODE_IR_FUNCTOR_H_
#define TVM_NODE_IR_FUNCTOR_H_

#include <dmlc/logging.h>
#include <string>
#include <vector>
#include <memory>
#include <type_traits>
#include <utility>
#include <functional>
#include "node.h"

namespace tvm {
/*!
37
 * \brief A dynamically dispatched functor on ObjectRef in the first argument.
38 39
 *
 * \code
40
 *   IRFunctor<std::string (const ObjectRef& n, std::string prefix)> tostr;
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
 *   tostr.set_dispatch<Add>([](const Add* op, std::string prefix) {
 *     return prefix + "Add";
 *   });
 *   tostr.set_dispatch<IntImm>([](const IntImm* op) {
 *     return prefix + "IntImm"
 *   });
 *
 *   Expr x = make_const(1);
 *   Expr y = x + x;
 *   // dispatch to IntImm, outputs "MyIntImm"
 *   LOG(INFO) << tostr(x, "My");
 *   // dispatch to IntImm, outputs "MyAdd"
 *   LOG(INFO) << tostr(y, "My");
 * \endcode
 *
 * \tparam FType function signiture
 *  This type if only defined for FType with function signature
 */
template<typename FType>
class IRFunctor;

template<typename R, typename ...Args>
63
class IRFunctor<R(const ObjectRef& n, Args...)> {
64
 private:
65 66
  using Function = std::function<R (const ObjectRef&n, Args...)>;
  using TSelf = IRFunctor<R (const ObjectRef& n, Args...)>;
67 68 69 70 71 72 73 74 75 76 77
  /*! \brief internal function table */
  std::vector<Function> func_;

 public:
  /*! \brief the result type of this functor */
  using result_type = R;
  /*!
   * \brief Whether the functor can dispatch the corresponding Node
   * \param n The node to be dispatched
   * \return Whether dispatching function is registered for n's type.
   */
78 79
  inline bool can_dispatch(const ObjectRef& n) const {
    uint32_t type_index = n->type_index();
80 81 82 83 84 85 86 87
    return type_index < func_.size() && func_[type_index] != nullptr;
  }
  /*!
   * \brief invoke the functor , dispatch on type of n
   * \param n The Node argument
   * \param args The additional arguments
   * \return The result.
   */
88 89
  inline R operator()(const ObjectRef& n, Args... args) const {
    uint32_t type_index = n->type_index();
90 91 92
    CHECK(type_index < func_.size() &&
          func_[type_index] != nullptr)
        << "IRFunctor calls un-registered function on type "
93
        << n->GetTypeKey();
94 95 96 97 98 99 100 101 102 103
    return func_[type_index](n, std::forward<Args>(args)...);
  }
  /*!
   * \brief set the dispacher for type TNode
   * \param f The function to be set.
   * \tparam TNode the type of Node to be dispatched.
   * \return reference to self.
   */
  template<typename TNode>
  inline TSelf& set_dispatch(Function f) {  // NOLINT(*)
104
    uint32_t tindex = TNode::RuntimeTypeIndex();
105 106 107 108
    if (func_.size() <= tindex) {
      func_.resize(tindex + 1, nullptr);
    }
    CHECK(func_[tindex] == nullptr)
109
        << "Dispatch for " << TNode::_type_key
110 111 112 113 114 115
        << " is already set";
    func_[tindex] = f;
    return *this;
  }
  /*!
   * \brief set the dispacher for type TNode
116
   *  This allows f to used detailed const Node pointer to replace ObjectRef
117 118 119 120 121 122 123
   *
   * \param f The function to be set.
   * \tparam TNode the type of Node to be dispatched.
   * \return reference to self.
   */
  template<typename TNode>
  inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
124 125
    Function fun = [f](const ObjectRef& n, Args... args) {
      return f(static_cast<const TNode*>(n.get()),
126 127 128 129 130 131 132 133 134 135 136 137
               std::forward<Args>(args)...);
    };
    return this->set_dispatch<TNode>(fun);
  }
  /*!
  * \brief unset the dispacher for type TNode
  *
  * \tparam TNode the type of Node to be dispatched.
  * \return reference to self.
  */
  template<typename TNode>
  inline TSelf& clear_dispatch() {  // NOLINT(*)
138
    uint32_t tindex = TNode::RuntimeTypeIndex();
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
    CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
    func_[tindex] = nullptr;
    return *this;
  }
};

#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif

/*! \brief helper macro to generate string concat */
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)

#define TVM_REGISTER_VAR_DEF(ClsName)                                 \
  static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName

/*!
 * \brief Useful macro to set IRFunctor dispatch in a global static field.
 *
 * \code
 *  // Use IRFunctor to implement IRPrinter similar to Visitor Pattern.
 *  // vtable allows easy patch in of new Node types, without changing
 *  // interface of IRPrinter.
 *
 *  class IRPrinter {
 *   public:
 *    std::ostream& stream;
 *    // the dispatch function.
 *    void print(Expr e) {
 *      const static FType& f = *vtable();
 *      f(e, this);
 *    }
 *
175
 *    using FType = IRFunctor<void (const ObjectRef&, IRPrinter *)>;
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
 *    // function to return global function table
 *    static FType& vtable();
 *  };
 *
 *  // in cpp/cc file
 *  IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*)
 *    static FType inst; return inst;
 *  }
 *
 *  TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 *  .set_dispatch<Add>([](const Add* n, IRPrinter* p) {
 *    p->print(n->a);
 *    p->stream << '+'
 *    p->print(n->b);
 *  });
 *
 *
 * \endcode
 *
 * \param ClsName The name of the class
 * \param FField The static function that returns a singleton of IRFunctor.
 */
#define TVM_STATIC_IR_FUNCTOR(ClsName, FField)                       \
  TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__)  =      \
                              ClsName::FField()

 /*!
 * \brief A container for a list of callbacks. All callbacks are invoked when
 * the object is destructed.
 */
class IRFunctorCleanList {
 public:
  ~IRFunctorCleanList() {
    for (auto &f : clean_items) {
      f();
    }
  }

  void append(std::function<void()> func) {
    clean_items.push_back(func);
  }

 private:
  std::vector< std::function<void()> > clean_items;
};

/*!
* \brief A wrapper around IRFunctor that will record calls to set_dispatch
* and make a corresponding call to clear_dispatch when the last copy of
* the IRFunctorStaticRegistry is destructed. When assigned to a static variable,
* this can be used by NNVM and other libraries to unregister callbacks when
* the library is unloaded. This prevents crashes when the underlying IRFunctor
* is destructed as it will no longer contain std::function instances allocated
* by a library that has been unloaded.
*/
template<typename FType>
class IRFunctorStaticRegistry;

template<typename R, typename ...Args>
235
class IRFunctorStaticRegistry<R(const ObjectRef& n, Args...)> {
236
 private:
237
  IRFunctor<R(const ObjectRef& n, Args...)> *irf_;
238 239
  std::shared_ptr<IRFunctorCleanList> free_list;

240
  using TSelf = IRFunctorStaticRegistry<R(const ObjectRef& n, Args...)>;
241 242

 public:
243
  IRFunctorStaticRegistry(IRFunctor<R(const ObjectRef& n, Args...)> *irf) {
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
    irf_ = irf;
    free_list = std::make_shared<IRFunctorCleanList>();
  }

  template<typename TNode>
  inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) {  // NOLINT(*)
    irf_->template set_dispatch<TNode>(f);
    auto irf_copy = irf_;
    free_list.get()->append([irf_copy] {
      irf_copy->template clear_dispatch<TNode>();
      });
    return *this;
  }
};

/*!
* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows
* the compiler to deduce the template types.
*/
template<typename R, typename ...Args>
264 265 266
IRFunctorStaticRegistry<R(const ObjectRef& n, Args...)> MakeIRFunctorStaticRegistry(
  IRFunctor<R(const ObjectRef& n, Args...)> *irf) {
  return IRFunctorStaticRegistry<R(const ObjectRef& n, Args...)>(irf);
267 268
}

269
#define TVM_AUTO_REGISTER_VAR_DEF(ClsName)                        \
270 271 272 273 274 275 276 277 278 279 280 281 282
  static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName

/*!
* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry.
* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of
* TVM_STATIC_IR_FUNCTOR.
*/
#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField)                  \
  TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__)  = \
                        MakeIRFunctorStaticRegistry(&ClsName::FField())

}  // namespace tvm
#endif  // TVM_NODE_IR_FUNCTOR_H_