Unverified Commit 78243e98 by Tianqi Chen Committed by GitHub

[REFACTOR] relay::Module Def -> TypeDef (#4665)

* [REFACTOR] relay::Module Def -> TypeDef

The term Def was not very clear about what is the object of interest(could be function def or type def).
Changes the term to TypeDef to be more explicit.

* Update include/tvm/relay/module.h

Co-Authored-By: Wei Chen <ipondering.weic@gmail.com>

Co-authored-by: Wei Chen <ipondering.weic@gmail.com>
parent 8a98a2e7
...@@ -104,18 +104,20 @@ class ModuleNode : public RelayNode { ...@@ -104,18 +104,20 @@ class ModuleNode : public RelayNode {
* \param update Controls whether you can replace a definition in the * \param update Controls whether you can replace a definition in the
* environment. * environment.
*/ */
TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type, bool update = false); TVM_DLL void AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update = false);
/*! /*!
* \brief Add a type definition to the global environment. * \brief Add a type-level definition to the global environment.
* \param var The name of the global function. * \param var The var of the global type definition.
* \param type The ADT. * \param type The ADT.
* \param update Controls whether you can replace a definition in the * \param update Controls whether you can replace a definition in the
* environment. * environment.
* *
* It does not do type inference as AddDef does. * It does not do type checking as AddTypeDef does.
*/ */
TVM_DLL void AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update = false); TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var,
const TypeData& type,
bool update = false);
/*! /*!
* \brief Update a function in the global environment. * \brief Update a function in the global environment.
...@@ -129,7 +131,7 @@ class ModuleNode : public RelayNode { ...@@ -129,7 +131,7 @@ class ModuleNode : public RelayNode {
* \param var The name of the global type definition to update. * \param var The name of the global type definition to update.
* \param type The new ADT. * \param type The new ADT.
*/ */
TVM_DLL void UpdateDef(const GlobalTypeVar& var, const TypeData& type); TVM_DLL void UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type);
/*! /*!
* \brief Remove a function from the global environment. * \brief Remove a function from the global environment.
...@@ -162,7 +164,7 @@ class ModuleNode : public RelayNode { ...@@ -162,7 +164,7 @@ class ModuleNode : public RelayNode {
* \brief Collect all global vars defined in this module. * \brief Collect all global vars defined in this module.
* \returns An array of global vars * \returns An array of global vars
*/ */
tvm::Array<GlobalVar> GetGlobalVars() const; TVM_DLL tvm::Array<GlobalVar> GetGlobalVars() const;
/*! /*!
* \brief Look up a global function by its name. * \brief Look up a global function by its name.
...@@ -175,7 +177,7 @@ class ModuleNode : public RelayNode { ...@@ -175,7 +177,7 @@ class ModuleNode : public RelayNode {
* \brief Collect all global type vars defined in this module. * \brief Collect all global type vars defined in this module.
* \returns An array of global type vars * \returns An array of global type vars
*/ */
tvm::Array<GlobalTypeVar> GetGlobalTypeVars() const; TVM_DLL tvm::Array<GlobalTypeVar> GetGlobalTypeVars() const;
/*! /*!
* \brief Look up a global function by its variable. * \brief Look up a global function by its variable.
...@@ -196,14 +198,14 @@ class ModuleNode : public RelayNode { ...@@ -196,14 +198,14 @@ class ModuleNode : public RelayNode {
* \param var The var of the global type definition. * \param var The var of the global type definition.
* \return The type definition. * \return The type definition.
*/ */
TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const; TVM_DLL TypeData LookupTypeDef(const GlobalTypeVar& var) const;
/*! /*!
* \brief Look up a global type definition by its name. * \brief Look up a global type definition by its name.
* \param var The name of the global type definition. * \param var The name of the global type definition.
* \return The type definition. * \return The type definition.
*/ */
TVM_DLL TypeData LookupDef(const std::string& var) const; TVM_DLL TypeData LookupTypeDef(const std::string& var) const;
/*! /*!
* \brief Look up a constructor by its tag. * \brief Look up a constructor by its tag.
......
...@@ -70,7 +70,7 @@ class AlphaEqualHandler: ...@@ -70,7 +70,7 @@ class AlphaEqualHandler:
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false; if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) { for (const auto& p : lhsm->type_definitions) {
if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) || if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->name_hint))) { !Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
return false; return false;
} }
} }
......
...@@ -185,15 +185,15 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& ...@@ -185,15 +185,15 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData&
} }
} }
void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type, bool update) { void ModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) {
AddDefUnchecked(var, type, update); AddTypeDefUnchecked(var, type, update);
// need to kind check at the end because the check can look up // need to kind check at the end because the check can look up
// a definition potentially // a definition potentially
CHECK(KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData) CHECK(KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
<< "Invalid or malformed typedata given to module: " << type; << "Invalid or malformed typedata given to module: " << type;
} }
void ModuleNode::AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) { void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) {
this->type_definitions.Set(var, type); this->type_definitions.Set(var, type);
if (!update) { if (!update) {
// set global type var map // set global type var map
...@@ -208,8 +208,8 @@ void ModuleNode::Update(const GlobalVar& var, const Function& func) { ...@@ -208,8 +208,8 @@ void ModuleNode::Update(const GlobalVar& var, const Function& func) {
this->Add(var, func, true); this->Add(var, func, true);
} }
void ModuleNode::UpdateDef(const GlobalTypeVar& var, const TypeData& type) { void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) {
this->AddDef(var, type, true); this->AddTypeDef(var, type, true);
} }
void ModuleNode::Remove(const GlobalVar& var) { void ModuleNode::Remove(const GlobalVar& var) {
...@@ -231,16 +231,16 @@ Function ModuleNode::Lookup(const std::string& name) const { ...@@ -231,16 +231,16 @@ Function ModuleNode::Lookup(const std::string& name) const {
return this->Lookup(id); return this->Lookup(id);
} }
TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const { TypeData ModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var); auto it = type_definitions.find(var);
CHECK(it != type_definitions.end()) CHECK(it != type_definitions.end())
<< "There is no definition of " << var->name_hint; << "There is no definition of " << var->name_hint;
return (*it).second; return (*it).second;
} }
TypeData ModuleNode::LookupDef(const std::string& name) const { TypeData ModuleNode::LookupTypeDef(const std::string& name) const {
GlobalTypeVar id = this->GetGlobalTypeVar(name); GlobalTypeVar id = this->GetGlobalTypeVar(name);
return this->LookupDef(id); return this->LookupTypeDef(id);
} }
Constructor ModuleNode::LookupTag(const int32_t tag) { Constructor ModuleNode::LookupTag(const int32_t tag) {
...@@ -257,13 +257,13 @@ void ModuleNode::Update(const Module& mod) { ...@@ -257,13 +257,13 @@ void ModuleNode::Update(const Module& mod) {
this->AddUnchecked(pair.first, pair.second); this->AddUnchecked(pair.first, pair.second);
} }
for (auto pair : mod->type_definitions) { for (auto pair : mod->type_definitions) {
this->AddDefUnchecked(pair.first, pair.second); this->AddTypeDefUnchecked(pair.first, pair.second);
} }
for (auto pair : mod->functions) { for (auto pair : mod->functions) {
this->Update(pair.first, pair.second); this->Update(pair.first, pair.second);
} }
for (auto pair : mod->type_definitions) { for (auto pair : mod->type_definitions) {
this->UpdateDef(pair.first, pair.second); this->UpdateTypeDef(pair.first, pair.second);
} }
} }
...@@ -347,7 +347,7 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Add") ...@@ -347,7 +347,7 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Add")
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_AddDef") TVM_REGISTER_GLOBAL("relay._module.Module_AddDef")
.set_body_method<Module>(&ModuleNode::AddDef); .set_body_method<Module>(&ModuleNode::AddTypeDef);
TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVar") TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVar")
.set_body_method<Module>(&ModuleNode::GetGlobalVar); .set_body_method<Module>(&ModuleNode::GetGlobalVar);
...@@ -376,12 +376,12 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str") ...@@ -376,12 +376,12 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str")
TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef") TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef")
.set_body_typed([](Module mod, GlobalTypeVar var) { .set_body_typed([](Module mod, GlobalTypeVar var) {
return mod->LookupDef(var); return mod->LookupTypeDef(var);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str") TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str")
.set_body_typed([](Module mod, std::string var) { .set_body_typed([](Module mod, std::string var) {
return mod->LookupDef(var); return mod->LookupTypeDef(var);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag") TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag")
......
...@@ -101,7 +101,7 @@ class EtaExpander : public ExprMutator { ...@@ -101,7 +101,7 @@ class EtaExpander : public ExprMutator {
params.push_back(VarNode::make("eta_expand_param", param_type)); params.push_back(VarNode::make("eta_expand_param", param_type));
} }
tvm::Array<Type> type_params; tvm::Array<Type> type_params;
TypeData adt_def = mod_->LookupDef(cons->belong_to); TypeData adt_def = mod_->LookupTypeDef(cons->belong_to);
for (const auto& type_var : adt_def->type_vars) { for (const auto& type_var : adt_def->type_vars) {
type_params.push_back(type_var_replacer_.VisitType(type_var)); type_params.push_back(type_var_replacer_.VisitType(type_var));
} }
......
...@@ -139,7 +139,7 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> { ...@@ -139,7 +139,7 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
// finally we need to check the module to check the number of type params // finally we need to check the module to check the number of type params
auto var = GetRef<GlobalTypeVar>(gtv); auto var = GetRef<GlobalTypeVar>(gtv);
auto data = mod->LookupDef(var); auto data = mod->LookupTypeDef(var);
if (data->type_vars.size() != op->args.size()) { if (data->type_vars.size() != op->args.size()) {
ReportFatalError(RELAY_ERROR("Expected " << data->type_vars.size() << "arguments for " << tc ReportFatalError(RELAY_ERROR("Expected " << data->type_vars.size() << "arguments for " << tc
<< "; got " << op->args.size())); << "; got " << op->args.size()));
......
...@@ -183,7 +183,7 @@ Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, ...@@ -183,7 +183,7 @@ Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
// for a wildcard node, create constructor nodes with wildcards for all args. // for a wildcard node, create constructor nodes with wildcards for all args.
if (cand.as<PatternWildcardNode>()) { if (cand.as<PatternWildcardNode>()) {
TypeData td = mod->LookupDef(gtv); TypeData td = mod->LookupTypeDef(gtv);
// for each constructor add a candidate. // for each constructor add a candidate.
Array<Pattern> ret; Array<Pattern> ret;
for (auto constructor : td->constructors) { for (auto constructor : td->constructors) {
......
...@@ -574,7 +574,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -574,7 +574,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
CHECK(mod_.defined()) CHECK(mod_.defined())
<< "Cannot do type inference without a environment:" << "Cannot do type inference without a environment:"
<< c->name_hint; << c->name_hint;
TypeData td = mod_->LookupDef(c->belong_to); TypeData td = mod_->LookupTypeDef(c->belong_to);
std::vector<Type> types; std::vector<Type> types;
for (const auto & t : td->type_vars) { for (const auto & t : td->type_vars) {
types.push_back(t); types.push_back(t);
......
...@@ -140,7 +140,7 @@ class TypeVarEVisitor : private ExprVisitor { ...@@ -140,7 +140,7 @@ class TypeVarEVisitor : private ExprVisitor {
void VisitExpr_(const ConstructorNode* cn) final { void VisitExpr_(const ConstructorNode* cn) final {
// for constructors, type vars will be bound in the module // for constructors, type vars will be bound in the module
auto data = mod_->LookupDef(cn->belong_to); auto data = mod_->LookupTypeDef(cn->belong_to);
for (const auto& tv : data->type_vars) { for (const auto& tv : data->type_vars) {
type_vars_.Insert(tv); type_vars_.Insert(tv);
bound_type_vars_.Insert(tv); bound_type_vars_.Insert(tv);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment