Unverified Commit ba876046 by Cody Yu Committed by GitHub

[BYOC] Use Non-Recursive Visitor/Mutator (#5410)

* Non-Recursive AnnotatedTarget and MergeAnnotation

* Non-Recursive AnnotatedRegionSet and RegionMerger
parent 1f6c498b
...@@ -86,32 +86,69 @@ AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) { ...@@ -86,32 +86,69 @@ AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) {
return *ret.first; return *ret.first;
} }
class AnnotatedRegionSet::Creator : public ExprVisitor { class AnnotatedRegionSet::Creator : protected MixedModeVisitor {
public: public:
Creator(const Op& region_begin_op, const Op& region_end_op) Creator(const Op& region_begin_op, const Op& region_end_op)
: begin_op_(region_begin_op), end_op_(region_end_op) {} : begin_op_(region_begin_op), end_op_(region_end_op) {}
void VisitExpr_(const CallNode* call) { AnnotatedRegionSet Create(const Expr& expr) {
auto op_node = call->op.as<OpNode>(); VisitExpr(expr);
return std::move(region_set_);
}
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) { void AddToArgRegion(Expr expr, Array<Expr> args) {
// Propagate region to arguments // Merge argument regions and add itself to the region.
auto region = region_set_->GetRegion(GetRef<Call>(call));
// Find the first open region.
AnnotatedRegion region;
for (auto arg : args) {
const CallNode* end = arg.as<CallNode>();
if (end && end->op == end_op_) { // Ignore closed regions.
continue;
}
region = region_set_->GetRegion(arg);
if (region.defined()) { if (region.defined()) {
for (auto arg : call->args) { break;
region_set_->AddToRegion(region, arg); }
}
// Try to merge open regions.
for (auto arg : args) {
const CallNode* end = arg.as<CallNode>();
if (end && end->op == end_op_) { // Ignore closed regions.
continue;
}
auto arg_region = region_set_->GetRegion(arg);
CHECK_EQ(region.defined(), arg_region.defined())
<< "Arg regions are inconsistent: " << AsText(expr);
if (region.defined() && region != arg_region) {
region_set_->MergeRegions(arg_region, region);
} }
} }
if (region.defined()) {
region_set_->AddToRegion(region, expr);
}
}
void VisitExpr_(const CallNode* call) {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
AddToArgRegion(GetRef<Call>(call), call->args);
} else if (call->op == begin_op_) { } else if (call->op == begin_op_) {
// The annotation node is inserted on edge so it must have only one argument. // The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U); CHECK_EQ(call->args.size(), 1U);
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
// Check if the argument already belongs to a region
auto region = region_set_->GetRegion(GetRef<Call>(call)); auto region = region_set_->GetRegion(GetRef<Call>(call));
if (!region.defined()) { CHECK(!region.defined());
throw Error(ErrorBuilder()
<< "Cannot find the corresponding region for start annotation:\n" // Create a new region.
<< AsText(GetRef<Call>(call), false)); region = region_set_->MakeRegion(target);
} region->nodes_.insert(GetRef<Call>(call));
region->ins_.push_back(GetRef<Call>(call)); region->ins_.push_back(GetRef<Call>(call));
} else { } else {
CHECK_EQ(call->op, end_op_); CHECK_EQ(call->op, end_op_);
...@@ -122,9 +159,8 @@ class AnnotatedRegionSet::Creator : public ExprVisitor { ...@@ -122,9 +159,8 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
// Check if the argument already belongs to a region // Check if the argument already belongs to a region
auto region = region_set_->GetRegion(call->args[0]); auto region = region_set_->GetRegion(call->args[0]);
if (!region.defined()) { if (!region.defined()) {
// Create a new region if the argument is not belonged to any regions yet. throw Error(ErrorBuilder() << "Cannot find the corresponding region for end annotation:\n"
region = region_set_->MakeRegion(target); << AsText(GetRef<Call>(call), false));
region->nodes_.insert(call->args[0]);
} else { } else {
// If the argument is belonged to a region, it must have the same target. // If the argument is belonged to a region, it must have the same target.
// Otherwise we should see a region_begin op. // Otherwise we should see a region_begin op.
...@@ -133,83 +169,44 @@ class AnnotatedRegionSet::Creator : public ExprVisitor { ...@@ -133,83 +169,44 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
region->nodes_.insert(GetRef<Call>(call)); region->nodes_.insert(GetRef<Call>(call));
region->outs_.push_back(GetRef<Call>(call)); region->outs_.push_back(GetRef<Call>(call));
} }
ExprVisitor::VisitExpr_(call);
}
AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
} }
void VisitExpr_(const TupleNode* op) { void VisitExpr_(const TupleNode* op) {
auto region = region_set_->GetRegion(GetRef<Tuple>(op)); AddToArgRegion(GetRef<Tuple>(op), op->fields);
if (region.defined()) {
for (auto field : op->fields) {
region_set_->AddToRegion(region, field);
}
}
ExprVisitor::VisitExpr_(op);
} }
void VisitExpr_(const TupleGetItemNode* g) { void VisitExpr_(const TupleGetItemNode* g) {
auto region = region_set_->GetRegion(GetRef<TupleGetItem>(g)); Array<Expr> args = {g->tuple};
if (region.defined()) { AddToArgRegion(GetRef<TupleGetItem>(g), args);
region_set_->AddToRegion(region, g->tuple);
}
ExprVisitor::VisitExpr_(g);
}
void VisitExpr_(const FunctionNode* op) {
auto region = region_set_->GetRegion(GetRef<Function>(op));
if (region.defined()) {
for (auto param : op->params) {
region_set_->AddToRegion(region, param);
}
}
ExprVisitor::VisitExpr_(op);
} }
void VisitExpr_(const LetNode* op) { void VisitExpr_(const LetNode* op) {
auto region = region_set_->GetRegion(GetRef<Let>(op)); Array<Expr> args = {op->var, op->value, op->body};
if (region.defined()) { AddToArgRegion(GetRef<Let>(op), args);
region_set_->AddToRegion(region, op->var);
region_set_->AddToRegion(region, op->value);
region_set_->AddToRegion(region, op->body);
}
ExprVisitor::VisitExpr_(op); ExprVisitor::VisitExpr_(op);
} }
void VisitExpr_(const IfNode* op) { void VisitExpr_(const IfNode* op) {
auto region = region_set_->GetRegion(GetRef<If>(op)); Array<Expr> args = {op->cond, op->true_branch, op->false_branch};
if (region.defined()) { AddToArgRegion(GetRef<If>(op), args);
region_set_->AddToRegion(region, op->cond);
region_set_->AddToRegion(region, op->true_branch);
region_set_->AddToRegion(region, op->false_branch);
}
ExprVisitor::VisitExpr_(op); ExprVisitor::VisitExpr_(op);
} }
void VisitExpr_(const RefCreateNode* op) { void VisitExpr_(const RefCreateNode* op) {
auto region = region_set_->GetRegion(GetRef<RefCreate>(op)); Array<Expr> args = {op->value};
if (region.defined()) { AddToArgRegion(GetRef<RefCreate>(op), args);
region_set_->AddToRegion(region, op->value);
}
ExprVisitor::VisitExpr_(op); ExprVisitor::VisitExpr_(op);
} }
void VisitExpr_(const RefReadNode* op) { void VisitExpr_(const RefReadNode* op) {
auto region = region_set_->GetRegion(GetRef<RefRead>(op)); Array<Expr> args = {op->ref};
if (region.defined()) { AddToArgRegion(GetRef<RefRead>(op), args);
region_set_->AddToRegion(region, op->ref);
}
ExprVisitor::VisitExpr_(op); ExprVisitor::VisitExpr_(op);
} }
void VisitExpr_(const RefWriteNode* op) { void VisitExpr_(const RefWriteNode* op) {
auto region = region_set_->GetRegion(GetRef<RefWrite>(op)); Array<Expr> args = {op->ref};
if (region.defined()) { AddToArgRegion(GetRef<RefWrite>(op), args);
region_set_->AddToRegion(region, op->ref);
}
ExprVisitor::VisitExpr_(op); ExprVisitor::VisitExpr_(op);
} }
......
...@@ -42,9 +42,9 @@ const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._mak ...@@ -42,9 +42,9 @@ const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._mak
// A helper class to insert annotation boundaries for a program region that will // A helper class to insert annotation boundaries for a program region that will
// be handled by a specific compiler. // be handled by a specific compiler.
class AnnotateTargetWrapper : public ExprMutator { class AnnotateTargetRewriter : public ExprRewriter {
public: public:
explicit AnnotateTargetWrapper(Array<runtime::String> targets) : targets_(std::move(targets)) {} explicit AnnotateTargetRewriter(Array<runtime::String> targets) : targets_(std::move(targets)) {}
/*! /*!
* \brief This function annotates a compiler end and a compiler begin to all arguments. * \brief This function annotates a compiler end and a compiler begin to all arguments.
...@@ -108,29 +108,29 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -108,29 +108,29 @@ class AnnotateTargetWrapper : public ExprMutator {
return new_op; return new_op;
} }
Expr VisitExpr_(const CallNode* cn) final { Expr Rewrite_(const CallNode* pre, const Expr& post) final {
// Supported targets for this node. The order implies the priority. // Supported targets for this node. The order implies the priority.
std::vector<std::string> supported_targets; std::vector<std::string> supported_targets;
auto op_node = cn->op.as<OpNode>(); auto op_node = pre->op.as<OpNode>();
// This graph has annotations, meaning that this is not the first time running this pass. // This graph has annotations, meaning that this is not the first time running this pass.
if (op_node && cn->op == compiler_begin_op) { if (op_node && pre->op == compiler_begin_op) {
// Bypass compiler begin due to lack of target information. It will be processed // Bypass compiler begin due to lack of target information. It will be processed
// when the following op handling arguments. // when the following op handling arguments.
CHECK_EQ(cn->args.size(), 1U); CHECK_EQ(pre->args.size(), 1U);
return VisitExpr(cn->args[0]); return post.as<CallNode>()->args[0];
} else if (op_node && cn->op == compiler_end_op) { } else if (op_node && pre->op == compiler_end_op) {
// Override compiler end with the new target. // Override compiler end with the new target.
CHECK_EQ(cn->args.size(), 1U); CHECK_EQ(pre->args.size(), 1U);
auto input_expr = VisitExpr(cn->args[0]); auto input_expr = post.as<CallNode>()->args[0];
CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end()); CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end());
return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op); return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op);
} }
// Peek the first argument. If it is compiler begin then this node had annotated by // Peek the first argument. If it is compiler begin then this node had annotated by
// another target before, so we also consider that target as a supported target. // another target before, so we also consider that target as a supported target.
const CallNode* first_arg_call = cn->args[0].as<CallNode>(); const CallNode* first_arg_call = pre->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == compiler_begin_op) { if (first_arg_call && first_arg_call->op == compiler_begin_op) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler; std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") { if (arg_target != "default") {
...@@ -142,21 +142,21 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -142,21 +142,21 @@ class AnnotateTargetWrapper : public ExprMutator {
if (op_node) { if (op_node) {
// TVM operators: Check target specific op checking function and add to supported_targets // TVM operators: Check target specific op checking function and add to supported_targets
// if it is supported. // if it is supported.
Op op = Downcast<Op>(cn->op); Op op = Downcast<Op>(pre->op);
CHECK(op.defined()); CHECK(op.defined());
for (const auto& target : this->targets_) { for (const auto& target : this->targets_) {
if (!Op::HasAttr("target." + std::string(target))) { if (!Op::HasAttr("target." + std::string(target))) {
continue; continue;
} }
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + std::string(target)); auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + std::string(target));
if (fannotate.count(op) && fannotate[op](cn->attrs, cn->args)) { if (fannotate.count(op) && fannotate[op](pre->attrs, pre->args)) {
supported_targets.push_back(target); supported_targets.push_back(target);
} }
} }
} else if (cn->op->IsInstance<FunctionNode>()) { } else if (pre->op->IsInstance<FunctionNode>()) {
// Composite function: Add the target of a composite function to supported_targets // Composite function: Add the target of a composite function to supported_targets
// if it is in the target list. // if it is in the target list.
Function func = Downcast<Function>(cn->op); Function func = Downcast<Function>(pre->op);
CHECK(func.defined()); CHECK(func.defined());
if (auto comp_name = func->GetAttr<String>(attr::kComposite)) { if (auto comp_name = func->GetAttr<String>(attr::kComposite)) {
...@@ -181,23 +181,22 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -181,23 +181,22 @@ class AnnotateTargetWrapper : public ExprMutator {
std::string target = supported_targets[0]; std::string target = supported_targets[0];
// Visit and mutate arguments after the target of this op has been determined. // Visit and mutate arguments after the target of this op has been determined.
auto new_call = Downcast<Call>(ExprMutator::VisitExpr_(cn)); Call post_call = Downcast<Call>(post);
// Add annotations to each arg. // Add annotations to each arg.
auto target_n_args = AnnotateArgs(new_call->args, target); auto target_n_args = AnnotateArgs(post_call->args, target);
Array<Expr> compiler_begins = std::get<1>(target_n_args); Array<Expr> compiler_begins = std::get<1>(target_n_args);
Call call = Call(new_call->op, compiler_begins, new_call->attrs); Call new_call = Call(post_call->op, compiler_begins, post_call->attrs);
call->checked_type_ = cn->checked_type_; new_call->checked_type_ = pre->checked_type_;
// Update the target map. // Update the target map.
op_expr_to_target_[call] = target; op_expr_to_target_[new_call] = target;
return std::move(call); return std::move(new_call);
} }
Expr VisitExpr_(const TupleNode* op) final { Expr Rewrite_(const TupleNode* op, const Expr& post) final {
auto new_e = ExprMutator::VisitExpr_(op); auto expr = Downcast<Tuple>(post);
auto expr = Downcast<Tuple>(new_e);
auto target_n_args = AnnotateArgs(expr->fields); auto target_n_args = AnnotateArgs(expr->fields);
auto new_expr = Tuple(std::get<1>(target_n_args)); auto new_expr = Tuple(std::get<1>(target_n_args));
...@@ -205,9 +204,8 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -205,9 +204,8 @@ class AnnotateTargetWrapper : public ExprMutator {
return std::move(new_expr); return std::move(new_expr);
} }
Expr VisitExpr_(const TupleGetItemNode* op) final { Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
auto new_e = ExprMutator::VisitExpr_(op); auto expr = Downcast<TupleGetItem>(post);
auto expr = Downcast<TupleGetItem>(new_e);
auto target_n_args = AnnotateArgs(Array<Expr>({expr->tuple})); auto target_n_args = AnnotateArgs(Array<Expr>({expr->tuple}));
auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index); auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index);
...@@ -215,7 +213,7 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -215,7 +213,7 @@ class AnnotateTargetWrapper : public ExprMutator {
return std::move(new_expr); return std::move(new_expr);
} }
Expr VisitExpr_(const FunctionNode* fn) final { Expr Rewrite_(const FunctionNode* fn, const Expr& post) final {
Function func; Function func;
Expr new_body; Expr new_body;
// don't step into composite functions // don't step into composite functions
...@@ -223,8 +221,7 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -223,8 +221,7 @@ class AnnotateTargetWrapper : public ExprMutator {
func = GetRef<Function>(fn); func = GetRef<Function>(fn);
new_body = func->body; new_body = func->body;
} else { } else {
auto new_e = ExprMutator::VisitExpr_(fn); func = Downcast<Function>(post);
func = Downcast<Function>(new_e);
new_body = func->body; new_body = func->body;
if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) { if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) {
new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op); new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op);
...@@ -234,9 +231,8 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -234,9 +231,8 @@ class AnnotateTargetWrapper : public ExprMutator {
return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs); return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs);
} }
Expr VisitExpr_(const LetNode* op) final { Expr Rewrite_(const LetNode* op, const Expr& post) final {
auto new_e = ExprMutator::VisitExpr_(op); auto let = Downcast<Let>(post);
auto let = Downcast<Let>(new_e);
auto target_n_args = AnnotateArgs({let->value, let->body}); auto target_n_args = AnnotateArgs({let->value, let->body});
auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
...@@ -244,9 +240,8 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -244,9 +240,8 @@ class AnnotateTargetWrapper : public ExprMutator {
return std::move(new_expr); return std::move(new_expr);
} }
Expr VisitExpr_(const IfNode* op) final { Expr Rewrite_(const IfNode* op, const Expr& post) final {
auto new_e = ExprMutator::VisitExpr_(op); auto expr = Downcast<If>(post);
auto expr = Downcast<If>(new_e);
auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch}); auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch});
CHECK_EQ(std::get<1>(target_n_args).size(), 3U); CHECK_EQ(std::get<1>(target_n_args).size(), 3U);
...@@ -256,9 +251,8 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -256,9 +251,8 @@ class AnnotateTargetWrapper : public ExprMutator {
return std::move(new_expr); return std::move(new_expr);
} }
Expr VisitExpr_(const RefCreateNode* op) final { Expr Rewrite_(const RefCreateNode* op, const Expr& post) final {
auto new_e = ExprMutator::VisitExpr_(op); auto expr = Downcast<RefCreate>(post);
auto expr = Downcast<RefCreate>(new_e);
auto target_n_args = AnnotateArgs(Array<Expr>({expr->value})); auto target_n_args = AnnotateArgs(Array<Expr>({expr->value}));
auto new_expr = RefCreate(std::get<1>(target_n_args)[0]); auto new_expr = RefCreate(std::get<1>(target_n_args)[0]);
...@@ -266,9 +260,8 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -266,9 +260,8 @@ class AnnotateTargetWrapper : public ExprMutator {
return std::move(new_expr); return std::move(new_expr);
} }
Expr VisitExpr_(const RefReadNode* op) final { Expr Rewrite_(const RefReadNode* op, const Expr& post) final {
auto new_e = ExprMutator::VisitExpr_(op); auto expr = Downcast<RefRead>(post);
auto expr = Downcast<RefRead>(new_e);
auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref})); auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref}));
auto new_expr = RefRead(std::get<1>(target_n_args)[0]); auto new_expr = RefRead(std::get<1>(target_n_args)[0]);
...@@ -276,9 +269,8 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -276,9 +269,8 @@ class AnnotateTargetWrapper : public ExprMutator {
return std::move(new_expr); return std::move(new_expr);
} }
Expr VisitExpr_(const RefWriteNode* op) final { Expr Rewrite_(const RefWriteNode* op, const Expr& post) final {
auto new_e = ExprMutator::VisitExpr_(op); auto expr = Downcast<RefWrite>(post);
auto expr = Downcast<RefWrite>(new_e);
auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref, expr->value})); auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref, expr->value}));
auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
...@@ -294,7 +286,8 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -294,7 +286,8 @@ class AnnotateTargetWrapper : public ExprMutator {
}; };
Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets) { Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets) {
return AnnotateTargetWrapper(targets).Mutate(expr); auto rewriter = AnnotateTargetRewriter(targets);
return PostOrderRewrite(expr, &rewriter);
} }
} // namespace annotate_target } // namespace annotate_target
......
...@@ -53,7 +53,7 @@ namespace merge_compiler_region { ...@@ -53,7 +53,7 @@ namespace merge_compiler_region {
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
class RegionMerger : public ExprVisitor { class RegionMerger : public MixedModeVisitor {
public: public:
explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {} explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}
...@@ -131,7 +131,6 @@ class RegionMerger : public ExprVisitor { ...@@ -131,7 +131,6 @@ class RegionMerger : public ExprVisitor {
} }
merged_regions_.insert(region->GetID()); merged_regions_.insert(region->GetID());
} }
ExprVisitor::VisitExpr_(call);
} }
private: private:
...@@ -140,11 +139,11 @@ class RegionMerger : public ExprVisitor { ...@@ -140,11 +139,11 @@ class RegionMerger : public ExprVisitor {
std::unordered_map<int, std::unordered_set<int>> region_restrictions_; std::unordered_map<int, std::unordered_set<int>> region_restrictions_;
}; };
class MergeAnnotations : public ExprMutator { class MergeAnnotations : public ExprRewriter {
public: public:
explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {} explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {}
Expr VisitExpr_(const CallNode* call) final { Expr Rewrite_(const CallNode* call, const Expr& post) final {
// Merge annotations which are now internal to a region. // Merge annotations which are now internal to a region.
// This happens if we see a compiler begin next to a // This happens if we see a compiler begin next to a
// compiler end and they're both in the same region. // compiler end and they're both in the same region.
...@@ -154,11 +153,12 @@ class MergeAnnotations : public ExprMutator { ...@@ -154,11 +153,12 @@ class MergeAnnotations : public ExprMutator {
auto region1 = regions_->GetRegion(GetRef<Call>(call)); auto region1 = regions_->GetRegion(GetRef<Call>(call));
auto region2 = regions_->GetRegion(arg); auto region2 = regions_->GetRegion(arg);
if (region1 == region2) { if (region1 == region2) {
return VisitExpr(arg->args[0]); auto post_arg = post.as<CallNode>()->args[0];
return post_arg.as<CallNode>()->args[0];
} }
} }
} }
return ExprMutator::VisitExpr_(call); return post;
} }
private: private:
...@@ -175,7 +175,7 @@ Expr MergeCompilerRegions(const Expr& expr) { ...@@ -175,7 +175,7 @@ Expr MergeCompilerRegions(const Expr& expr) {
// Remove annotations that are not in the region boundaries. // Remove annotations that are not in the region boundaries.
MergeAnnotations merge_anno(regions); MergeAnnotations merge_anno(regions);
return merge_anno.Mutate(expr); return PostOrderRewrite(expr, &merge_anno);
} }
} // namespace merge_compiler_region } // namespace merge_compiler_region
......
...@@ -522,8 +522,8 @@ def test_function_lifting(): ...@@ -522,8 +522,8 @@ def test_function_lifting():
bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
bn.astuple()) bn.astuple())
func0 = set_func_attr(func0, "test_compiler", "test_compiler_0") func0 = set_func_attr(func0, "test_compiler", "test_compiler_2")
gv0 = relay.GlobalVar("test_compiler_0") gv0 = relay.GlobalVar("test_compiler_2")
mod[gv0] = func0 mod[gv0] = func0
# function for conv2d # function for conv2d
...@@ -536,8 +536,8 @@ def test_function_lifting(): ...@@ -536,8 +536,8 @@ def test_function_lifting():
channels=16, channels=16,
padding=(1, 1)) padding=(1, 1))
func1 = relay.Function([data1, weight1], conv) func1 = relay.Function([data1, weight1], conv)
func1 = set_func_attr(func1, "test_compiler", "test_compiler_1") func1 = set_func_attr(func1, "test_compiler", "test_compiler_0")
gv1 = relay.GlobalVar("test_compiler_1") gv1 = relay.GlobalVar("test_compiler_0")
mod[gv1] = func1 mod[gv1] = func1
# main function # main function
...@@ -630,7 +630,6 @@ def test_constant_propagation(): ...@@ -630,7 +630,6 @@ def test_constant_propagation():
def expected(): def expected():
mod = tvm.IRModule() mod = tvm.IRModule()
x = relay.const(ones)
y = relay.var("y", shape=(8, 8)) y = relay.var("y", shape=(8, 8))
x0 = relay.const(ones) x0 = relay.const(ones)
y0 = relay.var("y0", shape=(8, 8)) y0 = relay.var("y0", shape=(8, 8))
...@@ -712,12 +711,12 @@ def test_multiple_outputs(): ...@@ -712,12 +711,12 @@ def test_multiple_outputs():
mod = tvm.IRModule() mod = tvm.IRModule()
# function 0 # function 0
data = relay.var("test_target_2_i0", relay.TensorType((1, 3, 224, 224), "float32")) data = relay.var("test_target_0_i0", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("test_target_2_i1", relay.TensorType((16, 3, 3, 3), "float32")) weight = relay.var("test_target_0_i1", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("test_target_2_i2", relay.TensorType((16, ), "float32")) bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("test_target_2_i3", relay.TensorType((16, ), "float32")) bn_beta = relay.var("test_target_0_i3", relay.TensorType((16, ), "float32"))
bn_mean = relay.var("test_target_2_i4", relay.TensorType((16, ), "float32")) bn_mean = relay.var("test_target_0_i4", relay.TensorType((16, ), "float32"))
bn_var = relay.var("test_target_2_i5", relay.TensorType((16, ), "float32")) bn_var = relay.var("test_target_0_i5", relay.TensorType((16, ), "float32"))
conv_o = relay.nn.conv2d( conv_o = relay.nn.conv2d(
data=data, data=data,
...@@ -730,12 +729,12 @@ def test_multiple_outputs(): ...@@ -730,12 +729,12 @@ def test_multiple_outputs():
bn_var) bn_var)
relu_o = relay.nn.relu(bn_o[0]) relu_o = relay.nn.relu(bn_o[0])
tuple_o = relay.Tuple((bn_o[2], bn_o[1], relu_o)) tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2]))
func0 = relay.Function([data, weight, bn_gamma, bn_beta, func0 = relay.Function([data, weight, bn_gamma, bn_beta,
bn_mean, bn_var], tuple_o) bn_mean, bn_var], tuple_o)
func0 = set_func_attr(func0, "test_target", "test_target_2") func0 = set_func_attr(func0, "test_target", "test_target_0")
gv0 = relay.GlobalVar("test_target_2") gv0 = relay.GlobalVar("test_target_0")
mod[gv0] = func0 mod[gv0] = func0
# body # body
...@@ -747,9 +746,9 @@ def test_multiple_outputs(): ...@@ -747,9 +746,9 @@ def test_multiple_outputs():
bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32")) bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32"))
f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var) f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var)
f0_relu_o = relay.TupleGetItem(f0_o, 2) f0_relu_o = relay.TupleGetItem(f0_o, 0)
f0_mean_o = relay.TupleGetItem(f0_o, 1) f0_mean_o = relay.TupleGetItem(f0_o, 1)
f0_var_o = relay.TupleGetItem(f0_o, 0) f0_var_o = relay.TupleGetItem(f0_o, 2)
f0_mean_abs = relay.abs(f0_mean_o) f0_mean_abs = relay.abs(f0_mean_o)
f0_var_abs = relay.abs(f0_var_o) f0_var_abs = relay.abs(f0_var_o)
...@@ -791,22 +790,22 @@ def test_mixed_single_multiple_outputs(): ...@@ -791,22 +790,22 @@ def test_mixed_single_multiple_outputs():
mod = tvm.IRModule() mod = tvm.IRModule()
# function 1 # function 1
f1_cb1 = relay.var('test_target_1_i0', shape=(10, 10)) f1_cb1 = relay.var('test_target_0_i0', shape=(10, 10))
f1_O_1 = relay.abs(f1_cb1) f1_O_1 = relay.abs(f1_cb1)
f1_O_2 = relay.nn.relu(f1_O_1) f1_O_2 = relay.nn.relu(f1_O_1)
f1_out = relay.Tuple((f1_O_2, f1_O_1)) f1_out = relay.Tuple((f1_O_2, f1_O_1))
func1 = relay.Function([f1_cb1], f1_out) func1 = relay.Function([f1_cb1], f1_out)
func1 = set_func_attr(func1, "test_target", "test_target_1") func1 = set_func_attr(func1, "test_target", "test_target_0")
gv1 = relay.GlobalVar("test_target_1") gv1 = relay.GlobalVar("test_target_0")
mod[gv1] = func1 mod[gv1] = func1
# function 0 # function 0
f2_cb3 = relay.var('test_target_0_i0', shape=(10, 10)) f2_cb3 = relay.var('test_target_1_i0', shape=(10, 10))
f2_cb4 = relay.var('test_target_0_i1', shape=(10, 10)) f2_cb4 = relay.var('test_target_1_i1', shape=(10, 10))
f2_O_3 = relay.add(f2_cb3, f2_cb4) f2_O_3 = relay.add(f2_cb3, f2_cb4)
func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3) func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3)
func0 = set_func_attr(func0, "test_target", "test_target_0") func0 = set_func_attr(func0, "test_target", "test_target_1")
gv0 = relay.GlobalVar("test_target_0") gv0 = relay.GlobalVar("test_target_1")
mod[gv0] = func0 mod[gv0] = func0
# body # body
...@@ -1109,22 +1108,22 @@ def test_duplicate_merge_and_tuplegetitem(): ...@@ -1109,22 +1108,22 @@ def test_duplicate_merge_and_tuplegetitem():
mod = tvm.IRModule() mod = tvm.IRModule()
# function 0 # function 0
f0_i0 = relay.var(target+"_1_i0", shape=(10, 10)) f0_i0 = relay.var(target + "_0_i0", shape=(10, 10))
f0_i1 = relay.var(target+"_1_i1") f0_i1 = relay.var(target + "_0_i1")
f0_i2 = relay.var(target+"_1_i2") f0_i2 = relay.var(target + "_0_i2")
f0_i3 = relay.var(target+"_1_i3") f0_i3 = relay.var(target + "_0_i3")
f0_i4 = relay.var(target+"_1_i4") f0_i4 = relay.var(target + "_0_i4")
f0_n0 = relay.nn.batch_norm(f0_i0, f0_i1, f0_i2, f0_i3, f0_i4) f0_n0 = relay.nn.batch_norm(f0_i0, f0_i1, f0_i2, f0_i3, f0_i4)
f0_n1 = f0_n0[1] f0_n1 = f0_n0[1]
f0_n2 = relay.nn.relu(f0_n0[0]) f0_n2 = relay.nn.relu(f0_n0[0])
f0_o0 = relay.Tuple([f0_n1, f0_n2]) f0_o0 = relay.Tuple([f0_n2, f0_n1])
func0 = relay.Function([f0_i0, f0_i1, f0_i2, f0_i3, f0_i4], f0_o0) func0 = relay.Function([f0_i0, f0_i1, f0_i2, f0_i3, f0_i4], f0_o0)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", target) func0 = func0.with_attr("Compiler", target)
func0 = func0.with_attr("global_symbol", target+"_1") func0 = func0.with_attr("global_symbol", target + "_0")
gv0 = relay.GlobalVar(target+"_1") gv0 = relay.GlobalVar(target + "_0")
mod[gv0] = func0 mod[gv0] = func0
# body # body
...@@ -1136,9 +1135,9 @@ def test_duplicate_merge_and_tuplegetitem(): ...@@ -1136,9 +1135,9 @@ def test_duplicate_merge_and_tuplegetitem():
function_out = gv0(data, bn_gamma, bn_beta, bn_mmean, bn_mvar) function_out = gv0(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
get_out0 = relay.TupleGetItem(function_out, 0) get_out0 = relay.TupleGetItem(function_out, 0)
get_out1 = relay.TupleGetItem(function_out, 1) get_out1 = relay.TupleGetItem(function_out, 1)
out_2 = relay.tanh(get_out0) out_2 = relay.tanh(get_out1)
out_3 = relay.log(get_out0) out_3 = relay.log(get_out1)
out = relay.Tuple([get_out1, out_2, out_3]) out = relay.Tuple([get_out0, out_2, out_3])
func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out) func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out)
mod["main"] = func mod["main"] = func
return mod return mod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment