Commit 92ffa062 by Wei Chen Committed by Zhi

[Relay][VM] Add more passes to VMCompiler (#4058)

* [Relay][VM] Add more passes to VMCompiler

* Check build config

* Add todo
parent 70840818
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/relay/error.h> #include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h> #include <tvm/relay/interpreter.h>
#include <tvm/relay/qnn/transform.h>
#include <tvm/logging.h> #include <tvm/logging.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h> #include <tvm/runtime/vm.h>
...@@ -803,7 +804,7 @@ void VMCompiler::Compile(const Module& mod_ref, ...@@ -803,7 +804,7 @@ void VMCompiler::Compile(const Module& mod_ref,
// Run some optimizations first, this code should // Run some optimizations first, this code should
// be moved to pass manager. // be moved to pass manager.
context_.module = OptimizeModule(mod_ref); context_.module = OptimizeModule(mod_ref, targets_);
// Populate the global map. // Populate the global map.
// //
...@@ -844,18 +845,63 @@ void VMCompiler::Compile(const Module& mod_ref, ...@@ -844,18 +845,63 @@ void VMCompiler::Compile(const Module& mod_ref,
} }
} }
Module VMCompiler::OptimizeModule(const Module& mod) { Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
// TODO(@icemelon9): check number of targets and build config, add more optimization pass Array<Pass> pass_seqs;
transform::Sequential seq({transform::SimplifyInference(), // Run all dialect legalization passes.
transform::InlinePrimitives(), pass_seqs.push_back(relay::qnn::transform::Legalize());
// TODO(@wweic): FuseOps pass currently don't handle Let
// For now, we put FuseOps before ToANormalForm to enable it // Legalize pass is restricted to homogeneous execution for now.
transform::FuseOps(), if (targets.size() == 1) {
transform::ToANormalForm(), pass_seqs.push_back(transform::Legalize());
transform::LambdaLift(), }
transform::InlinePrimitives()});
auto pass_ctx = transform::PassContext::Create(); pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
if (expr.as<CallNode>()) {
auto call_node = expr.as<CallNode>();
auto op_node = call_node->op.as<OpNode>();
if (op_node->name == "cast") {
auto attrs = call_node->attrs.as<CastAttrs>();
if (attrs->dtype == Int(32)) {
*rv = true;
}
}
}
*rv = false;
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::InlinePrimitives());
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());
// Alter layout transformation is only applied to homogeneous execution yet.
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
}
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FuseOps());
pass_seqs.push_back(transform::ToANormalForm());
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());
transform::Sequential seq(pass_seqs);
transform::PassContext pass_ctx = PassContext::Current();
// TODO(wweic): Support heterogenous execution
tvm::With<relay::transform::PassContext> ctx(pass_ctx); tvm::With<relay::transform::PassContext> ctx(pass_ctx);
if (targets.size() == 1) {
for (const auto& kv : targets) {
With<Target> tctx(kv.second);
return seq(mod);
}
}
return seq(mod); return seq(mod);
} }
......
...@@ -105,7 +105,7 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -105,7 +105,7 @@ class VMCompiler : public runtime::ModuleNode {
const tvm::Target& target_host); const tvm::Target& target_host);
protected: protected:
Module OptimizeModule(const Module& mod); Module OptimizeModule(const Module& mod, const TargetsMap& targets);
void PopulateGlobalMap(); void PopulateGlobalMap();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment