Commit 70041c48 by Zhi Committed by Jared Roesch

[relay][vm] move vm opt passes to pass manager (#3323)

parent 8f219b95
......@@ -20,24 +20,45 @@ The Relay Virtual Vachine.
Implements a Python interface to compiling and executing on the Relay VM.
"""
import numpy as np
import tvm
from tvm._ffi.function import Object
import numpy as np
from .. import ir_pass
from .. import transform
from ..backend.interpreter import Executor
from ..expr import GlobalVar, Function, Expr
from ..expr import GlobalVar, Expr
from . import _vm
Object = Object
def optimize(expr, mod=None):
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=mod)
simplified_expr = ir_pass.simplify_inference(ck_expr)
simplified_expr = ir_pass.infer_type(simplified_expr, mod=mod)
fused_expr = ir_pass.fuse_ops(simplified_expr, mod=mod)
ck_fused = ir_pass.infer_type(fused_expr, mod=mod)
return ck_fused
def optimize(mod):
"""Perform several optimizations on a module before executing it in the
Relay virtual machine.
Parameters
----------
mod : tvm.relay.Module
The module to optimize.
Returns
-------
ret : tvm.relay.Module
The optimized module.
"""
main_func = mod[mod.entry_func]
opt_passes = []
if not main_func.params and isinstance(main_func.body, GlobalVar):
opt_passes.append(transform.EtaExpand())
opt_passes = opt_passes + [
transform.SimplifyInference(),
transform.FuseOps(),
transform.InferType()
]
seq = transform.Sequential(opt_passes)
return seq(mod)
def _convert(arg, cargs):
if isinstance(arg, np.ndarray):
......@@ -76,15 +97,8 @@ def _eval_vm(mod, ctx, *args):
args: List[tvm.NDArray, np.ndarray]
The arguments to evaluate.
"""
main_func = mod[mod.entry_func]
if not main_func.params and isinstance(main_func.body, GlobalVar):
main_func = ir_pass.eta_expand(main_func.body, mod)
assert isinstance(main_func, Function)
main_func = optimize(mod[mod.entry_func], mod)
mod[mod.entry_func] = main_func
mod = optimize(mod)
args = list(args)
assert isinstance(args, list)
cargs = convert(args)
......
......@@ -27,7 +27,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <unordered_map>
......@@ -38,15 +38,22 @@
namespace tvm {
namespace relay {
namespace transform {
Pass LambdaLift();
Pass InlinePrimitives();
} // namespace transform
namespace vm {
using namespace tvm::runtime;
using namespace tvm::runtime::vm;
using namespace relay::transform;
// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);
Module LambdaLift(const Module& module);
Module InlinePrimitives(const Module& module);
template <typename T, typename U>
using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
......@@ -560,10 +567,13 @@ VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const F
}
Module OptimizeModule(const Module& mod) {
ToANormalForm(mod->entry_func, mod);
InlinePrimitives(mod);
LambdaLift(mod);
return InlinePrimitives(mod);
transform::Sequential seq({transform::ToANormalForm(),
transform::InlinePrimitives(),
transform::LambdaLift(),
transform::InlinePrimitives()});
auto pass_ctx = transform::PassContext::Create();
tvm::With<relay::transform::PassContext> ctx(pass_ctx);
return seq(mod);
}
void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) {
......
......@@ -26,7 +26,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>
......@@ -37,6 +37,21 @@ namespace tvm {
namespace relay {
namespace vm {
// TODO(@jroesch): write verifier
/* This pass will eliminate primitives which have been lifted by the ANF
* transform inlining them directly into call sites.
*
* This makes VM related code generation easier as the call target is always
* a primitive function.
*
* let prim = fn(...) { ... };
* prim(...)
*
* will become:
*
* (fn(...) { ... })(...)
*/
struct PrimitiveInliner : ExprMutator {
Module module_;
std::unordered_map<Var, Expr, NodeHash, NodeEqual> var_map;
......@@ -92,55 +107,46 @@ struct PrimitiveInliner : ExprMutator {
}
}
Function Inline(const Function& func) {
DLOG(INFO) << "Before inlining primitives: " << std::endl
<< "func= " << AsText(func, false) << std::endl;
auto inlined = FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
func->type_params, func->attrs);
inlined = Downcast<Function>(DeadCodeElimination(inlined));
DLOG(INFO) << "After inlining primitives" << std::endl
<< "after_func= " << AsText(inlined, false) << std::endl;
return inlined;
Module Inline() {
auto gvar_funcs = module_->functions;
for (auto pair : gvar_funcs) {
auto global = pair.first;
auto func = pair.second;
DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(global, func, true);
DLOG(INFO) << "After inlining primitives: " << global
<< std::endl << AsText(func, false);
}
return module_;
}
};
// TODO(@jroesch): write verifier
/* This pass will eliminate primitives which have been lifted by the ANF
* transform inlining them directly into call sites.
*
* This makes VM related code generation easier as the call target is always
* a primitive function.
*
* let prim = fn(...) { ... };
* prim(...)
*
* will become:
*
* (fn(...) { ... })(...)
*/
Module InlinePrimitives(const Module& module) {
PrimitiveInliner inliner(module);
} // namespace vm
tvm::Map<GlobalVar, Function> updates;
namespace transform {
// There is an ordering bug here.
for (auto pair : module->functions) {
auto global = pair.first;
auto func = pair.second;
updates.Set(global, inliner.Inline(func));
}
Pass InlinePrimitives() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
return relay::vm::PrimitiveInliner(m).Inline();
};
auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {});
// Eliminate dead code for each function after inlining.
return Sequential({inline_pass, DeadCodeElimination()}, "InlinePrimitives");
}
for (auto pair : updates) {
module->Add(pair.first, pair.second, true);
}
TVM_REGISTER_API("relay._transform.InlinePrimitives")
.set_body_typed(InlinePrimitives);
return module;
}
} // namespace transform
} // namespace vm
} // namespace relay
} // namespace tvm
......@@ -27,6 +27,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <vector>
......@@ -54,9 +55,14 @@ Function MarkClosure(const Function& func) {
return FunctionSetAttr(func, kIsClosure, tvm::Integer(1));
}
/* The goal of this class is to lift out any nested functions into top-level
* functions.
*
* We will lift a function out into a global which takes the set of the free
* vars and then return the new created function.
*/
struct LambdaLifter : ExprMutator {
Module module_;
std::vector<std::pair<GlobalVar, Function>> lifted_;
explicit LambdaLifter(const Module& module) : module_(module) {}
Expr VisitExpr_(const FunctionNode* func_node) final {
......@@ -71,8 +77,7 @@ struct LambdaLifter : ExprMutator {
auto free_type_vars = FreeTypeVars(func, module_);
auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node));
// When performing this optimization there are two
// cases.
// When performing this optimization there are two cases.
//
// The first case in which we have no free variables
// we can just lift the function into the global
......@@ -80,7 +85,7 @@ struct LambdaLifter : ExprMutator {
//
//
// The second case requires that we generate a special
// function with makes a distinction between allocating
// function which makes a distinction between allocating
// a closure, and then the code for the closure.
//
// We represent a closure allocation by lifting the
......@@ -92,7 +97,7 @@ struct LambdaLifter : ExprMutator {
// function marked as a closure is used to emit allocation
// code for the closure's environment.
//
// The "inner" function is should be used to generate the
// The "inner" function should be used to generate the
// code for the closure.
Function lifted_func;
if (free_vars.size() == 0) {
......@@ -107,16 +112,16 @@ struct LambdaLifter : ExprMutator {
CHECK(lifted_func.defined());
auto name = GenerateName(lifted_func);
auto global = this->module_->GetGlobalVar(name);
auto global = module_->GetGlobalVar(name);
lifted_.push_back({global, lifted_func});
// Add the lifted function to the module.
module_->Add(global, lifted_func);
if (free_vars.size() == 0) {
return std::move(global);
} else {
// If we need to allocate a closure
// we pass the variables in its environment
// here.
// If we need to allocate a closure,
// we pass the variables in its environment here.
Array<Expr> fvs;
for (auto fv : free_vars) {
fvs.push_back(fv);
......@@ -125,42 +130,39 @@ struct LambdaLifter : ExprMutator {
}
}
Function Lift(const Function& func) {
DLOG(INFO) << "Lifting: " << AsText(func, false) << std::endl;
return FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
func->type_params, func->attrs);
Module Lift() {
// There is an ordering bug here.
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
auto func = pair.second;
DLOG(INFO) << "Lifting " << AsText(func, false);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(pair.first, func, true);
}
return module_;
}
};
/* The goal of this pass is to lift out any nested functions into top-level
* functions.
*
* We will lift the functions out into globals which take the set of the free vars
* and then return a function whcih has b
*/
Module LambdaLift(const Module& module) {
LambdaLifter lifter(module);
tvm::Map<GlobalVar, Function> updates;
} // namespace vm
// There is an ordering bug here.
for (auto pair : module->functions) {
auto global = pair.first;
auto func = pair.second;
updates.Set(global, lifter.Lift(func));
}
namespace transform {
for (auto i = lifter.lifted_.begin(); i != lifter.lifted_.end(); i++) {
module->Add(i->first, i->second);
}
Pass LambdaLift() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
return relay::vm::LambdaLifter(m).Lift();
};
return CreateModulePass(pass_func, 1, "LambdaLift", {});
}
for (auto pair : updates) {
module->Add(pair.first, pair.second, true);
}
TVM_REGISTER_API("relay._transform.LambdaLift")
.set_body_typed(LambdaLift);
return module;
}
} // namespace transform
} // namespace vm
} // namespace relay
} // namespace tvm
......@@ -309,20 +309,24 @@ Module FunctionPassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
CHECK(mod.defined());
DLOG(INFO) << "Executing module pass : "
DLOG(INFO) << "Executing function pass : "
<< pass_info->name
<< " with opt level: "
<< pass_info->opt_level;
Module updated_mod = mod;
Module new_mod = ModuleNode::make({}, mod->type_definitions);
// Execute the pass function and return a new module.
std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : mod->functions) {
auto updated_func = SkipFunction(it.second)
? it.second
: pass_func(it.second, updated_mod, pass_ctx);
new_mod->Add(it.first, updated_func);
updates.push_back({it.first, updated_func});
}
for (const auto& pair : updates) {
updated_mod->Add(pair.first, pair.second, true);
}
return new_mod;
return updated_mod;
}
// TODO(zhiics) Create an enum attribute for FunctionNode
......@@ -539,7 +543,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
tvm::IRPrinter* p) {
p->stream << "Pass context information: " << "\n";
p->stream << "\topt_level: " << node->opt_level << "\n";
p->stream << "\tfallback device: " << runtime::DeviceName(node->opt_level)
p->stream << "\tfallback device: "
<< runtime::DeviceName(node->fallback_device)
<< "\n";
p->stream << "\trequired passes: [" << node->opt_level;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment