/*!
 *  Copyright (c) 2017 by Contributors
 * \file layout_transform.cc
 * \brief Transforms layout.
 */
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/contrib_op_param.h>

namespace nnvm {
namespace compiler {

const TLayoutInfo& GetDefaultLayout() {
  static TLayoutInfo default_layout = "default";
  return default_layout;
}

nnvm::NodePtr CreateLayoutTransformNode(const std::string& src,
                                        const std::string& dst) {
  static const nnvm::Op* trans_op = nnvm::Op::Get("layout_transform");
  static int count = 0;
  nnvm::NodePtr n = nnvm::Node::Create();
  n->attrs.op = trans_op;
  n->attrs.name = src + "_to_" + dst + std::to_string(count++);
  n->attrs.dict["src_layout"] = src;
  n->attrs.dict["dst_layout"] = dst;
  n->op()->attr_parser(&(n->attrs));
  return n;
}

/*!
 * \brief A simple layout transform pass that will
 *  insert layout transform nodes automatically.
 */
nnvm::Graph LayoutTransform(nnvm::Graph src) {
  static auto& op_layout_request =
    nnvm::Op::GetAttr<FTVMLayoutRequest>("FTVMLayoutRequest");
  static auto& op_vecop =
    nnvm::Op::GetAttr<FTVMVectorizedOp>("FTVMVectorizedOp");
  static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern");

  const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
  const std::vector<TLayoutInfo>& input_layouts =
      src.GetAttr<std::vector<TLayoutInfo> >("layout_inputs");

  const IndexedGraph& idx = src.indexed_graph();
  std::vector<TLayoutInfo> produce_vec(idx.num_node_entries(), GetDefaultLayout());
  std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr);

  // use op pattern to decide whether an op is map
  auto is_map_op = [&](size_t nid) {
    TOpPattern pt = op_pattern.get(idx[nid].source->op(), kOpaque);
    bool is_map = (pt <= kBroadcast);
    if (pt == kBroadcast) {
      for (const auto& e : idx[nid].inputs) {
        if (shape_vec[idx.entry_id(nid, 0)] != shape_vec[idx.entry_id(e)]) {
          is_map = false;
          break;
        }
      }
    }
    return is_map;
  };

  for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
    const auto& inode = idx[nid];
    nnvm::NodePtr new_node = nnvm::Node::Create();
    *new_node = *(inode.source);
    if (new_node->is_variable()) {
      auto input_iter = std::find(
        idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid);
      CHECK(input_iter != idx.input_nodes().cend());
      size_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter);
      produce_vec[idx.entry_id(nid, 0)] = input_layouts[input_id];
      mirror_vec[nid] = new_node;
      continue;
    }

    if (op_vecop.count(inode.source->op())) {
      new_node = op_vecop[inode.source->op()](inode.source);
      new_node->inputs.resize(new_node->num_inputs());
    }

    // set up output and input layouts
    std::vector<TLayoutInfo> request_ilayouts(new_node->num_inputs(), GetDefaultLayout());
    if (op_layout_request.count(new_node->op())) {
      std::vector<TLayoutInfo> produce_olayouts(new_node->num_outputs(), GetDefaultLayout());
      CHECK(op_layout_request[new_node->op()](
          new_node->attrs, &request_ilayouts, &produce_olayouts))
          << "Layout request fail";

      CHECK_EQ(request_ilayouts.size(), new_node->num_inputs());
      CHECK_EQ(produce_olayouts.size(), new_node->num_outputs());
      for (size_t i = 0; i < new_node->num_outputs(); ++i) {
        produce_vec[idx.entry_id(nid, i)] = produce_olayouts[i];
      }
    }

    bool map_layout = is_map_op(nid);
    if (map_layout) {
      const TLayoutInfo& layout = produce_vec[idx.entry_id(inode.inputs[0])];
      for (const auto& e : inode.inputs) {
        if (produce_vec[idx.entry_id(e)] != layout) {
          map_layout = false;
          break;
        }
      }
      if (map_layout) {
        for (size_t i = 0; i < inode.source->num_outputs(); ++i) {
          produce_vec[idx.entry_id(nid, i)] = layout;
        }
      }
    }

    for (size_t i = 0; i < inode.inputs.size(); ++i) {
      const auto& e = inode.inputs[i];
      const nnvm::NodePtr& in = mirror_vec[e.node_id];
      new_node->inputs[i] =
        nnvm::NodeEntry{in, e.index, e.version};

      TLayoutInfo produce = produce_vec[idx.entry_id(e)];
      TLayoutInfo request = request_ilayouts[i];
      if (!map_layout && (produce != request)) {
        nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request);
        tnode->attrs.name =
          idx[e.node_id].source->attrs.name + "_" + request;
        tnode->inputs.emplace_back(new_node->inputs[i]);
        new_node->inputs[i] = nnvm::NodeEntry{tnode, 0, 0};
      }
    }
    mirror_vec[nid] = new_node;
  }

  std::vector<nnvm::NodeEntry> outputs;
  for (const auto& e : idx.outputs()) {
    TLayoutInfo produce = produce_vec[idx.entry_id(e)];
    if (produce != GetDefaultLayout()) {
      nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, GetDefaultLayout());
      tnode->attrs.name =
        idx[e.node_id].source->attrs.name + "_default";
      tnode->inputs.emplace_back(
        nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
      outputs.emplace_back(nnvm::NodeEntry{tnode, 0, 0});
    } else {
      outputs.emplace_back(
        nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
    }
  }

  nnvm::Graph ret;
  ret.outputs = std::move(outputs);
  return ret;
}

}  // namespace compiler
}  // namespace nnvm