/*!
 *  Copyright (c) 2017 by Contributors
 * \file narrow_channel_access.cc
 * \brief Narrow channel access to a smaller range
 *  when possible by bringing it to the internal loop.
 */
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/arithmetic.h>
#include <tvm/channel.h>
#include "./ir_util.h"

namespace tvm {
namespace ir {
using namespace arith;

// Bound deducer for channel access.
class ChannelAccessBound : public IRVisitor {
 public:
  ChannelAccessBound(const Variable* buf_var, bool read_access)
      : buf_var_(buf_var), read_access_(read_access) {}

  void Visit_(const Store* op) final {
    if (!read_access_ && buf_var_ == op->buffer_var.get()) {
      ret_.emplace_back(EvalSet(op->index, dom_map_));
    }
    IRVisitor::Visit_(op);
  }
  void Visit_(const For* op) final {
    CHECK(is_zero(op->min));
    // We know that the extent of the loop won't depend on relaxed scope.
    // TODO(tqchen) have a verification pass.
    dom_map_[op->loop_var.get()] = IntSet::interval(op->min, op->extent - 1);
    IRVisitor::Visit_(op);
  }
  void Visit_(const Load* op) final {
    if (read_access_ && buf_var_ == op->buffer_var.get()) {
      ret_.emplace_back(EvalSet(op->index, dom_map_));
    }
    IRVisitor::Visit_(op);
  }
  void Visit_(const Let* op) final {
    LOG(FATAL) << "cannot pass through let";
  }
  void Visit_(const LetStmt* op) final {
    LOG(FATAL) << "cannot pass through let";
  }
  IntSet Eval(const Stmt& stmt) {
    Visit(stmt);
    return Union(ret_);
  }

 private:
  // The buffer variable.
  const Variable* buf_var_;
  // read or write
  bool read_access_{true};
  // Box
  std::vector<IntSet> ret_;
  // Domain map.
  std::unordered_map<const Variable*, IntSet> dom_map_;
};

class ChannelAccessIndexRewriter : public IRMutator {
 public:
  ChannelAccessIndexRewriter(const Variable* buf_var,
                             Expr min,
                             bool read_access)
      : buf_var_(buf_var), min_(min), read_access_(read_access) {}
  Expr Mutate_(const Load* op, const Expr& e) final {
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Load>();
    if (read_access_ && buf_var_ == op->buffer_var.get()) {
      return Load::make(
          op->type, op->buffer_var, ir::Simplify(op->index - min_),
          op->predicate);
    } else {
      return expr;
    }
  }
  Stmt Mutate_(const Store* op, const Stmt& s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Store>();
    if (!read_access_ && buf_var_ == op->buffer_var.get()) {
      return Store::make(
          op->buffer_var, op->value, ir::Simplify(op->index - min_),
          op->predicate);
    } else {
      return stmt;
    }
  }

 private:
  // The buffer variable.
  const Variable* buf_var_;
  // The min bound.
  Expr min_;
  // read or write
  bool read_access_{true};
};


// Rewrite channel access pattern.
class ChannelAccessRewriter : public IRMutator {
 public:
  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
    Stmt ret;
    const AttrStmt* adv = op->body.as<AttrStmt>();
    if ((op->attr_key == ir::attr::channel_read_scope &&
         adv && adv->attr_key == ir::attr::channel_read_advance) ||
        (op->attr_key == ir::attr::channel_write_scope &&
         adv && adv->attr_key == ir::attr::channel_write_advance)) {
      RewriteEntry e;
      e.window = op;
      e.advance = adv;
      e.read_access = op->attr_key == ir::attr::channel_read_scope;
      tasks_.push_back(e);
      ret = IRMutator::Mutate_(op, s);
      if (tasks_.back().rewrite_success) {
        ret = ret.as<AttrStmt>()->body.as<AttrStmt>()->body;
      }
      tasks_.pop_back();
      return ret;
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }

  Stmt Mutate_(const For* op, const Stmt& s) final {
    std::vector<RewriteEntry> tasks;
    std::swap(tasks_, tasks);
    Stmt body = op->body;
    std::vector<Stmt> nest;
    for (RewriteEntry& e : tasks) {
      body = RewriteAccess(op, body, &e, &nest);
    }

    if (!body.same_as(op->body)) {
      body = Mutate(body);
      body = For::make(
          op->loop_var, op->min, op->extent,
          op->for_type, op->device_api, body);
      body = MergeNest(nest, body);
    } else {
      CHECK_EQ(nest.size(), 0U);
      body = IRMutator::Mutate_(op, s);
    }
    std::swap(tasks_, tasks);
    return body;
  }

 private:
  struct RewriteEntry {
    bool read_access;
    const AttrStmt* window;
    const AttrStmt* advance;
    bool rewrite_success{false};
  };

  Stmt RewriteAccess(const For* for_op,
                     Stmt body,
                     RewriteEntry* e,
                     std::vector<Stmt>* outer_nest) {
    const AttrStmt* adv_op = e->advance;
    const Expr& window = e->window->value;
    bool read_access = e->read_access;
    Var var(for_op->loop_var);
    Channel ch(adv_op->node.node_);
    ChannelAccessBound acc(ch->handle_var.get(), read_access);
    IntSet iset = acc.Eval(for_op->body);
    Range r = iset.cover_range(Range::make_by_min_extent(0, window));
    r = Range::make_by_min_extent(
        ir::Simplify(r->min), ir::Simplify(r->extent));
    if (ExprUseVar(r->extent, var)) return body;
    Array<Expr> linear_eq = DetectLinearEquation(r->min, {var});
    if (linear_eq.size() == 0) return body;
    Expr coeff = linear_eq[0];
    Expr base = linear_eq[1];
    if (!is_zero(base)) return body;
    Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent);
    if (!can_prove(left >= 0)) return body;
    // rewrite access index.
    ChannelAccessIndexRewriter rw(
        ch->handle_var.get(), var * coeff, read_access);
    body = rw.Mutate(body);

    if (read_access) {
      body = AttrStmt::make(
          ch, ir::attr::channel_read_scope, r->extent,
          AttrStmt::make(ch, ir::attr::channel_read_advance, coeff,
                         body));
    } else {
      body = AttrStmt::make(
          ch, ir::attr::channel_write_scope, r->extent,
          AttrStmt::make(ch, ir::attr::channel_write_advance, coeff,
                         body));
    }

    if (!is_zero(left)) {
      Stmt no_op = Evaluate::make(0);
      if (read_access) {
        outer_nest->emplace_back(
            AttrStmt::make(ch, ir::attr::channel_read_advance, left, no_op));
      } else {
        outer_nest->emplace_back(
            AttrStmt::make(ch, ir::attr::channel_write_advance, left, no_op));
      }
    }

    e->rewrite_success = true;
    return body;
  }

  std::vector<RewriteEntry> tasks_;
};

Stmt NarrowChannelAccess(Stmt stmt) {
  return ChannelAccessRewriter().Mutate(stmt);
}

}  // namespace ir
}  // namespace tvm