module.h 6.47 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 * \file tvm/relay/module.h
22 23 24
 * \brief The global environment: contains information needed to
 * compile & optimize Relay programs.
 */
25 26
#ifndef TVM_RELAY_MODULE_H_
#define TVM_RELAY_MODULE_H_
27 28 29

#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
30
#include <tvm/relay/adt.h>
31 32 33 34 35 36 37 38
#include <tvm/relay/op.h>
#include <tvm/relay/type.h>
#include <string>
#include <vector>

namespace tvm {
namespace relay {

39
struct Module;
40 41 42 43 44 45 46 47 48 49

/*! \brief The global environment of Relay programs.
 *
 *  The global environment contains the global
 *  information needed to compile a Relay program.
 *
 *  It contains all global functions, and configuration
 *  options.
 *
 *  Many operations require access to the global
50
 *  Module. We pass the Module by value
51
 *  in a functional style as an explicit argument,
52
 *  but we mutate the Module while optimizing
53 54 55 56
 *  Relay programs.
 *
 *  The functional style allows users to construct custom
 *  environments easily, for example each thread can store
57
 *  a Module while auto-tuning.
58 59
 * */

60
class ModuleNode : public RelayNode {
61 62 63
 public:
  /*! \brief A map from ids to all global functions. */
  tvm::Map<GlobalVar, Function> functions;
64 65
  /*! \brief A map from global type vars to ADT type data. */
  tvm::Map<GlobalTypeVar, TypeData> type_definitions;
66

67 68 69
  /*! \brief The entry function (i.e. "main"). */
  GlobalVar entry_func;

70
  ModuleNode() {}
71 72 73

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("functions", &functions);
74
    v->Visit("type_definitions", &type_definitions);
75
    v->Visit("global_var_map_", &global_var_map_);
76
    v->Visit("entry_func", &entry_func);
77
    v->Visit("global_type_var_map_", &global_type_var_map_);
78 79
  }

80 81
  TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs,
                             tvm::Map<GlobalTypeVar, TypeData> global_type_defs);
82

83 84
  /*!
   * \brief Add a function to the global environment.
85
   * \param var The var of the global function.
86 87 88 89
   * \param func The function.
   * \param update Controls whether you can replace a definition in the
   * environment.
   */
90
  TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false);
91

92
  /*!
93 94 95 96
   * \brief Add a type-level definition to the global environment.
   * \param var The var of the global type definition.
   * \param type The type definition.
   */
97
  TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type);
98 99

  /*!
100 101 102 103 104 105
   * \brief Add a function to the global environment.
   * \param var The name of the global function.
   * \param func The function.
   *
   * It does not do type inference as Add does.
   */
106
  TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);
107 108

  /*!
109
   * \brief Update a function in the global environment.
110 111 112
   * \param var The name of the global function to update.
   * \param func The new function.
   */
113
  TVM_DLL void Update(const GlobalVar& var, const Function& func);
114

115 116
  /*!
   * \brief Remove a function from the global environment.
117 118
   * \param var The name of the global function to update.
   */
119
  TVM_DLL void Remove(const GlobalVar& var);
120

121 122
  /*!
   * \brief Lookup a global function by its variable.
123 124 125
   * \param str The unique string specifying the global variable.
   * \returns The global variable.
   */
126
  TVM_DLL GlobalVar GetGlobalVar(const std::string& str);
127

128
  /*!
129 130 131 132
   * \brief Look up a global function by its name.
   * \param str The unique string specifying the global variable.
   * \returns The global variable.
   */
133
  TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str);
134 135

  /*!
136
   * \brief Lookup a global function by its variable.
137 138 139
   * \param var The global var to lookup.
   * \returns The function named by the variable argument.
   */
140
  TVM_DLL Function Lookup(const GlobalVar& var);
141

142 143
  /*!
   * \brief Lookup a global function by its string name
144 145 146
   * \param name The name of the function.
   * \returns The function named by the argument.
   */
147
  TVM_DLL Function Lookup(const std::string& name);
148

149
  /*!
150 151 152 153
   * \brief Lookup a global type definition by its variable.
   * \param var The var of the global type definition.
   * \return The type definition.
   */
154
  TVM_DLL TypeData LookupDef(const GlobalTypeVar& var);
155 156 157 158 159 160

  /*!
   * \brief Lookup a global type definition by its name.
   * \param var The name of the global type definition.
   * \return The type definition.
   */
161
  TVM_DLL TypeData LookupDef(const std::string& var);
162 163

  /*!
164 165
   * \brief Update the functions inside this environment by
   *        functions in another environment.
166 167
   * \param other The other environment.
   */
168
  TVM_DLL void Update(const Module& other);
169

170 171 172 173 174 175 176 177 178 179
  /*! \brief Construct a module from a standalone expression.
   *
   * Allows one to optionally pass a global function map as
   * well.
   *
   * \param expr The expression to set as the entry point to the module.
   * \param global_funcs The global function map.
   *
   * \returns A module with expr set as the entry point.
   */
180
  TVM_DLL static Module FromExpr(
181 182 183
    const Expr& expr,
    const tvm::Map<GlobalVar, Function>& global_funcs = {});

184 185
  static constexpr const char* _type_key = "relay.Module";
  TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
186 187

 private:
188 189
  /*! \brief A map from string names to global variables that
   * ensures global uniqueness.
190
   */
191
  tvm::Map<std::string, GlobalVar> global_var_map_;
192 193 194 195 196

  /*! \brief A map from string names to global type variables (ADT names)
   * that ensures global uniqueness.
   */
  tvm::Map<std::string, GlobalTypeVar> global_type_var_map_;
197 198
};

199 200 201
struct Module : public NodeRef {
  Module() {}
  explicit Module(NodePtr<tvm::Node> p) : NodeRef(p) {}
202

203 204
  inline ModuleNode* operator->() const {
    return static_cast<ModuleNode*>(node_.get());
205 206
  }

207
  using ContainerType = ModuleNode;
208 209
};

210

211 212 213
}  // namespace relay
}  // namespace tvm

214
#endif  // TVM_RELAY_MODULE_H_