Commit ed9fdfb0 by Jon Soifer Committed by Jared Roesch

[Relay] Add new IR pass CombineParallelDense (#3862)

* Refactor to create abstract ParallelOpCombiner

* First draft of CombineParallelDense

* Begin to work on tests

* Test

* Refactor to move out more common code

* Clean up

* Fix

* Remove statics

* fix wording

* Start to add combine_parallel_op_batch

* Resolve PR comments

* Resolve PR comments

* dummy change to retrigger CI

* Change special case from bias_add to add

* Revert special case change

* Ignore units check

* dummy change to retrigger CI

* dummy change to re-trigger CI

* Improve docs

* Update docs

* Update docs
parent df6f54ac
...@@ -46,6 +46,8 @@ tvm.relay.transform ...@@ -46,6 +46,8 @@ tvm.relay.transform
.. autofunction:: tvm.relay.transform.CombineParallelConv2D .. autofunction:: tvm.relay.transform.CombineParallelConv2D
.. autofunction:: tvm.relay.transform.CombineParallelDense
.. autofunction:: tvm.relay.transform.AlterOpLayout .. autofunction:: tvm.relay.transform.AlterOpLayout
.. autofunction:: tvm.relay.transform.Legalize .. autofunction:: tvm.relay.transform.Legalize
......
...@@ -483,6 +483,17 @@ TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr); ...@@ -483,6 +483,17 @@ TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr);
TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3); TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
/*! /*!
* \brief Combine parallel dense ops into a single batch_matmul if the
* number of branches of this dense operator is not less than
* `min_num_branch`.
*
* \param min_num_branches The minimun number of branches.
*
* \return The pass.
*/
TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3);
/*!
* \brief Backward fold axis scaling into weights of conv/dense operators. * \brief Backward fold axis scaling into weights of conv/dense operators.
* *
* \return The pass. * \return The pass.
......
...@@ -138,6 +138,7 @@ def build_config(opt_level=2, ...@@ -138,6 +138,7 @@ def build_config(opt_level=2,
"CanonicalizeCast": 3, "CanonicalizeCast": 3,
"EliminateCommonSubexpr": 3, "EliminateCommonSubexpr": 3,
"CombineParallelConv2D": 4, "CombineParallelConv2D": 4,
"CombineParallelDense": 4
} }
fallback_device : int, str, or tvm.TVMContext, optional fallback_device : int, str, or tvm.TVMContext, optional
...@@ -400,6 +401,35 @@ def CombineParallelConv2D(min_num_branches=3): ...@@ -400,6 +401,35 @@ def CombineParallelConv2D(min_num_branches=3):
return _transform.CombineParallelConv2D(min_num_branches) return _transform.CombineParallelConv2D(min_num_branches)
def CombineParallelDense(min_num_branches=3):
"""Combine multiple dense operators into one. For example:
data
/ \
dense (2,2) dense (2,2)
| |
elemwise/bcast (2,2) elemwise/bcast (2,2)
Would become:
data
|
batch_matmul+elemwise/bcast (2,2,2)
Parameters
----------
min_num_branches : int
The minimum number of required parallel branches for performing this
optimization.
Returns
-------
ret: tvm.relay.Pass
The registered pass that combines parallel dense operators.
"""
return _transform.CombineParallelDense(min_num_branches)
def AlterOpLayout(): def AlterOpLayout():
"""Alternate the layouts of operators or replace primitive operators with """Alternate the layouts of operators or replace primitive operators with
other expressions. other expressions.
......
...@@ -299,6 +299,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -299,6 +299,7 @@ class RelayBuildModule : public runtime::ModuleNode {
}); });
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::CombineParallelConv2D(3)); pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::FoldConstant()); pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis()); pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast()); pass_seqs.push_back(transform::CanonicalizeCast());
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2019 by Contributors
* *
* \file combine_parallel_conv2d.cc * \file combine_parallel_conv2d.cc
* \brief Combine parallel 2d convolutions into a single convolution. * \brief Combine parallel 2d convolutions into a single convolution.
...@@ -43,68 +43,25 @@ ...@@ -43,68 +43,25 @@
#include <unordered_set> #include <unordered_set>
#include "./expr_subst.h" #include "./expr_subst.h"
#include "./pattern_util.h" #include "./pattern_util.h"
#include "./combine_parallel_op.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using Branch = std::vector<const CallNode*>; class ParallelConv2DCombiner : public ParallelOpCombiner {
using Group = std::vector<Branch>;
/*
Find parallel branches starting with conv2d as shown below and then group branches by kernel
shape and attributes of conv2d. Conv2d can be followed by zero or more elemwise or broadcast ops.
Intermediate nodes have exactly one successor. It is possible that branches meet at a point,
which should be handled in ParallelConv2DCombiner.
data
/ \
conv2d conv2d
| |
op op
| |
*/
class BranchGroupFinder : private ExprVisitor {
public: public:
std::vector<Group> Find(const Expr& expr) { explicit ParallelConv2DCombiner(uint64_t min_num_branches)
static const Op& conv2d = Op::Get("nn.conv2d"); : ParallelOpCombiner("nn.conv2d", min_num_branches) {
this->VisitExpr(expr);
std::vector<Group> groups;
for (const auto& root : conv_roots_) {
const auto& children = children_map_.at(root);
size_t ngroups = groups.size();
for (const CallNode* child : children) {
if (!child->op.same_as(conv2d)) continue;
auto&& branch = CreateBranch(child);
// add the branch to a group, or create a new group
auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) {
CHECK(!group.empty() && !group[0].empty());
return IsCompatibleConv2D(child, group[0][0]);
});
if (it != groups.end()) {
it->push_back(branch);
} else {
groups.emplace_back();
// each group has at least one branch
groups.back().push_back(branch);
}
}
}
return groups;
} }
private: protected:
std::unordered_set<Expr, NodeHash, NodeEqual> conv_roots_; bool IsSupportedOp(const CallNode* n) {
std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> children_map_; return n->attrs.as<Conv2DAttrs>()->groups == 1;
}
// Two 2d convolutions can be combined if they have the same attributes or bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
// only have different output channels.
bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) {
AttrsEqual eq; AttrsEqual eq;
static const Layout kOIHW("OIHW"); const Layout kOIHW("OIHW");
const auto* attrs_a = a->attrs.as<Conv2DAttrs>(); const auto* attrs_a = a->attrs.as<Conv2DAttrs>();
const auto* attrs_b = b->attrs.as<Conv2DAttrs>(); const auto* attrs_b = b->attrs.as<Conv2DAttrs>();
CHECK(attrs_a); CHECK(attrs_a);
...@@ -125,76 +82,8 @@ class BranchGroupFinder : private ExprVisitor { ...@@ -125,76 +82,8 @@ class BranchGroupFinder : private ExprVisitor {
eq(shape_a[3], shape_b[3]); eq(shape_a[3], shape_b[3]);
} }
// Create a branch starting from conv2d. Call MakeCombinedOp(const Group& branches) {
Branch CreateBranch(const CallNode* conv) { const Op& conv2d = Op::Get("nn.conv2d");
static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
// each branch has at least one element, the first element is always conv2d
Branch branch{conv};
auto it = children_map_.find(GetRef<Expr>(branch.back()));
while (it != children_map_.end() && it->second.size() == 1) {
const CallNode* call = it->second[0];
auto pattern = fpattern[Downcast<Op>(call->op)];
if (pattern <= kBroadcast) {
branch.push_back(call);
it = children_map_.find(GetRef<Expr>(branch.back()));
} else {
break;
}
}
return branch;
}
void VisitExpr_(const CallNode* n) final {
static const Op& conv2d = Op::Get("nn.conv2d");
ExprVisitor::VisitExpr_(n);
if (n->op.same_as(conv2d) && n->attrs.as<Conv2DAttrs>()->groups == 1) {
conv_roots_.insert(n->args[0]);
children_map_[n->args[0]].push_back(n);
} else {
for (size_t i = 0; i < n->args.size(); i++) {
children_map_[n->args[i]].push_back(n);
}
}
}
};
class ParallelConv2DCombiner {
public:
explicit ParallelConv2DCombiner(uint64_t min_num_branches) : min_num_branches_(min_num_branches) {
}
Expr Combine(const Expr& expr) {
auto groups = BranchGroupFinder().Find(expr);
for (const Group& group : groups) {
if (group.size() < min_num_branches_) {
continue;
}
CombineBranches(group);
}
return ExprSubst(expr, std::move(subst_map_));
}
private:
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map_;
uint64_t min_num_branches_;
std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
int64_t num_filters = 0; // number of filters of the transformed weight
Array<Expr> weights;
for (const auto& branch : branches) {
auto conv2d = branch[0];
weights.push_back(conv2d->args[1]);
auto channels = GetConv2DSuperChannelsDim(conv2d);
num_filters += channels;
}
auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->kernel_layout.find('O');
CHECK_NE(index, std::string::npos);
return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index),
MakeConstScalar(Int(32), num_filters));
}
Call MakeCombinedConv2D(const Group& branches) {
static const Op& conv2d = Op::Get("nn.conv2d");
Expr data = branches[0][0]->args[0]; Expr data = branches[0][0]->args[0];
Expr new_weight; Expr new_weight;
IndexExpr new_channels; IndexExpr new_channels;
...@@ -215,10 +104,15 @@ class ParallelConv2DCombiner { ...@@ -215,10 +104,15 @@ class ParallelConv2DCombiner {
new_attrs->out_dtype = attrs->out_dtype; new_attrs->out_dtype = attrs->out_dtype;
new_attrs->channels = new_channels; new_attrs->channels = new_channels;
const std::string& layout =
new_attrs->out_layout == "" ? new_attrs->data_layout : new_attrs->out_layout;
channel_pos_ = layout.find('C');
CHECK_NE(channel_pos_, std::string::npos);
return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {}); return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {});
} }
bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index, size_t channel_pos) { bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
AttrsEqual eq; AttrsEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>(); auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>(); auto tb = b->args[index]->type_as<TensorTypeNode>();
...@@ -229,12 +123,12 @@ class ParallelConv2DCombiner { ...@@ -229,12 +123,12 @@ class ParallelConv2DCombiner {
return false; return false;
// Position of the 'C' dimension in the argument // Position of the 'C' dimension in the argument
size_t arg_channel_pos = channel_pos - toutput_a->shape.size() + ta->shape.size(); size_t arg_channel_pos = channel_pos_ - toutput_a->shape.size() + ta->shape.size();
// Channel super-dimension shoule be present and not broadcasted // Channel super-dimension shoule be present and not broadcasted
if ((arg_channel_pos > channel_pos) || // size_t overflow if ((arg_channel_pos > channel_pos_) || // size_t overflow
!eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos]) || !eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos_]) ||
!eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos])) !eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos_]))
return false; return false;
for (size_t i = 0; i < ta->shape.size(); i++) { for (size_t i = 0; i < ta->shape.size(); i++) {
...@@ -245,38 +139,10 @@ class ParallelConv2DCombiner { ...@@ -245,38 +139,10 @@ class ParallelConv2DCombiner {
return true; return true;
} }
// Check if ops in depth-th level can be combined Call MakeCombinedCallFromFollowingOps(const Expr& data,
bool CheckLevel(const Group& branches, size_t depth, size_t channel_pos, size_t parent_index) { const Group& branches,
const CallNode* call = branches[0][depth]; size_t depth,
AttrsEqual attrs_equal; size_t parent_index) {
// check if all branches in current depth can be combined
for (auto it = branches.begin() + 1; it != branches.end(); it++) {
const Branch& branch = *it;
if (!branch[depth]->op.same_as(call->op) ||
!attrs_equal(branch[depth]->attrs, call->attrs) ||
branch[depth]->args.size() != call->args.size()) {
return false;
}
if (branch[depth]->args[parent_index].get() != branch[depth - 1])
return false;
// Check args
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) continue;
if (!IsArgCompatible(call, branch[depth], i, channel_pos) ||
!attrs_equal(call->attrs, branch[depth]->attrs)) {
return false;
}
}
}
return true;
}
// Combine args and make the combined CallNode
Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t channel_pos,
size_t parent_index) {
Array<Expr> new_args; Array<Expr> new_args;
const CallNode* call = branches[0][depth]; const CallNode* call = branches[0][depth];
size_t ndim = call->type_as<TensorTypeNode>()->shape.size(); size_t ndim = call->type_as<TensorTypeNode>()->shape.size();
...@@ -286,28 +152,32 @@ class ParallelConv2DCombiner { ...@@ -286,28 +152,32 @@ class ParallelConv2DCombiner {
new_args.push_back(data); new_args.push_back(data);
continue; continue;
} }
size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size(); size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size();
size_t arg_channel_pos = channel_pos - ndim + arg_ndim; size_t arg_channel_pos = channel_pos_ - ndim + arg_ndim;
Array<Expr> tuple; Array<Expr> tuple;
for (const auto& branch : branches) { for (const auto& branch : branches) {
tuple.push_back(branch[depth]->args[i]); tuple.push_back(branch[depth]->args[i]);
} }
auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos); auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos);
new_args.push_back(std::move(concat)); new_args.push_back(std::move(concat));
} }
return CallNode::make(call->op, new_args, call->attrs, {}); return CallNode::make(call->op, new_args, call->attrs, {});
} }
// Replace output of each branch with slices of the combined output void UpdateGroupOutput(const Expr& data,
void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, const Group& branches,
size_t channel_pos) { size_t depth,
ExprSubstMap* subst_map) {
int64_t index = 0; int64_t index = 0;
for (const auto& branch : branches) { for (const auto& branch : branches) {
const CallNode* conv2d = branch[0]; const CallNode* conv2d = branch[0];
int64_t channels = GetConv2DSuperChannelsDim(conv2d); int64_t channels = GetConv2DSuperChannelsDim(conv2d);
Array<Integer> begin; Array<Integer> begin;
Array<Integer> end; Array<Integer> end;
for (size_t i = 0; i < channel_pos; i++) { for (size_t i = 0; i < channel_pos_; i++) {
begin.push_back(0); begin.push_back(0);
end.push_back(NullValue<Integer>()); end.push_back(NullValue<Integer>());
} }
...@@ -315,38 +185,27 @@ class ParallelConv2DCombiner { ...@@ -315,38 +185,27 @@ class ParallelConv2DCombiner {
index += channels; index += channels;
end.push_back(index); end.push_back(index);
auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array<Integer>{}); auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array<Integer>{});
subst_map_[GetRef<Expr>(branch[depth])] = slice; subst_map->insert({GetRef<Expr>(branch[depth]), slice});
} }
} }
// Combine branches in a group. Conv2d in different branches in the same group are safe to private:
// combine. Subsequent ops may or may not be combined. We start from conv2d and try to /* \brief index of channel dimension */
// combine ops from all branches in the same depth. size_t channel_pos_;
void CombineBranches(const Group& branches) {
Call combined = MakeCombinedConv2D(branches); std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
auto conv_param = combined->attrs.as<Conv2DAttrs>(); int64_t num_filters = 0; // number of filters of the transformed weight
const std::string& layout = Array<Expr> weights;
conv_param->out_layout == "" ? conv_param->data_layout : conv_param->out_layout; for (const auto& branch : branches) {
size_t channel_pos = layout.find('C'); auto conv2d = branch[0];
CHECK_NE(channel_pos, std::string::npos); weights.push_back(conv2d->args[1]);
auto it = std::min_element(branches.begin(), branches.end(), auto channels = GetConv2DSuperChannelsDim(conv2d);
[](const Branch& branch_a, num_filters += channels;
const Branch& branch_b) {
return branch_a.size() < branch_b.size();
});
size_t depth = it->size();
size_t i;
// starting from 1 to skip the conv2d
for (i = 1; i < depth; i++) {
size_t parent_index;
for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) {
if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break;
}
CHECK_NE(parent_index, branches[0][i]->args.size());
if (!CheckLevel(branches, i, channel_pos, parent_index)) break;
combined = MakeCombinedCall(combined, branches, i, channel_pos, parent_index);
} }
UpdateGroupOutput(combined, branches, i - 1, channel_pos); auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->kernel_layout.find('O');
CHECK_NE(index, std::string::npos);
return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index),
MakeConstScalar(Int(32), num_filters));
} }
}; };
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
*
* \file combine_parallel_dense.cc
* \brief Combine parallel dense ops into a single dense.
*
* This pass replaces dense ops that share the same input node, same shape,
* and don't have "units" defined with a single batch matrix multiplication.
* The inputs of the new batch_matmul is the stack of the original inputs.
* Elemwise and broadcast ops following dense are also combined if possible.
*
* This prevents launching multiple kernels in networks with multiple
* dense branches, such as BERT.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op_batch.h"
namespace tvm {
namespace relay {
class ParallelDenseCombiner : public ParallelOpBatchCombiner {
public:
explicit ParallelDenseCombiner(uint64_t min_num_branches)
: ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {
}
protected:
virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
AttrsEqual eq;
const auto* attrs_a = a->attrs.as<DenseAttrs>();
const auto* attrs_b = b->attrs.as<DenseAttrs>();
CHECK(attrs_a);
CHECK(attrs_b);
const auto* weight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* weight_b = b->args[1]->type_as<TensorTypeNode>();
return eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
eq(weight_a->shape[0], weight_b->shape[0]) &&
eq(weight_a->shape[1], weight_b->shape[1]);
}
};
/*! \brief Combine parallel dense if number of branches >= min_num_branches */
Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) {
return ParallelDenseCombiner(min_num_branches).Combine(expr);
}
namespace transform {
Pass CombineParallelDense(uint64_t min_num_branches) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(CombineParallelDense(f, min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelDense",
{ir::StringImm::make("InferType")});
}
TVM_REGISTER_API("relay._transform.CombineParallelDense")
.set_body_typed(CombineParallelDense);
} // namespace transform
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
*
* \file combine_parallel_op.cc
* \brief Abstract class to combine parallel ops and their successive element-wise ops.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op.h"
namespace tvm {
namespace relay {
BranchGroupFinder::BranchGroupFinder(const std::string& op_name,
FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops)
: op_name_(op_name),
fis_supported_op_(fis_supported_op),
fare_compatible_ops_(fare_compatible_ops) {
}
std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
const Op& op = Op::Get(op_name_);
this->VisitExpr(expr);
std::vector<Group> groups;
for (const auto& root : op_roots_) {
const auto& children = children_map_.at(root);
size_t ngroups = groups.size();
for (const CallNode* child : children) {
if (!child->op.same_as(op)) continue;
auto&& branch = CreateBranch(child);
// add the branch to a group, or create a new group
auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) {
CHECK(!group.empty() && !group[0].empty());
return fare_compatible_ops_(child, group[0][0]);
});
if (it != groups.end()) {
it->push_back(branch);
} else {
groups.emplace_back();
// each group has at least one branch
groups.back().push_back(branch);
}
}
}
return groups;
}
// Create a branch starting from op.
Branch BranchGroupFinder::CreateBranch(const CallNode* op) {
auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
// each branch has at least one element, the first element is always op
Branch branch{op};
auto it = children_map_.find(GetRef<Expr>(branch.back()));
while (it != children_map_.end() && it->second.size() == 1) {
const CallNode* call = it->second[0];
auto pattern = fpattern[Downcast<Op>(call->op)];
if (pattern <= kBroadcast) {
branch.push_back(call);
it = children_map_.find(GetRef<Expr>(branch.back()));
} else {
break;
}
}
return branch;
}
void BranchGroupFinder::VisitExpr_(const CallNode* n) {
const Op& op = Op::Get(op_name_);
ExprVisitor::VisitExpr_(n);
if (n->op.same_as(op) && fis_supported_op_(n)) {
op_roots_.insert(n->args[0]);
children_map_[n->args[0]].push_back(n);
} else {
for (size_t i = 0; i < n->args.size(); i++) {
children_map_[n->args[i]].push_back(n);
}
}
}
ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches)
: op_name_(op_name),
min_num_branches_(min_num_branches) {
}
Expr ParallelOpCombiner::Combine(const Expr& expr) {
auto groups = BranchGroupFinder(op_name_,
[&](const CallNode* n) {
return IsSupportedOp(n);
},
[&](const CallNode* a, const CallNode* b) {
return CanOpsBeCombined(a, b);
}).Find(expr);
for (const Group& group : groups) {
if (group.size() < min_num_branches_) {
continue;
}
CombineBranches(group);
}
return ExprSubst(expr, std::move(subst_map_));
}
void ParallelOpCombiner::CombineBranches(const Group& branches) {
Call combined = MakeCombinedOp(branches);
auto it = std::min_element(branches.begin(), branches.end(),
[](const Branch& branch_a,
const Branch& branch_b) {
return branch_a.size() < branch_b.size();
});
size_t depth = it->size();
size_t i;
// starting from 1 to skip the op
for (i = 1; i < depth; i++) {
size_t parent_index;
for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) {
if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break;
}
CHECK_NE(parent_index, branches[0][i]->args.size());
if (!CheckLevel(branches, i, parent_index)) break;
combined = MakeCombinedCallFromFollowingOps(combined, branches, i, parent_index);
}
UpdateGroupOutput(combined, branches, i - 1, &subst_map_);
}
bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) {
const CallNode* call = branches[0][depth];
AttrsEqual attrs_equal;
// check if all branches in current depth can be combined
for (auto it = branches.begin() + 1; it != branches.end(); it++) {
const Branch& branch = *it;
if (!branch[depth]->op.same_as(call->op) ||
!attrs_equal(branch[depth]->attrs, call->attrs) ||
branch[depth]->args.size() != call->args.size()) {
return false;
}
if (branch[depth]->args[parent_index].get() != branch[depth - 1])
return false;
// Check args
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) continue;
if (!IsArgCompatible(call, branch[depth], i) ||
!attrs_equal(call->attrs, branch[depth]->attrs)) {
return false;
}
}
}
return true;
}
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
*
* \file combine_parallel_op.h
* \brief Abstract class to combine parallel ops and their successive element-wise ops.
*/
#ifndef TVM_RELAY_PASS_COMBINE_PARALLEL_OP_H_
#define TVM_RELAY_PASS_COMBINE_PARALLEL_OP_H_
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <string>
#include "./expr_subst.h"
#include "./pattern_util.h"
namespace tvm {
namespace relay {
using Branch = std::vector<const CallNode*>;
using Group = std::vector<Branch>;
using FIsSupportedOp = std::function<bool (const CallNode* n)>;
using FAreCompatibleOps = std::function<bool (const CallNode* a, const CallNode* b)>;
using ExprSubstMap = std::unordered_map<Expr, Expr, NodeHash, NodeEqual>;
/*
* Class to find parallel branches starting with op that are
* grouped if they are able to be combined. They are eligible to
* be combined if they have the same input data.
* Op can be followed by zero or more elemwise or broadcast ops,
* which are included in the group.
* Intermediate nodes have exactly one successor. It is possible that branches meet at a point,
* which should be handled in ParallelOpCombiner.
*
* data
* / \
* op op
* | |
* elem-wise elem-wise
* | |
*/
class BranchGroupFinder : private ExprVisitor {
public:
/*
* \brief Constructor
* \param op_name name of op to start each group
* \param fis_supported_op function that returns true if op
* is supported for combining
* \param fare_compatible_ops function that returns true if
* two ops are compatible for combining
*/
BranchGroupFinder(const std::string& op_name,
FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops);
/*
* \brief Finds all groups that can be combined.
* \param expr Relay expression that represents function
* to look at for groups to be combined
* \return Vector of groups which can be combined.
*/
std::vector<Group> Find(const Expr& expr);
private:
/* \brief name of op to find parallel branches for */
std::string op_name_;
/* \brief function to return true if op is eligible to be combined,
* false otherwise
*/
FIsSupportedOp fis_supported_op_;
/* \brief function to return true if two parallel ops are eligible
* to be combined, false otherwise
*/
FAreCompatibleOps fare_compatible_ops_;
/* \brief ops that are on the first (logically, leftmost) branch
* of parallel ops and are eligible to be combined
*/
std::unordered_set<Expr, NodeHash, NodeEqual> op_roots_;
/* \brief map of Expr to CallNodes that follow it */
std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> children_map_;
/*
* \brief Creates new branch from op and its children that have
* elementwise or broadcast patterns
* \return New branch
*/
Branch CreateBranch(const CallNode* op);
/*
* \brief Expression visitor function
*/
void VisitExpr_(const CallNode* n) final;
};
/*
* Abstract class to find and combine parallel ops and the elementwise ops that follow.
*/
class ParallelOpCombiner {
public:
/*
* \brief Constructor.
* \param op_name name of op to combine
* \param min_num_branches min number of parallel branches beginning with op
* to start combining
*/
explicit ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches);
/*
* \brief Combines ops and following elementwise or broadcast ops
* \param expr function to modify
* \return new function with combined ops
*/
Expr Combine(const Expr& expr);
protected:
/*
* \brief Checks if node is supported to be combined
* \param n node in question
* \return True if the op represented by n is supported to be the root of a branch
* to be combined. False otherwise.
*/
virtual bool IsSupportedOp(const CallNode* n) = 0;
/*
* \brief Checks if two ops can be combined
* \param a node a
* \param b node b
* \return True if a and b can be combined. False otherwise.
*/
virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) = 0;
/*
* \brief Makes combined op from parallel ops in branches. This usually involves
* concatenating or stacking inputs, then creating a new call.
* \param branches branches that are to be combined
* \return new call with branches combined.
*/
virtual Call MakeCombinedOp(const Group& branches) = 0;
/*
* \brief Checks if argument of op following combined ops are able to be combined
* \param a node a
* \param b node b
* \param index index of argument in question
* \return True if argument of a and b and index can be combined
*/
virtual bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) = 0;
/*
* \brief Create combined call from ops that follow the initial combined op at the depth-th level.
* This usually involves concatenating or stacking inputs, then creating a new call.
* Only called if IsArgCompatbile returns true for each arg.
* \param data combined op
* \param branches branches of parallel ops to be combined
* \param depth depth at which to combine ops
* \param parent_index index of arg that corresponds to original input that was shared among
* all combined ops
* \return new combined call
*/
virtual Call MakeCombinedCallFromFollowingOps(const Expr& data,
const Group& branches,
size_t depth,
size_t parent_index) = 0;
/*
* \brief Updates map of expr to substitute with combined expr. This usually involves
* slicing or splitting data.
* \param data combined op
* \param branches branches of parallel ops to be combined
* \param depth depth at which to substitute
* \param subst_map map of Expr to replace with Expr to replace it with
*/
virtual void UpdateGroupOutput(const Expr& data,
const Group& branches,
size_t depth,
ExprSubstMap* subst_map) = 0;
private:
/* \brief name of op to be combined */
std::string op_name_;
/* \brief minimum number of parallel branches to combine */
uint64_t min_num_branches_;
/* \brief map of Expr to Expr to substitute it with after running pass */
ExprSubstMap subst_map_;
/*
* \brief Combine parallel branches and updates subst_map_ with Exprs
* to be substituted
* \param branches branches to be combined
*/
void CombineBranches(const Group& branches);
/*
* \brief Combine parallel branches and updates subst_map_ with Exprs
* to be substituted
* \param branches parallel branches to potentially be combined
* \param depth depth at which to look at op
* \param parent_index index of arg that corresponds to original input that was shared among
* all combined ops
* \return true if parallel ops at depth can be combined, false otherwise
*/
bool CheckLevel(const Group& branches, size_t depth, size_t parent_index);
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_COMBINE_PARALLEL_OP_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
*
* \file combine_parallel_op_batch.cc
* \brief Combine parallel ops into a single batch op.
*
* This pass replaces ops that share the same input node and same shape
* with a single op that takes in batched input. The inputs of the new
* batched op are the stack of the original inputs. Elementwise and
* broadcast ops following the original op are also stacked
* and fused if possible. For example:
*
* data
* / \
* add (2,2) add (2,2)
* | |
* elemwise (2,2) elemwise (2,2)
* | |
*
* Would become:
*
* data
* |
* add+elemwise (2,2,2)
* / \
*
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op.h"
#include "./combine_parallel_op_batch.h"
namespace tvm {
namespace relay {
ParallelOpBatchCombiner::ParallelOpBatchCombiner(const std::string& op_name,
const std::string& batch_op_name,
uint64_t min_num_branches)
: ParallelOpCombiner(op_name, min_num_branches),
batch_op_name_(batch_op_name) {
}
bool ParallelOpBatchCombiner::IsSupportedOp(const CallNode* n) {
return true;
}
bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode* b) {
if (a->args.size() != b->args.size()) {
return false;
}
AttrsEqual eq;
for (size_t i = 0; i < a->args.size(); i++) {
auto ta = a->args[i]->type_as<TensorTypeNode>();
auto tb = b->args[i]->type_as<TensorTypeNode>();
if (ta->shape.size() != tb->shape.size() || !eq(ta->dtype, tb->dtype)) {
return false;
}
for (size_t j = 0; j < ta->shape.size(); j++) {
if (!eq(ta->shape[j], tb->shape[j])) {
return false;
}
}
}
return true;
}
Call ParallelOpBatchCombiner::MakeCombinedOp(const Group& branches) {
const Op& batch_op = Op::Get(batch_op_name_);
Array<Expr> new_args;
size_t num_args = branches[0][0]->args.size();
for (size_t i = 0; i < num_args; i++) {
Array<Expr> arg_from_all_branches;
for (const auto& branch : branches) {
arg_from_all_branches.push_back(branch[0]->args[i]);
}
new_args.push_back(MakeStack(TupleNode::make(arg_from_all_branches), 0));
}
return CallNode::make(batch_op, new_args, Attrs(), {});
}
bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
AttrsEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size())
return false;
for (size_t i = 0; i < ta->shape.size(); i++) {
if (!eq(ta->shape[i], tb->shape[i]))
return false;
}
return true;
}
Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data,
const Group& branches,
size_t depth,
size_t parent_index) {
Array<Expr> new_args;
const CallNode* call = branches[0][depth];
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) {
new_args.push_back(data);
continue;
}
Array<Expr> tuple;
for (const auto& branch : branches) {
// if the shape of the arg is of shape (j,),
// expand it to (1,j) so it can be properly broadcasted.
Expr arg = branch[depth]->args[i];
const TensorTypeNode* arg_tensor = arg->type_as<TensorTypeNode>();
if (arg_tensor->shape.size() == 1) {
Expr expanded_arg = MakeExpandDims(arg, 0, 1);
tuple.push_back(expanded_arg);
} else {
tuple.push_back(arg);
}
}
auto stack = MakeStack(TupleNode::make(tuple), 0);
new_args.push_back(std::move(stack));
}
return CallNode::make(call->op, new_args, call->attrs, {});
}
void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data,
const Group& branches,
size_t depth,
ExprSubstMap* subst_map) {
int index = 0;
auto split = MakeSplit(data, Integer(branches.size()), 0);
for (const auto& branch : branches) {
auto split_data = TupleGetItemNode::make(split, index++);
auto squeezed_data = MakeSqueeze(split_data, {0});
subst_map->insert({GetRef<Expr>(branch[depth]), squeezed_data});
}
}
/*! \brief Combine parallel op into batched op if number of branches >= min_num_branches */
Expr CombineParallelOpBatch(const Expr& expr,
const std::string& op_name,
const std::string& batch_op_name,
uint64_t min_num_branches) {
return ParallelOpBatchCombiner(op_name, batch_op_name, min_num_branches).Combine(expr);
}
namespace transform {
Pass CombineParallelOpBatch(const std::string& op_name,
const std::string& batch_op_name,
uint64_t min_num_branches) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(CombineParallelOpBatch(f,
op_name,
batch_op_name,
min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch",
{ir::StringImm::make("InferType")});
}
TVM_REGISTER_API("relay._transform.CombineParallelOpBatch")
.set_body_typed(CombineParallelOpBatch);
} // namespace transform
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file combine_parallel_op_batch.cc
* \brief Combine parallel ops into a single batch op.
*/
#ifndef TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_
#define TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include <string>
#include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op.h"
namespace tvm {
namespace relay {
/*
* Class to find and combine parallel ops and following element-wise
* and broadcast ops into a single batch op. Ops can be combined
* if they have the same input data. Batch op is formed by
* stacking inputs. Final results are retrieved by splitting output.
* For example:
*
* data
* / \
* dense (2,2) dense (2,2)
* | |
* elemwise/bcast (2,2) elemwise/bcast (2,2)
*
* Would become:
*
* data
* |
* batch_matmul+elemwise/bcast (2,2,2)
*/
class ParallelOpBatchCombiner : public ParallelOpCombiner {
public:
/*
* \brief Constructor.
* \param op_name name of op to combine
* \param batch_op_name name of op that combined branches will be joined into
* \param min_num_branches min number of parallel branches beginning with op
* to start combining
*/
ParallelOpBatchCombiner(const std::string& op_name,
const std::string& batch_op_name,
uint64_t min_num_branches);
protected:
/*
* \brief Checks if node is supported to be combined
* \param n node in question
* \return True by default
*/
virtual bool IsSupportedOp(const CallNode* n);
/*
* \brief Checks if two ops can be combined
* \param a node a
* \param b node b
* \return True if shapes and dtypes of all args of a and b are the same
*/
virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b);
/*
* \brief Makes combined op from parallel ops in branches. This usually involves
* concatenating or stacking inputs, then creating a new call.
* \param branches branches that are to be combined
* \return new call with branches combined as batch op by stacking args
*/
Call MakeCombinedOp(const Group& branches) final;
/*
* \brief Checks if argument of op following combined ops are able to be combined
* \param a node a
* \param b node b
* \param index index of argument in question
* \return True if shapes and dtypes of args[index] a and b are the same
*/
bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) final;
/*
* \brief Create combined call from ops that follow the initial combined op at the depth-th level.
* This usually involves concatenating or stacking inputs, then creating a new call.
* Only called if IsArgCompatbile returns true for each arg.
* \param data combined op
* \param branches branches of parallel ops to be combined
* \param depth depth at which to combine ops
* \param parent_index index of arg that corresponds to original input that was shared among
* all combined ops
* \return new combined call as batch op by stacking args
*/
Call MakeCombinedCallFromFollowingOps(const Expr& data,
const Group& branches,
size_t depth,
size_t parent_index) final;
/*
* \brief Updates map of expr to substitute with combined expr. This usually involves
* slicing or splitting data.
* \param data combined op
* \param branches branches of parallel ops to be combined
* \param depth depth at which to substitute
* \param subst_map map of Expr to replace with Expr to replace it with
*/
void UpdateGroupOutput(const Expr& data,
const Group& branches,
size_t depth,
ExprSubstMap* subst_map) final;
private:
/* \brief name of op to replace combined ops with. for example,
* for combining parallel dense, this will will be set to
* nn.batch_matmul
*/
std::string batch_op_name_;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_
...@@ -497,6 +497,14 @@ Expr MakeConcatenate(Expr data, int axis); ...@@ -497,6 +497,14 @@ Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides); Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
Expr MakeStack(Expr data, int axis);
Expr MakeSplit(Expr data, NodeRef indices_or_sections, int axis);
Expr MakeSqueeze(Expr data, Array<Integer> axis);
Expr MakeExpandDims(Expr data, int axis, int num_newaxis);
Expr StopFusion(Expr data); Expr StopFusion(Expr data);
Expr CastHint(Expr data, DataType dtype); Expr CastHint(Expr data, DataType dtype);
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm import relay
from tvm.relay import transform
def run_combine_parallel(expr, min_num_branches=3):
mod = relay.Module.from_expr(expr)
mod = transform.CombineParallelDense(min_num_branches)(mod)
return mod["main"]
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr)
mod = opt_pass(mod)
return mod["main"]
def test_combine_parallel_dense():
"""Simple testcase. One dense cannot be combined due to shape mismatch"""
def before(x, w1, w2, w3, w4):
args = [x, w1, w2, w3, w4]
y1 = relay.nn.dense(x, w1)
y2 = relay.nn.dense(x, w2)
# y3 cannot be combined
y3 = relay.nn.dense(x, w3)
y4 = relay.nn.dense(x, w4)
y = relay.Tuple((y1, y2, y3, y4))
return relay.Function(args, y)
def expected(x, w1, w2, w3, w4):
# use a fixed order of args so alpha equal check can pass
args = [x, w1, w2, w3, w4]
x_stacked = relay.stack((x, x, x), axis=0)
w = relay.stack((w1, w2, w4), axis=0)
y = relay.nn.batch_matmul(x_stacked, w)
(y1, y2, y4) = relay.split(y, 3)
y1 = relay.squeeze(y1, [0])
y2 = relay.squeeze(y2, [0])
y4 = relay.squeeze(y4, [0])
# y3 cannot be combined
y3 = relay.nn.dense(x, w3)
y = relay.Tuple((y1, y2, y3, y4))
return relay.Function(args, y)
def check(i, j, k):
x = relay.var("x", shape=(i, k))
w1 = relay.var("w1", shape=(j, k))
w2 = relay.var("w2", shape=(j, k))
w3 = relay.var("w3", shape=(j + 1, k))
w4 = relay.var("w4", shape=(j, k))
y_before = before(x, w1, w2, w3, w4)
y = run_opt_pass(y_before,
transform.CombineParallelDense(min_num_branches=2))
y_expected = expected(x, w1, w2, w3, w4)
y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected)
check(3, 5, 4)
check(100, 200, 300)
def test_combine_parallel_dense_biasadd():
"""Testcase of combining dense + 1d biasadd"""
def before(x, w1, w2, b1, b2):
args = [x, w1, w2, b1, b2]
y1 = relay.nn.dense(x, w1)
y2 = relay.nn.dense(x, w2)
y1 = relay.add(y1, b1)
y2 = relay.add(y2, b2)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def expected(x, w1, w2, b1, b2, is_2d_bias):
args = [x, w1, w2, b1, b2]
x_stacked = relay.stack((x, x), axis=0)
w = relay.stack((w1, w2), axis=0)
y = relay.nn.batch_matmul(x_stacked, w)
if not is_2d_bias:
b1 = relay.expand_dims(b1, 0)
b2 = relay.expand_dims(b2, 0)
b = relay.stack((b1, b2), axis=0)
y = relay.add(y, b)
(y1, y2) = relay.split(y, 2)
y1 = relay.squeeze(y1, [0])
y2 = relay.squeeze(y2, [0])
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def check(i, j, k, is_2d_bias):
x = relay.var("x", shape=(i, k))
w1 = relay.var("w1", shape=(j, k))
w2 = relay.var("w2", shape=(j, k))
if is_2d_bias:
b1 = relay.var("b1", shape=(i, j))
b2 = relay.var("b2", shape=(i, j))
else:
b1 = relay.var("b1", shape=(j,))
b2 = relay.var("b2", shape=(j,))
y_before = before(x, w1, w2, b1, b2)
y = run_opt_pass(y_before,
transform.CombineParallelDense(min_num_branches=2))
y_expected = expected(x, w1, w2, b1, b2, is_2d_bias)
y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected)
check(3, 5, 4, False)
check(100, 200, 300, False)
check(3, 5, 4, True)
check(100, 200, 300, True)
def test_combine_parallel_dense_biasadd_scale_reshape():
"""Testcase of combining dense + 1d biasadd + multiply with non-fused reshape"""
def before(x, w1, w2, b1, b2, scale1, scale2, newshape):
args = [x, w1, w2, b1, b2, scale1, scale2]
y1 = relay.nn.dense(x, w1)
y2 = relay.nn.dense(x, w2)
y1 = relay.add(y1, b1)
y2 = relay.add(y2, b2)
y1 = relay.multiply(y1, scale1)
y2 = relay.multiply(y2, scale2)
y1 = relay.reshape(y1, newshape=newshape)
y2 = relay.reshape(y2, newshape=newshape)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def expected(x, w1, w2, b1, b2, scale1, scale2, newshape):
args = [x, w1, w2, b1, b2, scale1, scale2]
x_stacked = relay.stack((x, x), axis=0)
w = relay.stack((w1, w2), axis=0)
y = relay.nn.batch_matmul(x_stacked, w)
b1 = relay.expand_dims(b1, 0)
b2 = relay.expand_dims(b2, 0)
b = relay.stack((b1, b2), axis=0)
y = relay.add(y, b)
scale1 = relay.expand_dims(scale1, 0)
scale2 = relay.expand_dims(scale2, 0)
scale = relay.stack((scale1, scale2), axis=0)
y = relay.multiply(y, scale)
(y1, y2) = relay.split(y, 2)
y1 = relay.squeeze(y1, [0])
y2 = relay.squeeze(y2, [0])
y1 = relay.reshape(y1, newshape=newshape)
y2 = relay.reshape(y2, newshape=newshape)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def check(i, j, k, scale1, scale2, newshape):
x = relay.var("x", shape=(i, k))
w1 = relay.var("w1", shape=(j, k))
w2 = relay.var("w2", shape=(j, k))
b1 = relay.var("b1", shape=(j,))
b2 = relay.var("b2", shape=(j,))
scale1 = relay.var("scale1", shape=(1,))
scale2 = relay.var("scale2", shape=(1,))
y_before = before(x, w1, w2, b1, b2, scale1, scale2, newshape)
y = run_opt_pass(y_before,
transform.CombineParallelDense(min_num_branches=2))
y_expected = expected(x, w1, w2, b1, b2, scale1, scale2, newshape)
y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected)
check(3, 5, 4, 0.5, 0.25, (1, 1, 15))
check(100, 200, 300, 0.5, 0.25, (1, 1, 200))
if __name__ == "__main__":
test_combine_parallel_dense()
test_combine_parallel_dense_biasadd()
test_combine_parallel_dense_biasadd_scale_reshape()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment