/*!
 *  Copyright (c) 2018 by Contributors
 * \file remap_thread_axis.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <unordered_map>


namespace tvm {
namespace ir {

// Mutator to change the read pattern
class ThreadAxisRewriter : private IRMutator {
 public:
  explicit ThreadAxisRewriter(
      const std::unordered_map<std::string, IterVar>& tmap)
      : tmap_(tmap) {
  }

  Stmt Rewrite(Stmt stmt) {
    return Mutate(stmt);
  }

 private:
  Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
    if (op->attr_key == attr::thread_extent) {
      IterVar iv(op->node.node_);
      CHECK_NE(iv->thread_tag.length(), 0U);
      auto it = tmap_.find(iv->thread_tag);
      if (it != tmap_.end()) {
        const IterVar& new_iv = it->second;
        const Variable* v = iv->var.get();
        if (!vmap_.count(v)) {
          vmap_[v] = new_iv->var;
        } else {
          CHECK(vmap_[v].same_as(new_iv->var));
        }
        Stmt body = this->Mutate(op->body);
        return AttrStmt::make(
            new_iv, op->attr_key, op->value, body);
      }
    }
    return IRMutator::Mutate_(op, stmt);
  }

  Expr Mutate_(const Variable* op, const Expr& expr) final {
    auto it = vmap_.find(op);
    if (it != vmap_.end()) return it->second;
    return IRMutator::Mutate_(op, expr);
  }
  // The thread map
  const std::unordered_map<std::string, IterVar>& tmap_;
  // variable map
  std::unordered_map<const Variable*, Var> vmap_;
};

LoweredFunc
RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) {
  std::unordered_map<std::string, IterVar> tmap;
  for (const auto& kv : thread_map) {
    const StringImm* str = kv.first.as<StringImm>();
    CHECK(str != nullptr);
    tmap[str->value] = kv.second;
  }

  CHECK_EQ(f->func_type, kDeviceFunc);
  auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
  // replace the thread axis
  for (size_t i = 0; i < n->thread_axis.size(); ++i) {
    auto it = tmap.find(n->thread_axis[i]->thread_tag);
    if (it != tmap.end()) {
      n->thread_axis.Set(i, it->second);
    }
  }
  n->body = ThreadAxisRewriter(tmap).Rewrite(n->body);
  return LoweredFunc(n);
}

}  // namespace ir
}  // namespace tvm