Unverified Commit 3d52a99c by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Allow Module to store BaseFunc. (#4678)

Under the unified IR. We will allow a single IRModule
to store different function variants, such as relay::Function,
ExternFunc, and low-level function.

This PR changes relay::Function -> BaseFunc in the module file
to support multiple function variants.
parent a684bd6f
...@@ -62,7 +62,7 @@ struct Module; ...@@ -62,7 +62,7 @@ struct Module;
class ModuleNode : public RelayNode { class ModuleNode : public RelayNode {
public: public:
/*! \brief A map from ids to all global functions. */ /*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions; tvm::Map<GlobalVar, BaseFunc> functions;
/*! \brief A map from global type vars to ADT type data. */ /*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions; tvm::Map<GlobalTypeVar, TypeData> type_definitions;
...@@ -75,7 +75,7 @@ class ModuleNode : public RelayNode { ...@@ -75,7 +75,7 @@ class ModuleNode : public RelayNode {
v->Visit("global_type_var_map_", &global_type_var_map_); v->Visit("global_type_var_map_", &global_type_var_map_);
} }
TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs, TVM_DLL static Module make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs, tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports = {}); std::unordered_set<std::string> imports = {});
...@@ -86,7 +86,7 @@ class ModuleNode : public RelayNode { ...@@ -86,7 +86,7 @@ class ModuleNode : public RelayNode {
* \param update Controls whether you can replace a definition in the * \param update Controls whether you can replace a definition in the
* environment. * environment.
*/ */
TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false); TVM_DLL void Add(const GlobalVar& var, const BaseFunc& func, bool update = false);
/*! /*!
* \brief Add a function to the global environment. * \brief Add a function to the global environment.
...@@ -95,7 +95,7 @@ class ModuleNode : public RelayNode { ...@@ -95,7 +95,7 @@ class ModuleNode : public RelayNode {
* *
* It does not do type inference as Add does. * It does not do type inference as Add does.
*/ */
TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func); TVM_DLL void AddUnchecked(const GlobalVar& var, const BaseFunc& func);
/*! /*!
* \brief Add a type-level definition to the global environment. * \brief Add a type-level definition to the global environment.
...@@ -124,7 +124,7 @@ class ModuleNode : public RelayNode { ...@@ -124,7 +124,7 @@ class ModuleNode : public RelayNode {
* \param var The name of the global function to update. * \param var The name of the global function to update.
* \param func The new function. * \param func The new function.
*/ */
TVM_DLL void Update(const GlobalVar& var, const Function& func); TVM_DLL void Update(const GlobalVar& var, const BaseFunc& func);
/*! /*!
* \brief Update a type definition in the global environment. * \brief Update a type definition in the global environment.
...@@ -184,14 +184,14 @@ class ModuleNode : public RelayNode { ...@@ -184,14 +184,14 @@ class ModuleNode : public RelayNode {
* \param var The global var to lookup. * \param var The global var to lookup.
* \returns The function named by the variable argument. * \returns The function named by the variable argument.
*/ */
TVM_DLL Function Lookup(const GlobalVar& var) const; TVM_DLL BaseFunc Lookup(const GlobalVar& var) const;
/*! /*!
* \brief Look up a global function by its string name * \brief Look up a global function by its string name
* \param name The name of the function. * \param name The name of the function.
* \returns The function named by the argument. * \returns The function named by the argument.
*/ */
TVM_DLL Function Lookup(const std::string& name) const; TVM_DLL BaseFunc Lookup(const std::string& name) const;
/*! /*!
* \brief Look up a global type definition by its variable. * \brief Look up a global type definition by its variable.
...@@ -256,7 +256,7 @@ class ModuleNode : public RelayNode { ...@@ -256,7 +256,7 @@ class ModuleNode : public RelayNode {
*/ */
TVM_DLL static Module FromExpr( TVM_DLL static Module FromExpr(
const Expr& expr, const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs = {}, const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {}); const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
static constexpr const char* _type_key = "relay.Module"; static constexpr const char* _type_key = "relay.Module";
......
...@@ -463,7 +463,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -463,7 +463,7 @@ class RelayBuildModule : public runtime::ModuleNode {
// Optimize input Relay Function and returns Relay Module // Optimize input Relay Function and returns Relay Module
relay::Module relay_module = Optimize(func, targets_, params); relay::Module relay_module = Optimize(func, targets_, params);
// Get the updated function. // Get the updated function.
func = relay_module->Lookup("main"); func = Downcast<Function>(relay_module->Lookup("main"));
// Generate code for the updated function. // Generate code for the updated function.
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen()); graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
......
...@@ -612,7 +612,13 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -612,7 +612,13 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
CHECK(it != context_->global_map.end()); CHECK(it != context_->global_map.end());
DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
<< " with func_index=" << it->second; << " with func_index=" << it->second;
auto func = context_->module->Lookup(global);
// TODO(tvm-team):
// Think about mixed call into global that is not a relay::Function
// perhaps establish as an invariance(all functions in mod must be relay::Function)
auto func = Downcast<Function>(context_->module->Lookup(global));
if (IsClosure(func)) { if (IsClosure(func)) {
auto arity = func->params.size(); auto arity = func->params.size();
Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister())); Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
...@@ -813,7 +819,10 @@ void VMCompiler::Lower(Module mod, ...@@ -813,7 +819,10 @@ void VMCompiler::Lower(Module mod,
CHECK_EQ(targets.size(), 1) CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation"; << "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) { if (params_.size()) {
auto f = BindParamsByName(mod->Lookup("main"), params_); BaseFunc base_func = mod->Lookup("main");
CHECK(base_func->IsInstance<FunctionNode>())
<< "VM compiler expects to compile relay::Function";
auto f = BindParamsByName(Downcast<Function>(base_func), params_);
auto gvar = mod->GetGlobalVar("main"); auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f); mod->Add(gvar, f);
} }
...@@ -837,13 +846,15 @@ void VMCompiler::Lower(Module mod, ...@@ -837,13 +846,15 @@ void VMCompiler::Lower(Module mod,
for (auto named_func : context_.module->functions) { for (auto named_func : context_.module->functions) {
auto gvar = named_func.first; auto gvar = named_func.first;
auto func = named_func.second; if (auto* n = named_func.second.as<FunctionNode>()) {
VMFunctionCompiler func_compiler(&context_, targets_, target_host_); auto func = GetRef<Function>(n);
auto vm_func = func_compiler.Compile(gvar, func); VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
auto vm_func = func_compiler.Compile(gvar, func);
size_t func_index = context_.global_map.at(gvar);
CHECK(func_index < exec_->functions.size()); size_t func_index = context_.global_map.at(gvar);
exec_->functions[func_index] = vm_func; CHECK(func_index < exec_->functions.size());
exec_->functions[func_index] = vm_func;
}
} }
#if USE_RELAY_DEBUG #if USE_RELAY_DEBUG
......
...@@ -110,19 +110,23 @@ struct PrimitiveInliner : ExprMutator { ...@@ -110,19 +110,23 @@ struct PrimitiveInliner : ExprMutator {
auto gvar_funcs = module_->functions; auto gvar_funcs = module_->functions;
for (auto pair : gvar_funcs) { for (auto pair : gvar_funcs) {
auto global = pair.first; auto global = pair.first;
auto func = pair.second; auto base_func = pair.second;
DLOG(INFO) << "Before inlining primitives: " << global if (auto* n = base_func.as<FunctionNode>()) {
<< std::endl << AsText(func, false); auto func = GetRef<Function>(n);
func = FunctionNode::make(func->params, DLOG(INFO) << "Before inlining primitives: " << global
VisitExpr(func->body), << std::endl << AsText(func, false);
func->ret_type,
func->type_params, func = FunctionNode::make(func->params,
func->attrs); VisitExpr(func->body),
module_->Add(global, func, true); func->ret_type,
func->type_params,
DLOG(INFO) << "After inlining primitives: " << global func->attrs);
<< std::endl << AsText(func, false); module_->Add(global, func, true);
DLOG(INFO) << "After inlining primitives: " << global
<< std::endl << AsText(func, false);
}
} }
return module_; return module_;
} }
......
...@@ -188,13 +188,15 @@ class LambdaLifter : public ExprMutator { ...@@ -188,13 +188,15 @@ class LambdaLifter : public ExprMutator {
// There is an ordering bug here. // There is an ordering bug here.
auto glob_funcs = module_->functions; auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) { for (auto pair : glob_funcs) {
auto func = pair.second; if (auto* n = pair.second.as<FunctionNode>()) {
func = FunctionNode::make(func->params, auto func = GetRef<Function>(n);
VisitExpr(func->body), func = FunctionNode::make(func->params,
func->ret_type, VisitExpr(func->body),
func->type_params, func->ret_type,
func->attrs); func->type_params,
module_->Add(pair.first, func, true); func->attrs);
module_->Add(pair.first, func, true);
}
} }
return module_; return module_;
} }
......
...@@ -34,10 +34,9 @@ namespace relay { ...@@ -34,10 +34,9 @@ namespace relay {
using tvm::NodePrinter; using tvm::NodePrinter;
using namespace runtime; using namespace runtime;
Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs, Module ModuleNode::make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs, tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports std::unordered_set<std::string> imports) {
) {
auto n = make_object<ModuleNode>(); auto n = make_object<ModuleNode>();
n->functions = std::move(global_funcs); n->functions = std::move(global_funcs);
n->type_definitions = std::move(global_type_defs); n->type_definitions = std::move(global_type_defs);
...@@ -112,40 +111,54 @@ tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) { ...@@ -112,40 +111,54 @@ tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
return ret; return ret;
} }
void ModuleNode::Add(const GlobalVar& var, // helper function to run type check
const Function& f, relay::Function RunTypeCheck(const Module& mod,
bool update) { const GlobalVar& var,
Function func = Downcast<Function>(DeDup(f)); relay::Function f) {
auto func = Downcast<relay::Function>(relay::DeDup(std::move(f)));
// Type check the item before we add it to the module. // Type check the item before we add it to the module.
auto mod = GetRef<Module>(this); auto fv = relay::FreeVars(func);
auto fv = FreeVars(func); auto ftv = relay::FreeTypeVars(func, mod);
auto ftv = FreeTypeVars(func, mod);
if (fv.size() != 0) { if (fv.size() != 0) {
LOG(WARNING) LOG(WARNING)
<< "There are free variables: " << "There are free variables: "
<< fv << fv
<< " in function: " << " in function: "
<< AsText(func, false) << AsText(func, false)
<< std::endl; << std::endl;
} }
if (ftv.size() != 0) { if (ftv.size() != 0) {
LOG(WARNING) LOG(WARNING)
<< "There are free type variables: " << "There are free type variables: "
<< ftv << ftv
<< " in function: " << " in function: "
<< AsText(func, false) << AsText(func, false)
<< std::endl; << std::endl;
} }
func = func =
FunctionNode::make(concat(func->params, fv), relay::FunctionNode::make(concat(func->params, fv),
func->body, func->body,
func->ret_type, func->ret_type,
concat(func->type_params, ftv), concat(func->type_params, ftv),
func->attrs); func->attrs);
// Type check the item before we add it to the module. // Type check the item before we add it to the module.
Function checked_func = InferType(func, mod, var); relay::Function checked_func = InferType(func, mod, var);
return checked_func;
}
void ModuleNode::Add(const GlobalVar& var,
const BaseFunc& f,
bool update) {
BaseFunc checked_func = f;
if (auto* ptr = f.as<relay::FunctionNode>()) {
checked_func = RunTypeCheck(GetRef<Module>(this),
var,
GetRef<relay::Function>(ptr));
}
auto type = checked_func->checked_type(); auto type = checked_func->checked_type();
CHECK(type.as<IncompleteTypeNode>() == nullptr); CHECK(type.as<relay::IncompleteTypeNode>() == nullptr);
if (functions.find(var) != functions.end()) { if (functions.find(var) != functions.end()) {
CHECK(update) CHECK(update)
<< "Already have definition for " << var->name_hint; << "Already have definition for " << var->name_hint;
...@@ -158,8 +171,7 @@ void ModuleNode::Add(const GlobalVar& var, ...@@ -158,8 +171,7 @@ void ModuleNode::Add(const GlobalVar& var,
} }
void ModuleNode::AddUnchecked(const GlobalVar& var, void ModuleNode::AddUnchecked(const GlobalVar& var,
const Function& func) { const BaseFunc& func) {
auto mod = GetRef<Module>(this);
this->functions.Set(var, func); this->functions.Set(var, func);
auto it = global_var_map_.find(var->name_hint); auto it = global_var_map_.find(var->name_hint);
...@@ -185,15 +197,19 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& ...@@ -185,15 +197,19 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData&
} }
} }
void ModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) { void ModuleNode::AddTypeDef(const GlobalTypeVar& var,
const TypeData& type,
bool update) {
AddTypeDefUnchecked(var, type, update); AddTypeDefUnchecked(var, type, update);
// need to kind check at the end because the check can look up // need to kind check at the end because the check can look up
// a definition potentially // a definition potentially
CHECK(KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData) CHECK(relay::KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
<< "Invalid or malformed typedata given to module: " << type; << "Invalid or malformed typedata given to module: " << type;
} }
void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) { void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var,
const TypeData& type,
bool update) {
this->type_definitions.Set(var, type); this->type_definitions.Set(var, type);
if (!update) { if (!update) {
// set global type var map // set global type var map
...@@ -204,11 +220,13 @@ void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& t ...@@ -204,11 +220,13 @@ void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& t
RegisterConstructors(var, type); RegisterConstructors(var, type);
} }
void ModuleNode::Update(const GlobalVar& var, const Function& func) { void ModuleNode::Update(const GlobalVar& var,
const BaseFunc& func) {
this->Add(var, func, true); this->Add(var, func, true);
} }
void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) { void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var,
const TypeData& type) {
this->AddTypeDef(var, type, true); this->AddTypeDef(var, type, true);
} }
...@@ -219,14 +237,14 @@ void ModuleNode::Remove(const GlobalVar& var) { ...@@ -219,14 +237,14 @@ void ModuleNode::Remove(const GlobalVar& var) {
gvar_node->data.erase(var->name_hint); gvar_node->data.erase(var->name_hint);
} }
Function ModuleNode::Lookup(const GlobalVar& var) const { BaseFunc ModuleNode::Lookup(const GlobalVar& var) const {
auto it = functions.find(var); auto it = functions.find(var);
CHECK(it != functions.end()) CHECK(it != functions.end())
<< "There is no definition of " << var->name_hint; << "There is no definition of " << var->name_hint;
return (*it).second; return (*it).second;
} }
Function ModuleNode::Lookup(const std::string& name) const { BaseFunc ModuleNode::Lookup(const std::string& name) const {
GlobalVar id = this->GetGlobalVar(name); GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id); return this->Lookup(id);
} }
...@@ -268,16 +286,17 @@ void ModuleNode::Update(const Module& mod) { ...@@ -268,16 +286,17 @@ void ModuleNode::Update(const Module& mod) {
} }
Module ModuleNode::FromExpr( Module ModuleNode::FromExpr(
const Expr& expr, const RelayExpr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs, const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) { const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = ModuleNode::make(global_funcs, type_definitions); auto mod = ModuleNode::make(global_funcs, type_definitions);
auto func_node = expr.as<FunctionNode>(); BaseFunc func;
Function func; if (auto* func_node = expr.as<relay::FunctionNode>()) {
if (func_node) { func = GetRef<relay::Function>(func_node);
func = GetRef<Function>(func_node);
} else { } else {
func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {}); func = relay::FunctionNode::make(
relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {});
} }
auto main_gv = GlobalVar("main"); auto main_gv = GlobalVar("main");
mod->Add(main_gv, func); mod->Add(main_gv, func);
...@@ -318,8 +337,8 @@ Module FromText(const std::string& source, const std::string& source_name) { ...@@ -318,8 +337,8 @@ Module FromText(const std::string& source, const std::string& source_name) {
TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_GLOBAL("relay._make.Module") TVM_REGISTER_GLOBAL("relay._make.Module")
.set_body_typed( .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
[](tvm::Map<GlobalVar, Function> funcs, tvm::Map<GlobalTypeVar, TypeData> types) { tvm::Map<GlobalTypeVar, TypeData> types) {
return ModuleNode::make(funcs, types, {}); return ModuleNode::make(funcs, types, {});
}); });
...@@ -330,17 +349,19 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Add") ...@@ -330,17 +349,19 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Add")
ObjectRef val = args[2]; ObjectRef val = args[2];
bool update = args[3]; bool update = args[3];
CHECK(val->IsInstance<ExprNode>()); CHECK(val->IsInstance<ExprNode>());
if (val->IsInstance<FunctionNode>()) {
mod->Add(var, Downcast<Function>(val), update); if (val->IsInstance<relay::FunctionNode>()) {
mod->Add(var, Downcast<relay::Function>(val), update);
} else if (val->IsInstance<GlobalVarNode>()) { } else if (val->IsInstance<GlobalVarNode>()) {
GlobalVar gv = Downcast<GlobalVar>(val); GlobalVar gv = Downcast<GlobalVar>(val);
auto mod_copy = Module(make_object<ModuleNode>(*mod.operator->())); auto mod_copy = Module(make_object<ModuleNode>(*mod.operator->()));
mod_copy = transform::EtaExpand( mod_copy = relay::transform::EtaExpand(
/* expand_constructor */ false, /* expand_global_var */ true)(mod_copy); /* expand_constructor */ false,
/* expand_global_var */ true)(mod_copy);
auto func = mod_copy->Lookup(gv->name_hint); auto func = mod_copy->Lookup(gv->name_hint);
mod->Add(var, Downcast<Function>(func), update); mod->Add(var, Downcast<relay::Function>(func), update);
} else { } else {
auto func = FunctionNode::make({}, Downcast<Expr>(val), Type(nullptr), {}); auto func = FunctionNode::make({}, Downcast<relay::Expr>(val), Type(nullptr), {});
mod->Add(var, func, update); mod->Add(var, func, update);
} }
*ret = mod; *ret = mod;
...@@ -390,8 +411,8 @@ TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag") ...@@ -390,8 +411,8 @@ TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag")
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr") TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr")
.set_body_typed([](Expr e, .set_body_typed([](RelayExpr e,
tvm::Map<GlobalVar, Function> funcs, tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> type_defs) { tvm::Map<GlobalTypeVar, TypeData> type_defs) {
return ModuleNode::FromExpr(e, funcs, type_defs); return ModuleNode::FromExpr(e, funcs, type_defs);
}); });
......
...@@ -486,7 +486,7 @@ class PrettyPrinter : ...@@ -486,7 +486,7 @@ class PrettyPrinter :
return doc; return doc;
} }
Doc PrintFunc(const Doc& prefix, const Function& fn) { Doc PrintFunc(const Doc& prefix, const relay::Function& fn) {
Doc doc; Doc doc;
doc << prefix; doc << prefix;
if (fn->type_params.size() > 0) { if (fn->type_params.size() > 0) {
...@@ -514,6 +514,17 @@ class PrettyPrinter : ...@@ -514,6 +514,17 @@ class PrettyPrinter :
return doc; return doc;
} }
Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) {
if (auto* n = base_func.as<relay::FunctionNode>()) {
return PrintFunc(prefix, GetRef<relay::Function>(n));
} else {
// def @xyz = meta['ExternalFunc'][id]
Doc doc;
doc << prefix << " = " << meta_.GetMetaNode(base_func);
return doc;
}
}
Doc PrintMod(const Module& mod) { Doc PrintMod(const Module& mod) {
Doc doc; Doc doc;
int counter = 0; int counter = 0;
......
...@@ -68,9 +68,12 @@ class EtaExpander : public ExprMutator { ...@@ -68,9 +68,12 @@ class EtaExpander : public ExprMutator {
Module Expand() { Module Expand() {
for (GlobalVar global_var : mod_->GetGlobalVars()) { for (GlobalVar global_var : mod_->GetGlobalVars()) {
const Function func = mod_->Lookup(global_var); const BaseFunc base_func = mod_->Lookup(global_var);
const Function new_func = Downcast<Function>(VisitExpr(func)); if (auto* n = base_func.as<FunctionNode>()) {
mod_->Update(global_var, new_func); const Function new_func = Downcast<Function>(
VisitExpr(GetRef<Function>(n)));
mod_->Update(global_var, new_func);
}
} }
return mod_; return mod_;
} }
...@@ -120,21 +123,26 @@ class EtaExpander : public ExprMutator { ...@@ -120,21 +123,26 @@ class EtaExpander : public ExprMutator {
if (!expand_global_var_) { if (!expand_global_var_) {
return std::move(gvar); return std::move(gvar);
} }
const auto base_func = mod_->Lookup(gvar);
const auto func = mod_->Lookup(gvar); if (auto *ptr = base_func.as<FunctionNode>()) {
tvm::Array<Expr> params; // handle relay function, skip external functions.
tvm::Array<Var> args; auto func = GetRef<Function>(ptr);
for (size_t i = 0; i < func->params.size(); ++i) { tvm::Array<Expr> params;
auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation); tvm::Array<Var> args;
params.push_back(var); for (size_t i = 0; i < func->params.size(); ++i) {
args.push_back(var); auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation);
} params.push_back(var);
args.push_back(var);
}
return FunctionNode::make( return FunctionNode::make(
args, args,
CallNode::make(gvar, params), CallNode::make(gvar, params),
func->ret_type, func->ret_type,
func->type_params); func->type_params);
} else {
return std::move(gvar);
}
} }
private: private:
......
...@@ -217,7 +217,7 @@ class ConstantFolder : public ExprMutator { ...@@ -217,7 +217,7 @@ class ConstantFolder : public ExprMutator {
mod->Add(global, func); mod->Add(global, func);
auto seq = transform::Sequential(passes); auto seq = transform::Sequential(passes);
mod = seq(mod); mod = seq(mod);
auto entry_func = mod->Lookup("main"); auto entry_func = Downcast<Function>(mod->Lookup("main"));
expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func; expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return ObjectToExpr(executor_(expr)); return ObjectToExpr(executor_(expr));
} }
......
...@@ -82,7 +82,12 @@ Type WithGradientType(const Type& t) { ...@@ -82,7 +82,12 @@ Type WithGradientType(const Type& t) {
//! \brief if the expression is a GlobalVar, transform to it's expression. //! \brief if the expression is a GlobalVar, transform to it's expression.
Expr DeGlobal(const Module& mod, const Expr& e) { Expr DeGlobal(const Module& mod, const Expr& e) {
if (const auto* x = e.as<GlobalVarNode>()) { if (const auto* x = e.as<GlobalVarNode>()) {
return mod->Lookup(GetRef<GlobalVar>(x))->body; BaseFunc base_func = mod->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) {
return n->body;
} else {
return e;
}
} else { } else {
return e; return e;
} }
......
...@@ -676,12 +676,18 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -676,12 +676,18 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic VisitGlobalVar(const GlobalVar& gv) { PStatic VisitGlobalVar(const GlobalVar& gv) {
CHECK(mod_.defined()); CHECK(mod_.defined());
if (gv_map_.count(gv) == 0) { if (gv_map_.count(gv) == 0) {
Function func = mod_->Lookup(gv); BaseFunc base_func = mod_->Lookup(gv);
InitializeFuncId(func); if (auto* n = base_func.as<FunctionNode>()) {
Func f = VisitFuncStatic(func, gv); Function func = GetRef<Function>(n);
gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); InitializeFuncId(func);
func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv))); Func f = VisitFuncStatic(func, gv);
mod_->Update(gv, func); gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv)));
mod_->Update(gv, func);
return gv_map_.at(gv);
} else {
return NoStatic(gv);
}
} }
return gv_map_.at(gv); return gv_map_.at(gv);
} }
...@@ -951,7 +957,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -951,7 +957,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
auto mod = ModuleNode::FromExpr(expr); auto mod = ModuleNode::FromExpr(expr);
auto seq = transform::Sequential(passes); auto seq = transform::Sequential(passes);
mod = seq(mod); mod = seq(mod);
auto entry_func = mod->Lookup("main"); auto entry_func = Downcast<Function>(mod->Lookup("main"));
auto fused_infered = auto fused_infered =
expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func; expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return Reify(executor_(fused_infered), ll); return Reify(executor_(fused_infered), ll);
......
...@@ -323,10 +323,14 @@ Module FunctionPassNode::operator()(const Module& mod, ...@@ -323,10 +323,14 @@ Module FunctionPassNode::operator()(const Module& mod,
Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions, mod->Imports()); Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates; std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : updated_mod->functions) { for (const auto& it : updated_mod->functions) {
auto updated_func = SkipFunction(it.second) // only picks up relay::Function
? it.second if (auto* n = it.second.as<FunctionNode>()) {
: pass_func(it.second, updated_mod, pass_ctx); Function func = GetRef<Function>(n);
updates.push_back({it.first, updated_func}); auto updated_func = SkipFunction(func)
? func
: pass_func(func, updated_mod, pass_ctx);
updates.push_back({it.first, updated_func});
}
} }
for (const auto& pair : updates) { for (const auto& pair : updates) {
......
...@@ -192,7 +192,7 @@ Expr QuantizeRealize(const Call& ref_call, ...@@ -192,7 +192,7 @@ Expr QuantizeRealize(const Call& ref_call,
Expr FoldConstantOpt(const Expr& expr) { Expr FoldConstantOpt(const Expr& expr) {
auto mod = ModuleNode::FromExpr(expr); auto mod = ModuleNode::FromExpr(expr);
mod = transform::FoldConstant()(mod); mod = transform::FoldConstant()(mod);
auto entry_func = mod->Lookup("main"); auto entry_func = Downcast<Function>(mod->Lookup("main"));
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func; return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
} }
......
...@@ -155,9 +155,17 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const ...@@ -155,9 +155,17 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const
Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final { Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final {
auto gv = GetRef<GlobalVar>(op); auto gv = GetRef<GlobalVar>(op);
if (cm->count(gv) == 0) { if (cm->count(gv) == 0) {
auto cps_gv = GlobalVar(gv->name_hint + "_cps"); // only look unfold non-external calls.
cm->insert({gv, cps_gv}); BaseFunc base_func = m->Lookup(gv);
m->Add(cps_gv, ToCPS(m->Lookup(gv), m, cm)); if (auto* n = base_func.as<FunctionNode>()) {
auto cps_gv = GlobalVar(gv->name_hint + "_cps");
cm->insert({gv, cps_gv});
m->Add(cps_gv, ToCPS(GetRef<Function>(n), m, cm));
} else {
// return the original global var if it is
// an external call to non-relay function.
return GetRef<GlobalVar>(op);
}
} }
return k(cm->at(gv)); return k(cm->at(gv));
} }
......
...@@ -86,7 +86,7 @@ TEST(Relay, Sequential) { ...@@ -86,7 +86,7 @@ TEST(Relay, Sequential) {
CHECK(mod.defined()); CHECK(mod.defined());
auto entry_func = mod->GetGlobalVar("main"); auto entry_func = mod->GetGlobalVar("main");
CHECK(entry_func.defined()); CHECK(entry_func.defined());
relay::Function f = mod->Lookup("main"); relay::Function f = Downcast<relay::Function>(mod->Lookup("main"));
CHECK(f.defined()); CHECK(f.defined());
// Expected function // Expected function
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment