Commit ed9fdfb0 by Jon Soifer Committed by Jared Roesch

[Relay] Add new IR pass CombineParallelDense (#3862)

* Refactor to create abstract ParallelOpCombiner

* First draft of CombineParallelDense

* Begin to work on tests

* Test

* Refactor to move out more common code

* Clean up

* Fix

* Remove statics

* fix wording

* Start to add combine_parallel_op_batch

* Resolve PR comments

* Resolve PR comments

* dummy change to retrigger CI

* Change special case from bias_add to add

* Revert special case change

* Ignore units check

* dummy change to retrigger CI

* dummy change to re-trigger CI

* Improve docs

* Update docs

* Update docs
parent df6f54ac
......@@ -46,6 +46,8 @@ tvm.relay.transform
.. autofunction:: tvm.relay.transform.CombineParallelConv2D
.. autofunction:: tvm.relay.transform.CombineParallelDense
.. autofunction:: tvm.relay.transform.AlterOpLayout
.. autofunction:: tvm.relay.transform.Legalize
......@@ -483,6 +483,17 @@ TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr);
TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
* \brief Combine parallel dense ops into a single batch_matmul if the
* number of branches of this dense operator is not less than
* `min_num_branch`.
* \param min_num_branches The minimun number of branches.
* \return The pass.
TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3);
* \brief Backward fold axis scaling into weights of conv/dense operators.
* \return The pass.
......@@ -138,6 +138,7 @@ def build_config(opt_level=2,
"CanonicalizeCast": 3,
"EliminateCommonSubexpr": 3,
"CombineParallelConv2D": 4,
"CombineParallelDense": 4
fallback_device : int, str, or tvm.TVMContext, optional
......@@ -400,6 +401,35 @@ def CombineParallelConv2D(min_num_branches=3):
return _transform.CombineParallelConv2D(min_num_branches)
def CombineParallelDense(min_num_branches=3):
"""Combine multiple dense operators into one. For example:
/ \
dense (2,2) dense (2,2)
| |
elemwise/bcast (2,2) elemwise/bcast (2,2)
Would become:
batch_matmul+elemwise/bcast (2,2,2)
min_num_branches : int
The minimum number of required parallel branches for performing this
ret: tvm.relay.Pass
The registered pass that combines parallel dense operators.
return _transform.CombineParallelDense(min_num_branches)
def AlterOpLayout():
"""Alternate the layouts of operators or replace primitive operators with
other expressions.
......@@ -299,6 +299,7 @@ class RelayBuildModule : public runtime::ModuleNode {
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
* Copyright (c) 2019 by Contributors
* \file
* \brief Combine parallel dense ops into a single dense.
* This pass replaces dense ops that share the same input node, same shape,
* and don't have "units" defined with a single batch matrix multiplication.
* The inputs of the new batch_matmul is the stack of the original inputs.
* Elemwise and broadcast ops following dense are also combined if possible.
* This prevents launching multiple kernels in networks with multiple
* dense branches, such as BERT.
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op_batch.h"
namespace tvm {
namespace relay {
class ParallelDenseCombiner : public ParallelOpBatchCombiner {
explicit ParallelDenseCombiner(uint64_t min_num_branches)
: ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {
virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
AttrsEqual eq;
const auto* attrs_a = a-><DenseAttrs>();
const auto* attrs_b = b-><DenseAttrs>();
const auto* weight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* weight_b = b->args[1]->type_as<TensorTypeNode>();
return eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
eq(weight_a->shape[0], weight_b->shape[0]) &&
eq(weight_a->shape[1], weight_b->shape[1]);
/*! \brief Combine parallel dense if number of branches >= min_num_branches */
Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) {
return ParallelDenseCombiner(min_num_branches).Combine(expr);
namespace transform {
Pass CombineParallelDense(uint64_t min_num_branches) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(CombineParallelDense(f, min_num_branches));
return CreateFunctionPass(pass_func, 4, "CombineParallelDense",
} // namespace transform
} // namespace relay
} // namespace tvm
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
* Copyright (c) 2019 by Contributors
* \file
* \brief Abstract class to combine parallel ops and their successive element-wise ops.
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op.h"
namespace tvm {
namespace relay {
BranchGroupFinder::BranchGroupFinder(const std::string& op_name,
FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops)
: op_name_(op_name),
fare_compatible_ops_(fare_compatible_ops) {
std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
const Op& op = Op::Get(op_name_);
std::vector<Group> groups;
for (const auto& root : op_roots_) {
const auto& children =;
size_t ngroups = groups.size();
for (const CallNode* child : children) {
if (!child->op.same_as(op)) continue;
auto&& branch = CreateBranch(child);
// add the branch to a group, or create a new group
auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) {
CHECK(!group.empty() && !group[0].empty());
return fare_compatible_ops_(child, group[0][0]);
if (it != groups.end()) {
} else {
// each group has at least one branch
return groups;
// Create a branch starting from op.
Branch BranchGroupFinder::CreateBranch(const CallNode* op) {
auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
// each branch has at least one element, the first element is always op
Branch branch{op};
auto it = children_map_.find(GetRef<Expr>(branch.back()));
while (it != children_map_.end() && it->second.size() == 1) {
const CallNode* call = it->second[0];
auto pattern = fpattern[Downcast<Op>(call->op)];
if (pattern <= kBroadcast) {
it = children_map_.find(GetRef<Expr>(branch.back()));
} else {
return branch;
void BranchGroupFinder::VisitExpr_(const CallNode* n) {
const Op& op = Op::Get(op_name_);
if (n->op.same_as(op) && fis_supported_op_(n)) {
} else {
for (size_t i = 0; i < n->args.size(); i++) {
ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches)
: op_name_(op_name),
min_num_branches_(min_num_branches) {
Expr ParallelOpCombiner::Combine(const Expr& expr) {
auto groups = BranchGroupFinder(op_name_,
[&](const CallNode* n) {
return IsSupportedOp(n);
[&](const CallNode* a, const CallNode* b) {
return CanOpsBeCombined(a, b);
for (const Group& group : groups) {
if (group.size() < min_num_branches_) {
return ExprSubst(expr, std::move(subst_map_));
void ParallelOpCombiner::CombineBranches(const Group& branches) {
Call combined = MakeCombinedOp(branches);
auto it = std::min_element(branches.begin(), branches.end(),
[](const Branch& branch_a,
const Branch& branch_b) {
return branch_a.size() < branch_b.size();
size_t depth = it->size();
size_t i;
// starting from 1 to skip the op
for (i = 1; i < depth; i++) {
size_t parent_index;
for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) {
if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break;
CHECK_NE(parent_index, branches[0][i]->args.size());
if (!CheckLevel(branches, i, parent_index)) break;
combined = MakeCombinedCallFromFollowingOps(combined, branches, i, parent_index);
UpdateGroupOutput(combined, branches, i - 1, &subst_map_);
bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) {
const CallNode* call = branches[0][depth];
AttrsEqual attrs_equal;
// check if all branches in current depth can be combined
for (auto it = branches.begin() + 1; it != branches.end(); it++) {
const Branch& branch = *it;
if (!branch[depth]->op.same_as(call->op) ||
!attrs_equal(branch[depth]->attrs, call->attrs) ||
branch[depth]->args.size() != call->args.size()) {
return false;
if (branch[depth]->args[parent_index].get() != branch[depth - 1])
return false;
// Check args
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) continue;
if (!IsArgCompatible(call, branch[depth], i) ||
!attrs_equal(call->attrs, branch[depth]->attrs)) {
return false;
return true;
} // namespace relay
} // namespace tvm
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
* Copyright (c) 2019 by Contributors
* \file combine_parallel_op.h
* \brief Abstract class to combine parallel ops and their successive element-wise ops.
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <string>
#include "./expr_subst.h"
#include "./pattern_util.h"
namespace tvm {
namespace relay {
using Branch = std::vector<const CallNode*>;
using Group = std::vector<Branch>;
using FIsSupportedOp = std::function<bool (const CallNode* n)>;
using FAreCompatibleOps = std::function<bool (const CallNode* a, const CallNode* b)>;
using ExprSubstMap = std::unordered_map<Expr, Expr, NodeHash, NodeEqual>;
* Class to find parallel branches starting with op that are
* grouped if they are able to be combined. They are eligible to
* be combined if they have the same input data.
* Op can be followed by zero or more elemwise or broadcast ops,
* which are included in the group.
* Intermediate nodes have exactly one successor. It is possible that branches meet at a point,
* which should be handled in ParallelOpCombiner.
* data
* / \
* op op
* | |
* elem-wise elem-wise
* | |
class BranchGroupFinder : private ExprVisitor {
* \brief Constructor
* \param op_name name of op to start each group
* \param fis_supported_op function that returns true if op
* is supported for combining
* \param fare_compatible_ops function that returns true if
* two ops are compatible for combining
BranchGroupFinder(const std::string& op_name,
FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops);
* \brief Finds all groups that can be combined.
* \param expr Relay expression that represents function
* to look at for groups to be combined
* \return Vector of groups which can be combined.
std::vector<Group> Find(const Expr& expr);
/* \brief name of op to find parallel branches for */
std::string op_name_;
/* \brief function to return true if op is eligible to be combined,
* false otherwise
FIsSupportedOp fis_supported_op_;
/* \brief function to return true if two parallel ops are eligible
* to be combined, false otherwise
FAreCompatibleOps fare_compatible_ops_;
/* \brief ops that are on the first (logically, leftmost) branch
* of parallel ops and are eligible to be combined
std::unordered_set<Expr, NodeHash, NodeEqual> op_roots_;
/* \brief map of Expr to CallNodes that follow it */
std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> children_map_;
* \brief Creates new branch from op and its children that have
* elementwise or broadcast patterns
* \return New branch
Branch CreateBranch(const CallNode* op);
* \brief Expression visitor function
void VisitExpr_(const CallNode* n) final;
* Abstract class to find and combine parallel ops and the elementwise ops that follow.
class ParallelOpCombiner {
* \brief Constructor.
* \param op_name name of op to combine
* \param min_num_branches min number of parallel branches beginning with op
* to start combining
explicit ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches);
* \brief Combines ops and following elementwise or broadcast ops
* \param expr function to modify
* \return new function with combined ops
Expr Combine(const Expr& expr);
* \brief Checks if node is supported to be combined
* \param n node in question
* \return True if the op represented by n is supported to be the root of a branch
* to be combined. False otherwise.
virtual bool IsSupportedOp(const CallNode* n) = 0;
* \brief Checks if two ops can be combined
* \param a node a
* \param b node b
* \return True if a and b can be combined. False otherwise.
virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) = 0;
* \brief Makes combined op from parallel ops in branches. This usually involves
* concatenating or stacking inputs, then creating a new call.
* \param branches branches that are to be combined
* \return new call with branches combined.
virtual Call MakeCombinedOp(const Group& branches) = 0;
* \brief Checks if argument of op following combined ops are able to be combined
* \param a node a
* \param b node b
* \param index index of argument in question
* \return True if argument of a and b and index can be combined
virtual bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) = 0;
* \brief Create combined call from ops that follow the initial combined op at the depth-th level.
* This usually involves concatenating or stacking inputs, then creating a new call.
* Only called if IsArgCompatbile returns true for each arg.
* \param data combined op
* \param branches branches of parallel ops to be combined
* \param depth depth at which to combine ops
* \param parent_index index of arg that corresponds to original input that was shared among
* all combined ops
* \return new combined call
virtual Call MakeCombinedCallFromFollowingOps(const Expr& data,
const Group& branches,
size_t depth,
size_t parent_index) = 0;
* \brief Updates map of expr to substitute with combined expr. This usually involves
* slicing or splitting data.
* \param data combined op
* \param branches branches of parallel ops to be combined
* \param depth depth at which to substitute
* \param subst_map map of Expr to replace with Expr to replace it with
virtual void UpdateGroupOutput(const Expr& data,
const Group& branches,
size_t depth,
ExprSubstMap* subst_map) = 0;
/* \brief name of op to be combined */
std::string op_name_;
/* \brief minimum number of parallel branches to combine */
uint64_t min_num_branches_;
/* \brief map of Expr to Expr to substitute it with after running pass */
ExprSubstMap subst_map_;
* \brief Combine parallel branches and updates subst_map_ with Exprs
* to be substituted
* \param branches branches to be combined
void CombineBranches(const Group& branches);
* \brief Combine parallel branches and updates subst_map_ with Exprs
* to be substituted
* \param branches parallel branches to potentially be combined
* \param depth depth at which to look at op
* \param parent_index index of arg that corresponds to original input that was shared among
* all combined ops
* \return true if parallel ops at depth can be combined, false otherwise
bool CheckLevel(const Group& branches, size_t depth, size_t parent_index);
} // namespace relay
} // namespace tvm
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
* Copyright (c) 2019 by Contributors
* \file
* \brief Combine parallel ops into a single batch op.
* This pass replaces ops that share the same input node and same shape
* with a single op that takes in batched input. The inputs of the new
* batched op are the stack of the original inputs. Elementwise and
* broadcast ops following the original op are also stacked
* and fused if possible. For example:
* data
* / \
* add (2,2) add (2,2)
* | |
* elemwise (2,2) elemwise (2,2)
* | |
* Would become:
* data
* |
* add+elemwise (2,2,2)
* / \
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op.h"
#include "./combine_parallel_op_batch.h"
namespace tvm {
namespace relay {
ParallelOpBatchCombiner::ParallelOpBatchCombiner(const std::string& op_name,
const std::string& batch_op_name,
uint64_t min_num_branches)
: ParallelOpCombiner(op_name, min_num_branches),
batch_op_name_(batch_op_name) {
bool ParallelOpBatchCombiner::IsSupportedOp(const CallNode* n) {
return true;
bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode* b) {
if (a->args.size() != b->args.size()) {
return false;
AttrsEqual eq;
for (size_t i = 0; i < a->args.size(); i++) {
auto ta = a->args[i]->type_as<TensorTypeNode>();
auto tb = b->args[i]->type_as<TensorTypeNode>();
if (ta->shape.size() != tb->shape.size() || !eq(ta->dtype, tb->dtype)) {
return false;
for (size_t j = 0; j < ta->shape.size(); j++) {
if (!eq(ta->shape[j], tb->shape[j])) {
return false;
return true;
Call ParallelOpBatchCombiner::MakeCombinedOp(const Group& branches) {
const Op& batch_op = Op::Get(batch_op_name_);
Array<Expr> new_args;
size_t num_args = branches[0][0]->args.size();
for (size_t i = 0; i < num_args; i++) {
Array<Expr> arg_from_all_branches;
for (const auto& branch : branches) {
new_args.push_back(MakeStack(TupleNode::make(arg_from_all_branches), 0));
return CallNode::make(batch_op, new_args, Attrs(), {});
bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
AttrsEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size())
return false;
for (size_t i = 0; i < ta->shape.size(); i++) {
if (!eq(ta->shape[i], tb->shape[i]))
return false;
return true;
Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data,
const Group& branches,
size_t depth,
size_t parent_index) {
Array<Expr> new_args;
const CallNode* call = branches[0][depth];
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) {
Array<Expr> tuple;
for (const auto& branch : branches) {
// if the shape of the arg is of shape (j,),
// expand it to (1,j) so it can be properly broadcasted.
Expr arg = branch[depth]->args[i];
const TensorTypeNode* arg_tensor = arg->type_as<TensorTypeNode>();
if (arg_tensor->shape.size() == 1) {
Expr expanded_arg = MakeExpandDims(arg, 0, 1);
} else {
auto stack = MakeStack(TupleNode::make(tuple), 0);
return CallNode::make(call->op, new_args, call->attrs, {});
void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data,
const Group& branches,
size_t depth,
ExprSubstMap* subst_map) {
int index = 0;
auto split = MakeSplit(data, Integer(branches.size()), 0);
for (const auto& branch : branches) {
auto split_data = TupleGetItemNode::make(split, index++);
auto squeezed_data = MakeSqueeze(split_data, {0});
subst_map->insert({GetRef<Expr>(branch[depth]), squeezed_data});
/*! \brief Combine parallel op into batched op if number of branches >= min_num_branches */
Expr CombineParallelOpBatch(const Expr& expr,
const std::string& op_name,
const std::string& batch_op_name,
uint64_t min_num_branches) {
return ParallelOpBatchCombiner(op_name, batch_op_name, min_num_branches).Combine(expr);
namespace transform {
Pass CombineParallelOpBatch(const std::string& op_name,
const std::string& batch_op_name,
uint64_t min_num_branches) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(CombineParallelOpBatch(f,
return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch",
} // namespace transform
} // namespace relay
} // namespace tvm
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
* Copyright (c) 2019 by Contributors
* \file
* \brief Combine parallel ops into a single batch op.
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include <string>
#include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op.h"
namespace tvm {
namespace relay {
* Class to find and combine parallel ops and following element-wise
* and broadcast ops into a single batch op. Ops can be combined
* if they have the same input data. Batch op is formed by
* stacking inputs. Final results are retrieved by splitting output.
* For example:
* data
* / \
* dense (2,2) dense (2,2)
* | |
* elemwise/bcast (2,2) elemwise/bcast (2,2)
* Would become:
* data
* |
* batch_matmul+elemwise/bcast (2,2,2)
class ParallelOpBatchCombiner : public ParallelOpCombiner {
* \brief Constructor.
* \param op_name name of op to combine
* \param batch_op_name name of op that combined branches will be joined into
* \param min_num_branches min number of parallel branches beginning with op
* to start combining
ParallelOpBatchCombiner(const std::string& op_name,
const std::string& batch_op_name,
uint64_t min_num_branches);
* \brief Checks if node is supported to be combined
* \param n node in question
* \return True by default
virtual bool IsSupportedOp(const CallNode* n);
* \brief Checks if two ops can be combined
* \param a node a
* \param b node b
* \return True if shapes and dtypes of all args of a and b are the same
virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b);
* \brief Makes combined op from parallel ops in branches. This usually involves
* concatenating or stacking inputs, then creating a new call.
* \param branches branches that are to be combined
* \return new call with branches combined as batch op by stacking args
Call MakeCombinedOp(const Group& branches) final;
* \brief Checks if argument of op following combined ops are able to be combined
* \param a node a
* \param b node b
* \param index index of argument in question
* \return True if shapes and dtypes of args[index] a and b are the same
bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) final;
* \brief Create combined call from ops that follow the initial combined op at the depth-th level.
* This usually involves concatenating or stacking inputs, then creating a new call.
* Only called if IsArgCompatbile returns true for each arg.
* \param data combined op
* \param branches branches of parallel ops to be combined
* \param depth depth at which to combine ops
* \param parent_index index of arg that corresponds to original input that was shared among
* all combined ops
* \return new combined call as batch op by stacking args
Call MakeCombinedCallFromFollowingOps(const Expr& data,
const Group& branches,
size_t depth,
size_t parent_index) final;
* \brief Updates map of expr to substitute with combined expr. This usually involves
* slicing or splitting data.
* \param data combined op
* \param branches branches of parallel ops to be combined
* \param depth depth at which to substitute
* \param subst_map map of Expr to replace with Expr to replace it with
void UpdateGroupOutput(const Expr& data,
const Group& branches,
size_t depth,
ExprSubstMap* subst_map) final;
/* \brief name of op to replace combined ops with. for example,
* for combining parallel dense, this will will be set to
* nn.batch_matmul
std::string batch_op_name_;
} // namespace relay
} // namespace tvm
......@@ -497,6 +497,14 @@ Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
Expr MakeStack(Expr data, int axis);
Expr MakeSplit(Expr data, NodeRef indices_or_sections, int axis);
Expr MakeSqueeze(Expr data, Array<Integer> axis);
Expr MakeExpandDims(Expr data, int axis, int num_newaxis);
Expr StopFusion(Expr data);
Expr CastHint(Expr data, DataType dtype);
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm import relay
from tvm.relay import transform
def run_combine_parallel(expr, min_num_branches=3):
mod = relay.Module.from_expr(expr)
mod = transform.CombineParallelDense(min_num_branches)(mod)
return mod["main"]
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr)
mod = opt_pass(mod)
return mod["main"]
def test_combine_parallel_dense():
"""Simple testcase. One dense cannot be combined due to shape mismatch"""
def before(x, w1, w2, w3, w4):
args = [x, w1, w2, w3, w4]
y1 = relay.nn.dense(x, w1)
y2 = relay.nn.dense(x, w2)
# y3 cannot be combined
y3 = relay.nn.dense(x, w3)
y4 = relay.nn.dense(x, w4)
y = relay.Tuple((y1, y2, y3, y4))
return relay.Function(args, y)
def expected(x, w1, w2, w3, w4):
# use a fixed order of args so alpha equal check can pass
args = [x, w1, w2, w3, w4]
x_stacked = relay.stack((x, x, x), axis=0)
w = relay.stack((w1, w2, w4), axis=0)
y = relay.nn.batch_matmul(x_stacked, w)
(y1, y2, y4) = relay.split(y, 3)
y1 = relay.squeeze(y1, [0])
y2 = relay.squeeze(y2, [0])
y4 = relay.squeeze(y4, [0])
# y3 cannot be combined
y3 = relay.nn.dense(x, w3)
y = relay.Tuple((y1, y2, y3, y4))
return relay.Function(args, y)
def check(i, j, k):
x = relay.var("x", shape=(i, k))
w1 = relay.var("w1", shape=(j, k))
w2 = relay.var("w2", shape=(j, k))
w3 = relay.var("w3", shape=(j + 1, k))
w4 = relay.var("w4", shape=(j, k))
y_before = before(x, w1, w2, w3, w4)
y = run_opt_pass(y_before,
y_expected = expected(x, w1, w2, w3, w4)
y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected)
check(3, 5, 4)
check(100, 200, 300)
def test_combine_parallel_dense_biasadd():
"""Testcase of combining dense + 1d biasadd"""
def before(x, w1, w2, b1, b2):
args = [x, w1, w2, b1, b2]
y1 = relay.nn.dense(x, w1)
y2 = relay.nn.dense(x, w2)
y1 = relay.add(y1, b1)
y2 = relay.add(y2, b2)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def expected(x, w1, w2, b1, b2, is_2d_bias):
args = [x, w1, w2, b1, b2]
x_stacked = relay.stack((x, x), axis=0)
w = relay.stack((w1, w2), axis=0)
y = relay.nn.batch_matmul(x_stacked, w)
if not is_2d_bias:
b1 = relay.expand_dims(b1, 0)
b2 = relay.expand_dims(b2, 0)
b = relay.stack((b1, b2), axis=0)
y = relay.add(y, b)
(y1, y2) = relay.split(y, 2)
y1 = relay.squeeze(y1, [0])
y2 = relay.squeeze(y2, [0])
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def check(i, j, k, is_2d_bias):
x = relay.var("x", shape=(i, k))
w1 = relay.var("w1", shape=(j, k))
w2 = relay.var("w2", shape=(j, k))
if is_2d_bias:
b1 = relay.var("b1", shape=(i, j))
b2 = relay.var("b2", shape=(i, j))
b1 = relay.var("b1", shape=(j,))
b2 = relay.var("b2", shape=(j,))
y_before = before(x, w1, w2, b1, b2)
y = run_opt_pass(y_before,
y_expected = expected(x, w1, w2, b1, b2, is_2d_bias)
y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected)
check(3, 5, 4, False)
check(100, 200, 300, False)
check(3, 5, 4, True)
check(100, 200, 300, True)
def test_combine_parallel_dense_biasadd_scale_reshape():
"""Testcase of combining dense + 1d biasadd + multiply with non-fused reshape"""
def before(x, w1, w2, b1, b2, scale1, scale2, newshape):
args = [x, w1, w2, b1, b2, scale1, scale2]
y1 = relay.nn.dense(x, w1)
y2 = relay.nn.dense(x, w2)
y1 = relay.add(y1, b1)
y2 = relay.add(y2, b2)
y1 = relay.multiply(y1, scale1)
y2 = relay.multiply(y2, scale2)
y1 = relay.reshape(y1, newshape=newshape)
y2 = relay.reshape(y2, newshape=newshape)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def expected(x, w1, w2, b1, b2, scale1, scale2, newshape):
args = [x, w1, w2, b1, b2, scale1, scale2]
x_stacked = relay.stack((x, x), axis=0)
w = relay.stack((w1, w2), axis=0)
y = relay.nn.batch_matmul(x_stacked, w)
b1 = relay.expand_dims(b1, 0)
b2 = relay.expand_dims(b2, 0)
b = relay.stack((b1, b2), axis=0)
y = relay.add(y, b)
scale1 = relay.expand_dims(scale1, 0)
scale2 = relay.expand_dims(scale2, 0)
scale = relay.stack((scale1, scale2), axis=0)
y = relay.multiply(y, scale)
(y1, y2) = relay.split(y, 2)
y1 = relay.squeeze(y1, [0])
y2 = relay.squeeze(y2, [0])
y1 = relay.reshape(y1, newshape=newshape)
y2 = relay.reshape(y2, newshape=newshape)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def check(i, j, k, scale1, scale2, newshape):
x = relay.var("x", shape=(i, k))
w1 = relay.var("w1", shape=(j, k))
w2 = relay.var("w2", shape=(j, k))
b1 = relay.var("b1", shape=(j,))
b2 = relay.var("b2", shape=(j,))
scale1 = relay.var("scale1", shape=(1,))
scale2 = relay.var("scale2", shape=(1,))
y_before = before(x, w1, w2, b1, b2, scale1, scale2, newshape)
y = run_opt_pass(y_before,
y_expected = expected(x, w1, w2, b1, b2, scale1, scale2, newshape)
y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected)
check(3, 5, 4, 0.5, 0.25, (1, 1, 15))
check(100, 200, 300, 0.5, 0.25, (1, 1, 200))
if __name__ == "__main__":
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment