/*!
 * Copyright (c) 2017 by Contributors
 * \file fold_scale_axis.cc
 * \author Fold scaling parameter of axis into weight of conv/dense
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "pattern_util.h"
#include "graph_transform.h"

namespace nnvm {
namespace compiler {

enum FoldScaleKind {
  // No folding is applied
  kNone,
  // The folding decision is pending, we can fold on a state.
  kPending,
  // The original operator that contains the scale.
  kProvider,
  // The final conumer of axis scale using multiply
  // Likely be a conv or dense operator.
  kMulConsumer,
  // The final conumer of axis scale using division
  kDivConsumer
};

struct FoldChainInfo {
  // Entry kind
  FoldScaleKind kind{kNone};
  // The output axis to be folded
  int axis{0};
  // Source node in the fold chain
  int source{0};
};

// The entry of folding chains on which
// we should perform folding on
struct FoldChainEntry {
  // Fold information
  FoldChainInfo info;
  // Number of outgoing fork count
  // in forward propagation.
  int fork_count{0};
  // Following field only used by provider.
  // The input index
  int fold_input_index{1};
  // The scale entry
  NodeEntry scale_entry;
};

// Try to pass axis scaling to backward,
// Given that we we know the status of current fold axis.
// return whether the forward signal is consumed.
using FScaleAxisBackward = std::function<
  bool(const NodeAttrs& attrs,
       const std::vector<TShape>& in_shape,
       const std::vector<TShape>& out_shape,
       const FoldChainInfo& out_info,
       std::vector<FoldChainInfo>* in_info)>;


// Try to pass axis scaling to forward,
// Given that we we know the status of one of its input to be pending
// also update other input info
// return whether the forward signal is consumed.
using FScaleAxisForward = std::function<
  bool(const NodeAttrs& attrs,
       const std::vector<TShape>& in_shape,
       const std::vector<TShape>& out_shape,
       std::vector<FoldChainInfo>* in_info,
       FoldChainInfo* out_info)>;


// Detect if there is a scaling axis happening
bool DetectScaleAxis(const IndexedGraph& idx,
                     uint32_t nid,
                     const ShapeVector& shape_vec,
                     const std::vector<uint32_t>& ref_count,
                     bool is_forward,
                     std::vector<FoldChainEntry>* chain) {
  const IndexedGraph::Node& inode = idx[nid];
  static const Op* bcast_mul = Op::Get("broadcast_mul");
  static const Op* expand_dims = Op::Get("expand_dims");
  if (inode.source->op() != bcast_mul) return false;
  const TShape& oshape = shape_vec[idx.entry_id(nid, 0)];
  CHECK_NE(oshape.ndim(), 0);
  if (oshape.ndim() <= 1) return false;
  for (int i = 0; i < 2; ++i) {
    const IndexedGraph::NodeEntry& a = inode.inputs[i];
    const IndexedGraph::NodeEntry& b = inode.inputs[1 - i];
    std::pair<int, int> axis =
        MatchBroadcast1DAxis(oshape, shape_vec[idx.entry_id(a)]);
    if (axis.first != -1 &&
        shape_vec[idx.entry_id(b)] == oshape) {
      if (ref_count[a.node_id] != 1) return false;
      if (is_forward && ref_count[nid] != 1) return false;
      if (!is_forward && ref_count[b.node_id] != 1) return false;
      const IndexedGraph::Node& anode = idx[a.node_id];
      // mark the current entry.
      FoldChainEntry& e = (*chain)[nid];
      if (anode.source->is_variable()) {
        e.fold_input_index = 1 - i;
        e.scale_entry = inode.source->inputs[1 - i];
      } else if (anode.source->op()  == expand_dims &&
                 shape_vec[idx.entry_id(anode.source->inputs[0])].ndim() == 1) {
        e.fold_input_index = 1 - i;
        e.scale_entry = anode.source->inputs[0];
      } else {
        return false;
      }
      e.info.axis = axis.first;
      e.info.kind = kPending;
      e.info.source = nid;
      e.fork_count = 1;
      // In the backward message passing
      // We need to eagerly pass it to the input
      // In the forward message passing
      // we will "pull" the message from input.
      if (!is_forward) {
        FoldChainEntry& enext = (*chain)[b.node_id];
        enext.info.axis = e.info.axis;
        enext.info.kind = kPending;
        enext.info.source = nid;
      }
      return true;
    }
  }
  return false;
}

Graph FoldScaleAxis(Graph src) {
  // Operator pattern
  static auto& fbackward =
      nnvm::Op::GetAttr<FScaleAxisBackward>("FScaleAxisBackward");
  static auto& fforward =
      nnvm::Op::GetAttr<FScaleAxisForward>("FScaleAxisForward");
  const IndexedGraph& idx = src.indexed_graph();
  const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
  std::vector<uint32_t> ref_count = GetNodeRefCounts(idx);
  std::vector<FoldChainEntry> bwd_chain(idx.num_nodes());
  std::vector<FoldChainEntry> fwd_chain(idx.num_nodes());
  // shape hint for the inference.
  std::vector<TShape> in_shape, out_shape;

  // perform backward folding.
  for (uint32_t i = idx.num_nodes(); i != 0; --i) {
    uint32_t nid = i - 1;
    const auto& inode = idx[nid];
    if (inode.source->is_variable()) continue;
    if (DetectScaleAxis(idx, nid, shape_vec,
                        ref_count, false, &bwd_chain)) continue;
    if (bwd_chain[nid].info.kind != kPending) continue;
    // if referred by multiple node, cannot do propagation
    if (ref_count[nid] != 1 || !fbackward.count(inode.source->op())) {
      bwd_chain[nid].info.kind = kNone; continue;
    }
    // get input shape and output shape.
    in_shape.clear(); out_shape.clear();
    for (const IndexedGraph::NodeEntry& e : inode.inputs) {
      in_shape.push_back(shape_vec[idx.entry_id(e)]);
    }
    for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
      out_shape.push_back(shape_vec[idx.entry_id(nid, i)]);
    }
    std::vector<FoldChainInfo> in_info(in_shape.size(), FoldChainInfo());
    bool consumed = fbackward[inode.source->op()](
        inode.source->attrs,
        in_shape,
        out_shape,
        bwd_chain[nid].info,
        &in_info);
    CHECK_EQ(in_info.size(), in_shape.size());
    // propagate back.
    bool can_prop = true;
    for (size_t i = 0; i < in_info.size(); ++i) {
      const IndexedGraph::NodeEntry& e = inode.inputs[i];
      if (ref_count[e.node_id] != 1 ||
          idx[e.node_id].source->num_outputs() != 1) {
        can_prop = false; break;
      }
    }
    if (!can_prop) continue;
    for (size_t i = 0; i < in_info.size(); ++i) {
      const IndexedGraph::NodeEntry& e = inode.inputs[i];
      bwd_chain[e.node_id].info = in_info[i];
    }
    // mark consumed by making the source as provider.
    if (consumed) {
      bwd_chain[bwd_chain[nid].info.source].info.kind = kProvider;
    }
  }


  // perform forward folding.
  for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
    const auto& inode = idx[nid];
    if (inode.source->is_variable()) continue;
    // skip scales that are already folded in backward.
    if (bwd_chain[nid].info.kind == kProvider) continue;
    if (DetectScaleAxis(idx, nid, shape_vec,
                        ref_count, true, &fwd_chain)) continue;
    if (inode.source->num_outputs() != 1) continue;
    // Do state update
    // get input shape and output shape.
    std::vector<FoldChainInfo> in_info;
    FoldChainInfo out_info;
    int num_inpending = 0;
    in_shape.clear(); out_shape.clear();
    for (const IndexedGraph::NodeEntry& e : inode.inputs) {
      in_shape.push_back(shape_vec[idx.entry_id(e)]);
      // input information
      in_info.push_back(fwd_chain[e.node_id].info);
      if (fwd_chain[e.node_id].info.kind == kPending) {
        ++num_inpending;
      }
    }
    for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
      out_shape.push_back(shape_vec[idx.entry_id(nid, i)]);
    }
    if (num_inpending != 1 ||
        !fforward.count(inode.source->op())) continue;
    bool consumed = fforward[inode.source->op()](
        inode.source->attrs,
        in_shape,
        out_shape,
        &in_info,
        &out_info);
    // update input info
    for (size_t i = 0; i < in_info.size(); ++i) {
      fwd_chain[inode.inputs[i].node_id].info = in_info[i];
    }
    if (consumed) {
      fwd_chain[nid].info = out_info;
      for (size_t i = 0; i < in_info.size(); ++i) {
        if (in_info[i].kind == kPending) {
          if (--fwd_chain[in_info[i].source].fork_count == 0) {
            fwd_chain[in_info[i].source].info.kind = kProvider;
          }
        }
      }
    } else {
      // can propagate condition
      if (inode.source->num_outputs() == 1) {
        fwd_chain[nid].info = out_info;
        if (out_info.kind == kPending) {
          // When there is multiple reference to input
          // every path have to be consumed
          fwd_chain[out_info.source].fork_count += ref_count[nid] - 1;
        }
      }
    }
  }

  auto transform = [&](uint32_t nid, const NodePtr& n, std::vector<NodeEntry>* ret) {
    NodeEntry rvalue = NodeEntry{n, 0, 0};
    {
      // Backward chain
      const FoldChainEntry& e = bwd_chain[nid];
      if (e.info.kind == kMulConsumer &&
          bwd_chain[e.info.source].info.kind == kProvider) {
        const FoldChainEntry& se = bwd_chain[e.info.source];
        CHECK_EQ(n->num_outputs(), 1);
        NodeEntry scale = ExpandBiasToMatchAxis(
            se.scale_entry,
            shape_vec[idx.entry_id(nid, 0)].ndim(),
            shape_vec[idx.entry_id(se.scale_entry)].ndim(),
            e.info.axis);
        rvalue = MakeNode("broadcast_mul", n->attrs.name + "_sc",
                          {rvalue, scale});
      } else if (e.info.kind == kProvider) {
        rvalue = n->inputs[e.fold_input_index];
      }
    }
    // Note that the value might get transformed twice if it
    // folds value from both fwd and backward chain.
    {
      // forward chain
      const FoldChainEntry& e = fwd_chain[nid];
      if (e.info.kind == kMulConsumer &&
          fwd_chain[e.info.source].info.kind == kProvider) {
        const FoldChainEntry& se = fwd_chain[e.info.source];
        CHECK_EQ(n->num_outputs(), 1);
        NodeEntry scale = ExpandBiasToMatchAxis(
            se.scale_entry,
            shape_vec[idx.entry_id(nid, 0)].ndim(),
            shape_vec[idx.entry_id(se.scale_entry)].ndim(),
            e.info.axis);
        rvalue = MakeNode("broadcast_mul", n->attrs.name + "_sc",
                          {rvalue, scale});
      } else if (e.info.kind == kDivConsumer &&
                 fwd_chain[e.info.source].info.kind == kProvider) {
        const FoldChainEntry& se = fwd_chain[e.info.source];
        CHECK_EQ(n->num_outputs(), 1);
        NodeEntry scale = ExpandBiasToMatchAxis(
            se.scale_entry,
            shape_vec[idx.entry_id(nid, 0)].ndim(),
            shape_vec[idx.entry_id(se.scale_entry)].ndim(),
            e.info.axis);
        rvalue = MakeNode("broadcast_div", n->attrs.name + "_sc",
                          {rvalue, scale});
      } else if (e.info.kind == kProvider) {
        rvalue = n->inputs[e.fold_input_index];
      }
    }
    if (rvalue.node == n) {
      return false;
    } else {
      *ret = {rvalue};
      return true;
    }
  };
  return GraphTransform(src, transform);
}

NNVM_REGISTER_PASS(FoldScaleAxis)
.set_body(FoldScaleAxis);

// property registration.
bool ReluScaleAxisBackward(
    const NodeAttrs& attrs,
    const std::vector<TShape>& in_shape,
    const std::vector<TShape>& out_shape,
    const FoldChainInfo& out_info,
    std::vector<FoldChainInfo>* in_axis) {
  (*in_axis)[0] = out_info;
  return false;
}

bool ReluScaleAxisForward(
    const NodeAttrs& attrs,
    const std::vector<TShape>& in_shape,
    const std::vector<TShape>& out_shape,
    std::vector<FoldChainInfo>* in_info,
    FoldChainInfo* out_info) {
  *out_info = (*in_info)[0];
  return false;
}

NNVM_REGISTER_OP(relu)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", ReluScaleAxisBackward);

NNVM_REGISTER_OP(leaky_relu)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", ReluScaleAxisBackward);

NNVM_REGISTER_OP(relu)
.set_attr<FScaleAxisForward>("FScaleAxisForward", ReluScaleAxisForward);

NNVM_REGISTER_OP(leaky_relu)
.set_attr<FScaleAxisForward>("FScaleAxisForward", ReluScaleAxisForward);

// property registration.
template <typename T>
bool Pool2DBackward(
    const NodeAttrs& attrs,
    const std::vector<TShape>& in_shape,
    const std::vector<TShape>& out_shape,
    const FoldChainInfo& out_info,
    std::vector<FoldChainInfo>* in_axis) {
  const T& param = nnvm::get<T>(attrs.parsed);
  if (out_info.axis == 1 && param.layout == "NCHW") {
    (*in_axis)[0] = out_info;
  }
  return false;
}

template <typename T>
bool Pool2DForward(
    const NodeAttrs& attrs,
    const std::vector<TShape>& in_shape,
    const std::vector<TShape>& out_shape,
    std::vector<FoldChainInfo>* in_info,
    FoldChainInfo* out_info) {
  const T& param = nnvm::get<T>(attrs.parsed);
  if ((*in_info)[0].axis == 1 && param.layout == "NCHW") {
    *out_info = (*in_info)[0];
  }
  return false;
}

NNVM_REGISTER_OP(max_pool2d)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", Pool2DBackward<top::MaxPool2DParam>);

NNVM_REGISTER_OP(avg_pool2d)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", Pool2DBackward<top::AvgPool2DParam>);

NNVM_REGISTER_OP(max_pool2d)
.set_attr<FScaleAxisForward>("FScaleAxisForward", Pool2DForward<top::MaxPool2DParam>);

NNVM_REGISTER_OP(avg_pool2d)
.set_attr<FScaleAxisForward>("FScaleAxisForward", Pool2DForward<top::AvgPool2DParam>);



bool BroadcastAddSubScaleAxisBackward(
    const NodeAttrs& attrs,
    const std::vector<TShape>& in_shape,
    const std::vector<TShape>& out_shape,
    const FoldChainInfo& out_info,
    std::vector<FoldChainInfo>* in_axis) {
  if (out_info.kind != kPending) return false;
  for (int i = 0; i < 2; ++i) {
    std::pair<int, int> m = MatchBroadcast1DAxis(out_shape[0], in_shape[1 - i]);
    if (m.second != -1 &&
        in_shape[i] == out_shape[0] &&
        m.first == out_info.axis) {
      (*in_axis)[i].kind = kPending;
      (*in_axis)[i].axis = out_info.axis;
      (*in_axis)[i].source = out_info.source;
      (*in_axis)[1 - i].kind = kMulConsumer;
      (*in_axis)[1 - i].axis = m.second;
      (*in_axis)[1 - i].source = out_info.source;
      return false;
    }
  }
  return false;
}

bool BroadcastAddSubScaleAxisForward(
    const NodeAttrs& attrs,
    const std::vector<TShape>& in_shape,
    const std::vector<TShape>& out_shape,
    std::vector<FoldChainInfo>* in_info,
    FoldChainInfo* out_info) {
  for (int i = 0; i < 2; ++i) {
    if ((*in_info)[i].kind == kPending) {
      std::pair<int, int> m = MatchBroadcast1DAxis(out_shape[0], in_shape[1 - i]);
      if (m.second != -1 &&
          in_shape[i] == out_shape[0] &&
          m.first == (*in_info)[i].axis) {
        out_info->kind = kPending;
        out_info->axis = m.first;
        out_info->source = (*in_info)[i].source;
        (*in_info)[1 - i].kind = kDivConsumer;
        (*in_info)[1 - i].axis = m.second;
        (*in_info)[1 - i].source = (*in_info)[i].source;
        return false;
      }
    }
  }
  return false;
}

NNVM_REGISTER_OP(broadcast_add)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", BroadcastAddSubScaleAxisBackward);

NNVM_REGISTER_OP(broadcast_sub)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", BroadcastAddSubScaleAxisBackward);

NNVM_REGISTER_OP(broadcast_add)
.set_attr<FScaleAxisForward>("FScaleAxisForward", BroadcastAddSubScaleAxisForward);

NNVM_REGISTER_OP(broadcast_sub)
.set_attr<FScaleAxisForward>("FScaleAxisForward", BroadcastAddSubScaleAxisForward);

bool Conv2DScaleAxisBackward(
    const NodeAttrs& attrs,
    const std::vector<TShape>& in_shape,
    const std::vector<TShape>& out_shape,
    const FoldChainInfo& out_info,
    std::vector<FoldChainInfo>* in_axis) {
  using top::Conv2DParam;
  const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
  if (out_info.kind != kPending) return false;
  // only optimize for kernel layout OIHW for now
  if (param.kernel_layout == "OIHW" && out_info.axis == 1) {
    (*in_axis)[1].kind = kMulConsumer;
    (*in_axis)[1].axis = 0;
    (*in_axis)[1].source = out_info.source;
    if (param.use_bias) {
      (*in_axis)[2].kind = kMulConsumer;
      (*in_axis)[2].axis = 0;
      (*in_axis)[2].source = out_info.source;
    }
    return true;
  } else {
    return false;
  }
}

bool Conv2DScaleAxisForward(
    const NodeAttrs& attrs,
    const std::vector<TShape>& in_shape,
    const std::vector<TShape>& out_shape,
    std::vector<FoldChainInfo>* in_info,
    FoldChainInfo* out_info) {
  using top::Conv2DParam;
  const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
  if ((*in_info)[0].kind != kPending) return false;
  // only optimize for nchw for now
  if (param.kernel_layout == "OIHW" && (*in_info)[0].axis == 1) {
    // Check whether it is depthwise conv2d
    if (param.use_bias) {
      CHECK_EQ(in_shape.size(), 3U) << "Input:[data, weight, bias]";
    } else {
      CHECK_EQ(in_shape.size(), 2U) << "Input:[data, weight]";
    }

    auto dshape = in_shape.at(0);
    CHECK_EQ(dshape.ndim(), 4U) << "Input data shape should be 4D";

    // TODO(FrozenGene): Currently, we don't support conv2d's groups != in channels.
    if (param.groups > 1 && dshape[1] != param.groups) {
      LOG(WARNING) << "FoldScaleAxis optimization doesn't support conv2d "
                   << "with groups != in channels. We will skip FoldScaleAxis "
                   << "optimization for this op.";
      return false;
    }


    // input channel equals to groups, which means depthwise conv2d
    bool is_depthwise_conv2d = (dshape[1] == param.groups);

    // if it is depthwise convolution, the weight fold axis should along to axis 0.
    // For example:
    // data shape [1,54,63,127] weights shape [54,1,3,3], scale shape [54]
    // depthwise convolution's weights shape means we have divided the data shape's channel
    // to groups parties. Here, we divide 54 channels into 54 parties. Every part size is 1.
    // weights shape's first dimision means how many parties we have divided (mapping to
    // input shape's channel). So, in the depthwise convolution, we shouldn't do like
    // traditional convolution(i.e. OIHW)

    // Backgroud of this algorithm:

    // Original Graph:
    //    Graph(%x,
    //          %in_scale,
    //          %weight,
    //          %bias,
    //          %out_scale) {
    //      %1 = __add_scalar__(%x, scalar='1')
    //      %3 = expand_dims(%in_scale, num_newaxis='2', axis='1')
    //      %4 = broadcast_mul(%1, %3)
    //      %7 = conv2d(%4, %weight, %bias, padding='(1, 1)', kernel_size='(3, 3)', channels='2')
    //      %8 = relu(%7)
    //      %10 = expand_dims(%out_scale, num_newaxis='2', axis='1')
    //      %11 = broadcast_mul(%8, %10)
    //      ret %11
    //    }

    // Optimized Graph:
    //    Graph(%x,
    //          %weight,
    //          %out_scale,
    //          %in_scale,
    //          %bias) {
    //      %1 = __add_scalar__(%x, scalar='1')
    //      %4 = expand_dims(%out_scale, num_newaxis='3', axis='1')
    //      %5 = broadcast_mul(%weight, %4)
    //      %7 = expand_dims(%in_scale, num_newaxis='2', axis='1')
    //      %8 = broadcast_mul(%5, %7)
    //      %10 = broadcast_mul(%bias, %out_scale)
    //      %11 = conv2d(%1, %8, %10, padding='(1, 1)', kernel_size='(3, 3)', channels='2')
    //      %12 = relu(%11)
    //      ret %12
    //    }

    // Conv2DScaleAxisForward will need in_scale. Conv2DScaleAxisBackward will need out_scale.
    // in_scale will apply into input data's channel (in_channel). out_scale will apply in
    // conv2d's result, which will apply in weight's output channel.
    // So, default Conv2DScaleAxisForward will fold axis 1 (weights' input channel).
    // Conv2DScaleAxisBackward will fold axis 0 (weights' output channel).
    // But depthwise convolution is another story as said previously.
    (*in_info)[1].kind = kMulConsumer;
    (*in_info)[1].axis = is_depthwise_conv2d ? 0 : 1;
    (*in_info)[1].source = (*in_info)[0].source;
    return true;
  } else {
    return false;
  }
}

NNVM_REGISTER_OP(conv2d)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", Conv2DScaleAxisBackward);

NNVM_REGISTER_OP(conv2d)
.set_attr<FScaleAxisForward>("FScaleAxisForward", Conv2DScaleAxisForward);

}  // namespace compiler
}  // namespace nnvm