Commit 70041c48 by Zhi Committed by Jared Roesch

[relay][vm] move vm opt passes to pass manager (#3323)

parent 8f219b95
...@@ -20,24 +20,45 @@ The Relay Virtual Vachine. ...@@ -20,24 +20,45 @@ The Relay Virtual Vachine.
Implements a Python interface to compiling and executing on the Relay VM. Implements a Python interface to compiling and executing on the Relay VM.
""" """
import numpy as np
import tvm import tvm
from tvm._ffi.function import Object from tvm._ffi.function import Object
import numpy as np from .. import transform
from .. import ir_pass
from ..backend.interpreter import Executor from ..backend.interpreter import Executor
from ..expr import GlobalVar, Function, Expr from ..expr import GlobalVar, Expr
from . import _vm from . import _vm
Object = Object Object = Object
def optimize(expr, mod=None): def optimize(mod):
# TODO: We need to move this optimization code into the optimizer/pass manager """Perform several optimizations on a module before executing it in the
ck_expr = ir_pass.infer_type(expr, mod=mod) Relay virtual machine.
simplified_expr = ir_pass.simplify_inference(ck_expr)
simplified_expr = ir_pass.infer_type(simplified_expr, mod=mod) Parameters
fused_expr = ir_pass.fuse_ops(simplified_expr, mod=mod) ----------
ck_fused = ir_pass.infer_type(fused_expr, mod=mod) mod : tvm.relay.Module
return ck_fused The module to optimize.
Returns
-------
ret : tvm.relay.Module
The optimized module.
"""
main_func = mod[mod.entry_func]
opt_passes = []
if not main_func.params and isinstance(main_func.body, GlobalVar):
opt_passes.append(transform.EtaExpand())
opt_passes = opt_passes + [
transform.SimplifyInference(),
transform.FuseOps(),
transform.InferType()
]
seq = transform.Sequential(opt_passes)
return seq(mod)
def _convert(arg, cargs): def _convert(arg, cargs):
if isinstance(arg, np.ndarray): if isinstance(arg, np.ndarray):
...@@ -76,15 +97,8 @@ def _eval_vm(mod, ctx, *args): ...@@ -76,15 +97,8 @@ def _eval_vm(mod, ctx, *args):
args: List[tvm.NDArray, np.ndarray] args: List[tvm.NDArray, np.ndarray]
The arguments to evaluate. The arguments to evaluate.
""" """
main_func = mod[mod.entry_func]
if not main_func.params and isinstance(main_func.body, GlobalVar):
main_func = ir_pass.eta_expand(main_func.body, mod)
assert isinstance(main_func, Function)
main_func = optimize(mod[mod.entry_func], mod)
mod[mod.entry_func] = main_func
mod = optimize(mod)
args = list(args) args = list(args)
assert isinstance(args, list) assert isinstance(args, list)
cargs = convert(args) cargs = convert(args)
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h> #include <tvm/relay/interpreter.h>
#include <tvm/logging.h> #include <tvm/logging.h>
#include <tvm/relay/pass.h> #include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h> #include <tvm/runtime/vm.h>
#include <iostream> #include <iostream>
#include <unordered_map> #include <unordered_map>
...@@ -38,15 +38,22 @@ ...@@ -38,15 +38,22 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace transform {
Pass LambdaLift();
Pass InlinePrimitives();
} // namespace transform
namespace vm { namespace vm {
using namespace tvm::runtime; using namespace tvm::runtime;
using namespace tvm::runtime::vm; using namespace tvm::runtime::vm;
using namespace relay::transform;
// (@jroesch): VM passes, eventually declare as passes. // (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func); bool IsClosure(const Function& func);
Module LambdaLift(const Module& module);
Module InlinePrimitives(const Module& module);
template <typename T, typename U> template <typename T, typename U>
using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>; using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
...@@ -560,10 +567,13 @@ VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const F ...@@ -560,10 +567,13 @@ VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const F
} }
Module OptimizeModule(const Module& mod) { Module OptimizeModule(const Module& mod) {
ToANormalForm(mod->entry_func, mod); transform::Sequential seq({transform::ToANormalForm(),
InlinePrimitives(mod); transform::InlinePrimitives(),
LambdaLift(mod); transform::LambdaLift(),
return InlinePrimitives(mod); transform::InlinePrimitives()});
auto pass_ctx = transform::PassContext::Create();
tvm::With<relay::transform::PassContext> ctx(pass_ctx);
return seq(mod);
} }
void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) { void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) {
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/logging.h> #include <tvm/logging.h>
#include <tvm/relay/pass.h> #include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h> #include <tvm/runtime/vm.h>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
...@@ -37,6 +37,21 @@ namespace tvm { ...@@ -37,6 +37,21 @@ namespace tvm {
namespace relay { namespace relay {
namespace vm { namespace vm {
// TODO(@jroesch): write verifier
/* This pass will eliminate primitives which have been lifted by the ANF
* transform inlining them directly into call sites.
*
* This makes VM related code generation easier as the call target is always
* a primitive function.
*
* let prim = fn(...) { ... };
* prim(...)
*
* will become:
*
* (fn(...) { ... })(...)
*/
struct PrimitiveInliner : ExprMutator { struct PrimitiveInliner : ExprMutator {
Module module_; Module module_;
std::unordered_map<Var, Expr, NodeHash, NodeEqual> var_map; std::unordered_map<Var, Expr, NodeHash, NodeEqual> var_map;
...@@ -92,55 +107,46 @@ struct PrimitiveInliner : ExprMutator { ...@@ -92,55 +107,46 @@ struct PrimitiveInliner : ExprMutator {
} }
} }
Function Inline(const Function& func) { Module Inline() {
DLOG(INFO) << "Before inlining primitives: " << std::endl auto gvar_funcs = module_->functions;
<< "func= " << AsText(func, false) << std::endl; for (auto pair : gvar_funcs) {
auto global = pair.first;
auto inlined = FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type, auto func = pair.second;
func->type_params, func->attrs); DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false);
inlined = Downcast<Function>(DeadCodeElimination(inlined));
func = FunctionNode::make(func->params,
DLOG(INFO) << "After inlining primitives" << std::endl VisitExpr(func->body),
<< "after_func= " << AsText(inlined, false) << std::endl; func->ret_type,
return inlined; func->type_params,
func->attrs);
module_->Add(global, func, true);
DLOG(INFO) << "After inlining primitives: " << global
<< std::endl << AsText(func, false);
}
return module_;
} }
}; };
// TODO(@jroesch): write verifier } // namespace vm
/* This pass will eliminate primitives which have been lifted by the ANF
* transform inlining them directly into call sites.
*
* This makes VM related code generation easier as the call target is always
* a primitive function.
*
* let prim = fn(...) { ... };
* prim(...)
*
* will become:
*
* (fn(...) { ... })(...)
*/
Module InlinePrimitives(const Module& module) {
PrimitiveInliner inliner(module);
tvm::Map<GlobalVar, Function> updates; namespace transform {
// There is an ordering bug here. Pass InlinePrimitives() {
for (auto pair : module->functions) { runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
auto global = pair.first; [=](Module m, PassContext pc) {
auto func = pair.second; return relay::vm::PrimitiveInliner(m).Inline();
updates.Set(global, inliner.Inline(func)); };
} auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {});
// Eliminate dead code for each function after inlining.
return Sequential({inline_pass, DeadCodeElimination()}, "InlinePrimitives");
}
for (auto pair : updates) { TVM_REGISTER_API("relay._transform.InlinePrimitives")
module->Add(pair.first, pair.second, true); .set_body_typed(InlinePrimitives);
}
return module; } // namespace transform
}
} // namespace vm
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/logging.h> #include <tvm/logging.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h> #include <tvm/runtime/vm.h>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
...@@ -54,9 +55,14 @@ Function MarkClosure(const Function& func) { ...@@ -54,9 +55,14 @@ Function MarkClosure(const Function& func) {
return FunctionSetAttr(func, kIsClosure, tvm::Integer(1)); return FunctionSetAttr(func, kIsClosure, tvm::Integer(1));
} }
/* The goal of this class is to lift out any nested functions into top-level
* functions.
*
* We will lift a function out into a global which takes the set of the free
* vars and then return the new created function.
*/
struct LambdaLifter : ExprMutator { struct LambdaLifter : ExprMutator {
Module module_; Module module_;
std::vector<std::pair<GlobalVar, Function>> lifted_;
explicit LambdaLifter(const Module& module) : module_(module) {} explicit LambdaLifter(const Module& module) : module_(module) {}
Expr VisitExpr_(const FunctionNode* func_node) final { Expr VisitExpr_(const FunctionNode* func_node) final {
...@@ -71,8 +77,7 @@ struct LambdaLifter : ExprMutator { ...@@ -71,8 +77,7 @@ struct LambdaLifter : ExprMutator {
auto free_type_vars = FreeTypeVars(func, module_); auto free_type_vars = FreeTypeVars(func, module_);
auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node)); auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node));
// When performing this optimization there are two // When performing this optimization there are two cases.
// cases.
// //
// The first case in which we have no free variables // The first case in which we have no free variables
// we can just lift the function into the global // we can just lift the function into the global
...@@ -80,7 +85,7 @@ struct LambdaLifter : ExprMutator { ...@@ -80,7 +85,7 @@ struct LambdaLifter : ExprMutator {
// //
// //
// The second case requires that we generate a special // The second case requires that we generate a special
// function with makes a distinction between allocating // function which makes a distinction between allocating
// a closure, and then the code for the closure. // a closure, and then the code for the closure.
// //
// We represent a closure allocation by lifting the // We represent a closure allocation by lifting the
...@@ -92,7 +97,7 @@ struct LambdaLifter : ExprMutator { ...@@ -92,7 +97,7 @@ struct LambdaLifter : ExprMutator {
// function marked as a closure is used to emit allocation // function marked as a closure is used to emit allocation
// code for the closure's environment. // code for the closure's environment.
// //
// The "inner" function is should be used to generate the // The "inner" function should be used to generate the
// code for the closure. // code for the closure.
Function lifted_func; Function lifted_func;
if (free_vars.size() == 0) { if (free_vars.size() == 0) {
...@@ -107,16 +112,16 @@ struct LambdaLifter : ExprMutator { ...@@ -107,16 +112,16 @@ struct LambdaLifter : ExprMutator {
CHECK(lifted_func.defined()); CHECK(lifted_func.defined());
auto name = GenerateName(lifted_func); auto name = GenerateName(lifted_func);
auto global = this->module_->GetGlobalVar(name); auto global = module_->GetGlobalVar(name);
lifted_.push_back({global, lifted_func}); // Add the lifted function to the module.
module_->Add(global, lifted_func);
if (free_vars.size() == 0) { if (free_vars.size() == 0) {
return std::move(global); return std::move(global);
} else { } else {
// If we need to allocate a closure // If we need to allocate a closure,
// we pass the variables in its environment // we pass the variables in its environment here.
// here.
Array<Expr> fvs; Array<Expr> fvs;
for (auto fv : free_vars) { for (auto fv : free_vars) {
fvs.push_back(fv); fvs.push_back(fv);
...@@ -125,42 +130,39 @@ struct LambdaLifter : ExprMutator { ...@@ -125,42 +130,39 @@ struct LambdaLifter : ExprMutator {
} }
} }
Function Lift(const Function& func) { Module Lift() {
DLOG(INFO) << "Lifting: " << AsText(func, false) << std::endl;
return FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
func->type_params, func->attrs);
}
};
/* The goal of this pass is to lift out any nested functions into top-level
* functions.
*
* We will lift the functions out into globals which take the set of the free vars
* and then return a function whcih has b
*/
Module LambdaLift(const Module& module) {
LambdaLifter lifter(module);
tvm::Map<GlobalVar, Function> updates;
// There is an ordering bug here. // There is an ordering bug here.
for (auto pair : module->functions) { auto glob_funcs = module_->functions;
auto global = pair.first; for (auto pair : glob_funcs) {
auto func = pair.second; auto func = pair.second;
updates.Set(global, lifter.Lift(func)); DLOG(INFO) << "Lifting " << AsText(func, false);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(pair.first, func, true);
} }
return module_;
for (auto i = lifter.lifted_.begin(); i != lifter.lifted_.end(); i++) {
module->Add(i->first, i->second);
} }
};
for (auto pair : updates) { } // namespace vm
module->Add(pair.first, pair.second, true);
}
return module; namespace transform {
Pass LambdaLift() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
return relay::vm::LambdaLifter(m).Lift();
};
return CreateModulePass(pass_func, 1, "LambdaLift", {});
} }
} // namespace vm TVM_REGISTER_API("relay._transform.LambdaLift")
.set_body_typed(LambdaLift);
} // namespace transform
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -309,20 +309,24 @@ Module FunctionPassNode::operator()(const Module& mod, ...@@ -309,20 +309,24 @@ Module FunctionPassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info(); const PassInfo& pass_info = Info();
CHECK(mod.defined()); CHECK(mod.defined());
DLOG(INFO) << "Executing module pass : " DLOG(INFO) << "Executing function pass : "
<< pass_info->name << pass_info->name
<< " with opt level: " << " with opt level: "
<< pass_info->opt_level; << pass_info->opt_level;
Module updated_mod = mod; Module updated_mod = mod;
Module new_mod = ModuleNode::make({}, mod->type_definitions);
// Execute the pass function and return a new module. // Execute the pass function and return a new module.
std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : mod->functions) { for (const auto& it : mod->functions) {
auto updated_func = SkipFunction(it.second) auto updated_func = SkipFunction(it.second)
? it.second ? it.second
: pass_func(it.second, updated_mod, pass_ctx); : pass_func(it.second, updated_mod, pass_ctx);
new_mod->Add(it.first, updated_func); updates.push_back({it.first, updated_func});
}
for (const auto& pair : updates) {
updated_mod->Add(pair.first, pair.second, true);
} }
return new_mod; return updated_mod;
} }
// TODO(zhiics) Create an enum attribute for FunctionNode // TODO(zhiics) Create an enum attribute for FunctionNode
...@@ -539,7 +543,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -539,7 +543,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
tvm::IRPrinter* p) { tvm::IRPrinter* p) {
p->stream << "Pass context information: " << "\n"; p->stream << "Pass context information: " << "\n";
p->stream << "\topt_level: " << node->opt_level << "\n"; p->stream << "\topt_level: " << node->opt_level << "\n";
p->stream << "\tfallback device: " << runtime::DeviceName(node->opt_level) p->stream << "\tfallback device: "
<< runtime::DeviceName(node->fallback_device)
<< "\n"; << "\n";
p->stream << "\trequired passes: [" << node->opt_level; p->stream << "\trequired passes: [" << node->opt_level;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment