module.h 7.29 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 * \file tvm/relay/module.h
22 23 24
 * \brief The global environment: contains information needed to
 * compile & optimize Relay programs.
 */
25 26
#ifndef TVM_RELAY_MODULE_H_
#define TVM_RELAY_MODULE_H_
27 28 29

#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
30
#include <tvm/relay/adt.h>
31 32 33 34
#include <tvm/relay/op.h>
#include <tvm/relay/type.h>
#include <string>
#include <vector>
35
#include <unordered_map>
36 37 38 39

namespace tvm {
namespace relay {

40
struct Module;
41 42 43 44 45 46 47 48 49 50

/*! \brief The global environment of Relay programs.
 *
 *  The global environment contains the global
 *  information needed to compile a Relay program.
 *
 *  It contains all global functions, and configuration
 *  options.
 *
 *  Many operations require access to the global
51
 *  Module. We pass the Module by value
52
 *  in a functional style as an explicit argument,
53
 *  but we mutate the Module while optimizing
54 55 56 57
 *  Relay programs.
 *
 *  The functional style allows users to construct custom
 *  environments easily, for example each thread can store
58
 *  a Module while auto-tuning.
59
 */
60

61
class ModuleNode : public RelayNode {
62 63 64
 public:
  /*! \brief A map from ids to all global functions. */
  tvm::Map<GlobalVar, Function> functions;
65 66
  /*! \brief A map from global type vars to ADT type data. */
  tvm::Map<GlobalTypeVar, TypeData> type_definitions;
67

68
  ModuleNode() {}
69 70 71

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("functions", &functions);
72
    v->Visit("type_definitions", &type_definitions);
73
    v->Visit("global_var_map_", &global_var_map_);
74
    v->Visit("global_type_var_map_", &global_type_var_map_);
75 76
  }

77 78
  TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs,
                             tvm::Map<GlobalTypeVar, TypeData> global_type_defs);
79

80 81
  /*!
   * \brief Add a function to the global environment.
82
   * \param var The var of the global function.
83 84 85 86
   * \param func The function.
   * \param update Controls whether you can replace a definition in the
   * environment.
   */
87
  TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false);
88

89
  /*!
90 91 92 93
   * \brief Add a type-level definition to the global environment.
   * \param var The var of the global type definition.
   * \param type The type definition.
   */
94
  TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type);
95 96

  /*!
97 98 99 100 101 102
   * \brief Add a function to the global environment.
   * \param var The name of the global function.
   * \param func The function.
   *
   * It does not do type inference as Add does.
   */
103
  TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);
104 105

  /*!
106
   * \brief Update a function in the global environment.
107 108 109
   * \param var The name of the global function to update.
   * \param func The new function.
   */
110
  TVM_DLL void Update(const GlobalVar& var, const Function& func);
111

112 113
  /*!
   * \brief Remove a function from the global environment.
114 115
   * \param var The name of the global function to update.
   */
116
  TVM_DLL void Remove(const GlobalVar& var);
117

118
  /*!
119 120 121 122 123 124 125
   * \brief Check if the global_var_map_ contains a global variable.
   * \param name The variable name.
   * \returns true if contains, otherise false.
   */
  TVM_DLL bool ContainGlobalVar(const std::string& name) const;

  /*!
126
   * \brief Lookup a global function by its variable.
127 128 129
   * \param str The unique string specifying the global variable.
   * \returns The global variable.
   */
雾雨魔理沙 committed
130
  TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const;
131

132
  /*!
133 134 135 136
   * \brief Look up a global function by its name.
   * \param str The unique string specifying the global variable.
   * \returns The global variable.
   */
雾雨魔理沙 committed
137
  TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;
138 139

  /*!
140
   * \brief Look up a global function by its variable.
141 142 143
   * \param var The global var to lookup.
   * \returns The function named by the variable argument.
   */
雾雨魔理沙 committed
144
  TVM_DLL Function Lookup(const GlobalVar& var) const;
145

146
  /*!
147
   * \brief Look up a global function by its string name
148 149 150
   * \param name The name of the function.
   * \returns The function named by the argument.
   */
雾雨魔理沙 committed
151
  TVM_DLL Function Lookup(const std::string& name) const;
152

153
  /*!
154
   * \brief Look up a global type definition by its variable.
155 156 157
   * \param var The var of the global type definition.
   * \return The type definition.
   */
雾雨魔理沙 committed
158
  TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const;
159 160

  /*!
161
   * \brief Look up a global type definition by its name.
162 163 164
   * \param var The name of the global type definition.
   * \return The type definition.
   */
雾雨魔理沙 committed
165
  TVM_DLL TypeData LookupDef(const std::string& var) const;
166 167

  /*!
168 169 170 171 172 173 174
   * \brief Look up a constructor by its tag.
   * \param tag The tag for the constructor.
   * \return The constructor object.
   */
  TVM_DLL Constructor LookupTag(const int32_t tag);

  /*!
175 176
   * \brief Update the functions inside this environment by
   *        functions in another environment.
177 178
   * \param other The other environment.
   */
179
  TVM_DLL void Update(const Module& other);
180

181 182
  /*! \brief Construct a module from a standalone expression.
   *
183 184
   * Allows one to optionally pass a global function map and
   * map of type definitions as well.
185
   *
186
   * \param expr The expression to set as the main function to the module.
187
   * \param global_funcs The global function map.
188
   * \param type_definitions Map of global type definitions
189
   *
190
   * \returns A module with expr set as the main function.
191
   */
192
  TVM_DLL static Module FromExpr(
193
    const Expr& expr,
194 195
    const tvm::Map<GlobalVar, Function>& global_funcs = {},
    const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
196

197 198
  static constexpr const char* _type_key = "relay.Module";
  TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
199 200

 private:
201 202 203
  /*! \brief Helper function for registering a typedef's constructors */
  void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);

204 205
  /*! \brief A map from string names to global variables that
   * ensures global uniqueness.
206
   */
207
  tvm::Map<std::string, GlobalVar> global_var_map_;
208 209 210 211 212

  /*! \brief A map from string names to global type variables (ADT names)
   * that ensures global uniqueness.
   */
  tvm::Map<std::string, GlobalTypeVar> global_type_var_map_;
213 214 215 216 217

  /*! \brief A map from constructor tags to constructor objects
   * for convenient access
   */
  std::unordered_map<int32_t, Constructor> constructor_tag_map_;
218 219
};

220 221 222
struct Module : public NodeRef {
  Module() {}
  explicit Module(NodePtr<tvm::Node> p) : NodeRef(p) {}
223

224 225
  inline ModuleNode* operator->() const {
    return static_cast<ModuleNode*>(node_.get());
226 227
  }

228
  using ContainerType = ModuleNode;
229 230
};

231

232 233 234
}  // namespace relay
}  // namespace tvm

235
#endif  // TVM_RELAY_MODULE_H_