module.h 8.39 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
tqchen committed
21
 * \file tvm/runtime/module.h
22 23 24 25 26 27 28
 * \brief Runtime container of the functions generated by TVM,
 *  This is used to support dynamically link, load and save
 *  functions from different convention under unified API.
 */
#ifndef TVM_RUNTIME_MODULE_H_
#define TVM_RUNTIME_MODULE_H_

29
#include <dmlc/io.h>
30 31 32 33 34

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>

35 36 37 38 39 40 41 42 43 44 45 46 47 48
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>

namespace tvm {
namespace runtime {

class ModuleNode;
class PackedFunc;

/*!
 * \brief Module container of TVM.
 */
49
class Module : public ObjectRef {
50 51 52
 public:
  Module() {}
  // constructor from container.
53 54
  explicit Module(ObjectPtr<Object> n)
      : ObjectRef(n) {}
55 56 57 58 59 60 61
  /*!
   * \brief Get packed function from current module by name.
   *
   * \param name The name of the function.
   * \param query_imports Whether also query dependency modules.
   * \return The result function.
   *  This function will return PackedFunc(nullptr) if function do not exist.
62
   * \note Implemented in packed_func.cc
63
   */
64 65
  inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
  // The following functions requires link with runtime.
66 67 68 69 70 71 72
  /*!
   * \brief Import another module into this module.
   * \param other The module to be imported.
   *
   * \note Cyclic dependency is not allowed among modules,
   *  An error will be thrown when cyclic dependency is detected.
   */
73 74 75 76 77
  inline void Import(Module other);
  /*! \return internal container */
  inline ModuleNode* operator->();
  /*! \return internal container */
  inline const ModuleNode* operator->() const;
78 79 80 81 82 83 84
  /*!
   * \brief Load a module from file.
   * \param file_name The name of the host function module.
   * \param format The format of the file.
   * \note This function won't load the import relationship.
   *  Re-create import relationship by calling Import.
   */
85 86
  TVM_DLL static Module LoadFromFile(const std::string& file_name,
                                     const std::string& format = "");
87 88 89
  // refer to the corresponding container.
  using ContainerType = ModuleNode;
  friend class ModuleNode;
90 91 92
};

/*!
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
 * \brief Base container of module.
 *
 * Please subclass ModuleNode to create a specific runtime module.
 *
 * \code
 *
 *  class MyModuleNode : public ModuleNode {
 *   public:
 *    // implement the interface
 *  };
 *
 *  // use make_object to create a specific
 *  // instace of MyModuleNode.
 *  Module CreateMyModule() {
 *    ObjectPtr<MyModuleNode> n =
 *      tvm::runtime::make_object<MyModuleNode>();
 *    return Module(n);
 *  }
 *
 * \endcode
113
 */
114
class TVM_DLL ModuleNode : public Object {
115 116
 public:
  /*! \brief virtual destructor */
117
  virtual ~ModuleNode() {}
118 119 120 121
  /*!
   * \return The per module type key.
   * \note This key is used to for serializing custom modules.
   */
122
  virtual const char* type_key() const = 0;
123 124 125 126 127 128 129 130 131
  /*!
   * \brief Get a PackedFunc from module.
   *
   *  The PackedFunc may not be fully initialized,
   *  there might still be first time running overhead when
   *  executing the function on certain devices.
   *  For benchmarking, use prepare to eliminate
   *
   * \param name the name of the function.
132
   * \param sptr_to_self The ObjectPtr that points to this module node.
133 134 135 136 137 138 139
   *
   * \return PackedFunc(nullptr) when it is not available.
   *
   * \note The function will always remain valid.
   *   If the function need resource from the module(e.g. late linking),
   *   it should capture sptr_to_self.
   */
140
  virtual PackedFunc GetFunction(
141
      const std::string& name,
142
      const ObjectPtr<Object>& sptr_to_self) = 0;
143 144 145 146 147
  /*!
   * \brief Save the module to file.
   * \param file_name The file to be saved to.
   * \param format The format of the file.
   */
148 149
  virtual void SaveToFile(const std::string& file_name,
                          const std::string& format);
150
  /*!
151 152 153 154 155 156
   * \brief Save the module to binary stream.
   * \param stream The binary stream to save to.
   * \note It is recommended to implement this for device modules,
   *   but not necessarily host modules.
   *   We can use this to do AOT loading of bundled device functions.
   */
157
  virtual void SaveToBinary(dmlc::Stream* stream);
158
  /*!
159 160 161 162
   * \brief Get the source code of module, when available.
   * \param format Format of the source code, can be empty by default.
   * \return Possible source code when available.
   */
163
  virtual std::string GetSource(const std::string& format = "");
164
  /*!
165 166 167 168 169 170 171 172
   * \brief Get packed function from current module by name.
   *
   * \param name The name of the function.
   * \param query_imports Whether also query dependency modules.
   * \return The result function.
   *  This function will return PackedFunc(nullptr) if function do not exist.
   * \note Implemented in packed_func.cc
   */
173
  PackedFunc GetFunction(const std::string& name, bool query_imports = false);
174 175 176 177 178 179 180
  /*!
   * \brief Import another module into this module.
   * \param other The module to be imported.
   *
   * \note Cyclic dependency is not allowed among modules,
   *  An error will be thrown when cyclic dependency is detected.
   */
181
  void Import(Module other);
182
  /*!
183 184 185 186 187 188
   * \brief Get a function from current environment
   *  The environment includes all the imports as well as Global functions.
   *
   * \param name name of the function.
   * \return The corresponding function.
   */
189
  const PackedFunc* GetFuncFromEnv(const std::string& name);
190 191 192 193 194
  /*! \return The module it imports from */
  const std::vector<Module>& imports() const {
    return imports_;
  }

195 196 197 198 199 200 201
  // integration with the existing components.
  static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
  static constexpr const char* _type_key = "runtime.Module";
  // NOTE: ModuleNode can still be sub-classed
  //
  TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object);

202
 protected:
203
  friend class Module;
204
  friend class ModuleInternal;
205 206
  /*! \brief The modules this module depend on */
  std::vector<Module> imports_;
207 208

 private:
209 210
  /*! \brief Cache used by GetImport */
  std::unordered_map<std::string,
211
                     std::shared_ptr<PackedFunc> > import_cache_;
212 213
};

214 215 216 217 218 219 220
/*!
 * \brief Check if runtime module is enabled for target.
 * \param target The target module name.
 * \return Whether runtime is enabled.
 */
TVM_DLL bool RuntimeEnabled(const std::string& target);

221 222 223 224
/*! \brief namespace for constant symbols */
namespace symbol {
/*! \brief Global variable to store module context. */
constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
225 226 227 228
/*! \brief Global variable to store device module blob */
constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob";
/*! \brief Number of bytes of device module blob. */
constexpr const char* tvm_dev_mblob_nbytes = "__tvm_dev_mblob_nbytes";
229 230
/*! \brief global function to set device */
constexpr const char* tvm_set_device = "__tvm_set_device";
231 232 233 234
/*! \brief Auxiliary counter to global barrier. */
constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
/*! \brief Prepare the global barrier before kernels that uses global barrier. */
constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier";
235 236
/*! \brief Placeholder for the module's entry function. */
constexpr const char* tvm_module_main = "__tvm_main__";
237
}  // namespace symbol
238 239

// implementations of inline functions.
240 241 242 243 244

inline void Module::Import(Module other) {
  return (*this)->Import(other);
}

245
inline ModuleNode* Module::operator->() {
246
  return static_cast<ModuleNode*>(get_mutable());
247 248
}

249
inline const ModuleNode* Module::operator->() const {
250
  return static_cast<const ModuleNode*>(get());
251 252
}

253 254 255
}  // namespace runtime
}  // namespace tvm

256
#include <tvm/runtime/packed_func.h>  // NOLINT(*)
257
#endif  // TVM_RUNTIME_MODULE_H_