module.cc 10.9 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21
/*!
 *  Copyright (c) 2018 by Contributors
22 23
 * \file  module.cc
 * \brief The global module in Relay.
24
 */
25
#include <tvm/relay/module.h>
Zhi committed
26 27
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
28 29 30 31 32 33 34 35
#include <sstream>

namespace tvm {
namespace relay {

using tvm::IRPrinter;
using namespace runtime;

36 37
Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
                        tvm::Map<GlobalTypeVar, TypeData> global_type_defs) {
38
  auto n = make_node<ModuleNode>();
39
  n->functions = std::move(global_funcs);
40
  n->type_definitions = std::move(global_type_defs);
41 42

  for (const auto& kv : n->functions) {
43
    // set global var map
44
    CHECK(!n->global_var_map_.count(kv.first->name_hint))
45
      << "Duplicate global function name " << kv.first->name_hint;
46 47
    n->global_var_map_.Set(kv.first->name_hint, kv.first);
  }
48

49 50 51 52 53
  for (const auto& kv : n->type_definitions) {
    // set global typevar map
    CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint))
      << "Duplicate global type definition name " << kv.first->var->name_hint;
    n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first);
54
    n->RegisterConstructors(kv.first, kv.second);
55 56
  }

57
  return Module(n);
58 59
}

60 61 62 63
bool ModuleNode::ContainGlobalVar(const std::string& name) const {
  return global_var_map_.find(name) != global_var_map_.end();
}

雾雨魔理沙 committed
64
GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const {
65
  auto it = global_var_map_.find(name);
雾雨魔理沙 committed
66 67 68
  CHECK(it != global_var_map_.end())
    << "Cannot find global var " << name << " in the Module";
  return (*it).second;
69 70
}

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
void ModuleNode::AddUnchecked(const GlobalVar& var,
                              const Function& func) {
  auto mod = GetRef<Module>(this);
  this->functions.Set(var, func);

  auto it = global_var_map_.find(var->name_hint);
  if (it != global_var_map_.end()) {
    CHECK_EQ((*it).second, var);
  } else {
    CHECK(!global_var_map_.count(var->name_hint))
        << "Duplicate global function name " << var->name_hint;
  }

  global_var_map_.Set(var->name_hint, var);
}

雾雨魔理沙 committed
87
GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
88 89 90 91 92 93
  auto it = global_type_var_map_.find(name);
  CHECK(it != global_type_var_map_.end())
    << "Cannot find global type var " << name << " in the Module";
  return (*it).second;
}

雾雨魔理沙 committed
94 95 96 97 98 99 100 101 102
template<typename T>
tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
  tvm::Array<T> ret(l);
  for (const T& t : r) {
    ret.push_back(t);
  }
  return ret;
}

103
void ModuleNode::Add(const GlobalVar& var,
104
                     const Function& f,
105
                     bool update) {
106
  Function func = Downcast<Function>(DeDup(f));
107
  // Type check the item before we add it to the module.
108
  auto mod = GetRef<Module>(this);
雾雨魔理沙 committed
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
  auto fv = FreeVars(func);
  auto ftv = FreeTypeVars(func, mod);
  if (fv.size() != 0) {
    LOG(WARNING)
      << "There are free variables: "
      << fv
      << " in function: "
      << AsText(func, false)
      << std::endl;
  }
  if (ftv.size() != 0) {
    LOG(WARNING)
      << "There are free type variables: "
      << ftv
      << " in function: "
      << AsText(func, false)
      << std::endl;
  }
  func =
    FunctionNode::make(concat(func->params, fv),
                       func->body,
                       func->ret_type,
                       concat(func->type_params, ftv),
                       func->attrs);
  // Type check the item before we add it to the module.
134
  Function checked_func = InferType(func, mod, var);
135 136 137 138 139 140 141
  auto type = checked_func->checked_type();
  CHECK(type.as<IncompleteTypeNode>() == nullptr);
  if (functions.find(var) != functions.end()) {
    CHECK(update)
        << "Already have definition for " << var->name_hint;
    auto old_type = functions[var].as<FunctionNode>()->checked_type();
    CHECK(AlphaEqual(type, old_type))
142
        << "Module#update changes type, not possible in this mode.";
143
  }
144
  var->checked_type_ = type;
145
  AddUnchecked(var, checked_func);
146 147
}

148 149 150 151 152 153 154 155 156 157 158 159
void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
  // We hash the global type var name to use as a globally unique prefix for tags.
  // The hash will be used as the most significant byte of the tag, with the index of
  // the constructor in the less significant bytes
  size_t hash = std::hash<std::string>()(var->var->name_hint);
  int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
  for (size_t i = 0; i < type->constructors.size(); ++i) {
    type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
    constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i];
  }
}

160 161 162 163 164 165
void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
  this->type_definitions.Set(var, type);
  // set global type var map
  CHECK(!global_type_var_map_.count(var->var->name_hint))
    << "Duplicate global type definition name " << var->var->name_hint;
  global_type_var_map_.Set(var->var->name_hint, var);
166
  RegisterConstructors(var, type);
167 168 169 170 171 172 173

  // need to kind check at the end because the check can look up
  // a definition potentially
  CHECK(KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
    << "Invalid or malformed typedata given to module: " << type;
}

174
void ModuleNode::Update(const GlobalVar& var, const Function& func) {
175 176 177
  this->Add(var, func, true);
}

178
void ModuleNode::Remove(const GlobalVar& var) {
179 180
  auto functions_node = this->functions.CopyOnWrite();
  functions_node->data.erase(var.node_);
181 182
  auto gvar_node = global_var_map_.CopyOnWrite();
  gvar_node->data.erase(var->name_hint);
183 184
}

雾雨魔理沙 committed
185
Function ModuleNode::Lookup(const GlobalVar& var) const {
186 187 188 189
  auto it = functions.find(var);
  CHECK(it != functions.end())
      << "There is no definition of " << var->name_hint;
  return (*it).second;
190 191
}

雾雨魔理沙 committed
192
Function ModuleNode::Lookup(const std::string& name) const {
193
  GlobalVar id = this->GetGlobalVar(name);
194 195 196
  return this->Lookup(id);
}

雾雨魔理沙 committed
197
TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const {
198 199 200 201 202 203
  auto it = type_definitions.find(var);
  CHECK(it != type_definitions.end())
    << "There is no definition of " << var->var->name_hint;
  return (*it).second;
}

雾雨魔理沙 committed
204
TypeData ModuleNode::LookupDef(const std::string& name) const {
205 206 207 208
  GlobalTypeVar id = this->GetGlobalTypeVar(name);
  return this->LookupDef(id);
}

209 210 211 212 213
bool ModuleNode::HasDef(const std::string& name) const {
  auto it = global_type_var_map_.find(name);
  return it != global_type_var_map_.end();
}

214 215 216 217 218 219 220
Constructor ModuleNode::LookupTag(const int32_t tag) {
  auto it = constructor_tag_map_.find(tag);
  CHECK(it != constructor_tag_map_.end())
    << "There is no constructor with the tag " << tag;
  return (*it).second;
}

221 222
void ModuleNode::Update(const Module& mod) {
  for (auto pair : mod->functions) {
223
    this->Update(pair.first, pair.second);
224 225 226
  }
}

227 228
Module ModuleNode::FromExpr(
  const Expr& expr,
229 230 231
  const tvm::Map<GlobalVar, Function>& global_funcs,
  const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
  auto mod = ModuleNode::make(global_funcs, type_definitions);
232 233 234 235 236
  auto func_node = expr.as<FunctionNode>();
  Function func;
  if (func_node) {
    func = GetRef<Function>(func_node);
  } else {
雾雨魔理沙 committed
237
    func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {});
238
  }
239 240
  auto main_gv = GlobalVarNode::make("main");
  mod->Add(main_gv, func);
241 242 243
  return mod;
}

244
TVM_REGISTER_NODE_TYPE(ModuleNode);
245

246
TVM_REGISTER_API("relay._make.Module")
247
.set_body_typed(ModuleNode::make);
248

249
TVM_REGISTER_API("relay._module.Module_Add")
Zhi committed
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
.set_body([](TVMArgs args, TVMRetValue* ret) {
  Module mod = args[0];
  GlobalVar var = args[1];
  NodeRef val = args[2];
  bool update = args[3];
  CHECK(val->derived_from<ExprNode>());
  if (val->derived_from<FunctionNode>()) {
    mod->Add(var, Downcast<Function>(val), update);
  } else if (val->derived_from<GlobalVarNode>()) {
    GlobalVar gv = Downcast<GlobalVar>(val);
    auto mod_copy = Module(make_node<ModuleNode>(*mod.operator->()));
    mod_copy = transform::EtaExpand()(mod_copy);
    auto func = mod_copy->Lookup(gv->name_hint);
    mod->Add(var, Downcast<Function>(func), update);
  } else {
    auto func = FunctionNode::make({}, Downcast<Expr>(val), Type(nullptr), {});
    mod->Add(var, func, update);
  }
  *ret = mod;
});
270

271
TVM_REGISTER_API("relay._module.Module_AddDef")
272
.set_body_method<Module>(&ModuleNode::AddDef);
273

274
TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
275
.set_body_method<Module>(&ModuleNode::GetGlobalVar);
276

277 278 279
TVM_REGISTER_API("relay._module.Module_ContainGlobalVar")
.set_body_method<Module>(&ModuleNode::ContainGlobalVar);

280
TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar")
281
.set_body_method<Module>(&ModuleNode::GetGlobalTypeVar);
282

283
TVM_REGISTER_API("relay._module.Module_Lookup")
284
.set_body_typed<Function(Module, GlobalVar)>([](Module mod, GlobalVar var) {
Zhi committed
285 286
  return mod->Lookup(var);
});
287

288
TVM_REGISTER_API("relay._module.Module_Lookup_str")
289
.set_body_typed<Function(Module, std::string)>([](Module mod, std::string var) {
Zhi committed
290 291
  return mod->Lookup(var);
});
292 293

TVM_REGISTER_API("relay._module.Module_LookupDef")
294
.set_body_typed<TypeData(Module, GlobalTypeVar)>([](Module mod, GlobalTypeVar var) {
Zhi committed
295 296
  return mod->LookupDef(var);
});
297 298

TVM_REGISTER_API("relay._module.Module_LookupDef_str")
299
.set_body_typed<TypeData(Module, std::string)>([](Module mod, std::string var) {
Zhi committed
300 301
  return mod->LookupDef(var);
});
302

303 304 305 306 307
TVM_REGISTER_API("relay._module.Module_LookupTag")
.set_body_typed<Constructor(Module, int32_t)>([](Module mod, int32_t tag) {
    return mod->LookupTag(tag);
  });

308
TVM_REGISTER_API("relay._module.Module_FromExpr")
309 310 311 312 313 314 315 316
.set_body_typed<
  Module(Expr,
         tvm::Map<GlobalVar, Function>,
         tvm::Map<GlobalTypeVar, TypeData>)>([](Expr e,
                                                tvm::Map<GlobalVar, Function> funcs,
                                                tvm::Map<GlobalTypeVar, TypeData> type_defs) {
                                               return ModuleNode::FromExpr(e, funcs, type_defs);
                                             });
317

318
TVM_REGISTER_API("relay._module.Module_Update")
319
.set_body_typed<void(Module, Module)>([](Module mod, Module from) {
Zhi committed
320 321
  mod->Update(from);
});
322 323

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
324
.set_dispatch<ModuleNode>(
Zhi committed
325 326 327
  [](const ModuleNode *node, tvm::IRPrinter *p) {
    p->stream << "ModuleNode( " << node->functions << ")";
});
328 329 330

}  // namespace relay
}  // namespace tvm