/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * * \file combine_parallel_op.cc * \brief Abstract class to combine parallel ops and their successive element-wise ops. */ #include <tvm/relay/analysis.h> #include <tvm/relay/expr_functor.h> #include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/transform.h> #include <tvm/relay/op.h> #include <tvm/relay/op_attr_types.h> #include <tvm/relay/transform.h> #include <algorithm> #include <utility> #include <unordered_map> #include <unordered_set> #include "expr_subst.h" #include "pattern_util.h" #include "combine_parallel_op.h" namespace tvm { namespace relay { BranchGroupFinder::BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops) : cached_op_(op), fis_supported_op_(fis_supported_op), fare_compatible_ops_(fare_compatible_ops) { } std::vector<Group> BranchGroupFinder::Find(const Expr& expr) { this->VisitExpr(expr); std::vector<Group> groups; for (const auto& root : op_roots_) { const auto& children = children_map_.at(root); size_t ngroups = groups.size(); for (const CallNode* child : children) { if (child->op != cached_op_) continue; auto&& branch = CreateBranch(child); // add the branch to a group, or create a new group auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) { CHECK(!group.empty() && !group[0].empty()); return fare_compatible_ops_(child, group[0][0]); }); if (it != groups.end()) { it->push_back(branch); } else { groups.emplace_back(); // each group has at least one branch groups.back().push_back(branch); } } } return groups; } // Create a branch starting from op. Branch BranchGroupFinder::CreateBranch(const CallNode* op) { auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern"); // each branch has at least one element, the first element is always op Branch branch{op}; auto it = children_map_.find(GetRef<Expr>(branch.back())); while (it != children_map_.end() && it->second.size() == 1) { const CallNode* call = it->second[0]; auto pattern = fpattern[Downcast<Op>(call->op)]; if (pattern <= kBroadcast) { branch.push_back(call); it = children_map_.find(GetRef<Expr>(branch.back())); } else { break; } } return branch; } void BranchGroupFinder::VisitExpr_(const CallNode* n) { ExprVisitor::VisitExpr_(n); if (n->op == cached_op_ && fis_supported_op_(n)) { op_roots_.insert(n->args[0]); children_map_[n->args[0]].push_back(n); } else { for (size_t i = 0; i < n->args.size(); i++) { children_map_[n->args[i]].push_back(n); } } } ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches) : cached_op_(Op::Get(op_name)), min_num_branches_(min_num_branches) { } Expr ParallelOpCombiner::Combine(const Expr& expr) { auto groups = BranchGroupFinder(cached_op_, [&](const CallNode* n) { return IsSupportedOp(n); }, [&](const CallNode* a, const CallNode* b) { return CanOpsBeCombined(a, b); }).Find(expr); for (const Group& group : groups) { if (group.size() < min_num_branches_) { continue; } CombineBranches(group); } return ExprSubst(expr, std::move(subst_map_)); } void ParallelOpCombiner::CombineBranches(const Group& branches) { Call combined = MakeCombinedOp(branches); auto it = std::min_element(branches.begin(), branches.end(), [](const Branch& branch_a, const Branch& branch_b) { return branch_a.size() < branch_b.size(); }); size_t depth = it->size(); size_t i; // starting from 1 to skip the op for (i = 1; i < depth; i++) { size_t parent_index; for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) { if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break; } CHECK_NE(parent_index, branches[0][i]->args.size()); if (!CheckLevel(branches, i, parent_index)) break; combined = MakeCombinedCallFromFollowingOps(combined, branches, i, parent_index); } UpdateGroupOutput(combined, branches, i - 1, &subst_map_); } bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) { const CallNode* call = branches[0][depth]; AttrsEqual attrs_equal; // check if all branches in current depth can be combined for (auto it = branches.begin() + 1; it != branches.end(); it++) { const Branch& branch = *it; if (!branch[depth]->op.same_as(call->op) || !attrs_equal(branch[depth]->attrs, call->attrs) || branch[depth]->args.size() != call->args.size()) { return false; } if (branch[depth]->args[parent_index].get() != branch[depth - 1]) return false; // Check args for (size_t i = 0; i < call->args.size(); i++) { if (i == parent_index) continue; if (!IsArgCompatible(call, branch[depth], i) || !attrs_equal(call->attrs, branch[depth]->attrs)) { return false; } } } return true; } } // namespace relay } // namespace tvm