storage_sync.cc 10.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/*!
 *  Copyright (c) 2017 by Contributors
 * \file storage_sync.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <unordered_map>
#include <unordered_set>
#include "./ir_util.h"
12
#include "./storage_access.h"
13 14 15 16 17
#include "../runtime/thread_storage_scope.h"

namespace tvm {
namespace ir {

18
class ThreadSyncPlanner : public StorageAccessVisitor {
19
 public:
20 21
  explicit ThreadSyncPlanner(StorageScope sync_scope)
      : sync_scope_(sync_scope) {}
22

23
    // The syncs inserted before each statement
24 25
  std::unordered_set<const Node*> syncs_inserted_;

26 27 28 29
 protected:
  bool Enabled(const Variable* buf,
               const StorageScope& scope) const final {
    return in_device_env() && scope == sync_scope_;
30 31
  }
  // Plan the sync
32 33 34
  std::vector<AccessEntry> Summarize(
      std::vector<StmtEntry> seq, const For* loop) final {
    // Unsynced reads and writes
35 36 37 38
    std::vector<AccessEntry> reads;
    std::vector<AccessEntry> writes;
    // if it is a loop, rotate two times to consider effect of loop.
    // simulation based approach to find dependenceies
39 40
    for (size_t i = 0; i < seq.size(); ++i) {
      const StmtEntry& s = seq[i];
41 42 43 44 45 46 47 48 49
      // check if sync before statement is needed.
      bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
      // Apply the syncs added already.
      if (sync_before_stmt) {
        reads.clear();
        writes.clear();
      }
      for (const AccessEntry& acc : s.access) {
        if (acc.type == kRead) {
50
          if (FindConflict(writes, acc, false)) {
51 52 53
            sync_before_stmt = true; break;
          }
        } else if (acc.type == kWrite) {
54
          if (FindConflict(reads, acc, false)) {
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
            sync_before_stmt = true; break;
          }
        } else if (acc.type == kSync) {
          reads.clear(); writes.clear();
        }
      }
      // If sync is inserted. remove the irrelevant things.
      if (sync_before_stmt) {
        reads.clear(); writes.clear();
      }
      // Add the read/write of current statement
      for (const AccessEntry& acc : s.access) {
        if (acc.type == kRead) {
          reads.push_back(acc);
        } else if (acc.type == kWrite) {
          writes.push_back(acc);
        } else if (acc.type == kSync) {
          reads.clear(); writes.clear();
        }
      }
      if (sync_before_stmt) {
76
        CHECK_EQ(condition_counter(), 0)
77 78 79 80
            << "Cannot insert syncs inside condition";
        syncs_inserted_.insert(s.stmt);
      }
    }
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    if (loop != nullptr) {
      for (size_t i = 0; i < seq.size(); ++i) {
        const StmtEntry& s = seq[i];
        if (syncs_inserted_.count(s.stmt) != 0) break;
        if (reads.empty() && writes.empty()) break;
        bool sync_before_stmt = false;
        for (const AccessEntry& acc : s.access) {
          if (acc.type == kRead) {
            if (FindConflict(writes, acc, true)) {
              sync_before_stmt = true; break;
            }
          } else if (acc.type == kWrite) {
            if (FindConflict(reads, acc, true)) {
              sync_before_stmt = true; break;
            }
          } else if (acc.type == kSync) {
            reads.clear(); writes.clear();
          }
        }
        if (sync_before_stmt) {
          CHECK_EQ(condition_counter(), 0)
              << "Cannot insert syncs inside condition";
          syncs_inserted_.insert(s.stmt);
          break;
        }
      }
    }
108 109 110 111
    // return the exposed entries, remove unecessary ones.
    int sync_count = 0;
    // head are before first sync, tail are after last sync
    std::vector<AccessEntry> head, tail;
112 113 114 115 116
    AccessEntry esync;
    esync.threads = this->env_threads();
    esync.type = kSync;
    esync.scope = sync_scope_;

117 118 119 120 121
    for (const StmtEntry& s : seq) {
      if (syncs_inserted_.count(s.stmt)) {
        if (sync_count != 0) {
          tail.clear();
        } else {
122
          head.push_back(esync);
123 124 125 126 127 128 129 130
        }
        ++sync_count;
      }
      for (const AccessEntry& acc : s.access) {
        if (acc.type == kSync) {
          if (sync_count != 0) {
            tail.clear();
          } else {
131
            head.push_back(esync);
132 133 134 135 136 137 138 139 140 141 142 143
          }
          ++sync_count;
        } else {
          if (sync_count != 0) {
            tail.push_back(acc);
          } else {
            head.push_back(acc);
          }
        }
      }
    }
    head.insert(head.end(), tail.begin(), tail.end());
144 145 146 147 148 149
    if (loop != nullptr) {
      // clear double buffer flag after a loop is finished.
      for (AccessEntry& e : head) {
        e.double_buffer_write = false;
      }
    }
150 151
    return head;
  }
152 153

 private:
154
  // find conflicting entry in vec.
155
  bool FindConflict(const std::vector<AccessEntry>& vec,
156 157
                    const AccessEntry& e,
                    bool loop_carry) {
158
    for (const AccessEntry& x : vec) {
159
      if (x.buffer.same_as(e.buffer)) {
160 161 162 163 164 165 166 167
        // Assumes no race between threads
        // Same index value means no conflicts
        // TODO(tqchen) more standard set based testing.
        if (e.touched.is_single_point() &&
            x.touched.is_single_point()) {
          if (Equal(e.touched.point_value(),
                    x.touched.point_value())) continue;
        }
168 169 170
        if (x.double_buffer_write &&
            e.type == kRead &&
            !loop_carry) continue;
171 172
        return true;
      }
173 174 175
    }
    return false;
  }
176 177 178

 private:
  // synchronization scope
179 180 181
  StorageScope sync_scope_;
};

182
class ThreadSyncInserter : public IRMutator {
183
 public:
184 185
  ThreadSyncInserter(StorageScope sync_scope,
                     const std::unordered_set<const Node*>& syncs)
186 187 188
      : sync_scope_(sync_scope), syncs_(syncs) {}

  Stmt Mutate(Stmt stmt) final {
189
    if (syncs_.size() == 0) return stmt;
190
    if (syncs_.count(stmt.get())) {
191
      Stmt barrier;
192
      if (sync_scope_.rank == StorageRank::kGlobal) {
193 194 195 196 197 198 199
        barrier = MakeGlobalBarrier();
      } else {
        barrier = Evaluate::make(
                Call::make(Int(32), intrinsic::tvm_storage_sync,
                           {StringImm::make(sync_scope_.to_string())},
                           Call::Intrinsic));
      }
200 201
      // Mutate after query, to avoid stmt change.
      stmt = IRMutator::Mutate(stmt);
202
      stmt = Block::make(barrier, stmt);
203 204
    } else {
      stmt = IRMutator::Mutate(stmt);
205 206 207
    }
    return stmt;
  }
208
  Expr Mutate_(const Load* op, const Expr& e) final {
209 210
    if (sync_scope_.rank == StorageRank::kGlobal &&
        GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
211 212 213 214 215
      ++rw_stats_[op->buffer_var].read_count;
    }
    return IRMutator::Mutate_(op, e);
  }
  Stmt Mutate_(const Store* op, const Stmt& s) final {
216 217
    if (sync_scope_.rank == StorageRank::kGlobal &&
        GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
218 219 220 221 222 223 224 225 226 227 228 229 230
      ++rw_stats_[op->buffer_var].write_count;
    }
    return IRMutator::Mutate_(op, s);
  }
  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
    if (op->attr_key == attr::thread_extent) {
      bool temp = true;
      std::swap(temp, in_thread_env_);
      thread_extents_.push_back(op);
      Stmt ret = IRMutator::Mutate_(op, s);
      thread_extents_.pop_back();
      std::swap(temp, in_thread_env_);
      // first thread scope.
231
      if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) {
232 233 234 235 236 237 238 239 240 241 242 243 244 245
        ret = InitGlobalBarrier(ret.as<AttrStmt>());
        num_blocks_ = Expr();
        is_lead_ = Expr();
      }
      return ret;
    } else if (op->attr_key == attr::storage_scope) {
      const Variable* buf = op->node.as<Variable>();
      storage_scope_[buf] =
          StorageScope::make(op->value.as<StringImm>()->value);
      return IRMutator::Mutate_(op, s);
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }
246

247 248 249 250 251 252 253 254 255
 private:
  // RW statistics about data
  struct Entry {
    int read_count{0};
    int write_count{0};
  };
  // Get current storage scope.
  StorageScope GetScope(const Variable* buf) const {
    auto it = storage_scope_.find(buf);
256 257
    StorageScope s;
    s.rank = StorageRank::kGlobal;
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
    if (it == storage_scope_.end()) return s;
    return it->second;
  }
  // private functions.
  Stmt InitGlobalBarrier(const AttrStmt* op) {
    CHECK(op != nullptr);
    Array<Expr> pargs = {StringImm::make(runtime::symbol::tvm_prepare_global_barrier)};
    Stmt prep = Evaluate::make(
        Call::make(Int(32), intrinsic::tvm_call_packed, pargs, Call::Intrinsic));
    Stmt body = op->body;
    for (const auto& kv : rw_stats_) {
      const auto& e = kv.second;
      if (e.read_count != 0 && e.write_count != 0) {
        body = AttrStmt::make(kv.first, attr::volatile_scope, 1, body);
      }
    }
    rw_stats_.clear();
    Stmt kinit = Evaluate::make(
        Call::make(Int(32), intrinsic::tvm_global_barrier_kinit, {}, Call::Intrinsic));
    body = Block::make(kinit, body);
    body = AttrStmt::make(
        op->node, op->attr_key, op->value, body);
    return Block::make(prep, body);
  }
  Stmt MakeGlobalBarrier() {
283
    CHECK(sync_scope_.rank == StorageRank::kGlobal);
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
    if (!num_blocks_.defined()) {
      CHECK(!is_lead_.defined());
      num_work_dim_ = thread_extents_.size();
      for (const AttrStmt* attr : thread_extents_) {
        IterVar iv(attr->node.node_);
        runtime::ThreadScope s = runtime::ThreadScope::make(iv->thread_tag);
        if (s.rank == 0) {
          num_blocks_ = (num_blocks_.defined() ?
                         attr->value * num_blocks_ : attr->value);
        } else if (s.rank == 1) {
          Expr cond = iv->var == make_zero(iv->var.type());
          is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond;
        }
      }
    } else {
      CHECK_EQ(num_work_dim_, thread_extents_.size());
    }
    return Evaluate::make(
        Call::make(Int(32), intrinsic::tvm_storage_sync,
                   {StringImm::make(sync_scope_.to_string()),
                    is_lead_, num_blocks_},
                   Call::Intrinsic));
  }
  // data structure.
308
  StorageScope sync_scope_;
309 310 311 312 313 314 315 316 317 318 319 320
  const std::unordered_set<const Node*>& syncs_;
  // The storage scope of each buffer
  std::unordered_map<const Variable*, StorageScope> storage_scope_;
  // The read write statistics of storage
  std::unordered_map<VarExpr, Entry, NodeHash, NodeEqual> rw_stats_;
  // The statistics for global barrier
  bool in_thread_env_{false};
  // memorized results
  std::vector<const AttrStmt*> thread_extents_;
  size_t num_work_dim_{0};
  Expr num_blocks_;
  Expr is_lead_;
321 322
};

323
Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
324
  StorageScope sync_scope = StorageScope::make(storage_scope);
325
  ThreadSyncPlanner planner(sync_scope);
326
  planner.Visit(stmt);
327
  return ThreadSyncInserter(sync_scope, planner.syncs_inserted_).Mutate(stmt);
328 329
}

330
LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
331
  CHECK_NE(f->func_type, kHostFunc);
332
  auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
333
  n->body = ThreadSync(f->body, storage_scope);
334 335 336 337 338
  return LoweredFunc(n);
}

}  // namespace ir
}  // namespace tvm