Commit 25bad440 by 雾雨魔理沙 Committed by Tianqi Chen

fix (#3417)

parent 311434e8
...@@ -123,42 +123,42 @@ class ModuleNode : public RelayNode { ...@@ -123,42 +123,42 @@ class ModuleNode : public RelayNode {
* \param str The unique string specifying the global variable. * \param str The unique string specifying the global variable.
* \returns The global variable. * \returns The global variable.
*/ */
TVM_DLL GlobalVar GetGlobalVar(const std::string& str); TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const;
/*! /*!
* \brief Look up a global function by its name. * \brief Look up a global function by its name.
* \param str The unique string specifying the global variable. * \param str The unique string specifying the global variable.
* \returns The global variable. * \returns The global variable.
*/ */
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str); TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;
/*! /*!
* \brief Lookup a global function by its variable. * \brief Lookup a global function by its variable.
* \param var The global var to lookup. * \param var The global var to lookup.
* \returns The function named by the variable argument. * \returns The function named by the variable argument.
*/ */
TVM_DLL Function Lookup(const GlobalVar& var); TVM_DLL Function Lookup(const GlobalVar& var) const;
/*! /*!
* \brief Lookup a global function by its string name * \brief Lookup a global function by its string name
* \param name The name of the function. * \param name The name of the function.
* \returns The function named by the argument. * \returns The function named by the argument.
*/ */
TVM_DLL Function Lookup(const std::string& name); TVM_DLL Function Lookup(const std::string& name) const;
/*! /*!
* \brief Lookup a global type definition by its variable. * \brief Lookup a global type definition by its variable.
* \param var The var of the global type definition. * \param var The var of the global type definition.
* \return The type definition. * \return The type definition.
*/ */
TVM_DLL TypeData LookupDef(const GlobalTypeVar& var); TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const;
/*! /*!
* \brief Lookup a global type definition by its name. * \brief Lookup a global type definition by its name.
* \param var The name of the global type definition. * \param var The name of the global type definition.
* \return The type definition. * \return The type definition.
*/ */
TVM_DLL TypeData LookupDef(const std::string& var); TVM_DLL TypeData LookupDef(const std::string& var) const;
/*! /*!
* \brief Update the functions inside this environment by * \brief Update the functions inside this environment by
......
...@@ -112,7 +112,7 @@ struct LambdaLifter : ExprMutator { ...@@ -112,7 +112,7 @@ struct LambdaLifter : ExprMutator {
CHECK(lifted_func.defined()); CHECK(lifted_func.defined());
auto name = GenerateName(lifted_func); auto name = GenerateName(lifted_func);
auto global = module_->GetGlobalVar(name); auto global = GlobalVarNode::make(name);
// Add the lifted function to the module. // Add the lifted function to the module.
module_->Add(global, lifted_func); module_->Add(global, lifted_func);
......
...@@ -57,15 +57,11 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs, ...@@ -57,15 +57,11 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
return Module(n); return Module(n);
} }
GlobalVar ModuleNode::GetGlobalVar(const std::string& name) { GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const {
auto it = global_var_map_.find(name); auto it = global_var_map_.find(name);
if (it == global_var_map_.end()) { CHECK(it != global_var_map_.end())
auto gvar = GlobalVarNode::make(name); << "Cannot find global var " << name << " in the Module";
global_var_map_.Set(name, gvar); return (*it).second;
return gvar;
} else {
return (*it).second;
}
} }
void ModuleNode::AddUnchecked(const GlobalVar& var, void ModuleNode::AddUnchecked(const GlobalVar& var,
...@@ -84,7 +80,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var, ...@@ -84,7 +80,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var,
global_var_map_.Set(var->name_hint, var); global_var_map_.Set(var->name_hint, var);
} }
GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) { GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
auto it = global_type_var_map_.find(name); auto it = global_type_var_map_.find(name);
CHECK(it != global_type_var_map_.end()) CHECK(it != global_type_var_map_.end())
<< "Cannot find global type var " << name << " in the Module"; << "Cannot find global type var " << name << " in the Module";
...@@ -137,26 +133,26 @@ void ModuleNode::Remove(const GlobalVar& var) { ...@@ -137,26 +133,26 @@ void ModuleNode::Remove(const GlobalVar& var) {
gvar_node->data.erase(var->name_hint); gvar_node->data.erase(var->name_hint);
} }
Function ModuleNode::Lookup(const GlobalVar& var) { Function ModuleNode::Lookup(const GlobalVar& var) const {
auto it = functions.find(var); auto it = functions.find(var);
CHECK(it != functions.end()) CHECK(it != functions.end())
<< "There is no definition of " << var->name_hint; << "There is no definition of " << var->name_hint;
return (*it).second; return (*it).second;
} }
Function ModuleNode::Lookup(const std::string& name) { Function ModuleNode::Lookup(const std::string& name) const {
GlobalVar id = this->GetGlobalVar(name); GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id); return this->Lookup(id);
} }
TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) { TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var); auto it = type_definitions.find(var);
CHECK(it != type_definitions.end()) CHECK(it != type_definitions.end())
<< "There is no definition of " << var->var->name_hint; << "There is no definition of " << var->var->name_hint;
return (*it).second; return (*it).second;
} }
TypeData ModuleNode::LookupDef(const std::string& name) { TypeData ModuleNode::LookupDef(const std::string& name) const {
GlobalTypeVar id = this->GetGlobalTypeVar(name); GlobalTypeVar id = this->GetGlobalTypeVar(name);
return this->LookupDef(id); return this->LookupDef(id);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment