env_func.h 4.54 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21 22
 * \file tvm/ir/env_func.h
 * \brief Serializable global function used in IR.
23
 */
24 25
#ifndef TVM_IR_ENV_FUNC_H_
#define TVM_IR_ENV_FUNC_H_
26 27

#include <tvm/node/reflection.h>
28

29
#include <string>
30
#include <utility>
31

32
namespace tvm {
33
/*!
34 35 36 37 38
 * \brief A serializable function backed by TVM's global environment.
 *
 * This is a wrapper to enable serializable global PackedFunc.
 * An EnvFunc is saved by its name in the global registry
 * under the assumption that the same function is registered during load.
39 40
 * \sa EnvFunc
 */
41
class EnvFuncNode : public Object {
42 43 44 45
 public:
  /*! \brief Unique name of the global function */
  std::string name;
  /*! \brief The internal packed function */
46
  runtime::PackedFunc func;
47 48 49
  /*! \brief constructor */
  EnvFuncNode() {}

50
  void VisitAttrs(AttrVisitor* v) {
51 52 53
    v->Visit("name", &name);
  }

54
  bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
55 56 57 58 59 60 61
    // name uniquely identifies the env function.
    return name == other->name;
  }

  void SHashReduce(SHashReducer hash_reduce) const {
    // Name uniquely identifies the env function.
    hash_reduce(name);
62 63
  }

64
  static constexpr const char* _type_key = "EnvFunc";
65
  static constexpr bool _type_has_method_sequal_reduce = true;
66
  static constexpr bool _type_has_method_shash_reduce = true;
67
  TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
68 69 70
};

/*!
71 72
 * \brief Managed reference to EnvFuncNode.
 * \sa EnvFuncNode
73
 */
74
class EnvFunc : public ObjectRef {
75 76
 public:
  EnvFunc() {}
77
  explicit EnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
78 79
  /*! \return The internal global function pointer */
  const EnvFuncNode* operator->() const {
80
    return static_cast<const EnvFuncNode*>(get());
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
  }
  /*!
   * \brief Invoke the function.
   * \param args The arguments
   * \returns The return value.
   */
  template<typename... Args>
  runtime::TVMRetValue operator()(Args&&... args) const {
    const EnvFuncNode* n = operator->();
    CHECK(n != nullptr);
    return n->func(std::forward<Args>(args)...);
  }
  /*!
   * \brief Get a global function based on the name.
   * \param name The name of the global function.
   * \return The created global function.
   * \note The function can be unique
   */
  TVM_DLL static EnvFunc Get(const std::string& name);
  /*! \brief specify container node */
  using ContainerType = EnvFuncNode;
};

/*!
 * \brief Please refer to \ref TypedEnvFuncAnchor "TypedEnvFunc<R(Args..)>"
 */
template<typename FType>
class TypedEnvFunc;

/*!
 * \anchor TypedEnvFuncAnchor
 * \brief A typed version of EnvFunc.
 * It is backed by a GlobalFuncNode internally.
 *
 * \tparam R The return value of the function.
 * \tparam Args The argument signature of the function.
 * \sa EnvFunc
 */
template<typename R, typename... Args>
120
class TypedEnvFunc<R(Args...)> : public ObjectRef {
121 122 123 124
 public:
  /*! \brief short hand for this function type */
  using TSelf = TypedEnvFunc<R(Args...)>;
  TypedEnvFunc() {}
125
  explicit TypedEnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
126 127 128 129 130 131
  /*!
   * \brief Assign global function to a TypedEnvFunc
   * \param other Another global function.
   * \return reference to self.
   */
  TSelf& operator=(const EnvFunc& other) {
132
    ObjectRef::operator=(other);
133 134 135 136
    return *this;
  }
  /*! \return The internal global function pointer */
  const EnvFuncNode* operator->() const {
137
    return static_cast<const EnvFuncNode*>(get());
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
  }
  /*!
   * \brief Invoke the function.
   * \param args The arguments
   * \returns The return value.
   */
  R operator()(Args... args) const {
    const EnvFuncNode* n = operator->();
    CHECK(n != nullptr);
    return runtime::detail::typed_packed_call_dispatcher<R>
        ::run(n->func, std::forward<Args>(args)...);
  }
  /*! \brief specify container node */
  using ContainerType = EnvFuncNode;
};

}  // namespace tvm
155
#endif  // TVM_IR_ENV_FUNC_H_