/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2019 by Contributors * * \file combine_parallel_conv2d.cc * \brief Combine parallel 2d convolutions into a single convolution. * * This pass replaces convolutions that share the same input node and the same * arguments (except that the number of output channels can be different) with a * single convolution. The weight of the new 2d convolution is the concatenation * of the original weights. Elemwise and broadcast ops following conv2d are also * combined if possible. * * This prevents launching multiple kernels in networks with multiple * convolution branches, such as Inception block. */ #include <tvm/relay/analysis.h> #include <tvm/relay/expr_functor.h> #include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/transform.h> #include <tvm/relay/op_attr_types.h> #include <tvm/relay/transform.h> #include <unordered_map> #include <unordered_set> #include "./expr_subst.h" #include "./pattern_util.h" #include "./combine_parallel_op.h" namespace tvm { namespace relay { class ParallelConv2DCombiner : public ParallelOpCombiner { public: explicit ParallelConv2DCombiner(uint64_t min_num_branches) : ParallelOpCombiner("nn.conv2d", min_num_branches) { } protected: bool IsSupportedOp(const CallNode* n) { return n->attrs.as<Conv2DAttrs>()->groups == 1; } bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { AttrsEqual eq; const Layout kOIHW("OIHW"); const auto* attrs_a = a->attrs.as<Conv2DAttrs>(); const auto* attrs_b = b->attrs.as<Conv2DAttrs>(); CHECK(attrs_a); CHECK(attrs_b); const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>(); const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>(); const auto shape_a = BijectiveLayoutNode::make( Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape); const auto shape_b = BijectiveLayoutNode::make( Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape); return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) && eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) && eq(attrs_a->data_layout, attrs_b->data_layout) && eq(attrs_a->kernel_layout, attrs_b->kernel_layout) && eq(attrs_a->out_dtype, attrs_b->out_dtype) && eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) && eq(shape_a[3], shape_b[3]); } Call MakeCombinedOp(const Group& branches) { const Op& conv2d = Op::Get("nn.conv2d"); Expr data = branches[0][0]->args[0]; Expr new_weight; IndexExpr new_channels; std::tie(new_weight, new_channels) = TransformWeight(branches); const CallNode* group_root = branches[0][0]; const auto* attrs = group_root->attrs.as<Conv2DAttrs>(); CHECK(attrs); const auto new_attrs = make_node<Conv2DAttrs>(); new_attrs->strides = attrs->strides; new_attrs->padding = attrs->padding; new_attrs->dilation = attrs->dilation; new_attrs->groups = attrs->groups; new_attrs->kernel_size = attrs->kernel_size; new_attrs->data_layout = attrs->data_layout; new_attrs->kernel_layout = attrs->kernel_layout; new_attrs->out_layout = attrs->out_layout; new_attrs->out_dtype = attrs->out_dtype; new_attrs->channels = new_channels; const std::string& layout = new_attrs->out_layout == "" ? new_attrs->data_layout : new_attrs->out_layout; channel_pos_ = layout.find('C'); CHECK_NE(channel_pos_, std::string::npos); return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {}); } bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { AttrsEqual eq; auto ta = a->args[index]->type_as<TensorTypeNode>(); auto tb = b->args[index]->type_as<TensorTypeNode>(); auto toutput_a = a->type_as<TensorTypeNode>(); auto toutput_b = b->type_as<TensorTypeNode>(); if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) return false; // Position of the 'C' dimension in the argument size_t arg_channel_pos = channel_pos_ - toutput_a->shape.size() + ta->shape.size(); // Channel super-dimension shoule be present and not broadcasted if ((arg_channel_pos > channel_pos_) || // size_t overflow !eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos_]) || !eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos_])) return false; for (size_t i = 0; i < ta->shape.size(); i++) { if (i == arg_channel_pos) continue; if (!eq(ta->shape[i], tb->shape[i])) return false; } return true; } Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, size_t parent_index) { Array<Expr> new_args; const CallNode* call = branches[0][depth]; size_t ndim = call->type_as<TensorTypeNode>()->shape.size(); for (size_t i = 0; i < call->args.size(); i++) { if (i == parent_index) { new_args.push_back(data); continue; } size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size(); size_t arg_channel_pos = channel_pos_ - ndim + arg_ndim; Array<Expr> tuple; for (const auto& branch : branches) { tuple.push_back(branch[depth]->args[i]); } auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos); new_args.push_back(std::move(concat)); } return CallNode::make(call->op, new_args, call->attrs, {}); } void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) { int64_t index = 0; for (const auto& branch : branches) { const CallNode* conv2d = branch[0]; int64_t channels = GetConv2DSuperChannelsDim(conv2d); Array<Integer> begin; Array<Integer> end; for (size_t i = 0; i < channel_pos_; i++) { begin.push_back(0); end.push_back(NullValue<Integer>()); } begin.push_back(index); index += channels; end.push_back(index); auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array<Integer>{}); subst_map->insert({GetRef<Expr>(branch[depth]), slice}); } } private: /* \brief index of channel dimension */ size_t channel_pos_; std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) { int64_t num_filters = 0; // number of filters of the transformed weight Array<Expr> weights; for (const auto& branch : branches) { auto conv2d = branch[0]; weights.push_back(conv2d->args[1]); auto channels = GetConv2DSuperChannelsDim(conv2d); num_filters += channels; } auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->kernel_layout.find('O'); CHECK_NE(index, std::string::npos); return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), MakeConstScalar(Int(32), num_filters)); } }; /*! \brief Combine parallel conv2d if number of branches >= min_num_branches */ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) { return ParallelConv2DCombiner(min_num_branches).Combine(expr); } namespace transform { Pass CombineParallelConv2D(uint64_t min_num_branches) { runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = [=](Function f, Module m, PassContext pc) { return Downcast<Function>(CombineParallelConv2D(f, min_num_branches)); }; return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {ir::StringImm::make("InferType")}); } TVM_REGISTER_API("relay._transform.CombineParallelConv2D") .set_body_typed(CombineParallelConv2D); } // namespace transform } // namespace relay } // namespace tvm