combine_parallel_conv2d.cc 8.61 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 * Copyright (c) 2019 by Contributors
22 23 24 25 26 27 28 29 30 31 32 33 34 35
 *
 * \file combine_parallel_conv2d.cc
 * \brief Combine parallel 2d convolutions into a single convolution.
 *
 * This pass replaces convolutions that share the same input node and the same
 * arguments (except that the number of output channels can be different) with a
 * single convolution. The weight of the new 2d convolution is the concatenation
 * of the original weights. Elemwise and broadcast ops following conv2d are also
 * combined if possible.
 *
 * This prevents launching multiple kernels in networks with multiple
 * convolution branches, such as Inception block.
 */

Zhi committed
36
#include <tvm/relay/analysis.h>
37 38 39 40
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
41
#include <tvm/relay/transform.h>
42 43 44 45
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"
46
#include "./combine_parallel_op.h"
47 48 49 50

namespace tvm {
namespace relay {

51
class ParallelConv2DCombiner : public ParallelOpCombiner {
52
 public:
53 54
  explicit ParallelConv2DCombiner(uint64_t min_num_branches)
    : ParallelOpCombiner("nn.conv2d", min_num_branches) {
55 56
  }

57 58 59 60
 protected:
  bool IsSupportedOp(const CallNode* n) {
    return n->attrs.as<Conv2DAttrs>()->groups == 1;
  }
61

62
  bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
63
    AttrsEqual eq;
64
    const Layout kOIHW("OIHW");
65 66 67 68 69 70
    const auto* attrs_a = a->attrs.as<Conv2DAttrs>();
    const auto* attrs_b = b->attrs.as<Conv2DAttrs>();
    CHECK(attrs_a);
    CHECK(attrs_b);
    const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
    const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
71 72 73 74
    const auto shape_a = BijectiveLayoutNode::make(
      Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape);
    const auto shape_b = BijectiveLayoutNode::make(
      Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape);
75 76 77 78

    return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
           eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) &&
           eq(attrs_a->data_layout, attrs_b->data_layout) &&
79
           eq(attrs_a->kernel_layout, attrs_b->kernel_layout) &&
80 81 82 83 84
           eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
           eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) &&
           eq(shape_a[3], shape_b[3]);
  }

85 86
  Call MakeCombinedOp(const Group& branches) {
    const Op& conv2d = Op::Get("nn.conv2d");
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
    Expr data = branches[0][0]->args[0];
    Expr new_weight;
    IndexExpr new_channels;
    std::tie(new_weight, new_channels) = TransformWeight(branches);

    const CallNode* group_root = branches[0][0];
    const auto* attrs = group_root->attrs.as<Conv2DAttrs>();
    CHECK(attrs);
    const auto new_attrs = make_node<Conv2DAttrs>();
    new_attrs->strides = attrs->strides;
    new_attrs->padding = attrs->padding;
    new_attrs->dilation = attrs->dilation;
    new_attrs->groups = attrs->groups;
    new_attrs->kernel_size = attrs->kernel_size;
    new_attrs->data_layout = attrs->data_layout;
102
    new_attrs->kernel_layout = attrs->kernel_layout;
103 104 105 106
    new_attrs->out_layout = attrs->out_layout;
    new_attrs->out_dtype = attrs->out_dtype;
    new_attrs->channels = new_channels;

107 108 109 110 111
    const std::string& layout =
        new_attrs->out_layout == "" ? new_attrs->data_layout : new_attrs->out_layout;
    channel_pos_ = layout.find('C');
    CHECK_NE(channel_pos_, std::string::npos);

112 113 114
    return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {});
  }

115
  bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
116 117 118 119 120 121 122 123 124 125
    AttrsEqual eq;
    auto ta = a->args[index]->type_as<TensorTypeNode>();
    auto tb = b->args[index]->type_as<TensorTypeNode>();
    auto toutput_a = a->type_as<TensorTypeNode>();
    auto toutput_b = b->type_as<TensorTypeNode>();

    if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size())
      return false;

    // Position of the 'C' dimension in the argument
126
    size_t arg_channel_pos = channel_pos_ - toutput_a->shape.size() + ta->shape.size();
127 128

    // Channel super-dimension shoule be present and not broadcasted
129 130 131
    if ((arg_channel_pos > channel_pos_) ||  // size_t overflow
        !eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos_]) ||
        !eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos_]))
132 133 134 135 136 137 138 139 140 141
      return false;

    for (size_t i = 0; i < ta->shape.size(); i++) {
      if (i == arg_channel_pos) continue;
      if (!eq(ta->shape[i], tb->shape[i]))
        return false;
    }
    return true;
  }

142 143 144 145
  Call MakeCombinedCallFromFollowingOps(const Expr& data,
                                        const Group& branches,
                                        size_t depth,
                                        size_t parent_index) {
146 147 148 149 150 151 152 153 154
    Array<Expr> new_args;
    const CallNode* call = branches[0][depth];
    size_t ndim = call->type_as<TensorTypeNode>()->shape.size();

    for (size_t i = 0; i < call->args.size(); i++) {
      if (i == parent_index) {
        new_args.push_back(data);
        continue;
      }
155

156
      size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size();
157
      size_t arg_channel_pos = channel_pos_ - ndim + arg_ndim;
158 159 160 161
      Array<Expr> tuple;
      for (const auto& branch : branches) {
        tuple.push_back(branch[depth]->args[i]);
      }
162

163 164 165
      auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos);
      new_args.push_back(std::move(concat));
    }
166

167 168 169
    return CallNode::make(call->op, new_args, call->attrs, {});
  }

170 171 172 173
  void UpdateGroupOutput(const Expr& data,
                         const Group& branches,
                         size_t depth,
                         ExprSubstMap* subst_map) {
174 175 176 177 178 179
    int64_t index = 0;
    for (const auto& branch : branches) {
      const CallNode* conv2d = branch[0];
      int64_t channels = GetConv2DSuperChannelsDim(conv2d);
      Array<Integer> begin;
      Array<Integer> end;
180
      for (size_t i = 0; i < channel_pos_; i++) {
181 182 183 184 185 186 187
        begin.push_back(0);
        end.push_back(NullValue<Integer>());
      }
      begin.push_back(index);
      index += channels;
      end.push_back(index);
      auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array<Integer>{});
188
      subst_map->insert({GetRef<Expr>(branch[depth]), slice});
189 190 191
    }
  }

192 193 194 195 196 197 198 199 200 201 202 203
 private:
  /* \brief index of channel dimension */
  size_t channel_pos_;

  std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
    int64_t num_filters = 0;  // number of filters of the transformed weight
    Array<Expr> weights;
    for (const auto& branch : branches) {
      auto conv2d = branch[0];
      weights.push_back(conv2d->args[1]);
      auto channels = GetConv2DSuperChannelsDim(conv2d);
      num_filters += channels;
204
    }
205 206 207 208
    auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->kernel_layout.find('O');
    CHECK_NE(index, std::string::npos);
    return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index),
                           MakeConstScalar(Int(32), num_filters));
209 210 211
  }
};

212 213 214 215
/*! \brief Combine parallel conv2d if number of branches >= min_num_branches */
Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
  return ParallelConv2DCombiner(min_num_branches).Combine(expr);
}
216

217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
namespace transform {

Pass CombineParallelConv2D(uint64_t min_num_branches) {
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
    [=](Function f, Module m, PassContext pc) {
      return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
  };
  return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d",
                            {ir::StringImm::make("InferType")});
}

TVM_REGISTER_API("relay._transform.CombineParallelConv2D")
.set_body_typed(CombineParallelConv2D);

}  // namespace transform

233 234
}  // namespace relay
}  // namespace tvm