module.cc 14.7 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21 22
 * \file  module.cc
 * \brief The global module in Relay.
23
 */
24 25
#include <tvm/runtime/registry.h>
#include <tvm/ir/module.h>
26 27 28 29 30
// NOTE: reverse dependency on relay.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
//
// Rationale: We calls into relay's analysis module to verify correctness.
Zhi committed
31 32
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
33

34
#include <sstream>
35 36
#include <fstream>
#include <unordered_set>
37 38 39

namespace tvm {

40 41 42 43 44 45
IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
                   tvm::Map<GlobalTypeVar, TypeData> type_definitions,
                   std::unordered_set<std::string> import_set) {
  auto n = make_object<IRModuleNode>();
  n->functions = std::move(functions);
  n->type_definitions = std::move(type_definitions);
46 47 48
  n->global_type_var_map_ = {};
  n->global_var_map_ = {};
  n->constructor_tag_map_ = {};
49
  n->import_set_ = std::move(import_set);
50 51

  for (const auto& kv : n->functions) {
52
    // set global var map
53
    CHECK(n->global_var_map_.count(kv.first->name_hint) == 0)
54
      << "Duplicate global function name " << kv.first->name_hint;
55 56
    n->global_var_map_.Set(kv.first->name_hint, kv.first);
  }
57

58 59
  for (const auto& kv : n->type_definitions) {
    // set global typevar map
60 61 62
    CHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0)
      << "Duplicate global type definition name " << kv.first->name_hint;
    n->global_type_var_map_.Set(kv.first->name_hint, kv.first);
63
    n->RegisterConstructors(kv.first, kv.second);
64
  }
65
  data_ = std::move(n);
66 67
}

68
bool IRModuleNode::ContainGlobalVar(const std::string& name) const {
69 70 71
  return global_var_map_.find(name) != global_var_map_.end();
}

72
bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const {
73 74 75
  return global_type_var_map_.find(name) != global_type_var_map_.end();
}

76
GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const {
77
  auto it = global_var_map_.find(name);
雾雨魔理沙 committed
78 79 80
  CHECK(it != global_var_map_.end())
    << "Cannot find global var " << name << " in the Module";
  return (*it).second;
81 82
}

83
tvm::Array<GlobalVar> IRModuleNode::GetGlobalVars() const {
84 85 86
  std::vector<GlobalVar> global_vars;
  for (const auto& pair : global_var_map_) {
    global_vars.push_back(pair.second);
87
  }
88
  return tvm::Array<GlobalVar>(global_vars);
89 90
}

91
GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const {
92
  CHECK(global_type_var_map_.defined());
93 94 95 96 97 98
  auto it = global_type_var_map_.find(name);
  CHECK(it != global_type_var_map_.end())
    << "Cannot find global type var " << name << " in the Module";
  return (*it).second;
}

99
tvm::Array<GlobalTypeVar> IRModuleNode::GetGlobalTypeVars() const {
100 101 102 103 104 105 106
  std::vector<GlobalTypeVar> global_type_vars;
  for (const auto& pair : global_type_var_map_) {
    global_type_vars.push_back(pair.second);
  }
  return tvm::Array<GlobalTypeVar>(global_type_vars);
}

雾雨魔理沙 committed
107 108 109 110 111 112 113 114 115
template<typename T>
tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
  tvm::Array<T> ret(l);
  for (const T& t : r) {
    ret.push_back(t);
  }
  return ret;
}

116
// helper function to run type check
117
relay::Function RunTypeCheck(const IRModule& mod,
118 119 120
                             const GlobalVar& var,
                             relay::Function f) {
  auto func = Downcast<relay::Function>(relay::DeDup(std::move(f)));
121
  // Type check the item before we add it to the module.
122 123
  auto fv = relay::FreeVars(func);
  auto ftv = relay::FreeTypeVars(func, mod);
雾雨魔理沙 committed
124 125
  if (fv.size() != 0) {
    LOG(WARNING)
126 127 128 129 130
        << "There are free variables: "
        << fv
        << " in function: "
        << AsText(func, false)
        << std::endl;
雾雨魔理沙 committed
131 132 133
  }
  if (ftv.size() != 0) {
    LOG(WARNING)
134 135 136 137 138
        << "There are free type variables: "
        << ftv
        << " in function: "
        << AsText(func, false)
        << std::endl;
雾雨魔理沙 committed
139 140
  }
  func =
141 142 143 144 145
      relay::FunctionNode::make(concat(func->params, fv),
                                func->body,
                                func->ret_type,
                                concat(func->type_params, ftv),
                                func->attrs);
雾雨魔理沙 committed
146
  // Type check the item before we add it to the module.
147 148 149 150
  relay::Function checked_func = InferType(func, mod, var);
  return checked_func;
}

151 152 153
void IRModuleNode::Add(const GlobalVar& var,
                       const BaseFunc& f,
                       bool update) {
154 155
  BaseFunc checked_func = f;
  if (auto* ptr = f.as<relay::FunctionNode>()) {
156
    checked_func = RunTypeCheck(GetRef<IRModule>(this),
157 158 159 160
                                var,
                                GetRef<relay::Function>(ptr));
  }

161
  auto type = checked_func->checked_type();
162 163
  CHECK(type.as<relay::IncompleteTypeNode>() == nullptr);

164 165 166
  if (functions.find(var) != functions.end()) {
    CHECK(update)
        << "Already have definition for " << var->name_hint;
167 168
    auto old_type = functions[var].as<relay::FunctionNode>()->checked_type();
    CHECK(relay::AlphaEqual(type, old_type))
169
        << "Module#update changes type, not possible in this mode.";
170
  }
171
  var->checked_type_ = type;
172
  AddUnchecked(var, checked_func);
173 174
}

175 176
void IRModuleNode::AddUnchecked(const GlobalVar& var,
                                const BaseFunc& func) {
177 178 179 180 181 182 183 184 185 186 187 188 189
  this->functions.Set(var, func);

  auto it = global_var_map_.find(var->name_hint);
  if (it != global_var_map_.end()) {
    CHECK_EQ((*it).second, var);
  } else {
    CHECK(global_var_map_.count(var->name_hint) == 0)
        << "Duplicate global function name " << var->name_hint;
  }

  global_var_map_.Set(var->name_hint, var);
}

190
void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
191 192 193
  // We hash the global type var name to use as a globally unique prefix for tags.
  // The hash will be used as the most significant byte of the tag, with the index of
  // the constructor in the less significant bytes
194
  size_t hash = std::hash<std::string>()(var->name_hint);
195 196 197 198 199 200 201
  int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
  for (size_t i = 0; i < type->constructors.size(); ++i) {
    type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
    constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i];
  }
}

202 203 204
void IRModuleNode::AddTypeDef(const GlobalTypeVar& var,
                              const TypeData& type,
                              bool update) {
205
  AddTypeDefUnchecked(var, type, update);
206 207
  // need to kind check at the end because the check can look up
  // a definition potentially
208
  CHECK(relay::KindCheck(type, GetRef<IRModule>(this)) == TypeKind::kTypeData)
209 210 211
    << "Invalid or malformed typedata given to module: " << type;
}

212 213 214
void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var,
                                       const TypeData& type,
                                       bool update) {
215 216 217
  this->type_definitions.Set(var, type);
  if (!update) {
    // set global type var map
218 219
    CHECK(global_type_var_map_.count(var->name_hint) == 0)
      << "Duplicate global type definition name " << var->name_hint;
220
  }
221
  global_type_var_map_.Set(var->name_hint, var);
222 223 224
  RegisterConstructors(var, type);
}

225 226
void IRModuleNode::Update(const GlobalVar& var,
                          const BaseFunc& func) {
227 228 229
  this->Add(var, func, true);
}

230 231
void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var,
                                 const TypeData& type) {
232
  this->AddTypeDef(var, type, true);
233 234
}

235
void IRModuleNode::Remove(const GlobalVar& var) {
236
  auto functions_node = this->functions.CopyOnWrite();
237
  functions_node->data.erase(var);
238 239
  auto gvar_node = global_var_map_.CopyOnWrite();
  gvar_node->data.erase(var->name_hint);
240 241
}

242
BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
243 244 245 246
  auto it = functions.find(var);
  CHECK(it != functions.end())
      << "There is no definition of " << var->name_hint;
  return (*it).second;
247 248
}

249
BaseFunc IRModuleNode::Lookup(const std::string& name) const {
250
  GlobalVar id = this->GetGlobalVar(name);
251 252 253
  return this->Lookup(id);
}

254
TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
255 256
  auto it = type_definitions.find(var);
  CHECK(it != type_definitions.end())
257
    << "There is no definition of " << var->name_hint;
258 259 260
  return (*it).second;
}

261
TypeData IRModuleNode::LookupTypeDef(const std::string& name) const {
262
  GlobalTypeVar id = this->GetGlobalTypeVar(name);
263
  return this->LookupTypeDef(id);
264 265
}

266
Constructor IRModuleNode::LookupTag(const int32_t tag) {
267 268 269 270 271 272
  auto it = constructor_tag_map_.find(tag);
  CHECK(it != constructor_tag_map_.end())
    << "There is no constructor with the tag " << tag;
  return (*it).second;
}

273
void IRModuleNode::Update(const IRModule& mod) {
274 275 276 277 278 279
  // add functions and type defs. we add them unchecked first, so all definitions
  // can reference each other, independent of the order in which they were defined.
  for (auto pair : mod->functions) {
    this->AddUnchecked(pair.first, pair.second);
  }
  for (auto pair : mod->type_definitions) {
280
    this->AddTypeDefUnchecked(pair.first, pair.second);
281
  }
282
  for (auto pair : mod->functions) {
283
    this->Update(pair.first, pair.second);
284
  }
285
  for (auto pair : mod->type_definitions) {
286
    this->UpdateTypeDef(pair.first, pair.second);
287
  }
288 289
}

290
IRModule IRModule::FromExpr(
291 292
  const RelayExpr& expr,
  const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
293
  const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
294
  auto mod = IRModule(global_funcs, type_definitions);
295 296 297
  BaseFunc func;
  if (auto* func_node = expr.as<relay::FunctionNode>()) {
    func = GetRef<relay::Function>(func_node);
298
  } else {
299 300 301
    func = relay::FunctionNode::make(
        relay::FreeVars(expr), expr, Type(),
        relay::FreeTypeVars(expr, mod), {});
302
  }
303
  auto main_gv = GlobalVar("main");
304
  mod->Add(main_gv, func);
305 306 307
  return mod;
}

308
void IRModuleNode::Import(const std::string& path) {
309 310
  if (this->import_set_.count(path) == 0) {
    this->import_set_.insert(path);
311
    DLOG(INFO) << "Importing: " << path;
312 313 314 315
    std::fstream src_file(path, std::fstream::in);
    std::string file_contents {
      std::istreambuf_iterator<char>(src_file),
      std::istreambuf_iterator<char>() };
316
    auto mod_to_import = IRModule::FromText(file_contents, path);
317
    Update(mod_to_import);
318 319 320
  }
}

321
void IRModuleNode::ImportFromStd(const std::string& path) {
322 323 324 325 326 327
  auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
  CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
  std::string std_path = (*f)();
  return this->Import(std_path + "/" + path);
}

328
std::unordered_set<std::string> IRModuleNode::Imports() const {
329 330 331
  return this->import_set_;
}

332
IRModule IRModule::FromText(const std::string& text, const std::string& source_path) {
333 334
  auto* f = tvm::runtime::Registry::Get("relay.fromtext");
  CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
335
  IRModule mod = (*f)(text, source_path);
336 337 338
  return mod;
}

339
TVM_REGISTER_NODE_TYPE(IRModuleNode);
340

341
TVM_REGISTER_GLOBAL("ir.IRModule")
342 343
.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
                   tvm::Map<GlobalTypeVar, TypeData> types) {
344
  return IRModule(funcs, types, {});
345
});
346

347
TVM_REGISTER_GLOBAL("ir.Module_Add")
Zhi committed
348
.set_body([](TVMArgs args, TVMRetValue* ret) {
349
  IRModule mod = args[0];
Zhi committed
350
  GlobalVar var = args[1];
351
  ObjectRef val = args[2];
Zhi committed
352
  bool update = args[3];
353
  CHECK(val->IsInstance<RelayExprNode>());
354 355 356

  if (val->IsInstance<relay::FunctionNode>()) {
    mod->Add(var, Downcast<relay::Function>(val), update);
357
  } else if (val->IsInstance<GlobalVarNode>()) {
Zhi committed
358
    GlobalVar gv = Downcast<GlobalVar>(val);
359
    auto mod_copy = IRModule(make_object<IRModuleNode>(*mod.operator->()));
360 361 362
    mod_copy = relay::transform::EtaExpand(
        /* expand_constructor */ false,
        /* expand_global_var */ true)(mod_copy);
Zhi committed
363
    auto func = mod_copy->Lookup(gv->name_hint);
364
    mod->Add(var, Downcast<relay::Function>(func), update);
Zhi committed
365
  } else {
366
    auto func = relay::FunctionNode::make({}, Downcast<RelayExpr>(val), Type(nullptr), {});
Zhi committed
367 368 369 370
    mod->Add(var, func, update);
  }
  *ret = mod;
});
371

372
TVM_REGISTER_GLOBAL("ir.Module_AddDef")
373
.set_body_method<IRModule>(&IRModuleNode::AddTypeDef);
374

375
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar")
376
.set_body_method<IRModule>(&IRModuleNode::GetGlobalVar);
377

378
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars")
379
.set_body_method<IRModule>(&IRModuleNode::GetGlobalVars);
380

381
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars")
382
.set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars);
383

384
TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar")
385
.set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar);
386

387
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar")
388
.set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar);
389

390
TVM_REGISTER_GLOBAL("ir.Module_Lookup")
391
.set_body_typed([](IRModule mod, GlobalVar var) {
Zhi committed
392 393
  return mod->Lookup(var);
});
394

395
TVM_REGISTER_GLOBAL("ir.Module_Lookup_str")
396
.set_body_typed([](IRModule mod, std::string var) {
Zhi committed
397 398
  return mod->Lookup(var);
});
399

400
TVM_REGISTER_GLOBAL("ir.Module_LookupDef")
401
.set_body_typed([](IRModule mod, GlobalTypeVar var) {
402
  return mod->LookupTypeDef(var);
Zhi committed
403
});
404

405
TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str")
406
.set_body_typed([](IRModule mod, std::string var) {
407
  return mod->LookupTypeDef(var);
Zhi committed
408
});
409

410
TVM_REGISTER_GLOBAL("ir.Module_LookupTag")
411
.set_body_typed([](IRModule mod, int32_t tag) {
412 413 414
    return mod->LookupTag(tag);
  });

415
TVM_REGISTER_GLOBAL("ir.Module_FromExpr")
416 417
.set_body_typed([](RelayExpr e,
                   tvm::Map<GlobalVar, BaseFunc> funcs,
418
                   tvm::Map<GlobalTypeVar, TypeData> type_defs) {
419
  return IRModule::FromExpr(e, funcs, type_defs);
420
});
421

422
TVM_REGISTER_GLOBAL("ir.Module_Update")
423
.set_body_typed([](IRModule mod, IRModule from) {
Zhi committed
424 425
  mod->Update(from);
});
426

427
TVM_REGISTER_GLOBAL("ir.Module_Import")
428
.set_body_typed([](IRModule mod, std::string path) {
429 430 431
  mod->Import(path);
});

432
TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd")
433
.set_body_typed([](IRModule mod, std::string path) {
434 435 436
  mod->ImportFromStd(path);
});;

437 438
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) {
439 440
    auto* node = static_cast<const IRModuleNode*>(ref.get());
    p->stream << "IRModuleNode( " << node->functions << ")";
Zhi committed
441
});
442 443

}  // namespace tvm