verilog_ir.cc 9.61 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
/*!
 *  Copyright (c) 2017 by Contributors
 * \file verilog_ir.cc
 */
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include "./verilog_ir.h"
#include "../../arithmetic/compute_expr.h"

namespace tvm {
namespace codegen {
namespace verilog {

using namespace ir;

ControlSignal ControlSignalNode::make(
    ControlSignalType type, int advance_size) {
  auto n = std::make_shared<ControlSignalNode>();
  n->ctrl_type = type;
  n->advance_size = advance_size;
  return ControlSignal(n);
}

StageInput StageInputNode::make(Var var, StageInputType input_type) {
  std::shared_ptr<StageInputNode> n = std::make_shared<StageInputNode>();
  n->var = var;
  n->input_type = input_type;
  return StageInput(n);
}

// Replace stage inputs by placeholder, update the input map.
class StageInputReplacer : public IRMutator {
 public:
  explicit StageInputReplacer(
      const std::unordered_map<const Variable*, StageInput>& var_info)
      : var_info_(var_info) {}

  Expr Mutate_(const Variable* op, const Expr& e) final {
    if (replace_.count(op)) {
      return replace_.at(op);
    }
    auto it = var_info_.find(op);
    if (it == var_info_.end()) return e;
    Var new_var(it->second->var->name_hint + ".sync", op->type);
    inputs_.Set(new_var, it->second);
    replace_[op] = new_var;
    return new_var;
  }
  Expr Mutate_(const Load* op, const Expr& e) final {
    CHECK(is_zero(op->index))
        << "Load should be in its own stage.";
    if (replace_.count(op->buffer_var.get())) {
      return replace_.at(op->buffer_var.get());
    }
    auto it = var_info_.find(op->buffer_var.get());
    CHECK(it != var_info_.end())
        << "Load from unknown channel";
    Var data(it->second->var->name_hint + ".load.sync", op->type);
    inputs_.Set(data, it->second);
    replace_[op->buffer_var.get()] = data;
    return data;
  }
  // inputs that get replaced.
  Map<Var, StageInput> inputs_;
  // replacement map
  std::unordered_map<const Variable*, Var> replace_;
  // Variable replacement plan.
  const std::unordered_map<const Variable*, StageInput>& var_info_;
};

/*! \brief Extract module block */
class PipelineExtractor: public IRVisitor {
 public:
  Pipeline Extract(LoweredFunc f) {
    // Initialize the memory map channels
    // TODO(tqchen) move the logic to explicit specification.
    for (auto arg : f->args) {
      if (arg.type().is_handle()) {
        arg_handle_[arg.get()] = arg;
      }
    }
    pipeline_ = std::make_shared<PipelineNode>();
    this->Visit(f->body);
    // setup channels
    for (const auto &kv : cmap_) {
      pipeline_->channels.Set(
          kv.second.node->channel->handle_var,
          ChannelBlock(kv.second.node));
    }
    pipeline_->args = f->args;
    return Pipeline(pipeline_);
  }

  void Visit_(const AttrStmt* op) final {
96
    if (op->attr_key == attr::pipeline_stage_scope) {
97 98 99 100 101 102
      CHECK(!in_pipeline_stage_);
      in_pipeline_stage_ = true;
      trigger_.emplace_back(std::make_pair(loop_.size(), op));
      IRVisitor::Visit_(op);
      trigger_.pop_back();
      in_pipeline_stage_ = false;
103 104
    } else if (op->attr_key == attr::channel_read_advance ||
               op->attr_key == attr::channel_write_advance) {
105 106 107
      trigger_.emplace_back(std::make_pair(loop_.size(), op));
      IRVisitor::Visit_(op);
      trigger_.pop_back();
108 109
    } else if (op->attr_key == attr::channel_read_scope ||
               op->attr_key == attr::channel_write_scope) {
110 111 112 113 114 115 116 117
      Channel ch(op->node.node_);
      ChannelEntry& cb = cmap_[ch->handle_var.get()];
      if (cb.node != nullptr) {
        CHECK(cb.node->channel.same_as(ch));
      } else {
        cb.node = std::make_shared<ChannelBlockNode>();
        cb.node->channel = ch;
      }
118
      if (op->attr_key == attr::channel_read_scope) {
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
        CHECK_EQ(cb.read_ref_count, 0)
            << "One channel can only be read from one consumer";
        ++cb.read_ref_count;
        CHECK(arith::GetConstInt(op->value, &(cb.node->read_window)))
              << "Only supprt constant read window";
      } else {
        CHECK_EQ(cb.write_ref_count, 0)
            << "One channel can only be write by one producer";
        ++cb.write_ref_count;
        CHECK(arith::GetConstInt(op->value, &(cb.node->write_window)))
              << "Only supprt constant write window";
      }
      var_info_[ch->handle_var.get()] =
          StageInputNode::make(ch->handle_var, kChannel);
      IRVisitor::Visit_(op);
      var_info_.erase(ch->handle_var.get());
    } else {
      IRVisitor::Visit_(op);
    }
  }
  void Visit_(const Block* op) final {
    CHECK(!in_pipeline_stage_)
        << "Do not support serial execution inside pipeline";
    IRVisitor::Visit_(op);
  }
  void Visit_(const IfThenElse* op) final {
    LOG(FATAL) << "Not implemeneted";
  }
  void Visit_(const For* op) final {
    if (in_pipeline_stage_) {
      loop_.push_back(
          For::make(op->loop_var, op->min, op->extent,
                    op->for_type, op->device_api, Evaluate::make(0)));
      var_info_[op->loop_var.get()] =
          StageInputNode::make(Var(op->loop_var.node_), kLoopVar);
      IRVisitor::Visit_(op);
      var_info_.erase(op->loop_var.get());
      loop_.pop_back();
    } else {
      IRVisitor::Visit_(op);
    }
  }
  void Visit_(const Store* op) final {
    // Check the access pattern
    Channel arg_write =
        CheckArgHandleAccess(op->buffer_var.get(), op->value.type(), false);
    this->Visit(op->value);
    // The replace logic
    StageInputReplacer repl(var_info_);
    // Setup the compute block.
    std::shared_ptr<ComputeBlockNode> compute =
        std::make_shared<ComputeBlockNode>();
    compute->loop = Array<Stmt>(loop_);
    // setup the advance triggers
    for (const auto& e : trigger_) {
      const AttrStmt* attr = e.second;
      Channel ch;
176
      if (attr->attr_key == attr::pipeline_stage_scope) {
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
        ch = arg_write;
        if (!ch.defined()) continue;
      } else {
        ch = Channel(attr->node.node_);
      }
      std::shared_ptr<SignalTriggerNode> trigger
          = std::make_shared<SignalTriggerNode>();
      trigger->channel_var = ch->handle_var;
      // predicate for the trigger
      Expr predicate = const_true();
      for (size_t i = e.first; i < loop_.size(); ++i) {
        const For* loop = loop_[i].as<For>();
        predicate = predicate &&
            (loop->loop_var == (loop->extent - 1));
      }
      trigger->predicate = ir::Simplify(predicate);
      // Add the signal back to the channels.
      ChannelEntry& cb = cmap_.at(ch->handle_var.get());
      trigger->signal_index = static_cast<int>(cb.node->ctrl_signals.size());
      // Grab the advance constant size.
      int trigger_size;
198
      if (attr->attr_key == attr::pipeline_stage_scope) {
199 200
        cb.node->ctrl_signals.push_back(
            ControlSignalNode::make(kComputeFinish, 0));
201
      } else if (attr->attr_key == attr::channel_read_advance) {
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
        CHECK(arith::GetConstInt(attr->value, &trigger_size))
            << "Only support constant advance size";
        cb.node->ctrl_signals.push_back(
            ControlSignalNode::make(kReadAdvance, trigger_size));
      } else {
        CHECK(arith::GetConstInt(attr->value, &trigger_size))
            << "Only support constant advance size";
        cb.node->ctrl_signals.push_back(
            ControlSignalNode::make(kWriteAdvance, trigger_size));
      }
      compute->triggers.push_back(SignalTrigger(trigger));
    }

    // Check if we are writing to FIFO.
    const Load* load = op->value.as<Load>();
    if (is_zero(op->index) && load) {
      compute->body = Store::make(
          op->buffer_var,
220 221 222
          Load::make(load->type, load->buffer_var,
                     repl.Mutate(load->index), op->predicate),
          op->index, op->predicate);
223 224
    } else {
      compute->body = Store::make(
225 226
          op->buffer_var, repl.Mutate(op->value),
          repl.Mutate(op->index), op->predicate);
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
    }
    compute->inputs = repl.inputs_;
    pipeline_->stages.push_back(ComputeBlock(compute));
  }
  void Visit_(const LetStmt* op) final {
    LOG(FATAL) << "cannot pass through let";
  }
  void Visit_(const Evaluate* op) final {
    LOG(FATAL) << "Not implemeneted";
  }
  void Visit_(const Allocate* op) final {
    CHECK(!in_pipeline_stage_);
  }
  void Visit_(const AssertStmt* op) final {
    LOG(FATAL) << "Not implemeneted";
  }
  void Visit_(const Load* op) final {
    CheckArgHandleAccess(op->buffer_var.get(), op->type, true);
  }
  Channel CheckArgHandleAccess(const Variable* var, Type dtype, bool read_access) {
    if (!arg_handle_.count(var)) return Channel();
    CHECK(!cmap_.count(var))
        << "Multiple access to the same handle";
    ChannelEntry& cb = cmap_[var];
    cb.node = std::make_shared<ChannelBlockNode>();
    cb.node->channel = ChannelNode::make(arg_handle_.at(var), dtype);
    return cb.node->channel;
  }

 private:
  // The channel information.
  struct ChannelEntry {
    std::shared_ptr<ChannelBlockNode> node;
    int read_ref_count{0};
    int write_ref_count{0};
  };
  // Whether we are inside the pipeline stage.
  bool in_pipeline_stage_{false};
  // The current loop nest
  std::vector<Stmt> loop_;
  // Advance signal trigger
  std::vector<std::pair<size_t, const AttrStmt*> > trigger_;
  // Read write scope
  std::vector<const AttrStmt*> channel_scope_;
  // The loop index.
  std::unordered_map<const Variable*, StageInput> var_info_;
  // The channel entry;
  std::unordered_map<const Variable*, ChannelEntry> cmap_;
  // The argument handle map
  std::unordered_map<const Variable*, Var> arg_handle_;
  // The result block.
  std::shared_ptr<PipelineNode> pipeline_;
};

Pipeline MakePipeline(LoweredFunc f) {
  return PipelineExtractor().Extract(f);
}
}  // namespace verilog
}  // namespace codegen
}  // namespace tvm