/*!
 *  Copyright (c) 2018 by Contributors
 * \file alter_op_layout.cc
 * \brief Alter the operator layouts. Keep inferred layouts (if any) from previous stages.
 *        e.g., convolution may calculates faster with NCHW16c layout.
 */
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/layout.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/pass_functions.h>
#include <tvm/tvm.h>
#include <algorithm>
#include <functional>
#include "./compile_engine.h"
#include "./graph_transform.h"

namespace nnvm {
namespace compiler {
namespace {

tvm::Array<tvm::Tensor> GetTensorInfo(const IndexedGraph& idx_graph,
                                      const uint32_t nid,
                                      const ShapeVector& shape_vec,
                                      const DTypeVector& dtype_vec) {
  tvm::Array<tvm::Tensor> vec;
  for (uint32_t i = 0; i < idx_graph[nid].source->num_outputs(); ++i) {
    tvm::Array<tvm::Expr> shape;
    for (int64_t x : shape_vec[idx_graph.entry_id(nid, i)]) {
      CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
      shape.push_back(tvm::make_const(tvm::Int(32), x));
    }
    vec.push_back(tvm::placeholder(
      shape, GetTVMType(dtype_vec[idx_graph.entry_id(nid, i)])));
  }
  return vec;
}

Graph AlterOpLayout(const Graph& src) {
  static auto& falter_op_layout =
    Op::GetAttr<nnvm::compiler::FTVMAlterOpLayout >("FTVMAlterOpLayout");

  const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
  const DTypeVector& dtype_vec = src.GetAttr<DTypeVector>("dtype");
  const IndexedGraph& idx_graph = src.indexed_graph();

  std::vector<std::vector<Layout> > in_layouts_of_node(idx_graph.num_nodes());
  std::vector<std::vector<Layout> > out_layouts_of_node(idx_graph.num_nodes());
  std::unordered_map<const Node*, uint32_t> new_nodes;

  if (src.HasAttr("layout")) {
    // record layouts so that LayoutTransform pass can fix layouts correctly,
    // e.g., conv2d can be replaced by some contrib implement
    // whose layout is different from the original one
    // (which was imported from a model file).
    const auto& layouts = src.GetAttr<std::vector<Layout> >("layout");
    for (uint32_t nid = 0; nid < idx_graph.num_nodes(); ++nid) {
      const auto &inode = idx_graph[nid];
      if (falter_op_layout.count(inode.source->op())) {
        // do not record input layouts of nodes that will be replaced.
        continue;
      }
      std::vector<Layout> in_layout;
      for (const auto& e : inode.inputs) {
        in_layout.emplace_back(layouts[idx_graph.entry_id(e)]);
      }
      in_layouts_of_node[nid] = in_layout;

      std::vector<Layout> out_layout;
      for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
        out_layout.emplace_back(layouts[idx_graph.entry_id(nid, i)]);
      }
      out_layouts_of_node[nid] = out_layout;
    }
  }

  auto transform = [&](uint32_t nid,
                       const NodePtr& n,
                       std::vector<NodeEntry>* ret) {
    nnvm::compiler::FTVMAlterOpLayout fn_alter_op_layout =
      falter_op_layout.get(n->op(), nullptr);
    if (fn_alter_op_layout == nullptr) {
      new_nodes[n.get()] = nid;
      return false;
    }

    // construct parameters for registered function
    std::vector<Symbol> op_inputs;
    tvm::Array<tvm::Tensor> tensor_infos;
    CHECK_EQ(n->num_inputs(), idx_graph[nid].inputs.size());
    for (uint32_t i = 0; i < n->num_inputs(); ++i) {
      const nnvm::NodeEntry& input = n->inputs[i];
      // input operator
      Symbol op_input;
      op_input.outputs.push_back(input);
      op_inputs.push_back(op_input);

      // input tinfo, extract from the original graph
      // because it was where infer_shape & infer_type applied.
      tvm::Array<tvm::Tensor> op_output_tinfos =
        GetTensorInfo(idx_graph, idx_graph[nid].inputs[i].node_id,
                      shape_vec, dtype_vec);
      tensor_infos.push_back(op_output_tinfos[input.index]);
    }
    // callback registered function to get a new operator.
    Symbol op;
    bool do_alter =
      fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos, &op);
    if (do_alter) *ret = op.outputs;
    return do_alter;
  };

  Graph ret = nnvm::compiler::GraphTransform(src, transform);

  if (src.HasAttr("layout")) {
    // restore the layouts to return graph
    const auto& ret_idx = ret.indexed_graph();
    std::vector<Layout> ret_layouts(ret_idx.num_node_entries(), Layout::Undef());
    for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) {
      const auto& inode = ret_idx[nid];
      if (new_nodes.count(inode.source)) {
        const std::vector<Layout>& in_layouts =
          in_layouts_of_node[new_nodes[inode.source]];
        for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
          const auto& e = inode.inputs[i];
          ret_layouts[ret_idx.entry_id(e)] = in_layouts[i];
        }
        const std::vector<Layout>& out_layouts =
          out_layouts_of_node[new_nodes[inode.source]];
        for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
          ret_layouts[ret_idx.entry_id(nid, i)] = out_layouts[i];
        }
      }
    }

    // cannot call indexed_graph() before return the origin Graph,
    // thus create a new one.
    nnvm::Graph new_ret;
    new_ret.outputs = ret.outputs;
    new_ret.attrs["layout"] = std::make_shared<any>(std::move(ret_layouts));
    return new_ret;
  }

  return ret;
}

// register pass
NNVM_REGISTER_PASS(AlterOpLayout)
.set_body(AlterOpLayout)
.set_change_graph(true);

}  // namespace
}  // namespace compiler
}  // namespace nnvm