/*!
 *  Copyright (c) 2016 by Contributors
 * \file buffer.cc
 */
#include <tvm/buffer.h>
#include <tvm/runtime/device_api.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <iterator>
#include "../arithmetic/compute_expr.h"

namespace tvm {

Array<Expr> SimplifyArray(Array<Expr> array) {
  for (size_t i = 0; i < array.size(); ++i) {
    array.Set(i, ir::Simplify(array[i]));
  }
  return array;
}

Buffer decl_buffer(Array<Expr> shape,
                   Type dtype,
                   std::string name) {
  return BufferNode::make(
      Var(name, Handle()),
      dtype,
      shape,
      Array<Expr>(),
      Expr(),
      name,
      "",
      0, 0);
}

// Split the given expression w.r.t the add operator
inline std::vector<const Expr*> ExprSplitAddition(const Expr &expr) {
  using namespace ir;
  std::vector<const Expr*> ret;
  std::stack<const Expr*> split_buffer;
  split_buffer.push(&expr);
  while (!split_buffer.empty()) {
    const Expr* top_ele = split_buffer.top();
    split_buffer.pop();
    auto expr_add_match = top_ele->as<Add>();
    if (expr_add_match) {
      split_buffer.push(&expr_add_match->b);
      split_buffer.push(&expr_add_match->a);
    } else {
      ret.emplace_back(top_ele);
    }
  }
  return ret;
}


// Searches for the following types of expr:
//   mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
//   mod_l_expr = c
//   mod_r_expr = k1 * k2 * ... * ki
// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c)
// Currently the we will not search the add/mult combinations exhaustively
//   as it will take too much computation.
inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
                                              const Expr &mod_l_expr,
                                              const Expr &mod_r_expr) {
  using namespace ir;
  const Mul* mult_ptr = mult_expr.as<Mul>();
  if (!mult_ptr) return std::make_pair(false, Expr());
  Expr mult_outer = mult_ptr->b;
  const Expr* inner = &(mult_ptr->a);
  // 1. Calculate the outer multiplier
  while (true) {
    mult_ptr = inner->as<Mul>();
    if (mult_ptr) {
      inner = &(mult_ptr->a);
      mult_outer = mult_ptr->b * mult_outer;
    } else {
      break;
    }
  }
  // 2. Search for the pattern c / (...) * (...) + c % (...)
  // We match the search element with Add, Mul and Div.
  //   If Add is found, we need to continue our search for the rhs
  //   If Mult is found, we will expand the inner multiplication factor
  //   If Div is found, we will go on testing whether lhs matches the lhs of mod expr
  //      and returns the optimization result.
  const Expr* search_ptr = inner;
  Expr mult_inner;  // The inner multiplication factor
  Expr no_opt_sum;  // Sum of the exprs that cannot be optimized
  while (true) {
    auto inner_div_ptr = search_ptr->as<Div>();
    auto inner_mult_ptr = search_ptr->as<Mul>();
    auto inner_add_ptr = search_ptr->as<Add>();
    if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
      return std::make_pair(false, Expr());
    } else if (inner_div_ptr) {
      Expr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer;
      if (Equal(overall_mult, inner_div_ptr->b)
          && Equal(overall_mult, mod_r_expr)
          && Equal(inner_div_ptr->a, mod_l_expr)) {
        // Found!
        Expr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr;
        return std::make_pair(true, ret);
      } else {
        return std::make_pair(false, Expr());
      }
    } else if (inner_mult_ptr) {
      mult_inner = mult_inner.get() ? inner_mult_ptr->b * mult_inner : inner_mult_ptr->b;
      search_ptr = &(inner_mult_ptr->a);
    } else if (inner_add_ptr) {
      if (mult_inner.get()) {
        return std::make_pair(false, Expr());
      }
      no_opt_sum = no_opt_sum.get() ? no_opt_sum + inner_add_ptr->a : inner_add_ptr->a;
      search_ptr = &(inner_add_ptr->b);
    } else {
      LOG(FATAL) << "Unexpected search result!";
      break;
    }
  }
  return std::make_pair(false, Expr());
}

// Insert the elements into the corresponding mult_exprs and mod_exprs.
// If the element is found to match Mul, it will be pushed to the mult_exprs.
// If the element it found to match Mod, it will be pused to the mod_exprs.
// Otherwise, the elements will be added to the no_opt_sum variable
inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles,
                                      std::list<Expr>* mult_exprs,
                                      std::list<std::pair<Expr, Expr> >* mod_exprs,
                                      Expr* no_opt_sum,
                                      bool* has_mult,
                                      bool* has_mod) {
  using namespace ir;
  *has_mult = false;
  *has_mod = false;
  for (const Expr* ele : eles) {
    auto mod_ptr = ele->as<Mod>();
    auto mult_ptr = ele->as<Mul>();
    if (mod_ptr) {
      *has_mod = true;
      mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b)));
    } else if (mult_ptr) {
      *has_mult = true;
      mult_exprs->emplace_back(*ele);
    } else {
      *no_opt_sum = no_opt_sum->get() ? *no_opt_sum + *ele : *ele;
    }
  }
}

// Searches for this types of expr:
//   (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
//   + c % (k1 * k2 * ... * ki)
// and simplifies to (a1 + a2 + ... + aj) * kt * ... * ki + c
// The search will be performed repeatively until no pattern is found.
// Return: a pair with (false, Expr()) if cannot be optimized.
//         a pair with (true, optimized_expr) if can be optimized
inline Expr MergeMulMod(const Expr &base) {
  using namespace ir;
  // 1. Prepare the lists.
  // We store two lists, a list that contain all the elements that match Mul and
  //                     a list that contain all the elements that match Mod.
  // The elements in the Mod will be used to match against the elements in Mul.
  // The result will then be split and pushed back to these two lists.
  Expr simplified_base = Simplify(base);
  std::vector<const Expr*> eles = ExprSplitAddition(simplified_base);
  std::list<Expr> mult_exprs;
  std::list<std::pair<Expr, Expr> > mod_exprs;
  Expr no_opt_sum;
  bool has_mult;
  bool has_mod;
  MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs,
                            &no_opt_sum, &has_mult, &has_mod);
  bool find_opt = false;
  std::list<std::pair<Expr, Expr> >::iterator search_mod_it = mod_exprs.begin();
  // 2. Exhaustive Search
  while (search_mod_it != mod_exprs.end()) {
    std::list<Expr>::iterator mult_it = mult_exprs.begin();
    bool inner_find_opt = false;
    while (mult_it != mult_exprs.end()) {
      std::pair<bool, Expr> ret = MergeMulModInner(*mult_it,
                                                   search_mod_it->first,
                                                   search_mod_it->second);
      if (ret.first) {
        inner_find_opt = true;
        auto temp_mod_it = search_mod_it;
        ++search_mod_it;
        mod_exprs.erase(temp_mod_it);
        mult_exprs.erase(mult_it);
        std::vector<const Expr*> ret_eles = ExprSplitAddition(ret.second);
        MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs,
                                  &no_opt_sum, &has_mult, &has_mod);
        if (has_mult) {
          search_mod_it = mod_exprs.begin();
        } else if (has_mod && search_mod_it == mod_exprs.end()) {
          search_mod_it--;
        }
        break;
      } else {
        ++mult_it;
      }
    }
    find_opt = find_opt || inner_find_opt;
    if (!inner_find_opt) {
      ++search_mod_it;
    }
  }
  if (!find_opt) {
    return simplified_base;
  }
  for (std::list<Expr>::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) {
    no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it;
  }
  for (std::list<std::pair<Expr, Expr> >::iterator it = mod_exprs.begin();
                                                   it != mod_exprs.end(); ++it) {
    no_opt_sum = no_opt_sum.get() ? no_opt_sum + it->first % it->second : it->first % it->second;
  }
  return no_opt_sum;
}

// The buffer offset in convention of number of elements of
// original data ignoring number of lanes.
// We also perform optimization to simplify the indexing expression.
inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
  Expr base = n->elem_offset;
  if (n->strides.size() == 0) {
    CHECK_EQ(n->shape.size(), index.size());
    if (n->shape.size() != 0) {
      if (is_zero(base)) {
        base = index[0];
      } else {
        base = base + index[0];
      }
    }
    base = MergeMulMod(base);
    for (size_t i = 1; i < index.size(); ++i) {
      base = MergeMulMod(base * n->shape[i] + index[i]);
    }
  } else {
    CHECK_EQ(n->strides.size(), index.size());
    if (is_zero(base)) {
      base = MergeMulMod(index[0] * n->strides[0]);
    } else {
      base = MergeMulMod(base + index[0] * n->strides[0]);
    }
    for (size_t i = 1; i < index.size(); ++i) {
      base = MergeMulMod(base + index[i] * n->strides[i]);
    }
  }
  return base;
}

inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, Type dtype) {
  Expr offset = ElemOffset(n, index);
  if (n->dtype.lanes() != 1) {
    offset = offset * make_const(offset.type(), dtype.lanes());
  }
  if (dtype.lanes() != 1) {
    return ir::Ramp::make(offset, make_const(offset.type(), 1), dtype.lanes());
  } else {
    return offset;
  }
}

Expr Buffer::vload(Array<Expr> begin, Type dtype) const {
  const BufferNode* n = operator->();
  CHECK(dtype.element_of() == n->dtype.element_of() &&
        dtype.lanes() % n->dtype.lanes() == 0)
      << "Cannot load " << dtype
      << " from buffer of " << n->dtype;
  return ir::Load::make(
      dtype, n->data, BufferOffset(n, begin, dtype),
      const_true(dtype.lanes()));
}

Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
  const BufferNode* n = operator->();
  Type dtype = value.type();
  CHECK(dtype.element_of() == n->dtype.element_of() &&
        dtype.lanes() % n->dtype.lanes() == 0)
      << "Cannot load " << dtype
      << " from buffer of " << n->dtype;
  return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype),
                         const_true(dtype.lanes()));
}

Buffer Buffer::MakeStrideView() const {
  if ((*this)->strides.size() != 0) return *this;
  if ((*this)->shape.size() == 0) return *this;
  std::vector<Expr> temp;
  auto n = std::make_shared<BufferNode>(*operator->());
  Expr acc = make_const(n->DefaultIndexType(), 1);
  for (size_t i = n->shape.size(); i != 0 ; --i) {
    temp.push_back(acc);
    acc = acc * n->shape[i - 1];
  }
  for (size_t i = temp.size(); i != 0; --i) {
    n->strides.push_back(temp[i - 1]);
  }
  return Buffer(n);
}

Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
  const BufferNode* n = operator->();
  begins = SimplifyArray(begins);
  Expr elem_offset = ir::Simplify(ElemOffset(n, begins));
  Array<Expr> strides = n->strides;
  if (strides.size() == 0) {
    bool can_relax = true;
    bool need_stride = false;
    // check if stride is needed.
    for (size_t i = 0; i < extents.size(); ++i) {
      if (!can_relax) {
        if (!is_zero(begins[i]) ||
            !is_zero(ir::Simplify(extents[i] - n->shape[i]))) {
          need_stride = true;
        }
      }
      if (!is_one(extents[i])) can_relax = false;
    }
    // make stride.
    if (need_stride) {
      return MakeStrideView().MakeSlice(begins, extents);
    }
  }
  return BufferNode::make(n->data,
                          n->dtype,
                          extents,
                          strides,
                          elem_offset,
                          n->name + "_slice",
                          n->scope,
                          n->data_alignment,
                          0);
}

Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const {
  const BufferNode* self = operator->();
  Expr e_dtype;
  Expr extent;
  if (self->shape.size() == 0) {
    extent = make_const(self->DefaultIndexType(), 1);
  } else if (self->strides.size() == self->shape.size()) {
    int highest_dim = 0;
    extent = arith::ComputeExpr<ir::Mul>(
        self->strides[highest_dim], self->shape[highest_dim]);
  } else {
    extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr());
  }
  Expr elem_offset = self->elem_offset + offset;
  if (content_lanes > 1) {
    e_dtype = make_zero(self->dtype.with_lanes(content_lanes));
    extent = extent / make_const(self->elem_offset.type(), content_lanes);
    elem_offset = self->elem_offset / make_const(self->elem_offset.type(),
                                                 content_lanes);
  } else {
    e_dtype = make_zero(self->dtype);
  }
  Array<Expr> acc_args{
    e_dtype, self->data, elem_offset,
        extent, make_const(Int(32), access_mask)};
  return ir::Call::make(
      ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic);
}

Buffer BufferNode::make(Var data,
                        Type dtype,
                        Array<Expr> shape,
                        Array<Expr> strides,
                        Expr elem_offset,
                        std::string name,
                        std::string scope,
                        int data_alignment,
                        int offset_factor) {
  auto n = std::make_shared<BufferNode>();
  n->data = std::move(data);
  n->dtype = dtype;
  n->shape = std::move(shape);
  n->strides = std::move(strides);
  n->name = std::move(name);
  if (scope.length() == 0) {
    scope = "global";
  }
  n->scope = std::move(scope);
  if (!elem_offset.defined()) {
    elem_offset = make_const(n->DefaultIndexType(), 0);
  }
  if (data_alignment <= 0) {
    data_alignment = runtime::kAllocAlignment;
  }
  if (offset_factor == 0) {
    offset_factor = 1;
  }
  n->elem_offset = std::move(elem_offset);
  n->data_alignment = data_alignment;
  n->offset_factor = offset_factor;
  return Buffer(n);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BufferNode>([](const BufferNode *op, IRPrinter *p) {
    p->stream << "buffer(" << op->name << ", " << op << ")";
});

TVM_REGISTER_NODE_TYPE(BufferNode);

}  // namespace tvm