inject_virtual_thread.cc 15.9 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22
/*!
 * \file inject_virtual_thread.cc
 */
23 24 25
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
26
#include <unordered_set>
27
#include "../../arith/compute_expr.h"
28 29

namespace tvm {
30
namespace tir {
31 32

// If expression is touched by var.
33
class ExprTouched final : public StmtExprVisitor {
34
 public:
35
  explicit ExprTouched(const std::unordered_set<const VarNode*> &touched,
36 37
                       bool check_write)
      : touched_var_(touched), check_write_(check_write) {}
38

39
  void VisitExpr(const PrimExpr& n) final {
40
    // early stopping
41
    if (expr_touched_ && !check_write_) return;
42
    StmtExprVisitor::VisitExpr(n);
43
  }
44 45 46 47 48
    void VisitStmt(const Stmt& n) final {
    // early stopping
    if (expr_touched_ && !check_write_) return;
    StmtExprVisitor::VisitStmt(n);
  }
49
  void VisitExpr_(const LoadNode *op) final {
50
    HandleUseVar(op->buffer_var.get());
51
    StmtExprVisitor::VisitExpr_(op);
52
  }
53
  void VisitExpr_(const VarNode *op) final {
54 55
    HandleUseVar(op);
  }
56
  void VisitExpr_(const CallNode *op) final {
57
    if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
58
      int rw_mask = 0;
59
      CHECK(arith::GetConstInt(op->args[4], &rw_mask));
60
      const VarNode* buffer_var = op->args[1].as<VarNode>();
61 62 63 64 65 66 67 68
      CHECK(buffer_var);
      // read
      if (rw_mask & 1) {
        HandleUseVar(buffer_var);
      }
      if (rw_mask & 2) {
        HandleWriteVar(buffer_var);
      }
69
      this->VisitExpr(op->args[2]);
70
    } else {
71
      StmtExprVisitor::VisitExpr_(op);
72 73
    }
  }
74
  void HandleUseVar(const VarNode* var) {
75 76 77 78 79 80 81 82 83 84
    auto it = touched_var_.find(var);
    if (it != touched_var_.end()) {
      expr_touched_ = true;
    }
    // rember the used vars
    // in case the var get touched later in a loop.
    if (!expr_touched_) {
      used_vars_.push_back(var);
    }
  }
85
  void HandleWriteVar(const VarNode* var) {
86 87
    write_vars_.push_back(var);
  }
88 89
  // the fields.
  bool expr_touched_{false};
90 91 92
  std::vector<const VarNode*> used_vars_;
  std::vector<const VarNode*> write_vars_;
  const std::unordered_set<const VarNode*>& touched_var_;
93
  bool check_write_;
94 95 96
};

// Analyze if the buffers are invariant to value of var
97
class VarTouchedAnalysis : public StmtVisitor {
98
 public:
99
  void VisitStmt_(const LetStmtNode* op) final {
100
    ExprTouched tc(touched_var_, false);
101
    tc(op->value);
102
    Record(op->var.get(), tc);
103
    this->VisitStmt(op->body);
104
  }
105
  void VisitStmt_(const StoreNode* op) final {
106
    ExprTouched tc(touched_var_, false);
107 108
    tc(op->value);
    tc(op->index);
109 110
    Record(op->buffer_var.get(), tc);
  }
111
  void VisitStmt_(const ForNode* op) final {
112
    ExprTouched tc(touched_var_, false);
113 114
    tc(op->min);
    tc(op->extent);
115
    Record(op->loop_var.get(), tc);
116
    this->VisitStmt(op->body);
117
  }
118
  // external function call
119
  void VisitStmt_(const EvaluateNode* op) final {
120
    ExprTouched tc(touched_var_, true);
121
    tc(op->value);
122
    for (const VarNode* var : tc.write_vars_) {
123 124 125
      Record(var, tc);
    }
  }
126
  void VisitStmt_(const AllocateNode* op) final {
127
    ExprTouched tc(touched_var_, false);
128
    for (size_t i = 0; i < op->extents.size(); ++i) {
129
      tc(op->extents[i]);
130
    }
131
    tc.VisitExpr(op->condition);
132
    if (op->new_expr.defined()) {
133
      tc(op->new_expr);
134 135
    }
    Record(op->buffer_var.get(), tc);
136
    this->VisitStmt(op->body);
137
  }
138
  void Record(const VarNode* var,
139 140 141 142 143
              const ExprTouched& tc) {
    if (touched_var_.count(var)) return;
    if (tc.expr_touched_) {
      touched_var_.insert(var);
    } else {
144
      for (const VarNode* r : tc.used_vars_) {
145 146 147
        if (r != var) {
          affect_[r].push_back(var);
        }
148 149 150 151
      }
    }
  }

152
  std::unordered_set<const VarNode*>
153
  TouchedVar(const Stmt& stmt,
154
             const VarNode* var) {
155
    touched_var_.insert(var);
156
    this->VisitStmt(stmt);
157
    // do a DFS to push affect around dependency.
158
    std::vector<const VarNode*> pending(
159 160
        touched_var_.begin(), touched_var_.end());
    while (!pending.empty()) {
161
      const VarNode* v = pending.back();
162
      pending.pop_back();
163
      for (const VarNode* r : affect_[v]) {
164 165 166 167 168 169 170 171 172 173 174
        if (!touched_var_.count(r)) {
          touched_var_.insert(r);
          pending.push_back(r);
        }
      }
    }
    return std::move(touched_var_);
  }

 private:
  // Whether variable is touched by the thread variable.
175
  std::unordered_set<const VarNode*> touched_var_;
176
  // x -> all the buffers x read from
177 178
  std::unordered_map<const VarNode*,
                     std::vector<const VarNode*> > affect_;
179 180 181 182 183
};


// Inject virtual thread loop
// rewrite the buffer access pattern when necessary.
184
class VTInjector : public StmtExprMutator {
185 186 187 188
 public:
  // constructor
  VTInjector(Var var,
             int num_threads,
189
             const std::unordered_set<const VarNode*>& touched_var,
190 191 192
             bool allow_share)
      : var_(var), num_threads_(num_threads),
        touched_var_(touched_var), allow_share_(allow_share) {
193 194
  }
  // Inject VTLoop when needed.
195
  Stmt VisitStmt(const Stmt& s) final {
196
    CHECK(!visit_touched_var_);
197
    auto stmt = StmtExprMutator::VisitStmt(s);
198 199 200 201
    if (visit_touched_var_ || trigger_base_inject_) {
      if (!vt_loop_injected_)  {
        return InjectVTLoop(stmt, false);
      }
202
      visit_touched_var_ = false;
203
      trigger_base_inject_ = false;
204 205 206 207
    }
    return stmt;
  }
  // Variable
208
  PrimExpr VisitExpr_(const VarNode* op) final {
209 210
    CHECK(!alloc_remap_.count(op))
        << "Buffer address may get rewritten in virtual thread";
211 212 213
    if (touched_var_.count(op)) {
      visit_touched_var_ = true;
    }
214
    return GetRef<PrimExpr>(op);
215
  }
216
  PrimExpr RewriteIndex(PrimExpr index, PrimExpr alloc_extent) const {
217
    return index + var_ * alloc_extent;
218 219
  }
  // Load
220 221
  PrimExpr VisitExpr_(const LoadNode* op) final {
    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
222
    op = expr.as<LoadNode>();
223 224 225
    if (touched_var_.count(op->buffer_var.get())) {
      visit_touched_var_ = true;
    }
226 227
    auto it = alloc_remap_.find(op->buffer_var.get());
    if (it != alloc_remap_.end()) {
228
      return LoadNode::make(op->dtype, op->buffer_var,
229 230
                        RewriteIndex(op->index, it->second),
                        op->predicate);
231 232 233 234
    } else {
      return expr;
    }
  }
235
  // Expression.
236
  PrimExpr VisitExpr_(const CallNode* op) final {
237 238
    if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
      CHECK_EQ(op->args.size(), 5U);
239
      DataType dtype = op->args[0].dtype();
240
      const VarNode* buffer = op->args[1].as<VarNode>();
241
      auto it = alloc_remap_.find(buffer);
242
      if (it == alloc_remap_.end()) return StmtExprMutator::VisitExpr_(op);
243
      visit_touched_var_ = true;
244 245 246
      PrimExpr offset = this->VisitExpr(op->args[2]);
      PrimExpr extent = this->VisitExpr(op->args[3]);
      PrimExpr stride =
247
          it->second / make_const(offset.dtype(), dtype.lanes());
248
      offset = stride * var_ + offset;
249
      return CallNode::make(
250
          op->dtype, op->name,
251 252 253
          {op->args[0], op->args[1], offset, extent, op->args[4]},
          op->call_type);
    } else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
254
      return allow_share_ ? GetRef<PrimExpr>(op) : var_;
255
    } else {
256
      return StmtExprMutator::VisitExpr_(op);
257 258
    }
  }
259
  Stmt VisitStmt_(const EvaluateNode* op) final {
260
    trigger_base_inject_ = !allow_share_;
261
    return StmtExprMutator::VisitStmt_(op);
262
  }
263
  // Store
264
  Stmt VisitStmt_(const StoreNode* op) final {
265
    Stmt stmt = StmtExprMutator::VisitStmt_(op);
266
    op = stmt.as<StoreNode>();
267 268 269
    if (touched_var_.count(op->buffer_var.get())) {
      visit_touched_var_ = true;
    }
270 271 272
    trigger_base_inject_ = !allow_share_;
    auto it = alloc_remap_.find(op->buffer_var.get());
    if (it != alloc_remap_.end()) {
273
      return StoreNode::make(op->buffer_var,
274
                         op->value,
275 276
                         RewriteIndex(op->index, it->second),
                         op->predicate);
277 278 279 280 281
    } else {
      return stmt;
    }
  }
  // Attribute
282
  Stmt VisitStmt_(const AttrStmtNode* op) final {
283
    PrimExpr value = this->VisitExpr(op->value);
284
    if (visit_touched_var_ && !vt_loop_injected_) {
285
      return InjectVTLoop(GetRef<Stmt>(op), true);
286
    } else if (!allow_share_ && !vt_loop_injected_ &&
287 288
               (op->attr_key == attr::coproc_uop_scope ||
                op->attr_key == attr::coproc_scope)) {
289
      return InjectVTLoop(GetRef<Stmt>(op), true);
290
    } else {
291
      Stmt body = this->VisitStmt(op->body);
292 293
      if (value.same_as(op->value) &&
          body.same_as(op->body)) {
294
        return GetRef<Stmt>(op);
295
      } else {
296
        return AttrStmtNode::make(op->node, op->attr_key, value, body);
297 298 299 300
      }
    }
  }
  // LetStmt
301
  Stmt VisitStmt_(const LetStmtNode* op) final {
302
    PrimExpr value = this->VisitExpr(op->value);
303
    if (visit_touched_var_ && !vt_loop_injected_) {
304
      return InjectVTLoop(GetRef<Stmt>(op), true);
305 306
    }
    visit_touched_var_ = false;
307
    Stmt body = this->VisitStmt(op->body);
308 309
    if (value.same_as(op->value) &&
        body.same_as(op->body)) {
310
      return GetRef<Stmt>(op);
311
    } else {
312
      return LetStmtNode::make(op->var, value, body);
313 314 315
    }
  }
  // For
316
  Stmt VisitStmt_(const ForNode* op) final {
317
    CHECK(is_zero(op->min));
318
    PrimExpr extent = this->VisitExpr(op->extent);
319
    if (visit_touched_var_ && !vt_loop_injected_) {
320
      Stmt stmt = InjectVTLoop(GetRef<Stmt>(op), true);
321 322 323 324
      ++max_loop_depth_;
      return stmt;
    }
    visit_touched_var_ = false;
325
    Stmt body = this->VisitStmt(op->body);
326 327 328
    ++max_loop_depth_;
    if (extent.same_as(op->extent) &&
        body.same_as(op->body)) {
329
      return GetRef<Stmt>(op);
330
    } else {
331
      return ForNode::make(
332 333 334 335
          op->loop_var, op->min, extent, op->for_type, op->device_api, body);
    }
  }
  // IfThenElse
336
  Stmt VisitStmt_(const IfThenElseNode* op) final {
337
    PrimExpr condition = this->VisitExpr(op->condition);
338
    if (visit_touched_var_ && !vt_loop_injected_) {
339
      return InjectVTLoop(GetRef<Stmt>(op), true);
340 341 342
    }
    visit_touched_var_ = false;
    CHECK_EQ(max_loop_depth_, 0);
343
    Stmt then_case = this->VisitStmt(op->then_case);
344
    Stmt else_case;
345
    if (op->else_case.defined()) {
346 347
      int temp = max_loop_depth_;
      max_loop_depth_ = 0;
348
      else_case = this->VisitStmt(op->else_case);
349 350 351 352 353
      max_loop_depth_ = std::max(temp, max_loop_depth_);
    }
    if (condition.same_as(op->condition) &&
        then_case.same_as(op->then_case) &&
        else_case.same_as(op->else_case)) {
354
      return GetRef<Stmt>(op);
355
    } else {
356
      return IfThenElseNode::make(condition, then_case, else_case);
357 358
    }
  }
359 360 361

  // Seq
  Stmt VisitStmt_(const SeqStmtNode* op) final {
362
    CHECK_EQ(max_loop_depth_, 0);
363 364 365 366 367 368 369 370
    auto fmutate = [this](const Stmt& s) {
      int temp = max_loop_depth_;
      max_loop_depth_ = 0;
      Stmt ret = this->VisitStmt(s);
      max_loop_depth_ = std::max(max_loop_depth_, temp);
      return ret;
    };
    return StmtMutator::VisitSeqStmt_(op, false, fmutate);
371 372
  }
  // Allocate
373
  Stmt VisitStmt_(const AllocateNode* op) final {
374
    if (op->new_expr.defined() && !vt_loop_injected_) {
375
      return InjectVTLoop(GetRef<Stmt>(op), true);
376
    }
377
    PrimExpr condition = this->VisitExpr(op->condition);
378
    if (visit_touched_var_ && !vt_loop_injected_) {
379
      return InjectVTLoop(GetRef<Stmt>(op), true);
380 381 382
    }

    bool changed = false;
383
    Array<PrimExpr> extents;
384
    for (size_t i = 0; i < op->extents.size(); i++) {
385
      PrimExpr new_ext = this->VisitExpr(op->extents[i]);
386
      if (visit_touched_var_ && !vt_loop_injected_) {
387
        return InjectVTLoop(GetRef<Stmt>(op), true);
388 389 390 391 392 393 394
      }
      if (!new_ext.same_as(op->extents[i])) changed = true;
      extents.push_back(new_ext);
    }
    visit_touched_var_ = false;

    Stmt body;
395 396
    // always rewrite if not allow sharing.
    if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
397
      // place v on highest dimension.
398 399 400
      PrimExpr stride = arith::ComputeReduce<MulNode>(
          op->extents, PrimExpr()) * op->dtype.lanes();
      Array<PrimExpr> other;
401
      other.push_back(make_const(op->extents[0].dtype(), num_threads_));
402
      for (PrimExpr e : extents) {
403 404 405 406 407
        other.push_back(e);
      }
      extents = other;
      changed = true;
      // mark this buffer get touched.
408
      alloc_remap_[op->buffer_var.get()] = stride;
409
      // Mutate the body.
410
      body = this->VisitStmt(op->body);
411 412
    } else {
      // Mutate the body.
413
      body = this->VisitStmt(op->body);
414 415 416 417
    }
    if (!changed &&
        body.same_as(op->body) &&
        condition.same_as(op->condition)) {
418
      return GetRef<Stmt>(op);
419
    } else {
420
      return AllocateNode::make(
421
          op->buffer_var, op->dtype,
422 423 424 425 426 427 428 429 430 431
          extents, condition, body,
          op->new_expr, op->free_function);
    }
  }

  // inject vthread loop
  Stmt InjectVTLoop(Stmt stmt, bool before_mutation) {
    CHECK(!vt_loop_injected_);
    // reset the flags
    visit_touched_var_ = false;
432
    trigger_base_inject_ = false;
433 434
    vt_loop_injected_ = true;
    if (before_mutation) {
435
      stmt = this->VisitStmt(stmt);
436 437 438 439
    }
    // reset the flags after processing.
    vt_loop_injected_ = false;
    visit_touched_var_ = false;
440 441
    // only unroll if number of vthreads are small
    if (max_loop_depth_ == 0 && num_threads_ < 16) {
442
      // do unrolling if it is inside innermost content.
443 444 445
      Array<Stmt> seq;
      for (int i = 0; i < num_threads_; ++i) {
        seq.push_back(Substitute(stmt, {{var_, make_const(var_.dtype(), i)}}));
446
      }
447
      return SeqStmt::Flatten(seq);
448 449
    } else {
      // insert a for loop
450
      Var idx(var_->name_hint + ".s", var_->dtype);
451
      Map<Var, PrimExpr> values{{var_, idx}};
Changming Sun committed
452
      stmt = Substitute(stmt, values);
453
      return ForNode::make(idx, make_zero(idx.dtype()),
454
                       make_const(idx.dtype(), num_threads_),
455 456 457 458 459 460 461 462 463 464 465 466 467
                       ForType::Serial, DeviceAPI::None, stmt);
    }
  }

 private:
  // vthread variable
  Var var_;
  // the threads/lanes
  int num_threads_;
  // whethe the loop is already injected.
  bool vt_loop_injected_{false};
  // whether current expression get touched.
  bool visit_touched_var_{false};
468 469
  // Trigger base stmt
  bool trigger_base_inject_{false};
470 471 472
  // the counter of loops in after mutation.
  int max_loop_depth_{0};
  // The variables that get touched.
473
  const std::unordered_set<const VarNode*>& touched_var_;
474 475
  // Whether allow shareding.
  bool allow_share_;
476
  // The allocations that get touched -> extent
477
  std::unordered_map<const VarNode*, PrimExpr> alloc_remap_;
478 479 480
};


481
class VirtualThreadInjector : public StmtMutator {
482
 public:
483
  Stmt VisitStmt_(const AttrStmtNode* op) final {
484
    Stmt stmt = StmtMutator::VisitStmt_(op);
485
    op = stmt.as<AttrStmtNode>();
486
    if (op->attr_key == attr::virtual_thread) {
487
      IterVar iv = Downcast<IterVar>(op->node);
488
      bool allow_share = iv->thread_tag == "vthread";
489
      int nthread = static_cast<int>(op->value.as<IntImmNode>()->value);
490 491
      VarTouchedAnalysis vs;
      auto touched = vs.TouchedVar(op->body, iv->var.get());
492
      VTInjector injecter(iv->var, nthread, touched, allow_share);
493
      return injecter(op->body);
494 495 496 497 498
    } else {
      return stmt;
    }
  }

499
  Stmt VisitStmt_(const ProvideNode* op) final {
500
    LOG(FATAL) << "Need to call StorageFlatten first";
501
    return GetRef<Stmt>(op);
502 503 504 505
  }
};

Stmt InjectVirtualThread(Stmt stmt) {
506 507
  stmt = VirtualThreadInjector()(std::move(stmt));
  return ConvertSSA(std::move(stmt));
508 509
}

510
}  // namespace tir
511
}  // namespace tvm