Unverified Commit 78243e98 by Tianqi Chen Committed by GitHub

[REFACTOR] relay::Module Def -> TypeDef (#4665)

* [REFACTOR] relay::Module Def -> TypeDef

The term Def was not very clear about what is the object of interest(could be function def or type def).
Changes the term to TypeDef to be more explicit.

* Update include/tvm/relay/module.h

Co-Authored-By: Wei Chen <ipondering.weic@gmail.com>

Co-authored-by: Wei Chen <ipondering.weic@gmail.com>
parent 8a98a2e7
......@@ -104,18 +104,20 @@ class ModuleNode : public RelayNode {
* \param update Controls whether you can replace a definition in the
* environment.
*/
TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type, bool update = false);
TVM_DLL void AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update = false);
/*!
* \brief Add a type definition to the global environment.
* \param var The name of the global function.
* \brief Add a type-level definition to the global environment.
* \param var The var of the global type definition.
* \param type The ADT.
* \param update Controls whether you can replace a definition in the
* environment.
*
* It does not do type inference as AddDef does.
* It does not do type checking as AddTypeDef does.
*/
TVM_DLL void AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update = false);
TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var,
const TypeData& type,
bool update = false);
/*!
* \brief Update a function in the global environment.
......@@ -129,7 +131,7 @@ class ModuleNode : public RelayNode {
* \param var The name of the global type definition to update.
* \param type The new ADT.
*/
TVM_DLL void UpdateDef(const GlobalTypeVar& var, const TypeData& type);
TVM_DLL void UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type);
/*!
* \brief Remove a function from the global environment.
......@@ -162,7 +164,7 @@ class ModuleNode : public RelayNode {
* \brief Collect all global vars defined in this module.
* \returns An array of global vars
*/
tvm::Array<GlobalVar> GetGlobalVars() const;
TVM_DLL tvm::Array<GlobalVar> GetGlobalVars() const;
/*!
* \brief Look up a global function by its name.
......@@ -175,7 +177,7 @@ class ModuleNode : public RelayNode {
* \brief Collect all global type vars defined in this module.
* \returns An array of global type vars
*/
tvm::Array<GlobalTypeVar> GetGlobalTypeVars() const;
TVM_DLL tvm::Array<GlobalTypeVar> GetGlobalTypeVars() const;
/*!
* \brief Look up a global function by its variable.
......@@ -196,14 +198,14 @@ class ModuleNode : public RelayNode {
* \param var The var of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const;
TVM_DLL TypeData LookupTypeDef(const GlobalTypeVar& var) const;
/*!
* \brief Look up a global type definition by its name.
* \param var The name of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupDef(const std::string& var) const;
TVM_DLL TypeData LookupTypeDef(const std::string& var) const;
/*!
* \brief Look up a constructor by its tag.
......
......@@ -70,7 +70,7 @@ class AlphaEqualHandler:
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->name_hint))) {
!Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
return false;
}
}
......
......@@ -185,15 +185,15 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData&
}
}
void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type, bool update) {
AddDefUnchecked(var, type, update);
void ModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) {
AddTypeDefUnchecked(var, type, update);
// need to kind check at the end because the check can look up
// a definition potentially
CHECK(KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
<< "Invalid or malformed typedata given to module: " << type;
}
void ModuleNode::AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) {
void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) {
this->type_definitions.Set(var, type);
if (!update) {
// set global type var map
......@@ -208,8 +208,8 @@ void ModuleNode::Update(const GlobalVar& var, const Function& func) {
this->Add(var, func, true);
}
void ModuleNode::UpdateDef(const GlobalTypeVar& var, const TypeData& type) {
this->AddDef(var, type, true);
void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) {
this->AddTypeDef(var, type, true);
}
void ModuleNode::Remove(const GlobalVar& var) {
......@@ -231,16 +231,16 @@ Function ModuleNode::Lookup(const std::string& name) const {
return this->Lookup(id);
}
TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const {
TypeData ModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var);
CHECK(it != type_definitions.end())
<< "There is no definition of " << var->name_hint;
return (*it).second;
}
TypeData ModuleNode::LookupDef(const std::string& name) const {
TypeData ModuleNode::LookupTypeDef(const std::string& name) const {
GlobalTypeVar id = this->GetGlobalTypeVar(name);
return this->LookupDef(id);
return this->LookupTypeDef(id);
}
Constructor ModuleNode::LookupTag(const int32_t tag) {
......@@ -257,13 +257,13 @@ void ModuleNode::Update(const Module& mod) {
this->AddUnchecked(pair.first, pair.second);
}
for (auto pair : mod->type_definitions) {
this->AddDefUnchecked(pair.first, pair.second);
this->AddTypeDefUnchecked(pair.first, pair.second);
}
for (auto pair : mod->functions) {
this->Update(pair.first, pair.second);
}
for (auto pair : mod->type_definitions) {
this->UpdateDef(pair.first, pair.second);
this->UpdateTypeDef(pair.first, pair.second);
}
}
......@@ -347,7 +347,7 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Add")
});
TVM_REGISTER_GLOBAL("relay._module.Module_AddDef")
.set_body_method<Module>(&ModuleNode::AddDef);
.set_body_method<Module>(&ModuleNode::AddTypeDef);
TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVar")
.set_body_method<Module>(&ModuleNode::GetGlobalVar);
......@@ -376,12 +376,12 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str")
TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef")
.set_body_typed([](Module mod, GlobalTypeVar var) {
return mod->LookupDef(var);
return mod->LookupTypeDef(var);
});
TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str")
.set_body_typed([](Module mod, std::string var) {
return mod->LookupDef(var);
return mod->LookupTypeDef(var);
});
TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag")
......
......@@ -101,7 +101,7 @@ class EtaExpander : public ExprMutator {
params.push_back(VarNode::make("eta_expand_param", param_type));
}
tvm::Array<Type> type_params;
TypeData adt_def = mod_->LookupDef(cons->belong_to);
TypeData adt_def = mod_->LookupTypeDef(cons->belong_to);
for (const auto& type_var : adt_def->type_vars) {
type_params.push_back(type_var_replacer_.VisitType(type_var));
}
......
......@@ -139,7 +139,7 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
// finally we need to check the module to check the number of type params
auto var = GetRef<GlobalTypeVar>(gtv);
auto data = mod->LookupDef(var);
auto data = mod->LookupTypeDef(var);
if (data->type_vars.size() != op->args.size()) {
ReportFatalError(RELAY_ERROR("Expected " << data->type_vars.size() << "arguments for " << tc
<< "; got " << op->args.size()));
......
......@@ -183,7 +183,7 @@ Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
// for a wildcard node, create constructor nodes with wildcards for all args.
if (cand.as<PatternWildcardNode>()) {
TypeData td = mod->LookupDef(gtv);
TypeData td = mod->LookupTypeDef(gtv);
// for each constructor add a candidate.
Array<Pattern> ret;
for (auto constructor : td->constructors) {
......
......@@ -574,7 +574,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
CHECK(mod_.defined())
<< "Cannot do type inference without a environment:"
<< c->name_hint;
TypeData td = mod_->LookupDef(c->belong_to);
TypeData td = mod_->LookupTypeDef(c->belong_to);
std::vector<Type> types;
for (const auto & t : td->type_vars) {
types.push_back(t);
......
......@@ -140,7 +140,7 @@ class TypeVarEVisitor : private ExprVisitor {
void VisitExpr_(const ConstructorNode* cn) final {
// for constructors, type vars will be bound in the module
auto data = mod_->LookupDef(cn->belong_to);
auto data = mod_->LookupTypeDef(cn->belong_to);
for (const auto& tv : data->type_vars) {
type_vars_.Insert(tv);
bound_type_vars_.Insert(tv);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment