split_pipeline.cc 9.49 KB
Newer Older
Tianqi Chen committed
1 2 3 4 5 6 7 8 9 10 11 12
/*!
 *  Copyright (c) 2017 by Contributors
 * \file split_pipeline.cc
 * \brief Split statement into pipeline stage modules.
 */
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/channel.h>
#include <unordered_map>
13
#include <unordered_set>
Tianqi Chen committed
14 15 16 17 18 19 20 21
#include "./ir_util.h"

namespace tvm {
namespace ir {

class MarkChannelAccess : public IRMutator {
 public:
  MarkChannelAccess(
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
      const std::unordered_map<const Variable*, Channel>& cmap,
      const std::unordered_map<const Variable*, Channel>& fifo_map)
      : cmap_(cmap), fifo_map_(fifo_map) {}
  using IRMutator::Mutate;
  Stmt Mutate(Stmt stmt) final {
    Stmt ret = IRMutator::Mutate(stmt);
    if (read_fifos_.size() != 0) {
      for (const Variable* v : read_fifos_) {
        Channel ch = fifo_map_.at(v);
        ret = ReadChannel(ch, 1, ret);
      }
      read_fifos_.clear();
    }
    if (write_fifos_.size() != 0) {
      for (const Variable* v : write_fifos_) {
        Channel ch = fifo_map_.at(v);
        ret = WriteChannel(ch, 1, ret);
      }
      write_fifos_.clear();
    }
    return ret;
  }
Tianqi Chen committed
44 45 46 47 48 49

  Expr Mutate_(const Load *op, const Expr& e) final {
    auto it = rmap_.find(op->buffer_var.get());
    if (it != rmap_.end()) {
      ++it->second.read_count;
    }
50 51 52 53
    if (fifo_map_.count(op->buffer_var.get())) {
      read_fifos_.insert(op->buffer_var.get());
      CHECK(!write_fifos_.count(op->buffer_var.get()));
    }
Tianqi Chen committed
54 55 56 57 58 59 60
    return IRMutator::Mutate_(op, e);
  }
  Stmt Mutate_(const Store *op, const Stmt& s) final {
    auto it = rmap_.find(op->buffer_var.get());
    if (it != rmap_.end()) {
      ++it->second.write_count;
    }
61 62 63 64
    if (fifo_map_.count(op->buffer_var.get())) {
      write_fifos_.insert(op->buffer_var.get());
      CHECK(!read_fifos_.count(op->buffer_var.get()));
    }
Tianqi Chen committed
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
    return IRMutator::Mutate_(op, s);
  }
  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
    if (cmap_.count(op->buffer_var.get())) {
      CHECK(!rmap_.count(op->buffer_var.get()));
      rmap_[op->buffer_var.get()] = Entry();
      Stmt body = Mutate(op->body);
      body = CreateChannelAccess(op, body);
      rmap_.erase(op->buffer_var.get());
      return body;
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }
  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
80
    if (op->attr_key == ir::attr::storage_scope) {
Tianqi Chen committed
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
      Var buf_var(op->node.node_);
      if (cmap_.count(buf_var.get())) return Mutate(op->body);
    }
    return IRMutator::Mutate_(op, s);
  }

 private:
  // Create channel access wrap
  Stmt CreateChannelAccess(const Allocate* op, Stmt body) {
    const Entry& rw = rmap_.at(op->buffer_var.get());
    CHECK(rw.write_count == 0 || rw.read_count == 0)
        << "Cannot read/write to the same channel " << op->buffer_var
        <<  " body:" << body;
    if (rw.write_count == 0 && rw.read_count == 0) {
      return body;
    }
    const Channel& ch = cmap_.at(op->buffer_var.get());
    int32_t csize = op->constant_allocation_size();
    Expr alloc_size;
    if (csize > 0) {
      alloc_size = IntImm::make(Int(32), csize);
    } else {
      alloc_size = op->extents[0];
      for (size_t i = 1; i < op->extents.size(); ++i) {
        alloc_size *= op->extents[i];
      }
      alloc_size = ir::Simplify(alloc_size);
    }

    if (rw.write_count) {
111
      return WriteChannel(ch, alloc_size, body);
Tianqi Chen committed
112 113
    } else {
      CHECK(rw.read_count);
114
      return ReadChannel(ch, alloc_size, body);
Tianqi Chen committed
115 116
    }
  }
117 118 119 120 121 122 123 124 125 126 127 128
  Stmt ReadChannel(Channel ch, Expr size, Stmt body) {
    return AttrStmt::make(
        ch, ir::attr::channel_read_scope, size,
        AttrStmt::make(ch, ir::attr::channel_read_advance, size,
                       body));
  }
  Stmt WriteChannel(Channel ch, Expr size, Stmt body) {
    return AttrStmt::make(
        ch, ir::attr::channel_write_scope, size,
        AttrStmt::make(ch, ir::attr::channel_write_advance, size,
                       body));
  }
Tianqi Chen committed
129 130 131 132 133 134
  struct Entry {
    int read_count{0};
    int write_count{0};
  };
  // The channels of each allocation.
  const std::unordered_map<const Variable*, Channel>& cmap_;
135 136
  // FIFO map.
  const std::unordered_map<const Variable*, Channel>& fifo_map_;
Tianqi Chen committed
137 138
  // the result.
  std::unordered_map<const Variable*, Entry> rmap_;
139 140
  // Accessed FIFOs
  std::unordered_set<const Variable*> read_fifos_, write_fifos_;
Tianqi Chen committed
141 142 143 144 145
};

// Mark the statment of each stage.
class StageSplitter : public IRMutator {
 public:
146 147 148 149
  using IRMutator::Mutate;
  explicit StageSplitter(bool split_load)
      : split_load_(split_load) {}

Tianqi Chen committed
150 151 152 153 154 155
  Stmt Mutate(Stmt stmt) final {
    nest_.push_back(stmt);
    Stmt ret = IRMutator::Mutate(stmt);
    nest_.pop_back();
    return ret;
  }
156 157 158 159
  Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
    if (!op->is_producer) {
      return Mutate(op->body);
    }
Tianqi Chen committed
160 161 162 163
    Stmt body = Mutate(op->body);
    stages_.emplace_back(BuildStage(body, op->func));
    return Evaluate::make(0);
  }
164 165 166 167 168 169 170 171 172
  Expr Mutate_(const Load* op, const Expr& e) final {
    if (!split_load_) return IRMutator::Mutate_(op, e);
    std::ostringstream cname;
    cname << "fifo." << temp_fifo_count_++;
    // Create FIFO channel for load.
    Channel ch = ChannelNode::make(Var(cname.str(), Handle()), op->type);
    Expr index = Mutate(op->index);
    Stmt provide = Store::make(
        ch->handle_var,
173 174
        Load::make(op->type, op->buffer_var, index, op->predicate),
        0, op->predicate);
175 176 177 178
    Stmt temp = nest_.back(); nest_.pop_back();
    stages_.emplace_back(BuildStage(provide, ch));
    nest_.push_back(temp);
    fifo_map_[ch->handle_var.get()] = ch;
179
    return Load::make(op->type, ch->handle_var, 0, op->predicate);
180
  }
Tianqi Chen committed
181

182
  Stmt Split(Stmt stmt, const ProducerConsumer* env) {
Tianqi Chen committed
183
    stmt = Mutate(stmt);
184 185 186 187 188 189
    if (env) {
      stages_.emplace_back(BuildStage(stmt, env->func));
    } else {
      stmt = RemoveNoOp(stmt);
      CHECK(is_no_op(stmt));
    }
Tianqi Chen committed
190 191 192 193 194
    CHECK_NE(stages_.size(), 0);
    stmt = stages_.back();
    for (size_t i = stages_.size() - 1; i != 0; --i) {
      stmt = Block::make(stages_[i - 1], stmt);
    }
195
    stmt = MarkChannelAccess(cmap_, fifo_map_).Mutate(stmt);
Tianqi Chen committed
196 197 198 199 200 201
    return RemoveNoOp(stmt);
  }

 private:
  // Build the stage.
  Stmt BuildStage(Stmt body, NodeRef target) {
202
    int stage_index = static_cast<int>(stages_.size());
Tianqi Chen committed
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
    std::string stage_suffix = "." + std::to_string(stage_index);
    // The Substitute
    Map<Var, Expr> subst;
    std::vector<Stmt> nest;
    Stmt no_op = Evaluate::make(0);

    for (const Stmt& s : nest_) {
      if (const For* op = s.as<For>()) {
        Var loop_var(op->loop_var);
        Var new_var = loop_var.copy_with_suffix(stage_suffix);
        subst.Set(loop_var, new_var);
        nest.emplace_back(For::make(
            new_var, op->min, op->extent,
            op->for_type, op->device_api, no_op));
      } else if (const LetStmt* op = s.as<LetStmt>()) {
        Var var(op->var);
        Var new_var = var.copy_with_suffix(stage_suffix);
        subst.Set(var, new_var);
        nest.emplace_back(LetStmt::make(new_var, op->value, no_op));
      } else if (const IfThenElse* op = s.as<IfThenElse>()) {
        CHECK(!op->else_case.defined());
        nest.emplace_back(IfThenElse::make(op->condition, no_op));
      } else if (const AttrStmt* op = s.as<AttrStmt>()) {
        nest.emplace_back(AttrStmt::make(
227
            op->node, op->attr_key, op->value, no_op));
Tianqi Chen committed
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
      } else if (s.as<ProducerConsumer>()) {
      } else if (s.as<Block>()) {
      } else if (const Allocate* op = s.as<Allocate>()) {
        nest.emplace_back(Allocate::make(
            op->buffer_var, op->type, op->extents,
            op->condition, no_op, op->new_expr, op->free_function));
        MarkChannel(op);
      } else {
        LOG(FATAL) << "not supported nest type " << s->type_key();
      }
    }
    body = Substitute(MergeNest(nest, body), subst);
    return AttrStmt::make(
        target, ir::attr::pipeline_stage_scope,
        make_const(Int(32), stage_index), body);
  }
  void MarkChannel(const Allocate* op) {
    if (!cmap_.count(op->buffer_var.get())) {
      Channel ch = ChannelNode::make(Var(op->buffer_var), op->type);
      cmap_[op->buffer_var.get()] = ch;
    }
  }
  // The stack
  std::vector<Stmt> nest_;
  // The stages
  std::vector<Stmt> stages_;
  // channel map
  std::unordered_map<const Variable*, Channel> cmap_;
256 257 258 259 260 261 262 263 264 265 266 267 268 269
  // Whether split load into a temp fifo.
  bool split_load_{true};
  // Counter for temp FIFOs.
  size_t temp_fifo_count_{0};
  // fifo map
  std::unordered_map<const Variable*, Channel> fifo_map_;
};

class PipelineSplitter : public IRMutator {
 public:
  explicit PipelineSplitter(bool split_load)
      : split_load_(split_load) {}

  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
270
    if (op->attr_key == ir::attr::pipeline_exec_scope) {
271 272 273 274 275 276 277 278 279
      CHECK_LE(env_.size(), 1U);
      const ProducerConsumer* env = nullptr;
      if (env_.size() == 1) {
        std::swap(env_[0], env);
      }
      Stmt body = StageSplitter(split_load_).Split(
          op->body, env);
      if (body.same_as(op->body)) return s;
      return AttrStmt::make(
280
          op->node, op->attr_key, op->value, body);
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }
  Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) {
    env_.push_back(op);
    Stmt ret = IRMutator::Mutate_(op, s);
    if (env_.back() == nullptr) {
      ret = ret.as<ProducerConsumer>()->body;
    }
    env_.pop_back();
    return ret;
  }

 private:
  bool split_load_;
  std::vector<const ProducerConsumer *> env_;
Tianqi Chen committed
298 299
};

300 301
Stmt SplitPipeline(Stmt stmt, bool split_load) {
  return PipelineSplitter(split_load).Mutate(stmt);
Tianqi Chen committed
302 303 304 305
}

}  // namespace ir
}  // namespace tvm