Unverified Commit 3db8880d by Zhi Committed by GitHub

fix fuse over functions that are handled by external codegen (#5365)

parent c936a81d
......@@ -924,13 +924,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
......@@ -938,6 +931,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
// external codegen.
pass_seqs.push_back(transform::Inline());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
......
......@@ -199,6 +199,9 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
// Post order tree
void VisitExpr_(const FunctionNode* op) final {
// Skip the function that should be handled by external codegen.
if (op->GetAttr<String>(attr::kCompiler).defined()) return;
for (auto param : op->params) {
this->Update(param, nullptr, kOpaque);
}
......
......@@ -457,7 +457,6 @@ def test_extern_dnnl_mobilenet():
mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32')
mod["main"] = bind_params_by_name(mod["main"], params)
mod = transform.AnnotateTarget(["dnnl"])(mod)
mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment