combine_parallel_dense.cc 3.23 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19 20 21 22 23 24 25
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *
 * \file combine_parallel_dense.cc
 * \brief Combine parallel dense ops into a single dense.
 *
 * This pass replaces dense ops that share the same input node, same shape,
26 27
 * and don't have "units" defined with a single batch matrix multiplication.
 * The inputs of the new batch_matmul is the stack of the original inputs.
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
 * Elemwise and broadcast ops following dense are also combined if possible.
 *
 * This prevents launching multiple kernels in networks with multiple
 * dense branches, such as BERT.
 */

#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
43
#include "pattern_util.h"
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
#include "./combine_parallel_op_batch.h"

namespace tvm {
namespace relay {

class ParallelDenseCombiner : public ParallelOpBatchCombiner {
 public:
  explicit ParallelDenseCombiner(uint64_t min_num_branches)
    : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {
  }

 protected:
  virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
    AttrsEqual eq;
    const auto* attrs_a = a->attrs.as<DenseAttrs>();
    const auto* attrs_b = b->attrs.as<DenseAttrs>();
    CHECK(attrs_a);
    CHECK(attrs_b);
    const auto* weight_a = a->args[1]->type_as<TensorTypeNode>();
    const auto* weight_b = b->args[1]->type_as<TensorTypeNode>();

    return eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
           eq(weight_a->shape[0], weight_b->shape[0]) &&
           eq(weight_a->shape[1], weight_b->shape[1]);
  }
};

/*! \brief Combine parallel dense if number of branches >= min_num_branches */
Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) {
  return ParallelDenseCombiner(min_num_branches).Combine(expr);
}

namespace transform {

Pass CombineParallelDense(uint64_t min_num_branches) {
79 80
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
81 82 83
      return Downcast<Function>(CombineParallelDense(f, min_num_branches));
  };
  return CreateFunctionPass(pass_func, 4, "CombineParallelDense",
84
                            {tir::StringImmNode::make("InferType")});
85 86
}

87
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense")
88 89 90 91 92 93
.set_body_typed(CombineParallelDense);

}  // namespace transform

}  // namespace relay
}  // namespace tvm