/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file combine_parallel_op_batch.cc
 * \brief Combine parallel ops into a single batch op.
 */
#ifndef TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_
#define TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_

#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include <string>
#include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op.h"

namespace tvm {
namespace relay {

/*
 * Class to find and combine parallel ops and following element-wise
 * and broadcast ops into a single batch op. Ops can be combined
 * if they have the same input data. Batch op is formed by
 * stacking inputs. Final results are retrieved by splitting output.
 * For example:
 *
 *               data
 *         /              \
 *    dense (2,2)         dense (2,2)
 *        |                 |
 *   elemwise/bcast (2,2)  elemwise/bcast (2,2)
 *
 *   Would become:
 *
 *            data
 *             |
 *     batch_matmul+elemwise/bcast (2,2,2)
 */
class ParallelOpBatchCombiner : public ParallelOpCombiner {
 public:
  /*
   * \brief Constructor.
   * \param op_name name of op to combine
   * \param batch_op_name name of op that combined branches will be joined into
   * \param min_num_branches min number of parallel branches beginning with op
   *                         to start combining
   */
  ParallelOpBatchCombiner(const std::string& op_name,
                          const std::string& batch_op_name,
                          uint64_t min_num_branches);

 protected:
  /*
   * \brief Checks if node is supported to be combined
   * \param n node in question
   * \return True by default
   */
  virtual bool IsSupportedOp(const CallNode* n);

  /*
   * \brief Checks if two ops can be combined
   * \param a node a
   * \param b node b
   * \return True if shapes and dtypes of all args of a and b are the same
   */
  virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b);

  /*
   * \brief Makes combined op from parallel ops in branches. This usually involves
   *        concatenating or stacking inputs, then creating a new call.
   * \param branches branches that are to be combined
   * \return new call with branches combined as batch op by stacking args
   */
  Call MakeCombinedOp(const Group& branches) final;

  /*
   * \brief Checks if argument of op following combined ops are able to be combined
   * \param a node a
   * \param b node b
   * \param index index of argument in question
   * \return True if shapes and dtypes of args[index] a and b are the same
   */
  bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) final;

  /*
   * \brief Create combined call from ops that follow the initial combined op at the depth-th level.
   *        This usually involves concatenating or stacking inputs, then creating a new call.
   *        Only called if IsArgCompatbile returns true for each arg.
   * \param data combined op
   * \param branches branches of parallel ops to be combined
   * \param depth depth at which to combine ops
   * \param parent_index index of arg that corresponds to original input that was shared among
   *                     all combined ops
   * \return new combined call as batch op by stacking args
   */
  Call MakeCombinedCallFromFollowingOps(const Expr& data,
                                        const Group& branches,
                                        size_t depth,
                                        size_t parent_index) final;

  /*
   * \brief Updates map of expr to substitute with combined expr. This usually involves
   *        slicing or splitting data.
   * \param data combined op
   * \param branches branches of parallel ops to be combined
   * \param depth depth at which to substitute
   * \param subst_map map of Expr to replace with Expr to replace it with
   */
  void UpdateGroupOutput(const Expr& data,
                         const Group& branches,
                         size_t depth,
                         ExprSubstMap* subst_map) final;

 private:
  /* \brief name of op to replace combined ops with. for example,
   *         for combining parallel dense, this will will be set to
   *         nn.batch_matmul 
   */
  std::string batch_op_name_;
};

}  // namespace relay
}  // namespace tvm

#endif  // TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_