split_pipeline.cc 10.3 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

Tianqi Chen committed
20 21 22 23 24 25 26 27 28 29 30 31
/*!
 *  Copyright (c) 2017 by Contributors
 * \file split_pipeline.cc
 * \brief Split statement into pipeline stage modules.
 */
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/channel.h>
#include <unordered_map>
32
#include <unordered_set>
33
#include "ir_util.h"
Tianqi Chen committed
34 35 36 37 38 39 40

namespace tvm {
namespace ir {

class MarkChannelAccess : public IRMutator {
 public:
  MarkChannelAccess(
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
      const std::unordered_map<const Variable*, Channel>& cmap,
      const std::unordered_map<const Variable*, Channel>& fifo_map)
      : cmap_(cmap), fifo_map_(fifo_map) {}
  using IRMutator::Mutate;
  Stmt Mutate(Stmt stmt) final {
    Stmt ret = IRMutator::Mutate(stmt);
    if (read_fifos_.size() != 0) {
      for (const Variable* v : read_fifos_) {
        Channel ch = fifo_map_.at(v);
        ret = ReadChannel(ch, 1, ret);
      }
      read_fifos_.clear();
    }
    if (write_fifos_.size() != 0) {
      for (const Variable* v : write_fifos_) {
        Channel ch = fifo_map_.at(v);
        ret = WriteChannel(ch, 1, ret);
      }
      write_fifos_.clear();
    }
    return ret;
  }
Tianqi Chen committed
63 64 65 66 67 68

  Expr Mutate_(const Load *op, const Expr& e) final {
    auto it = rmap_.find(op->buffer_var.get());
    if (it != rmap_.end()) {
      ++it->second.read_count;
    }
69 70 71 72
    if (fifo_map_.count(op->buffer_var.get())) {
      read_fifos_.insert(op->buffer_var.get());
      CHECK(!write_fifos_.count(op->buffer_var.get()));
    }
Tianqi Chen committed
73 74 75 76 77 78 79
    return IRMutator::Mutate_(op, e);
  }
  Stmt Mutate_(const Store *op, const Stmt& s) final {
    auto it = rmap_.find(op->buffer_var.get());
    if (it != rmap_.end()) {
      ++it->second.write_count;
    }
80 81 82 83
    if (fifo_map_.count(op->buffer_var.get())) {
      write_fifos_.insert(op->buffer_var.get());
      CHECK(!read_fifos_.count(op->buffer_var.get()));
    }
Tianqi Chen committed
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    return IRMutator::Mutate_(op, s);
  }
  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
    if (cmap_.count(op->buffer_var.get())) {
      CHECK(!rmap_.count(op->buffer_var.get()));
      rmap_[op->buffer_var.get()] = Entry();
      Stmt body = Mutate(op->body);
      body = CreateChannelAccess(op, body);
      rmap_.erase(op->buffer_var.get());
      return body;
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }
  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
99
    if (op->attr_key == ir::attr::storage_scope) {
100
      Var buf_var = Downcast<Var>(op->node);
Tianqi Chen committed
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
      if (cmap_.count(buf_var.get())) return Mutate(op->body);
    }
    return IRMutator::Mutate_(op, s);
  }

 private:
  // Create channel access wrap
  Stmt CreateChannelAccess(const Allocate* op, Stmt body) {
    const Entry& rw = rmap_.at(op->buffer_var.get());
    CHECK(rw.write_count == 0 || rw.read_count == 0)
        << "Cannot read/write to the same channel " << op->buffer_var
        <<  " body:" << body;
    if (rw.write_count == 0 && rw.read_count == 0) {
      return body;
    }
    const Channel& ch = cmap_.at(op->buffer_var.get());
    int32_t csize = op->constant_allocation_size();
    Expr alloc_size;
    if (csize > 0) {
      alloc_size = IntImm::make(Int(32), csize);
    } else {
      alloc_size = op->extents[0];
      for (size_t i = 1; i < op->extents.size(); ++i) {
124
        alloc_size = alloc_size * op->extents[i];
Tianqi Chen committed
125 126 127 128
      }
    }

    if (rw.write_count) {
129
      return WriteChannel(ch, alloc_size, body);
Tianqi Chen committed
130 131
    } else {
      CHECK(rw.read_count);
132
      return ReadChannel(ch, alloc_size, body);
Tianqi Chen committed
133 134
    }
  }
135 136 137 138 139 140 141 142 143 144 145 146
  Stmt ReadChannel(Channel ch, Expr size, Stmt body) {
    return AttrStmt::make(
        ch, ir::attr::channel_read_scope, size,
        AttrStmt::make(ch, ir::attr::channel_read_advance, size,
                       body));
  }
  Stmt WriteChannel(Channel ch, Expr size, Stmt body) {
    return AttrStmt::make(
        ch, ir::attr::channel_write_scope, size,
        AttrStmt::make(ch, ir::attr::channel_write_advance, size,
                       body));
  }
Tianqi Chen committed
147 148 149 150 151 152
  struct Entry {
    int read_count{0};
    int write_count{0};
  };
  // The channels of each allocation.
  const std::unordered_map<const Variable*, Channel>& cmap_;
153 154
  // FIFO map.
  const std::unordered_map<const Variable*, Channel>& fifo_map_;
Tianqi Chen committed
155 156
  // the result.
  std::unordered_map<const Variable*, Entry> rmap_;
157 158
  // Accessed FIFOs
  std::unordered_set<const Variable*> read_fifos_, write_fifos_;
Tianqi Chen committed
159 160 161 162 163
};

// Mark the statment of each stage.
class StageSplitter : public IRMutator {
 public:
164 165 166 167
  using IRMutator::Mutate;
  explicit StageSplitter(bool split_load)
      : split_load_(split_load) {}

Tianqi Chen committed
168 169 170 171 172 173
  Stmt Mutate(Stmt stmt) final {
    nest_.push_back(stmt);
    Stmt ret = IRMutator::Mutate(stmt);
    nest_.pop_back();
    return ret;
  }
174 175 176 177
  Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
    if (!op->is_producer) {
      return Mutate(op->body);
    }
Tianqi Chen committed
178 179 180 181
    Stmt body = Mutate(op->body);
    stages_.emplace_back(BuildStage(body, op->func));
    return Evaluate::make(0);
  }
182 183 184 185 186 187 188 189 190
  Expr Mutate_(const Load* op, const Expr& e) final {
    if (!split_load_) return IRMutator::Mutate_(op, e);
    std::ostringstream cname;
    cname << "fifo." << temp_fifo_count_++;
    // Create FIFO channel for load.
    Channel ch = ChannelNode::make(Var(cname.str(), Handle()), op->type);
    Expr index = Mutate(op->index);
    Stmt provide = Store::make(
        ch->handle_var,
191 192
        Load::make(op->type, op->buffer_var, index, op->predicate),
        0, op->predicate);
193 194 195 196
    Stmt temp = nest_.back(); nest_.pop_back();
    stages_.emplace_back(BuildStage(provide, ch));
    nest_.push_back(temp);
    fifo_map_[ch->handle_var.get()] = ch;
197
    return Load::make(op->type, ch->handle_var, 0, op->predicate);
198
  }
Tianqi Chen committed
199

200
  Stmt Split(Stmt stmt, const ProducerConsumer* env) {
Tianqi Chen committed
201
    stmt = Mutate(stmt);
202 203 204 205 206 207
    if (env) {
      stages_.emplace_back(BuildStage(stmt, env->func));
    } else {
      stmt = RemoveNoOp(stmt);
      CHECK(is_no_op(stmt));
    }
Tianqi Chen committed
208 209 210 211 212
    CHECK_NE(stages_.size(), 0);
    stmt = stages_.back();
    for (size_t i = stages_.size() - 1; i != 0; --i) {
      stmt = Block::make(stages_[i - 1], stmt);
    }
213
    stmt = MarkChannelAccess(cmap_, fifo_map_).Mutate(stmt);
Tianqi Chen committed
214 215 216 217 218 219
    return RemoveNoOp(stmt);
  }

 private:
  // Build the stage.
  Stmt BuildStage(Stmt body, NodeRef target) {
220
    int stage_index = static_cast<int>(stages_.size());
Tianqi Chen committed
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
    std::string stage_suffix = "." + std::to_string(stage_index);
    // The Substitute
    Map<Var, Expr> subst;
    std::vector<Stmt> nest;
    Stmt no_op = Evaluate::make(0);

    for (const Stmt& s : nest_) {
      if (const For* op = s.as<For>()) {
        Var loop_var(op->loop_var);
        Var new_var = loop_var.copy_with_suffix(stage_suffix);
        subst.Set(loop_var, new_var);
        nest.emplace_back(For::make(
            new_var, op->min, op->extent,
            op->for_type, op->device_api, no_op));
      } else if (const LetStmt* op = s.as<LetStmt>()) {
        Var var(op->var);
        Var new_var = var.copy_with_suffix(stage_suffix);
        subst.Set(var, new_var);
        nest.emplace_back(LetStmt::make(new_var, op->value, no_op));
      } else if (const IfThenElse* op = s.as<IfThenElse>()) {
        CHECK(!op->else_case.defined());
        nest.emplace_back(IfThenElse::make(op->condition, no_op));
      } else if (const AttrStmt* op = s.as<AttrStmt>()) {
        nest.emplace_back(AttrStmt::make(
245
            op->node, op->attr_key, op->value, no_op));
Tianqi Chen committed
246 247 248 249 250 251 252 253
      } else if (s.as<ProducerConsumer>()) {
      } else if (s.as<Block>()) {
      } else if (const Allocate* op = s.as<Allocate>()) {
        nest.emplace_back(Allocate::make(
            op->buffer_var, op->type, op->extents,
            op->condition, no_op, op->new_expr, op->free_function));
        MarkChannel(op);
      } else {
254
        LOG(FATAL) << "not supported nest type " << s->GetTypeKey();
Tianqi Chen committed
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
      }
    }
    body = Substitute(MergeNest(nest, body), subst);
    return AttrStmt::make(
        target, ir::attr::pipeline_stage_scope,
        make_const(Int(32), stage_index), body);
  }
  void MarkChannel(const Allocate* op) {
    if (!cmap_.count(op->buffer_var.get())) {
      Channel ch = ChannelNode::make(Var(op->buffer_var), op->type);
      cmap_[op->buffer_var.get()] = ch;
    }
  }
  // The stack
  std::vector<Stmt> nest_;
  // The stages
  std::vector<Stmt> stages_;
  // channel map
  std::unordered_map<const Variable*, Channel> cmap_;
274 275 276 277 278 279 280 281 282 283 284 285 286 287
  // Whether split load into a temp fifo.
  bool split_load_{true};
  // Counter for temp FIFOs.
  size_t temp_fifo_count_{0};
  // fifo map
  std::unordered_map<const Variable*, Channel> fifo_map_;
};

class PipelineSplitter : public IRMutator {
 public:
  explicit PipelineSplitter(bool split_load)
      : split_load_(split_load) {}

  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
288
    if (op->attr_key == ir::attr::pipeline_exec_scope) {
289 290 291 292 293 294 295 296 297
      CHECK_LE(env_.size(), 1U);
      const ProducerConsumer* env = nullptr;
      if (env_.size() == 1) {
        std::swap(env_[0], env);
      }
      Stmt body = StageSplitter(split_load_).Split(
          op->body, env);
      if (body.same_as(op->body)) return s;
      return AttrStmt::make(
298
          op->node, op->attr_key, op->value, body);
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }
  Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) {
    env_.push_back(op);
    Stmt ret = IRMutator::Mutate_(op, s);
    if (env_.back() == nullptr) {
      ret = ret.as<ProducerConsumer>()->body;
    }
    env_.pop_back();
    return ret;
  }

 private:
  bool split_load_;
  std::vector<const ProducerConsumer *> env_;
Tianqi Chen committed
316 317
};

318 319
Stmt SplitPipeline(Stmt stmt, bool split_load) {
  return PipelineSplitter(split_load).Mutate(stmt);
Tianqi Chen committed
320 321 322 323
}

}  // namespace ir
}  // namespace tvm