storage_sync.cc 12.6 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22
/*!
 * \file storage_sync.cc
 */
23 24 25
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
26 27
#include <unordered_map>
#include <unordered_set>
28 29
#include "ir_util.h"
#include "storage_access.h"
30
#include "../../runtime/thread_storage_scope.h"
31 32

namespace tvm {
33
namespace tir {
34

35
class ThreadSyncPlanner : public StorageAccessVisitor {
36
 public:
37 38
  explicit ThreadSyncPlanner(StorageScope sync_scope)
      : sync_scope_(sync_scope) {}
39

40
    // The syncs inserted before each statement
41
  std::unordered_set<const Object*> syncs_inserted_;
42

43
 protected:
44
  bool Enabled(const VarNode* buf,
45 46
               const StorageScope& scope) const final {
    return in_device_env() && scope == sync_scope_;
47 48
  }
  // Plan the sync
49
  std::vector<AccessEntry> Summarize(
50
      std::vector<StmtEntry> seq, const ForNode* loop) final {
51
    // Unsynced reads and writes
52 53 54 55
    std::vector<AccessEntry> reads;
    std::vector<AccessEntry> writes;
    // if it is a loop, rotate two times to consider effect of loop.
    // simulation based approach to find dependenceies
56 57
    for (size_t i = 0; i < seq.size(); ++i) {
      const StmtEntry& s = seq[i];
58 59 60 61 62 63 64 65 66
      // check if sync before statement is needed.
      bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
      // Apply the syncs added already.
      if (sync_before_stmt) {
        reads.clear();
        writes.clear();
      }
      for (const AccessEntry& acc : s.access) {
        if (acc.type == kRead) {
67
          if (FindConflict(writes, acc, false)) {
68 69 70
            sync_before_stmt = true; break;
          }
        } else if (acc.type == kWrite) {
71
          if (FindConflict(reads, acc, false)) {
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
            sync_before_stmt = true; break;
          }
        } else if (acc.type == kSync) {
          reads.clear(); writes.clear();
        }
      }
      // If sync is inserted. remove the irrelevant things.
      if (sync_before_stmt) {
        reads.clear(); writes.clear();
      }
      // Add the read/write of current statement
      for (const AccessEntry& acc : s.access) {
        if (acc.type == kRead) {
          reads.push_back(acc);
        } else if (acc.type == kWrite) {
          writes.push_back(acc);
        } else if (acc.type == kSync) {
          reads.clear(); writes.clear();
        }
      }
      if (sync_before_stmt) {
93
        CHECK_EQ(condition_counter(), 0)
94 95 96 97
            << "Cannot insert syncs inside condition";
        syncs_inserted_.insert(s.stmt);
      }
    }
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
    if (loop != nullptr) {
      for (size_t i = 0; i < seq.size(); ++i) {
        const StmtEntry& s = seq[i];
        if (syncs_inserted_.count(s.stmt) != 0) break;
        if (reads.empty() && writes.empty()) break;
        bool sync_before_stmt = false;
        for (const AccessEntry& acc : s.access) {
          if (acc.type == kRead) {
            if (FindConflict(writes, acc, true)) {
              sync_before_stmt = true; break;
            }
          } else if (acc.type == kWrite) {
            if (FindConflict(reads, acc, true)) {
              sync_before_stmt = true; break;
            }
          } else if (acc.type == kSync) {
            reads.clear(); writes.clear();
          }
        }
        if (sync_before_stmt) {
          CHECK_EQ(condition_counter(), 0)
              << "Cannot insert syncs inside condition";
          syncs_inserted_.insert(s.stmt);
          break;
        }
      }
    }
125 126 127 128
    // return the exposed entries, remove unecessary ones.
    int sync_count = 0;
    // head are before first sync, tail are after last sync
    std::vector<AccessEntry> head, tail;
129 130 131 132 133
    AccessEntry esync;
    esync.threads = this->env_threads();
    esync.type = kSync;
    esync.scope = sync_scope_;

134 135 136 137 138
    for (const StmtEntry& s : seq) {
      if (syncs_inserted_.count(s.stmt)) {
        if (sync_count != 0) {
          tail.clear();
        } else {
139
          head.push_back(esync);
140 141 142 143 144 145 146 147
        }
        ++sync_count;
      }
      for (const AccessEntry& acc : s.access) {
        if (acc.type == kSync) {
          if (sync_count != 0) {
            tail.clear();
          } else {
148
            head.push_back(esync);
149 150 151 152 153 154 155 156 157 158 159 160
          }
          ++sync_count;
        } else {
          if (sync_count != 0) {
            tail.push_back(acc);
          } else {
            head.push_back(acc);
          }
        }
      }
    }
    head.insert(head.end(), tail.begin(), tail.end());
161 162 163 164 165 166
    if (loop != nullptr) {
      // clear double buffer flag after a loop is finished.
      for (AccessEntry& e : head) {
        e.double_buffer_write = false;
      }
    }
167 168
    return head;
  }
169 170

 private:
171
  // find conflicting entry in vec.
172
  bool FindConflict(const std::vector<AccessEntry>& vec,
173 174
                    const AccessEntry& e,
                    bool loop_carry) {
175
    for (const AccessEntry& x : vec) {
176
      if (x.buffer.same_as(e.buffer)) {
177 178 179 180 181 182 183 184
        // Assumes no race between threads
        // Same index value means no conflicts
        // TODO(tqchen) more standard set based testing.
        if (e.touched.is_single_point() &&
            x.touched.is_single_point()) {
          if (Equal(e.touched.point_value(),
                    x.touched.point_value())) continue;
        }
185 186 187
        if (x.double_buffer_write &&
            e.type == kRead &&
            !loop_carry) continue;
188 189
        return true;
      }
190 191 192
    }
    return false;
  }
193 194 195

 private:
  // synchronization scope
196 197 198
  StorageScope sync_scope_;
};

199
class ThreadSyncInserter : public StmtExprMutator {
200
 public:
201
  ThreadSyncInserter(StorageScope sync_scope,
202
                     const std::unordered_set<const Object*>& syncs)
203 204
      : sync_scope_(sync_scope), syncs_(syncs) {}

205
  Stmt VisitStmt(const Stmt& stmt) final {
206
    if (syncs_.size() == 0) return stmt;
207
    if (syncs_.count(stmt.get())) {
208
      Stmt barrier;
209
      if (sync_scope_.rank == StorageRank::kGlobal) {
210 211
        barrier = MakeGlobalBarrier();
      } else {
212 213 214 215
        barrier = EvaluateNode::make(
                CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync,
                           {StringImmNode::make(sync_scope_.to_string())},
                           CallNode::Intrinsic));
216
      }
217
      // Mutate after query, to avoid stmt change.
218
      auto ret = StmtExprMutator::VisitStmt(stmt);
219
      ret = SeqStmt({barrier, ret});
220
      return ret;
221
    } else {
222
      return StmtExprMutator::VisitStmt(stmt);
223 224
    }
  }
225
  PrimExpr VisitExpr_(const LoadNode* op) final {
226 227
    if (sync_scope_.rank == StorageRank::kGlobal &&
        GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
228 229
      ++rw_stats_[op->buffer_var].read_count;
    }
230
    return StmtExprMutator::VisitExpr_(op);
231
  }
232
  Stmt VisitStmt_(const StoreNode* op) final {
233 234
    if (sync_scope_.rank == StorageRank::kGlobal &&
        GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
235 236
      ++rw_stats_[op->buffer_var].write_count;
    }
237
    return StmtExprMutator::VisitStmt_(op);
238
  }
239
  Stmt VisitStmt_(const AttrStmtNode* op) final {
240 241 242 243
    if (op->attr_key == attr::thread_extent) {
      bool temp = true;
      std::swap(temp, in_thread_env_);
      thread_extents_.push_back(op);
244
      Stmt ret = StmtExprMutator::VisitStmt_(op);
245 246 247
      thread_extents_.pop_back();
      std::swap(temp, in_thread_env_);
      // first thread scope.
248
      if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) {
249
        ret = InitGlobalBarrier(ret.as<AttrStmtNode>());
250 251
        num_blocks_ = PrimExpr();
        is_lead_ = PrimExpr();
252 253 254
      }
      return ret;
    } else if (op->attr_key == attr::storage_scope) {
255
      const VarNode* buf = op->node.as<VarNode>();
256
      storage_scope_[buf] =
257
          StorageScope::make(op->value.as<StringImmNode>()->value);
258
      return StmtExprMutator::VisitStmt_(op);
259
    } else {
260
      return StmtExprMutator::VisitStmt_(op);
261 262
    }
  }
263

264
  PrimExpr VisitExpr_(const CallNode* op) final {
265
    if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
266
      PrimExpr expr = StmtExprMutator::VisitExpr_(op);
267
      op = expr.as<CallNode>();
268
      CHECK_EQ(op->args.size(), 5U);
269
      const VarNode* buffer_var = op->args[1].as<VarNode>();
270
      Var var(GetRef<Var>(buffer_var));
271
      const IntImmNode* flag = op->args[4].as<IntImmNode>();
272 273 274 275 276 277 278 279 280 281
      if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal &&
          GetScope(buffer_var).rank == StorageRank::kGlobal) {
        ++rw_stats_[var].read_count;
      }
      if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal &&
          GetScope(buffer_var).rank == StorageRank::kGlobal) {
        ++rw_stats_[var].write_count;
      }
      return expr;
    } else {
282
      return StmtExprMutator::VisitExpr_(op);
283 284 285
    }
  }

286 287 288 289 290 291 292
 private:
  // RW statistics about data
  struct Entry {
    int read_count{0};
    int write_count{0};
  };
  // Get current storage scope.
293
  StorageScope GetScope(const VarNode* buf) const {
294
    auto it = storage_scope_.find(buf);
295 296
    StorageScope s;
    s.rank = StorageRank::kGlobal;
297 298 299 300
    if (it == storage_scope_.end()) return s;
    return it->second;
  }
  // private functions.
301
  Stmt InitGlobalBarrier(const AttrStmtNode* op) {
302
    CHECK(op != nullptr);
303
    Array<PrimExpr> pargs = {StringImmNode::make(runtime::symbol::tvm_prepare_global_barrier)};
304 305
    Stmt prep = EvaluateNode::make(
        CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic));
306 307 308 309
    Stmt body = op->body;
    for (const auto& kv : rw_stats_) {
      const auto& e = kv.second;
      if (e.read_count != 0 && e.write_count != 0) {
310
        body = AttrStmtNode::make(kv.first, attr::volatile_scope, 1, body);
311 312 313
      }
    }
    rw_stats_.clear();
314 315 316 317
    Stmt kinit = EvaluateNode::make(
        CallNode::make(
            DataType::Int(32),
            intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic));
318
    body = SeqStmt({kinit, body});
319
    body = AttrStmtNode::make(
320
        op->node, op->attr_key, op->value, body);
321
    return SeqStmt({prep, body});
322 323
  }
  Stmt MakeGlobalBarrier() {
324
    CHECK(sync_scope_.rank == StorageRank::kGlobal);
325 326 327
    if (!num_blocks_.defined()) {
      CHECK(!is_lead_.defined());
      num_work_dim_ = thread_extents_.size();
328
      for (const AttrStmtNode* attr : thread_extents_) {
329
        IterVar iv = Downcast<IterVar>(attr->node);
330 331 332 333 334
        runtime::ThreadScope s = runtime::ThreadScope::make(iv->thread_tag);
        if (s.rank == 0) {
          num_blocks_ = (num_blocks_.defined() ?
                         attr->value * num_blocks_ : attr->value);
        } else if (s.rank == 1) {
335
          PrimExpr cond = iv->var == make_zero(iv->var.dtype());
336 337 338 339 340 341
          is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond;
        }
      }
    } else {
      CHECK_EQ(num_work_dim_, thread_extents_.size());
    }
342 343 344
    return EvaluateNode::make(
        CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync,
                   {StringImmNode::make(sync_scope_.to_string()),
345
                    is_lead_, num_blocks_},
346
                   CallNode::Intrinsic));
347 348
  }
  // data structure.
349
  StorageScope sync_scope_;
350
  const std::unordered_set<const Object*>& syncs_;
351
  // The storage scope of each buffer
352
  std::unordered_map<const VarNode*, StorageScope> storage_scope_;
353
  // The read write statistics of storage
354
  std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> rw_stats_;
355 356 357
  // The statistics for global barrier
  bool in_thread_env_{false};
  // memorized results
358
  std::vector<const AttrStmtNode*> thread_extents_;
359
  size_t num_work_dim_{0};
360 361
  PrimExpr num_blocks_;
  PrimExpr is_lead_;
362 363
};

364
Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
365
  StorageScope sync_scope = StorageScope::make(storage_scope);
366
  ThreadSyncPlanner planner(sync_scope);
367 368
  planner(stmt);
  return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt));
369 370
}

371
LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
372
  CHECK_NE(f->func_type, kHostFunc);
373
  auto n = make_object<LoweredFuncNode>(*f.operator->());
374
  n->body = ThreadSync(f->body, storage_scope);
375 376 377
  return LoweredFunc(n);
}

378
}  // namespace tir
379
}  // namespace tvm