registry.h 10.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
tqchen committed
21
 * \file tvm/runtime/registry.h
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 * \brief This file defines the TVM global function registry.
 *
 *  The registered functions will be made available to front-end
 *  as well as backend users.
 *
 *  The registry stores type-erased functions.
 *  Each registered function is automatically exposed
 *  to front-end language(e.g. python).
 *
 *  Front-end can also pass callbacks as PackedFunc, or register
 *  then into the same global registry in C++.
 *  The goal is to mix the front-end language and the TVM back-end.
 *
 * \code
 *   // register the function as MyAPIFuncName
 *   TVM_REGISTER_GLOBAL(MyAPIFuncName)
 *   .set_body([](TVMArgs args, TVMRetValue* rv) {
 *     // my code.
 *   });
 * \endcode
 */
#ifndef TVM_RUNTIME_REGISTRY_H_
#define TVM_RUNTIME_REGISTRY_H_

#include <string>
#include <vector>
48
#include "packed_func.h"
49 50 51 52 53 54 55 56 57 58 59

namespace tvm {
namespace runtime {

/*! \brief Registry for global function */
class Registry {
 public:
  /*!
   * \brief set the body of the function to be f
   * \param f The body of the function.
   */
60
  TVM_DLL Registry& set_body(PackedFunc f);  // NOLINT(*)
61 62 63 64 65 66 67 68
  /*!
   * \brief set the body of the function to be f
   * \param f The body of the function.
   */
  Registry& set_body(PackedFunc::FType f) {  // NOLINT(*)
    return set_body(PackedFunc(f));
  }
  /*!
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
   * \brief set the body of the function to be TypedPackedFunc.
   *
   * \code
   *
   * TVM_REGISTER_API("addone")
   * .set_body_typed<int(int)>([](int x) { return x + 1; });
   *
   * \endcode
   *
   * \param f The body of the function.
   * \tparam FType the signature of the function.
   * \tparam FLambda The type of f.
   */
  template<typename FType, typename FLambda>
  Registry& set_body_typed(FLambda f) {
    return set_body(TypedPackedFunc<FType>(f).packed());
  }
86 87 88 89 90 91 92 93

  /*!
   * \brief set the body of the function to the given function pointer.
   *        Note that this doesn't work with lambdas, you need to
   *        explicitly give a type for those.
   *        Note that this will ignore default arg values and always require all arguments to be provided.
   *
   * \code
94
   *
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
   * int multiply(int x, int y) {
   *   return x * y;
   * }
   *
   * TVM_REGISTER_API("multiply")
   * .set_body_typed(multiply); // will have type int(int, int)
   *
   * \endcode
   *
   * \param f The function to forward to.
   * \tparam R the return type of the function (inferred).
   * \tparam Args the argument types of the function (inferred).
   */
  template<typename R, typename ...Args>
  Registry& set_body_typed(R (*f)(Args...)) {
    return set_body(TypedPackedFunc<R(Args...)>(f));
  }

  /*!
   * \brief set the body of the function to be the passed method pointer.
   *        Note that this will ignore default arg values and always require all arguments to be provided.
   *
   * \code
118
   *
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
   * // node subclass:
   * struct Example {
   *    int doThing(int x);
   * }
   * TVM_REGISTER_API("Example_doThing")
   * .set_body_method(&Example::doThing); // will have type int(Example, int)
   *
   * \endcode
   *
   * \param f the method pointer to forward to.
   * \tparam T the type containing the method (inferred).
   * \tparam R the return type of the function (inferred).
   * \tparam Args the argument types of the function (inferred).
   */
  template<typename T, typename R, typename ...Args>
  Registry& set_body_method(R (T::*f)(Args...)) {
    return set_body_typed<R(T, Args...)>([f](T target, Args... params) -> R {
      // call method pointer
      return (target.*f)(params...);
    });
  }

  /*!
   * \brief set the body of the function to be the passed method pointer.
   *        Note that this will ignore default arg values and always require all arguments to be provided.
   *
   * \code
146
   *
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
   * // node subclass:
   * struct Example {
   *    int doThing(int x);
   * }
   * TVM_REGISTER_API("Example_doThing")
   * .set_body_method(&Example::doThing); // will have type int(Example, int)
   *
   * \endcode
   *
   * \param f the method pointer to forward to.
   * \tparam T the type containing the method (inferred).
   * \tparam R the return type of the function (inferred).
   * \tparam Args the argument types of the function (inferred).
   */
  template<typename T, typename R, typename ...Args>
  Registry& set_body_method(R (T::*f)(Args...) const) {
    return set_body_typed<R(T, Args...)>([f](const T target, Args... params) -> R {
      // call method pointer
      return (target.*f)(params...);
    });
  }

  /*!
   * \brief set the body of the function to be the passed method pointer.
171
   *        Used when calling a method on a Node subclass through a ObjectRef subclass.
172 173 174
   *        Note that this will ignore default arg values and always require all arguments to be provided.
   *
   * \code
175
   *
176 177 178 179
   * // node subclass:
   * struct ExampleNode: BaseNode {
   *    int doThing(int x);
   * }
180
   *
181
   * // noderef subclass
182
   * struct Example;
183 184 185
   *
   * TVM_REGISTER_API("Example_doThing")
   * .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
186
   *
187 188 189 190 191 192 193
   * // note that just doing:
   * // .set_body_method(&ExampleNode::doThing);
   * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
   *
   * \endcode
   *
   * \param f the method pointer to forward to.
194
   * \tparam TObjectRef the node reference type to call the method on
195 196 197 198
   * \tparam TNode the node type containing the method (inferred).
   * \tparam R the return type of the function (inferred).
   * \tparam Args the argument types of the function (inferred).
   */
199 200
  template<typename TObjectRef, typename TNode, typename R, typename ...Args,
    typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
201
  Registry& set_body_method(R (TNode::*f)(Args...)) {
202
    return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
203 204 205 206 207 208 209 210
      TNode* target = ref.operator->();
      // call method pointer
      return (target->*f)(params...);
    });
  }

  /*!
   * \brief set the body of the function to be the passed method pointer.
211
   *        Used when calling a method on a Node subclass through a ObjectRef subclass.
212 213 214
   *        Note that this will ignore default arg values and always require all arguments to be provided.
   *
   * \code
215
   *
216 217 218 219
   * // node subclass:
   * struct ExampleNode: BaseNode {
   *    int doThing(int x);
   * }
220
   *
221
   * // noderef subclass
222
   * struct Example;
223 224 225
   *
   * TVM_REGISTER_API("Example_doThing")
   * .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
226
   *
227 228 229 230 231 232 233
   * // note that just doing:
   * // .set_body_method(&ExampleNode::doThing);
   * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
   *
   * \endcode
   *
   * \param f the method pointer to forward to.
234
   * \tparam TObjectRef the node reference type to call the method on
235 236 237 238
   * \tparam TNode the node type containing the method (inferred).
   * \tparam R the return type of the function (inferred).
   * \tparam Args the argument types of the function (inferred).
   */
239 240
  template<typename TObjectRef, typename TNode, typename R, typename ...Args,
    typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
241
  Registry& set_body_method(R (TNode::*f)(Args...) const) {
242
    return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
243 244 245 246 247 248
      const TNode* target = ref.operator->();
      // call method pointer
      return (target->*f)(params...);
    });
  }

249
  /*!
250 251
   * \brief Register a function with given name
   * \param name The name of the function.
252 253
   * \param override Whether allow oveeride existing function.
   * \return Reference to theregistry.
254
   */
255
  TVM_DLL static Registry& Register(const std::string& name, bool override = false);  // NOLINT(*)
256 257 258 259 260
  /*!
   * \brief Erase global function from registry, if exist.
   * \param name The name of the function.
   * \return Whether function exist.
   */
261
  TVM_DLL static bool Remove(const std::string& name);
262 263 264 265 266 267
  /*!
   * \brief Get the global function by name.
   * \param name The name of the function.
   * \return pointer to the registered function,
   *   nullptr if it does not exist.
   */
268
  TVM_DLL static const PackedFunc* Get(const std::string& name);  // NOLINT(*)
269 270 271 272
  /*!
   * \brief Get the names of currently registered global function.
   * \return The names
   */
273
  TVM_DLL static std::vector<std::string> ListNames();
274

275 276 277
  // Internal class.
  struct Manager;

nhynes committed
278
 protected:
279 280 281 282 283 284 285 286 287 288 289 290 291 292
  /*! \brief name of the function */
  std::string name_;
  /*! \brief internal packed function */
  PackedFunc func_;
  friend struct Manager;
};

/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif

293 294 295 296
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)

#define TVM_FUNC_REG_VAR_DEF                                            \
297
  static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM
298

299 300 301
#define TVM_TYPE_REG_VAR_DEF                                            \
  static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::ExtTypeVTable* __mk_ ## TVMT

302 303 304
/*!
 * \brief Register a function globally.
 * \code
305
 *   TVM_REGISTER_GLOBAL("MyPrint")
306 307 308 309 310
 *   .set_body([](TVMArgs args, TVMRetValue* rv) {
 *   });
 * \endcode
 */
#define TVM_REGISTER_GLOBAL(OpName)                              \
311 312
  TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) =            \
      ::tvm::runtime::Registry::Register(OpName)
313

314 315 316
/*!
 * \brief Macro to register extension type.
 *  This must be registered in a cc file
317
 *  after the trait extension_type_info is defined.
318 319 320 321 322
 */
#define TVM_REGISTER_EXT_TYPE(T)                                 \
  TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) =            \
      ::tvm::runtime::ExtTypeVTable::Register_<T>()

323 324 325
}  // namespace runtime
}  // namespace tvm
#endif  // TVM_RUNTIME_REGISTRY_H_