/*!
 *  Copyright (c) 2017 by Contributors
 * \file verilog_ir.cc
 */
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include "./verilog_ir.h"
#include "../../arithmetic/compute_expr.h"

namespace tvm {
namespace codegen {
namespace verilog {

using namespace ir;

ControlSignal ControlSignalNode::make(
    ControlSignalType type, int advance_size) {
  auto n = std::make_shared<ControlSignalNode>();
  n->ctrl_type = type;
  n->advance_size = advance_size;
  return ControlSignal(n);
}

StageInput StageInputNode::make(Var var, StageInputType input_type) {
  std::shared_ptr<StageInputNode> n = std::make_shared<StageInputNode>();
  n->var = var;
  n->input_type = input_type;
  return StageInput(n);
}

// Replace stage inputs by placeholder, update the input map.
class StageInputReplacer : public IRMutator {
 public:
  explicit StageInputReplacer(
      const std::unordered_map<const Variable*, StageInput>& var_info)
      : var_info_(var_info) {}

  Expr Mutate_(const Variable* op, const Expr& e) final {
    if (replace_.count(op)) {
      return replace_.at(op);
    }
    auto it = var_info_.find(op);
    if (it == var_info_.end()) return e;
    Var new_var(it->second->var->name_hint + ".sync", op->type);
    inputs_.Set(new_var, it->second);
    replace_[op] = new_var;
    return new_var;
  }
  Expr Mutate_(const Load* op, const Expr& e) final {
    CHECK(is_zero(op->index))
        << "Load should be in its own stage.";
    if (replace_.count(op->buffer_var.get())) {
      return replace_.at(op->buffer_var.get());
    }
    auto it = var_info_.find(op->buffer_var.get());
    CHECK(it != var_info_.end())
        << "Load from unknown channel";
    Var data(it->second->var->name_hint + ".load.sync", op->type);
    inputs_.Set(data, it->second);
    replace_[op->buffer_var.get()] = data;
    return data;
  }
  // inputs that get replaced.
  Map<Var, StageInput> inputs_;
  // replacement map
  std::unordered_map<const Variable*, Var> replace_;
  // Variable replacement plan.
  const std::unordered_map<const Variable*, StageInput>& var_info_;
};

/*! \brief Extract module block */
class PipelineExtractor: public IRVisitor {
 public:
  Pipeline Extract(LoweredFunc f) {
    // Initialize the memory map channels
    // TODO(tqchen) move the logic to explicit specification.
    for (auto arg : f->args) {
      if (arg.type().is_handle()) {
        arg_handle_[arg.get()] = arg;
      }
    }
    pipeline_ = std::make_shared<PipelineNode>();
    this->Visit(f->body);
    // setup channels
    for (const auto &kv : cmap_) {
      pipeline_->channels.Set(
          kv.second.node->channel->handle_var,
          ChannelBlock(kv.second.node));
    }
    pipeline_->args = f->args;
    return Pipeline(pipeline_);
  }

  void Visit_(const AttrStmt* op) final {
    if (op->attr_key == attr::pipeline_stage_scope) {
      CHECK(!in_pipeline_stage_);
      in_pipeline_stage_ = true;
      trigger_.emplace_back(std::make_pair(loop_.size(), op));
      IRVisitor::Visit_(op);
      trigger_.pop_back();
      in_pipeline_stage_ = false;
    } else if (op->attr_key == attr::channel_read_advance ||
               op->attr_key == attr::channel_write_advance) {
      trigger_.emplace_back(std::make_pair(loop_.size(), op));
      IRVisitor::Visit_(op);
      trigger_.pop_back();
    } else if (op->attr_key == attr::channel_read_scope ||
               op->attr_key == attr::channel_write_scope) {
      Channel ch(op->node.node_);
      ChannelEntry& cb = cmap_[ch->handle_var.get()];
      if (cb.node != nullptr) {
        CHECK(cb.node->channel.same_as(ch));
      } else {
        cb.node = std::make_shared<ChannelBlockNode>();
        cb.node->channel = ch;
      }
      if (op->attr_key == attr::channel_read_scope) {
        CHECK_EQ(cb.read_ref_count, 0)
            << "One channel can only be read from one consumer";
        ++cb.read_ref_count;
        CHECK(arith::GetConstInt(op->value, &(cb.node->read_window)))
              << "Only supprt constant read window";
      } else {
        CHECK_EQ(cb.write_ref_count, 0)
            << "One channel can only be write by one producer";
        ++cb.write_ref_count;
        CHECK(arith::GetConstInt(op->value, &(cb.node->write_window)))
              << "Only supprt constant write window";
      }
      var_info_[ch->handle_var.get()] =
          StageInputNode::make(ch->handle_var, kChannel);
      IRVisitor::Visit_(op);
      var_info_.erase(ch->handle_var.get());
    } else {
      IRVisitor::Visit_(op);
    }
  }
  void Visit_(const Block* op) final {
    CHECK(!in_pipeline_stage_)
        << "Do not support serial execution inside pipeline";
    IRVisitor::Visit_(op);
  }
  void Visit_(const IfThenElse* op) final {
    LOG(FATAL) << "Not implemeneted";
  }
  void Visit_(const For* op) final {
    if (in_pipeline_stage_) {
      loop_.push_back(
          For::make(op->loop_var, op->min, op->extent,
                    op->for_type, op->device_api, Evaluate::make(0)));
      var_info_[op->loop_var.get()] =
          StageInputNode::make(Var(op->loop_var.node_), kLoopVar);
      IRVisitor::Visit_(op);
      var_info_.erase(op->loop_var.get());
      loop_.pop_back();
    } else {
      IRVisitor::Visit_(op);
    }
  }
  void Visit_(const Store* op) final {
    // Check the access pattern
    Channel arg_write =
        CheckArgHandleAccess(op->buffer_var.get(), op->value.type(), false);
    this->Visit(op->value);
    // The replace logic
    StageInputReplacer repl(var_info_);
    // Setup the compute block.
    std::shared_ptr<ComputeBlockNode> compute =
        std::make_shared<ComputeBlockNode>();
    compute->loop = Array<Stmt>(loop_);
    // setup the advance triggers
    for (const auto& e : trigger_) {
      const AttrStmt* attr = e.second;
      Channel ch;
      if (attr->attr_key == attr::pipeline_stage_scope) {
        ch = arg_write;
        if (!ch.defined()) continue;
      } else {
        ch = Channel(attr->node.node_);
      }
      std::shared_ptr<SignalTriggerNode> trigger
          = std::make_shared<SignalTriggerNode>();
      trigger->channel_var = ch->handle_var;
      // predicate for the trigger
      Expr predicate = const_true();
      for (size_t i = e.first; i < loop_.size(); ++i) {
        const For* loop = loop_[i].as<For>();
        predicate = predicate &&
            (loop->loop_var == (loop->extent - 1));
      }
      trigger->predicate = ir::Simplify(predicate);
      // Add the signal back to the channels.
      ChannelEntry& cb = cmap_.at(ch->handle_var.get());
      trigger->signal_index = static_cast<int>(cb.node->ctrl_signals.size());
      // Grab the advance constant size.
      int trigger_size;
      if (attr->attr_key == attr::pipeline_stage_scope) {
        cb.node->ctrl_signals.push_back(
            ControlSignalNode::make(kComputeFinish, 0));
      } else if (attr->attr_key == attr::channel_read_advance) {
        CHECK(arith::GetConstInt(attr->value, &trigger_size))
            << "Only support constant advance size";
        cb.node->ctrl_signals.push_back(
            ControlSignalNode::make(kReadAdvance, trigger_size));
      } else {
        CHECK(arith::GetConstInt(attr->value, &trigger_size))
            << "Only support constant advance size";
        cb.node->ctrl_signals.push_back(
            ControlSignalNode::make(kWriteAdvance, trigger_size));
      }
      compute->triggers.push_back(SignalTrigger(trigger));
    }

    // Check if we are writing to FIFO.
    const Load* load = op->value.as<Load>();
    if (is_zero(op->index) && load) {
      compute->body = Store::make(
          op->buffer_var,
          Load::make(load->type, load->buffer_var,
                     repl.Mutate(load->index), op->predicate),
          op->index, op->predicate);
    } else {
      compute->body = Store::make(
          op->buffer_var, repl.Mutate(op->value),
          repl.Mutate(op->index), op->predicate);
    }
    compute->inputs = repl.inputs_;
    pipeline_->stages.push_back(ComputeBlock(compute));
  }
  void Visit_(const LetStmt* op) final {
    LOG(FATAL) << "cannot pass through let";
  }
  void Visit_(const Evaluate* op) final {
    LOG(FATAL) << "Not implemeneted";
  }
  void Visit_(const Allocate* op) final {
    CHECK(!in_pipeline_stage_);
  }
  void Visit_(const AssertStmt* op) final {
    LOG(FATAL) << "Not implemeneted";
  }
  void Visit_(const Load* op) final {
    CheckArgHandleAccess(op->buffer_var.get(), op->type, true);
  }
  Channel CheckArgHandleAccess(const Variable* var, Type dtype, bool read_access) {
    if (!arg_handle_.count(var)) return Channel();
    CHECK(!cmap_.count(var))
        << "Multiple access to the same handle";
    ChannelEntry& cb = cmap_[var];
    cb.node = std::make_shared<ChannelBlockNode>();
    cb.node->channel = ChannelNode::make(arg_handle_.at(var), dtype);
    return cb.node->channel;
  }

 private:
  // The channel information.
  struct ChannelEntry {
    std::shared_ptr<ChannelBlockNode> node;
    int read_ref_count{0};
    int write_ref_count{0};
  };
  // Whether we are inside the pipeline stage.
  bool in_pipeline_stage_{false};
  // The current loop nest
  std::vector<Stmt> loop_;
  // Advance signal trigger
  std::vector<std::pair<size_t, const AttrStmt*> > trigger_;
  // Read write scope
  std::vector<const AttrStmt*> channel_scope_;
  // The loop index.
  std::unordered_map<const Variable*, StageInput> var_info_;
  // The channel entry;
  std::unordered_map<const Variable*, ChannelEntry> cmap_;
  // The argument handle map
  std::unordered_map<const Variable*, Var> arg_handle_;
  // The result block.
  std::shared_ptr<PipelineNode> pipeline_;
};

Pipeline MakePipeline(LoweredFunc f) {
  return PipelineExtractor().Extract(f);
}
}  // namespace verilog
}  // namespace codegen
}  // namespace tvm