/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * Copyright (c) 2017 by Contributors
 * \file pattern_util.h
 * \brief Utilities for doing various pattern matching in graph.
*/
#ifndef NNVM_COMPILER_PATTERN_UTIL_H_
#define NNVM_COMPILER_PATTERN_UTIL_H_

#include <nnvm/graph.h>
#include <vector>
#include <utility>
#include <string>
#include <unordered_map>

namespace nnvm {
namespace compiler {

/*!
 * \brief find axis in oshape, such that:
 *  bias_shape = [1,1, ... oshape[axis], 1,1,]
 *
 *  This is used to detect bias or scaling factor on channel dimension.
 * \param oshape The output shape
 * \param bias_shape The shape of bias or scaling factor.
 * \return Pair of matched axis in o shape and bias_shape if found.
 */
inline std::pair<int, int> MatchBroadcast1DAxis(
    const TShape& oshape, const TShape& bias_shape) {
  dim_t axis_dim = bias_shape.ndim();
  for (dim_t i = bias_shape.ndim(); i != 0; --i, --axis_dim) {
    if (bias_shape[i - 1] != 1) break;
  }
  // everything is 1
  if (axis_dim == 0) {
    return {oshape.ndim()  - bias_shape.ndim(), 0};
  }
  axis_dim = axis_dim - 1;
  // The bias shape is not 1D
  for (dim_t i = 0; i < axis_dim; ++i) {
    if (bias_shape[i] != 1) return {-1, -1};
  }
  int axis = static_cast<int>(
      oshape.ndim() - bias_shape.ndim() + axis_dim);
  if (oshape[axis] != bias_shape[axis_dim]) return {-1, -1};
  return {axis, axis_dim};
}

/*!
 * \brief Expand bias dimension to match needed axis.
 *
 * \param bias The bias NodeEntry
 * \param out_dim output dimension.
 * \param bias_dim The current bias dimension.
 * \param axis The axis we want to match on.
 */
inline NodeEntry
ExpandBiasToMatchAxis(NodeEntry bias,
                      int out_dim,
                      int bias_dim,
                      int axis) {
  if (bias_dim != 1) {
    bias = MakeNode("squeeze", bias.node->attrs.name + "_sqz", {bias});
  }
  int num_pad_axis = out_dim - axis - 1;
  if (num_pad_axis > 0) {
    std::unordered_map<std::string, std::string> kwargs{
      {"axis", "1"},
      {"num_newaxis", std::to_string(num_pad_axis)}};
    return MakeNode("expand_dims", bias.node->attrs.name + "_expand",
                    {bias}, kwargs);

  } else {
    return bias;
  }
}

/*!
 * \brief Get the reference count of each node.
 * \param idx The IndexedGraph
 * \return ref_count vector of length number nodes.
 */
inline std::vector<uint32_t>
GetNodeRefCounts(const IndexedGraph& idx) {
  std::vector<uint32_t> ref_count(idx.num_nodes(), 0);
  for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
    const auto& inode = idx[nid];
    if (inode.source->is_variable()) continue;
    for (const auto& e : inode.inputs) {
      ++ref_count[e.node_id];
    }
  }
  for (const auto& e : idx.outputs()) {
    // this line will realize all the outputs
    ref_count[e.node_id] += 1;
  }
  return ref_count;
}
}  // namespace compiler
}  // namespace nnvm
#endif  //  NNVM_COMPILER_PATTERN_UTIL_H_