Unverified Commit 54975a3f by mbaret Committed by GitHub

[RELAY][FIX] Fix hang in MergeCompilerRegions (#5227)

For certain network topologies, MCR could hang.
This patch fixes that case.

Change-Id: I3edd8a8a6b452b2b838b777720adea22a3b995b4
parent b796c13c
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/ir/error.h> #include <tvm/ir/error.h>
#include <algorithm>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -58,8 +57,8 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, ...@@ -58,8 +57,8 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
std::vector<Expr> ins_to_remove; std::vector<Expr> ins_to_remove;
for (const auto& input : dest->ins) { for (const auto& input : dest->ins) {
auto call = Downcast<Call>(input); auto call = Downcast<Call>(input);
auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]); auto it = src->nodes.find(call->args[0]);
if (it != src->outs.end()) { if (it != src->nodes.end()) {
dest->outs.remove(*it); dest->outs.remove(*it);
ins_to_remove.push_back(input); ins_to_remove.push_back(input);
} }
......
...@@ -263,6 +263,7 @@ class RegionMerger : public ExprVisitor { ...@@ -263,6 +263,7 @@ class RegionMerger : public ExprVisitor {
void VisitExpr_(const CallNode* call) final { void VisitExpr_(const CallNode* call) final {
if (call->op == compiler_end_op) { if (call->op == compiler_end_op) {
auto region = regions_->GetRegion(GetRef<Call>(call)); auto region = regions_->GetRegion(GetRef<Call>(call));
if (merged_regions_.find(region->GetID()) != merged_regions_.end()) return;
// set the region target // set the region target
auto compiler_attrs = call->attrs.as<CompilerAttrs>(); auto compiler_attrs = call->attrs.as<CompilerAttrs>();
region_targets_[region->GetID()] = compiler_attrs->compiler; region_targets_[region->GetID()] = compiler_attrs->compiler;
...@@ -281,13 +282,13 @@ class RegionMerger : public ExprVisitor { ...@@ -281,13 +282,13 @@ class RegionMerger : public ExprVisitor {
} }
} }
// get the mergeable regions now all the parents have been visited // get the mergeable regions now all the parents have been visited
std::vector<AnnotatedRegion> mergeable_regions; std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
for (const auto& arg : region->GetInputs()) { for (const auto& arg : region->GetInputs()) {
auto begin = Downcast<Call>(arg); auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op); CHECK_EQ(begin->op, compiler_begin_op);
auto parent_region = regions_->GetRegion(begin->args[0]); auto parent_region = regions_->GetRegion(begin->args[0]);
if (!parent_region.defined()) continue; if (!parent_region.defined()) continue;
mergeable_regions.push_back(parent_region); mergeable_regions.insert(parent_region);
} }
auto& region_restrictions = region_restrictions_[region->GetID()]; auto& region_restrictions = region_restrictions_[region->GetID()];
for (const auto& parent_region : mergeable_regions) { for (const auto& parent_region : mergeable_regions) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment