Unverified Commit 3d52a99c by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Allow Module to store BaseFunc. (#4678)

Under the unified IR. We will allow a single IRModule
to store different function variants, such as relay::Function,
ExternFunc, and low-level function.

This PR changes relay::Function -> BaseFunc in the module file
to support multiple function variants.
parent a684bd6f
......@@ -62,7 +62,7 @@ struct Module;
class ModuleNode : public RelayNode {
public:
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions;
tvm::Map<GlobalVar, BaseFunc> functions;
/*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions;
......@@ -75,7 +75,7 @@ class ModuleNode : public RelayNode {
v->Visit("global_type_var_map_", &global_type_var_map_);
}
TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs,
TVM_DLL static Module make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports = {});
......@@ -86,7 +86,7 @@ class ModuleNode : public RelayNode {
* \param update Controls whether you can replace a definition in the
* environment.
*/
TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false);
TVM_DLL void Add(const GlobalVar& var, const BaseFunc& func, bool update = false);
/*!
* \brief Add a function to the global environment.
......@@ -95,7 +95,7 @@ class ModuleNode : public RelayNode {
*
* It does not do type inference as Add does.
*/
TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);
TVM_DLL void AddUnchecked(const GlobalVar& var, const BaseFunc& func);
/*!
* \brief Add a type-level definition to the global environment.
......@@ -124,7 +124,7 @@ class ModuleNode : public RelayNode {
* \param var The name of the global function to update.
* \param func The new function.
*/
TVM_DLL void Update(const GlobalVar& var, const Function& func);
TVM_DLL void Update(const GlobalVar& var, const BaseFunc& func);
/*!
* \brief Update a type definition in the global environment.
......@@ -184,14 +184,14 @@ class ModuleNode : public RelayNode {
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
TVM_DLL Function Lookup(const GlobalVar& var) const;
TVM_DLL BaseFunc Lookup(const GlobalVar& var) const;
/*!
* \brief Look up a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
TVM_DLL Function Lookup(const std::string& name) const;
TVM_DLL BaseFunc Lookup(const std::string& name) const;
/*!
* \brief Look up a global type definition by its variable.
......@@ -256,7 +256,7 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs = {},
const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
static constexpr const char* _type_key = "relay.Module";
......
......@@ -463,7 +463,7 @@ class RelayBuildModule : public runtime::ModuleNode {
// Optimize input Relay Function and returns Relay Module
relay::Module relay_module = Optimize(func, targets_, params);
// Get the updated function.
func = relay_module->Lookup("main");
func = Downcast<Function>(relay_module->Lookup("main"));
// Generate code for the updated function.
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
......
......@@ -612,7 +612,13 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
CHECK(it != context_->global_map.end());
DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
<< " with func_index=" << it->second;
auto func = context_->module->Lookup(global);
// TODO(tvm-team):
// Think about mixed call into global that is not a relay::Function
// perhaps establish as an invariance(all functions in mod must be relay::Function)
auto func = Downcast<Function>(context_->module->Lookup(global));
if (IsClosure(func)) {
auto arity = func->params.size();
Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
......@@ -813,7 +819,10 @@ void VMCompiler::Lower(Module mod,
CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) {
auto f = BindParamsByName(mod->Lookup("main"), params_);
BaseFunc base_func = mod->Lookup("main");
CHECK(base_func->IsInstance<FunctionNode>())
<< "VM compiler expects to compile relay::Function";
auto f = BindParamsByName(Downcast<Function>(base_func), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}
......@@ -837,13 +846,15 @@ void VMCompiler::Lower(Module mod,
for (auto named_func : context_.module->functions) {
auto gvar = named_func.first;
auto func = named_func.second;
VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
auto vm_func = func_compiler.Compile(gvar, func);
size_t func_index = context_.global_map.at(gvar);
CHECK(func_index < exec_->functions.size());
exec_->functions[func_index] = vm_func;
if (auto* n = named_func.second.as<FunctionNode>()) {
auto func = GetRef<Function>(n);
VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
auto vm_func = func_compiler.Compile(gvar, func);
size_t func_index = context_.global_map.at(gvar);
CHECK(func_index < exec_->functions.size());
exec_->functions[func_index] = vm_func;
}
}
#if USE_RELAY_DEBUG
......
......@@ -110,19 +110,23 @@ struct PrimitiveInliner : ExprMutator {
auto gvar_funcs = module_->functions;
for (auto pair : gvar_funcs) {
auto global = pair.first;
auto func = pair.second;
DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(global, func, true);
DLOG(INFO) << "After inlining primitives: " << global
<< std::endl << AsText(func, false);
auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) {
auto func = GetRef<Function>(n);
DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(global, func, true);
DLOG(INFO) << "After inlining primitives: " << global
<< std::endl << AsText(func, false);
}
}
return module_;
}
......
......@@ -188,13 +188,15 @@ class LambdaLifter : public ExprMutator {
// There is an ordering bug here.
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
auto func = pair.second;
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(pair.first, func, true);
if (auto* n = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(n);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(pair.first, func, true);
}
}
return module_;
}
......
......@@ -34,10 +34,9 @@ namespace relay {
using tvm::NodePrinter;
using namespace runtime;
Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
Module ModuleNode::make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports
) {
std::unordered_set<std::string> imports) {
auto n = make_object<ModuleNode>();
n->functions = std::move(global_funcs);
n->type_definitions = std::move(global_type_defs);
......@@ -112,40 +111,54 @@ tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
return ret;
}
void ModuleNode::Add(const GlobalVar& var,
const Function& f,
bool update) {
Function func = Downcast<Function>(DeDup(f));
// helper function to run type check
relay::Function RunTypeCheck(const Module& mod,
const GlobalVar& var,
relay::Function f) {
auto func = Downcast<relay::Function>(relay::DeDup(std::move(f)));
// Type check the item before we add it to the module.
auto mod = GetRef<Module>(this);
auto fv = FreeVars(func);
auto ftv = FreeTypeVars(func, mod);
auto fv = relay::FreeVars(func);
auto ftv = relay::FreeTypeVars(func, mod);
if (fv.size() != 0) {
LOG(WARNING)
<< "There are free variables: "
<< fv
<< " in function: "
<< AsText(func, false)
<< std::endl;
<< "There are free variables: "
<< fv
<< " in function: "
<< AsText(func, false)
<< std::endl;
}
if (ftv.size() != 0) {
LOG(WARNING)
<< "There are free type variables: "
<< ftv
<< " in function: "
<< AsText(func, false)
<< std::endl;
<< "There are free type variables: "
<< ftv
<< " in function: "
<< AsText(func, false)
<< std::endl;
}
func =
FunctionNode::make(concat(func->params, fv),
func->body,
func->ret_type,
concat(func->type_params, ftv),
func->attrs);
relay::FunctionNode::make(concat(func->params, fv),
func->body,
func->ret_type,
concat(func->type_params, ftv),
func->attrs);
// Type check the item before we add it to the module.
Function checked_func = InferType(func, mod, var);
relay::Function checked_func = InferType(func, mod, var);
return checked_func;
}
void ModuleNode::Add(const GlobalVar& var,
const BaseFunc& f,
bool update) {
BaseFunc checked_func = f;
if (auto* ptr = f.as<relay::FunctionNode>()) {
checked_func = RunTypeCheck(GetRef<Module>(this),
var,
GetRef<relay::Function>(ptr));
}
auto type = checked_func->checked_type();
CHECK(type.as<IncompleteTypeNode>() == nullptr);
CHECK(type.as<relay::IncompleteTypeNode>() == nullptr);
if (functions.find(var) != functions.end()) {
CHECK(update)
<< "Already have definition for " << var->name_hint;
......@@ -158,8 +171,7 @@ void ModuleNode::Add(const GlobalVar& var,
}
void ModuleNode::AddUnchecked(const GlobalVar& var,
const Function& func) {
auto mod = GetRef<Module>(this);
const BaseFunc& func) {
this->functions.Set(var, func);
auto it = global_var_map_.find(var->name_hint);
......@@ -185,15 +197,19 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData&
}
}
void ModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) {
void ModuleNode::AddTypeDef(const GlobalTypeVar& var,
const TypeData& type,
bool update) {
AddTypeDefUnchecked(var, type, update);
// need to kind check at the end because the check can look up
// a definition potentially
CHECK(KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
CHECK(relay::KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
<< "Invalid or malformed typedata given to module: " << type;
}
void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) {
void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var,
const TypeData& type,
bool update) {
this->type_definitions.Set(var, type);
if (!update) {
// set global type var map
......@@ -204,11 +220,13 @@ void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& t
RegisterConstructors(var, type);
}
void ModuleNode::Update(const GlobalVar& var, const Function& func) {
void ModuleNode::Update(const GlobalVar& var,
const BaseFunc& func) {
this->Add(var, func, true);
}
void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) {
void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var,
const TypeData& type) {
this->AddTypeDef(var, type, true);
}
......@@ -219,14 +237,14 @@ void ModuleNode::Remove(const GlobalVar& var) {
gvar_node->data.erase(var->name_hint);
}
Function ModuleNode::Lookup(const GlobalVar& var) const {
BaseFunc ModuleNode::Lookup(const GlobalVar& var) const {
auto it = functions.find(var);
CHECK(it != functions.end())
<< "There is no definition of " << var->name_hint;
return (*it).second;
}
Function ModuleNode::Lookup(const std::string& name) const {
BaseFunc ModuleNode::Lookup(const std::string& name) const {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
}
......@@ -268,16 +286,17 @@ void ModuleNode::Update(const Module& mod) {
}
Module ModuleNode::FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs,
const RelayExpr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = ModuleNode::make(global_funcs, type_definitions);
auto func_node = expr.as<FunctionNode>();
Function func;
if (func_node) {
func = GetRef<Function>(func_node);
BaseFunc func;
if (auto* func_node = expr.as<relay::FunctionNode>()) {
func = GetRef<relay::Function>(func_node);
} else {
func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {});
func = relay::FunctionNode::make(
relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVar("main");
mod->Add(main_gv, func);
......@@ -318,8 +337,8 @@ Module FromText(const std::string& source, const std::string& source_name) {
TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_GLOBAL("relay._make.Module")
.set_body_typed(
[](tvm::Map<GlobalVar, Function> funcs, tvm::Map<GlobalTypeVar, TypeData> types) {
.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> types) {
return ModuleNode::make(funcs, types, {});
});
......@@ -330,17 +349,19 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Add")
ObjectRef val = args[2];
bool update = args[3];
CHECK(val->IsInstance<ExprNode>());
if (val->IsInstance<FunctionNode>()) {
mod->Add(var, Downcast<Function>(val), update);
if (val->IsInstance<relay::FunctionNode>()) {
mod->Add(var, Downcast<relay::Function>(val), update);
} else if (val->IsInstance<GlobalVarNode>()) {
GlobalVar gv = Downcast<GlobalVar>(val);
auto mod_copy = Module(make_object<ModuleNode>(*mod.operator->()));
mod_copy = transform::EtaExpand(
/* expand_constructor */ false, /* expand_global_var */ true)(mod_copy);
mod_copy = relay::transform::EtaExpand(
/* expand_constructor */ false,
/* expand_global_var */ true)(mod_copy);
auto func = mod_copy->Lookup(gv->name_hint);
mod->Add(var, Downcast<Function>(func), update);
mod->Add(var, Downcast<relay::Function>(func), update);
} else {
auto func = FunctionNode::make({}, Downcast<Expr>(val), Type(nullptr), {});
auto func = FunctionNode::make({}, Downcast<relay::Expr>(val), Type(nullptr), {});
mod->Add(var, func, update);
}
*ret = mod;
......@@ -390,8 +411,8 @@ TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag")
});
TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr")
.set_body_typed([](Expr e,
tvm::Map<GlobalVar, Function> funcs,
.set_body_typed([](RelayExpr e,
tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> type_defs) {
return ModuleNode::FromExpr(e, funcs, type_defs);
});
......
......@@ -486,7 +486,7 @@ class PrettyPrinter :
return doc;
}
Doc PrintFunc(const Doc& prefix, const Function& fn) {
Doc PrintFunc(const Doc& prefix, const relay::Function& fn) {
Doc doc;
doc << prefix;
if (fn->type_params.size() > 0) {
......@@ -514,6 +514,17 @@ class PrettyPrinter :
return doc;
}
Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) {
if (auto* n = base_func.as<relay::FunctionNode>()) {
return PrintFunc(prefix, GetRef<relay::Function>(n));
} else {
// def @xyz = meta['ExternalFunc'][id]
Doc doc;
doc << prefix << " = " << meta_.GetMetaNode(base_func);
return doc;
}
}
Doc PrintMod(const Module& mod) {
Doc doc;
int counter = 0;
......
......@@ -68,9 +68,12 @@ class EtaExpander : public ExprMutator {
Module Expand() {
for (GlobalVar global_var : mod_->GetGlobalVars()) {
const Function func = mod_->Lookup(global_var);
const Function new_func = Downcast<Function>(VisitExpr(func));
mod_->Update(global_var, new_func);
const BaseFunc base_func = mod_->Lookup(global_var);
if (auto* n = base_func.as<FunctionNode>()) {
const Function new_func = Downcast<Function>(
VisitExpr(GetRef<Function>(n)));
mod_->Update(global_var, new_func);
}
}
return mod_;
}
......@@ -120,21 +123,26 @@ class EtaExpander : public ExprMutator {
if (!expand_global_var_) {
return std::move(gvar);
}
const auto func = mod_->Lookup(gvar);
tvm::Array<Expr> params;
tvm::Array<Var> args;
for (size_t i = 0; i < func->params.size(); ++i) {
auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation);
params.push_back(var);
args.push_back(var);
}
const auto base_func = mod_->Lookup(gvar);
if (auto *ptr = base_func.as<FunctionNode>()) {
// handle relay function, skip external functions.
auto func = GetRef<Function>(ptr);
tvm::Array<Expr> params;
tvm::Array<Var> args;
for (size_t i = 0; i < func->params.size(); ++i) {
auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation);
params.push_back(var);
args.push_back(var);
}
return FunctionNode::make(
args,
CallNode::make(gvar, params),
func->ret_type,
func->type_params);
args,
CallNode::make(gvar, params),
func->ret_type,
func->type_params);
} else {
return std::move(gvar);
}
}
private:
......
......@@ -217,7 +217,7 @@ class ConstantFolder : public ExprMutator {
mod->Add(global, func);
auto seq = transform::Sequential(passes);
mod = seq(mod);
auto entry_func = mod->Lookup("main");
auto entry_func = Downcast<Function>(mod->Lookup("main"));
expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return ObjectToExpr(executor_(expr));
}
......
......@@ -82,7 +82,12 @@ Type WithGradientType(const Type& t) {
//! \brief if the expression is a GlobalVar, transform to it's expression.
Expr DeGlobal(const Module& mod, const Expr& e) {
if (const auto* x = e.as<GlobalVarNode>()) {
return mod->Lookup(GetRef<GlobalVar>(x))->body;
BaseFunc base_func = mod->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) {
return n->body;
} else {
return e;
}
} else {
return e;
}
......
......@@ -676,12 +676,18 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic VisitGlobalVar(const GlobalVar& gv) {
CHECK(mod_.defined());
if (gv_map_.count(gv) == 0) {
Function func = mod_->Lookup(gv);
InitializeFuncId(func);
Func f = VisitFuncStatic(func, gv);
gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv)));
mod_->Update(gv, func);
BaseFunc base_func = mod_->Lookup(gv);
if (auto* n = base_func.as<FunctionNode>()) {
Function func = GetRef<Function>(n);
InitializeFuncId(func);
Func f = VisitFuncStatic(func, gv);
gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv)));
mod_->Update(gv, func);
return gv_map_.at(gv);
} else {
return NoStatic(gv);
}
}
return gv_map_.at(gv);
}
......@@ -951,7 +957,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
auto mod = ModuleNode::FromExpr(expr);
auto seq = transform::Sequential(passes);
mod = seq(mod);
auto entry_func = mod->Lookup("main");
auto entry_func = Downcast<Function>(mod->Lookup("main"));
auto fused_infered =
expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return Reify(executor_(fused_infered), ll);
......
......@@ -323,10 +323,14 @@ Module FunctionPassNode::operator()(const Module& mod,
Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : updated_mod->functions) {
auto updated_func = SkipFunction(it.second)
? it.second
: pass_func(it.second, updated_mod, pass_ctx);
updates.push_back({it.first, updated_func});
// only picks up relay::Function
if (auto* n = it.second.as<FunctionNode>()) {
Function func = GetRef<Function>(n);
auto updated_func = SkipFunction(func)
? func
: pass_func(func, updated_mod, pass_ctx);
updates.push_back({it.first, updated_func});
}
}
for (const auto& pair : updates) {
......
......@@ -192,7 +192,7 @@ Expr QuantizeRealize(const Call& ref_call,
Expr FoldConstantOpt(const Expr& expr) {
auto mod = ModuleNode::FromExpr(expr);
mod = transform::FoldConstant()(mod);
auto entry_func = mod->Lookup("main");
auto entry_func = Downcast<Function>(mod->Lookup("main"));
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
}
......
......@@ -155,9 +155,17 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const
Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final {
auto gv = GetRef<GlobalVar>(op);
if (cm->count(gv) == 0) {
auto cps_gv = GlobalVar(gv->name_hint + "_cps");
cm->insert({gv, cps_gv});
m->Add(cps_gv, ToCPS(m->Lookup(gv), m, cm));
// only look unfold non-external calls.
BaseFunc base_func = m->Lookup(gv);
if (auto* n = base_func.as<FunctionNode>()) {
auto cps_gv = GlobalVar(gv->name_hint + "_cps");
cm->insert({gv, cps_gv});
m->Add(cps_gv, ToCPS(GetRef<Function>(n), m, cm));
} else {
// return the original global var if it is
// an external call to non-relay function.
return GetRef<GlobalVar>(op);
}
}
return k(cm->at(gv));
}
......
......@@ -86,7 +86,7 @@ TEST(Relay, Sequential) {
CHECK(mod.defined());
auto entry_func = mod->GetGlobalVar("main");
CHECK(entry_func.defined());
relay::Function f = mod->Lookup("main");
relay::Function f = Downcast<relay::Function>(mod->Lookup("main"));
CHECK(f.defined());
// Expected function
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment