combine_parallel_op_batch.h 5.22 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
21
 * \file combine_parallel_op_batch.h
22 23
 * \brief Combine parallel ops into a single batch op.
 */
24 25
#ifndef TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_
#define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_
26 27 28 29 30 31 32 33 34 35 36 37

#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include <string>
#include "./expr_subst.h"
#include "./combine_parallel_op.h"
38
#include "pattern_util.h"
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147

namespace tvm {
namespace relay {

/*
 * Class to find and combine parallel ops and following element-wise
 * and broadcast ops into a single batch op. Ops can be combined
 * if they have the same input data. Batch op is formed by
 * stacking inputs. Final results are retrieved by splitting output.
 * For example:
 *
 *               data
 *         /              \
 *    dense (2,2)         dense (2,2)
 *        |                 |
 *   elemwise/bcast (2,2)  elemwise/bcast (2,2)
 *
 *   Would become:
 *
 *            data
 *             |
 *     batch_matmul+elemwise/bcast (2,2,2)
 */
class ParallelOpBatchCombiner : public ParallelOpCombiner {
 public:
  /*
   * \brief Constructor.
   * \param op_name name of op to combine
   * \param batch_op_name name of op that combined branches will be joined into
   * \param min_num_branches min number of parallel branches beginning with op
   *                         to start combining
   */
  ParallelOpBatchCombiner(const std::string& op_name,
                          const std::string& batch_op_name,
                          uint64_t min_num_branches);

 protected:
  /*
   * \brief Checks if node is supported to be combined
   * \param n node in question
   * \return True by default
   */
  virtual bool IsSupportedOp(const CallNode* n);

  /*
   * \brief Checks if two ops can be combined
   * \param a node a
   * \param b node b
   * \return True if shapes and dtypes of all args of a and b are the same
   */
  virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b);

  /*
   * \brief Makes combined op from parallel ops in branches. This usually involves
   *        concatenating or stacking inputs, then creating a new call.
   * \param branches branches that are to be combined
   * \return new call with branches combined as batch op by stacking args
   */
  Call MakeCombinedOp(const Group& branches) final;

  /*
   * \brief Checks if argument of op following combined ops are able to be combined
   * \param a node a
   * \param b node b
   * \param index index of argument in question
   * \return True if shapes and dtypes of args[index] a and b are the same
   */
  bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) final;

  /*
   * \brief Create combined call from ops that follow the initial combined op at the depth-th level.
   *        This usually involves concatenating or stacking inputs, then creating a new call.
   *        Only called if IsArgCompatbile returns true for each arg.
   * \param data combined op
   * \param branches branches of parallel ops to be combined
   * \param depth depth at which to combine ops
   * \param parent_index index of arg that corresponds to original input that was shared among
   *                     all combined ops
   * \return new combined call as batch op by stacking args
   */
  Call MakeCombinedCallFromFollowingOps(const Expr& data,
                                        const Group& branches,
                                        size_t depth,
                                        size_t parent_index) final;

  /*
   * \brief Updates map of expr to substitute with combined expr. This usually involves
   *        slicing or splitting data.
   * \param data combined op
   * \param branches branches of parallel ops to be combined
   * \param depth depth at which to substitute
   * \param subst_map map of Expr to replace with Expr to replace it with
   */
  void UpdateGroupOutput(const Expr& data,
                         const Group& branches,
                         size_t depth,
                         ExprSubstMap* subst_map) final;

 private:
  /* \brief name of op to replace combined ops with. for example,
   *         for combining parallel dense, this will will be set to
   *         nn.batch_matmul 
   */
  std::string batch_op_name_;
};

}  // namespace relay
}  // namespace tvm

148
#endif  // TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_