Commit 92ffa062 by Wei Chen Committed by Zhi

[Relay][VM] Add more passes to VMCompiler (#4058)

* [Relay][VM] Add more passes to VMCompiler

* Check build config

* Add todo
parent 70840818
......@@ -27,6 +27,7 @@
#include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/qnn/transform.h>
#include <tvm/logging.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
......@@ -803,7 +804,7 @@ void VMCompiler::Compile(const Module& mod_ref,
// Run some optimizations first, this code should
// be moved to pass manager.
context_.module = OptimizeModule(mod_ref);
context_.module = OptimizeModule(mod_ref, targets_);
// Populate the global map.
//
......@@ -844,18 +845,63 @@ void VMCompiler::Compile(const Module& mod_ref,
}
}
Module VMCompiler::OptimizeModule(const Module& mod) {
// TODO(@icemelon9): check number of targets and build config, add more optimization pass
transform::Sequential seq({transform::SimplifyInference(),
transform::InlinePrimitives(),
// TODO(@wweic): FuseOps pass currently don't handle Let
// For now, we put FuseOps before ToANormalForm to enable it
transform::FuseOps(),
transform::ToANormalForm(),
transform::LambdaLift(),
transform::InlinePrimitives()});
auto pass_ctx = transform::PassContext::Create();
Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
Array<Pass> pass_seqs;
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());
// Legalize pass is restricted to homogeneous execution for now.
if (targets.size() == 1) {
pass_seqs.push_back(transform::Legalize());
}
pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
if (expr.as<CallNode>()) {
auto call_node = expr.as<CallNode>();
auto op_node = call_node->op.as<OpNode>();
if (op_node->name == "cast") {
auto attrs = call_node->attrs.as<CastAttrs>();
if (attrs->dtype == Int(32)) {
*rv = true;
}
}
}
*rv = false;
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::InlinePrimitives());
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());
// Alter layout transformation is only applied to homogeneous execution yet.
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
}
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FuseOps());
pass_seqs.push_back(transform::ToANormalForm());
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());
transform::Sequential seq(pass_seqs);
transform::PassContext pass_ctx = PassContext::Current();
// TODO(wweic): Support heterogenous execution
tvm::With<relay::transform::PassContext> ctx(pass_ctx);
if (targets.size() == 1) {
for (const auto& kv : targets) {
With<Target> tctx(kv.second);
return seq(mod);
}
}
return seq(mod);
}
......
......@@ -105,7 +105,7 @@ class VMCompiler : public runtime::ModuleNode {
const tvm::Target& target_host);
protected:
Module OptimizeModule(const Module& mod);
Module OptimizeModule(const Module& mod, const TargetsMap& targets);
void PopulateGlobalMap();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment