/*!
 * Copyright (c) 2018 by Contributors
 * \file alter_op_layout.cc
 * \brief Alternate the layouts of operators or replace primitive operators with
          other expressions. This pass can be used for computing convolution in
          custom layouts or other general weight pre-transformation.
 */
#include <tvm/relay/pass.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/tvm.h>
#include <tuple>
#include <vector>
#include <functional>
#include <string>

#include "alter_op_layout.h"

namespace tvm {
namespace relay {

namespace alter_op_layout {

// Make a transform CallNode
Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) {
  if (src_layout.Equals(dst_layout)) { return raw; }
  CHECK(src_layout.defined() && dst_layout.defined())
    << "Cannot insert layout transform because there are undefined layouts";
  CHECK(src_layout.Convertible(dst_layout))
    << "Cannot insert layout transform because there are inconvertible layouts: "
    << src_layout << " v.s. " << dst_layout;
  static auto &transform_op = Op::Get("layout_transform");
  NodePtr<LayoutTransformAttrs> attrs = make_node<LayoutTransformAttrs>();
  attrs->src_layout = src_layout.name();
  attrs->dst_layout = dst_layout.name();
  Call transform = CallNode::make(transform_op, {raw}, Attrs{attrs});
  return transform;
}

// Memorize layout transform so we can reuse internal transformed nodes
class TransformMemorizerNode : public Node {
 public:
  // map from (Expr, src_layout, dst_layout) to transformed Expr
  using TransformKey = std::tuple<const Node*, std::string, std::string>;
  struct key_hash : public std::unary_function<TransformKey , std::size_t> {
    std::size_t operator()(const TransformKey& k) const {
      return dmlc::HashCombine<std::string>(dmlc::HashCombine<std::string>(
              std::hash<const Node*>()(std::get<0>(k)), std::get<1>(k)), (std::get<2>(k)));
    }
  };

  std::unordered_map<TransformKey, Expr, key_hash> memo;
  static constexpr const char *_type_key = "relay.alter_op_layout.TransformMemorizerNode";
  TVM_DECLARE_NODE_TYPE_INFO(TransformMemorizerNode, Node);
};

class TransformMemorizer : public NodeRef {
 public:
  TransformMemorizer() {}
  explicit TransformMemorizer(NodePtr<Node> n) : NodeRef(n) {}

  TransformMemorizerNode* operator->() {
    return static_cast<TransformMemorizerNode*>(node_.get());
  }

  // Transform layout with memorizer
  Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) {
    if (src_layout.Equals(dst_layout)) { return raw; }

    std::tuple<const Node*, std::string, std::string> key =
        std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name());
    auto& memo = operator->()->memo;

    auto iter = memo.find(key);
    if (iter != memo.end()) {
      return iter->second;
    } else {
      Expr transform = TransformLayout(raw, src_layout, dst_layout);
      memo[key] = transform;
      return transform;
    }
  }

  using ContainerType = TransformMemorizerNode;
};


// TempExprNode during layout transform
// Instance of this expr will be Realized to normal expr ultimately
class LayoutAlternatedExprNode : public TempExprNode {
 public:
  Expr value;
  Layout old_layout;
  Layout new_layout;
  TransformMemorizer memorizer;

  Expr Realize() const final {
    // NOTE: use a copy to discard the "const" qualifier
    TransformMemorizer tmp_memorizer = memorizer;
    // fallback to old layout
    return tmp_memorizer.Transform(value, new_layout, old_layout);
  }

  void VisitAttrs(AttrVisitor *v) final {
    v->Visit("value", &value);
    v->Visit("old_layout", &old_layout);
    v->Visit("new_layout", &new_layout);
  }

  static constexpr const char *_type_key = "relay.alter_op_layout.LayoutAlternatedExprNode";
  TVM_DECLARE_NODE_TYPE_INFO(LayoutAlternatedExprNode, TempExprNode);
};

RELAY_DEFINE_NODE_REF(LayoutAlternatedExpr, LayoutAlternatedExprNode, TempExpr);

// Call registered FInferCorrectLayout of an op.
// Parameters are the same as the parameters for FInferCorrectLayout
// Returns inferred_input_layout, inferred_output_layout, success
std::tuple<Array<Layout>, Array<Layout>, bool> CallInfer(
    const Call& call,
    const Array<Layout>& new_in_layouts,
    const Array<Layout>& old_in_layouts,
    const Array<Array<IndexExpr> > &old_in_shapes) {
  static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout");

  Op op = Downcast<Op>(call->op);
  if (finfer_layout.count(op)) {
    Array<Array<Layout> > inferred_layouts;
    inferred_layouts = finfer_layout[op](call->attrs, new_in_layouts,
                                         old_in_layouts, old_in_shapes);
    CHECK_EQ(inferred_layouts.size(), 2)
      << "FInferCorrectLayout should return an array with size of 2";
    for (auto x : inferred_layouts) {
      for (auto y : x) {
        if (!y.defined()) {  // inference fails
          return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
        }
      }
    }
    return std::make_tuple<>(inferred_layouts[0], inferred_layouts[1], true);
  } else {
    return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
  }
}

// Call registered FTVMAlterOpLayout of an op
// Returns the altered expression
Call CallAlter(const Call& ref_call,
               const std::vector<Expr>& new_args) {
  static auto falter_layout = Op::GetAttr<FTVMAlterOpLayout>("FTVMAlterOpLayout");
  Op op = Downcast<Op>(ref_call->op);

  Expr new_e;
  bool modified = false;
  if (falter_layout.count(op)) {
    tvm::Array<tvm::Tensor> tinfos;
    for (auto expr : ref_call->args) {
      auto ttype = expr->type_as<TensorTypeNode>();
      tinfos.push_back(tvm::placeholder(ttype->shape, ttype->dtype));
    }
    Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos);
    if (altered_value.defined()) {
      new_e = altered_value;
      modified = true;
    }
  }
  if (!modified) {
    new_e = CallNode::make(ref_call->op, new_args,
                           ref_call->attrs);
  }

  const CallNode *new_call = new_e.as<CallNode>();
  CHECK(new_call) << "Can only replace the original operator with another call node";
  return GetRef<Call>(new_call);
}

Expr AlterOpLayoutRewrite(const Call &ref_call,
                          const Array<Expr> &new_args,
                          const NodeRef& ctx) {
  std::vector<LayoutAlternatedExpr> inputs;
  std::vector<Expr> normal_new_args;
  Array<Array<IndexExpr> > input_shapes;

  // NOTE: discard the "const" qualifier
  TransformMemorizer memorizer = Downcast<TransformMemorizer>(ctx);

  // fill incomplete state and flatten tuple
  auto push_back_one_arg = [&inputs, memorizer](Expr arg) {
    // We always expect LayoutAlternatedExpr.
    // This is used to convert the normal Expr to LayoutAlternatedExpr.
    if (const LayoutAlternatedExprNode *inp = arg.as<LayoutAlternatedExprNode>()) {
      inputs.push_back(GetRef<LayoutAlternatedExpr>(inp));
      return inp->value;
    } else {
      auto inode = make_node<LayoutAlternatedExprNode>();
      inode->value = arg;
      inode->memorizer = memorizer;
      inputs.push_back(LayoutAlternatedExpr(inode));
      return arg;
    }
  };

  for (auto new_arg : new_args) {
    // NOTE: do not support nested tuple
    if (new_arg->is_type<TupleNode>()) {
      Tuple tuple_new_arg = Downcast<Tuple>(new_arg);
      std::vector<Expr> fields;
      for (auto x : tuple_new_arg->fields) {
        Expr tmp = push_back_one_arg(x);
        fields.push_back(tmp);
      }
      normal_new_args.push_back(TupleNode::make(fields));
    } else {
      Expr tmp = push_back_one_arg(new_arg);
      normal_new_args.push_back(tmp);
    }
  }

  // old_in, new_in = state[inputs]
  Array<Layout> old_in, old_out, new_in, new_out, new_in2;
  for (auto inp : inputs) {
    old_in.push_back(inp->old_layout);
    new_in.push_back(inp->new_layout);
  }

  for (auto arg : ref_call->args) {
    if (arg->is_type<TupleNode>()) {  // flatten tuple
      Tuple tuple_arg = Downcast<Tuple>(arg);
      for (auto x : tuple_arg->fields) {
        input_shapes.push_back(x->type_as<TensorTypeNode>()->shape);
      }
    } else {
      input_shapes.push_back(arg->type_as<TensorTypeNode>()->shape);
    }
  }

  // old_in, old_out = op.infer(old_in)
  bool success = false;
  std::tie(old_in, old_out, success) = CallInfer(ref_call,
                                                 Array<Layout>(nullptr),
                                                 old_in, input_shapes);
  if (!success) { return Expr(nullptr); }
  CHECK_EQ(old_in.size(), new_in.size());

  // if new_in == 'undef':  new_in = old_in
  for (size_t i = 0; i < new_in.size(); ++i) {
    if (!new_in[i].defined()) {
      new_in.Set(i, old_in[i]);
    }
  }

  // new_op = alter(op)
  Call new_call = CallAlter(ref_call, normal_new_args);

  // new_in2, new_out = op.infer(new_in)
  if (new_call->op->is_type<OpNode>()) {
    success = false;
    std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in, old_in, input_shapes);
    if (!success) { return Expr(nullptr); }
  } else {
    return Expr(nullptr);
  }

  CHECK_EQ(new_out.size(), old_out.size())
    << "The number of output nodes should keep the same during alter_op_layout";
  CHECK_EQ(new_in.size(), new_in2.size())
    << "The number of input nodes should keep the same during alter_op_layout";

  // if (new_in != new_in2): insert transform (new_in -> new_in2)
  Array<Expr> transformed_args;
  size_t pt = 0;
  for (auto arg : new_call->args) {
    if (arg->is_type<TupleNode>()) {  // unflatten tuple
      Tuple tuple_arg = Downcast<Tuple>(arg);
      std::vector<Expr> transformed_tuple_arg;
      for (auto arg_item : tuple_arg->fields) {
          transformed_tuple_arg.push_back(
                  memorizer.Transform(arg_item, new_in[pt], new_in2[pt]));
          pt++;
      }
      transformed_args.push_back(TupleNode::make(transformed_tuple_arg));
    } else {
      transformed_args.push_back(
              memorizer.Transform(arg, new_in[pt], new_in2[pt]));
      pt++;
    }
  }
  CHECK_EQ(pt, inputs.size());

  // state[node] = (old_out, new_out)
  // (handle tuple output)
  if (ref_call->checked_type()->is_type<TupleTypeNode>()) {
    Expr tuple_output = CallNode::make(new_call->op, transformed_args,
                                       new_call->attrs);
    Array<Expr> fields;
    for (size_t i = 0; i < new_out.size(); ++i) {
      auto rnode = make_node<LayoutAlternatedExprNode>();
      rnode->value = TupleGetItemNode::make(tuple_output, i);
      rnode->old_layout = old_out[i];
      rnode->new_layout = new_out[i];
      rnode->memorizer = memorizer;
      fields.push_back(Expr(rnode));
    }
    return TupleNode::make(fields);
  } else {
    auto rnode = make_node<LayoutAlternatedExprNode>();
    CHECK_EQ(new_out.size(), 1);
    rnode->value = CallNode::make(new_call->op, transformed_args,
                                  new_call->attrs);
    rnode->old_layout = old_out[0];
    rnode->new_layout = new_out[0];
    rnode->memorizer = memorizer;
    return Expr(rnode);
  }
}

// Limiations:
// 1. the altered op should have the same number of arguments as the previous one
// 2. do not support nested tuple arguments
TVM_REGISTER_API("relay._ir_pass.AlterOpLayout")
.set_body([](TVMArgs args, TVMRetValue *ret) {
  TransformMemorizer transformMemorizer(make_node<TransformMemorizerNode>());
  auto fcontext = [&](const Call& call) -> NodeRef{
    return transformMemorizer;
  };

  *ret = ForwardRewrite(args[0], AlterOpLayoutRewrite, fcontext);
});

}  // namespace alter_op_layout

}  // namespace relay
}  // namespace tvm