lower_thread_allreduce.cc 10.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/*!
 *  Copyright (c) 2017 by Contributors
 *  Lower allreduce to device implementable ir.
 * \file lower_thread_allreduce.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"

namespace tvm {
namespace ir {

17
class ThreadAllreduceBuilder final : public IRMutator {
18 19 20 21 22
 public:
  explicit ThreadAllreduceBuilder(int warp_size)
      : warp_size_(warp_size) {}

  Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
23
    if (op->attr_key == attr::thread_extent) {
24 25 26 27
      thread_extents_.push_back(op);
      Stmt ret = IRMutator::Mutate_(op, s);
      thread_extents_.pop_back();
      return ret;
28
    } else if (op->attr_key == attr::storage_scope) {
29 30 31 32 33 34 35 36
      Stmt ret = IRMutator::Mutate_(op, s);
      op = ret.as<AttrStmt>();
      const Variable* v = op->node.as<Variable>();
      if (alloc_remap_.count(v)) {
        return op->body;
      } else {
        return ret;
      }
ziheng committed
37 38 39 40 41 42 43
    } else if (op->attr_key == attr::reduce_scope) {
      const CommReducerNode *combiner = op->node.as<CommReducerNode>();
      CHECK(combiner);
      reduce_combiner_.push_back(combiner);
      Stmt ret = IRMutator::Mutate_(op, s);
      reduce_combiner_.pop_back();
      return ret;
44 45 46 47
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }
48
  Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
49
    Stmt stmt = IRMutator::Mutate_(op, s);
50
    op = stmt.as<Evaluate>();
51 52
    const Call* call = op->value.as<Call>();
    if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
53
      return MakeAllreduce(call);
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    } else {
      return stmt;
    }
  }
  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Allocate>();
    auto it = alloc_remap_.find(op->buffer_var.get());
    if (it != alloc_remap_.end()) {
      const Allocate* repl = it->second.as<Allocate>();
      // use volatile access to shared buffer.
      stmt = AttrStmt::make(
          repl->buffer_var, attr::volatile_scope, 1, op->body);
      stmt = Allocate::make(
          repl->buffer_var, repl->type,
          repl->extents, repl->condition, stmt);
      stmt = AttrStmt::make(
          repl->buffer_var, attr::storage_scope,
          StringImm::make("shared"), stmt);
      return stmt;
    } else {
      return stmt;
    }
  }
  Expr Mutate_(const Load* op, const Expr& e) final {
    auto it = load_remap_.find(op->buffer_var.get());
    if (it != load_remap_.end()) {
      CHECK(is_zero(op->index));
      return it->second;
    } else {
      return IRMutator::Mutate_(op, e);
    }
  }

 private:
  // Thread entry
  struct ThreadEntry {
    runtime::ThreadScope scope;
    IterVar iv;
    int extent;
    // comparator
    bool operator<(const ThreadEntry& other) const {
      return scope.dim_index < other.scope.dim_index;
    }
  };
  // make allreduce.
100
  Stmt MakeAllreduce(const Call* call) {
ziheng committed
101 102
    CHECK(!reduce_combiner_.empty());
    const CommReducerNode *combiner = reduce_combiner_.back();
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
    size_t size = combiner->result.size();

    const UIntImm *size_of_args = call->args[0].as<UIntImm>();
    CHECK(size_of_args) << call->args[0]->type_key();
    CHECK_EQ(size, size_of_args->value);
    Array<Expr> inits = combiner->identity_element;
    std::vector<Expr> values(size);
    std::vector<Type> types(size);
    Expr cond  = call->args[size+1];
    for (size_t idx = 0; idx < size; ++idx) {
      values[idx] = call->args[1+idx];
      if (!is_one(cond)) {
        values[idx] = Select::make(cond, values[idx], inits[idx]);
      }
      types[idx] = values[idx].type();
    }
    std::vector<const Variable*> buffers(size);
    for (size_t idx = 0; idx < size; ++idx) {
      const Variable* buffer = call->args[2+size+idx].as<Variable>();
      CHECK(buffer);
      buffers[idx] = buffer;
124 125
    }

126
    std::unordered_set<const Variable*> reduce_set;
127
    for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
128 129
      const Variable* v = call->args[i].as<Variable>();
      CHECK(v);
130
      reduce_set.insert(v);
131 132 133 134 135 136 137 138 139 140 141 142
    }
    size_t nmatch = 0;
    std::vector<ThreadEntry> vred, vpar;
    for (const AttrStmt* attr : thread_extents_) {
      ThreadEntry e;
      IterVar iv(attr->node.node_);
      e.scope = runtime::ThreadScope::make(iv->thread_tag);
      e.iv = iv;
      CHECK_LE(e.scope.rank, 1);
      CHECK_GE(e.scope.dim_index, 0)
          << "vthread do not work with cross thread reduction";
      if (e.scope.rank == 1) {
143 144
        CHECK(arith::GetConstInt(attr->value, &(e.extent)))
            << "Need constant extent for reduce set " << iv;
145 146 147 148 149
        if (reduce_set.count(iv->var.get())) {
          vred.push_back(e);
          ++nmatch;
        } else {
          vpar.push_back(e);
150 151 152
        }
      }
    }
153
    CHECK_EQ(nmatch, reduce_set.size())
154 155 156 157 158 159 160 161 162 163
        << "Not all reduce index are presented in the context";
    std::sort(vred.begin(), vred.end());
    std::sort(vpar.begin(), vpar.end());
    // the size of each index.
    int reduce_extent, group_extent;
    int threadx_extent = 1;
    Expr reduce_index = FlattenThread(vred, &reduce_extent);
    Expr group_index = FlattenThread(vpar, &group_extent);
    if (reduce_extent == 1) {
      // special case, no reduction is needed.
164 165 166 167 168 169 170
      std::vector<Stmt> stores(size);
      for (size_t i = 0; i < size; ++i) {
        Expr pred = const_true(types[i].lanes());
        Var buffer_var(call->args[2+size+i].node_);
        stores[i] = Store::make(buffer_var, values[i], 0, pred);
      }
      return Block::make(stores);
171 172 173 174 175 176
    }
    // Whether the threadIdx.x is involved in reduction.
    if (vred[0].scope.dim_index == 0) {
      threadx_extent = vred[0].extent;
    }
    std::vector<Stmt> seq;
177 178 179 180 181 182 183 184
    std::vector<Var> shared_bufs(size);
    for (size_t idx = 0; idx < size; ++idx) {
      shared_bufs[idx] = Var("red_buf"+std::to_string(idx), Handle());
      Expr pred = const_true(types[idx].lanes());
      seq.emplace_back(Store::make(
          shared_bufs[idx], values[idx],
          BufIndex(reduce_index, group_index, reduce_extent), pred));
    }
185
    seq.emplace_back(SyncThread("shared"));
186
    seq.emplace_back(MakeBufAllreduce(
187
        combiner, types, shared_bufs,
188
        reduce_index, group_index, reduce_extent, threadx_extent));
189 190 191 192 193 194 195 196 197 198 199
    for (size_t idx = 0; idx < size; ++idx) {
      CHECK(!load_remap_.count(buffers[idx]));
      Expr pred = const_true(types[idx].lanes());
      load_remap_[buffers[idx]] = Load::make(
        types[idx], shared_bufs[idx],
        BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent), pred);
      alloc_remap_[buffers[idx]] = Allocate::make(
        shared_bufs[idx], types[idx],
        {Expr(group_extent), Expr(reduce_extent)},
        pred, Evaluate::make(0));
    }
200 201 202
    return MergeSeq(seq);
  }
  // make allreduce.
ziheng committed
203
  Stmt MakeBufAllreduce(const CommReducerNode *combiner,
204 205
                        const std::vector<Type>& types,
                        const Array<Var>& shared_bufs,
206 207 208 209 210 211 212 213 214 215 216 217
                        Expr reduce_index,
                        Expr group_index,
                        int reduce_extent,
                        int threadx_extent) {
    // Get next power of two
    int reduce_align = 1;
    while (reduce_extent > reduce_align) {
      reduce_align = reduce_align << 1;
    }
    CHECK_GT(reduce_align, 1);
    std::vector<Stmt> seq;

218
    size_t size = shared_bufs.size();
219 220 221
    Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
    // make reduction
    auto freduce = [&](int offset) {
222 223 224 225 226 227 228 229 230 231 232 233 234
      Array<Expr> a, b;
      for (size_t i = 0; i < size; ++i) {
        b.push_back(Load::make(types[i], shared_bufs[i],
          BufIndex(reduce_index + offset, group_index, reduce_extent),
          const_true()));
        a.push_back(Load::make(types[i], shared_bufs[i], buf_index, const_true()));
      }
      Array<Expr> ret = (*combiner)(a, b);
      std::vector<Stmt> stores(size);
      for (size_t i = 0; i < size; ++i) {
        stores[i] = Store::make(shared_bufs[i], ret[i], buf_index, const_true());
      }
      return Block::make(stores);
235 236 237 238 239 240 241
    };
    // Step one, check for
    if (reduce_align > reduce_extent) {
      // reduction with the boundary condition
      reduce_align = reduce_align >> 1;
      Expr cond = reduce_index < (reduce_extent - reduce_align);
      seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
242
      seq.emplace_back(SyncThread("shared"));
243 244 245 246 247 248 249 250
    }
    CHECK(threadx_extent >= 1 && warp_size_ >= 1);
    // normal synchronization
    while (reduce_align > threadx_extent ||
           reduce_align > warp_size_) {
      reduce_align =  reduce_align >> 1;
      Expr cond = reduce_index < reduce_align;
      seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
251
      seq.emplace_back(SyncThread("shared"));
252 253 254 255 256 257 258
    }
    // in warp synchronization.
    std::vector<Stmt> in_warp_seq;
    Expr in_warp_cond = reduce_index < (reduce_align >> 1);
    while (reduce_align > 1) {
      reduce_align = reduce_align >> 1;
      in_warp_seq.emplace_back(freduce(reduce_align));
259
      seq.emplace_back(SyncThread("warp"));
260 261 262 263
    }
    if (in_warp_seq.size() != 0) {
      Stmt warp_body = MergeSeq(in_warp_seq);
      seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body));
264
      seq.emplace_back(SyncThread("shared"));
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
    }
    return MergeSeq(seq);
  }
  // Flatten the thread index.
  // Also return a warp number,
  Expr FlattenThread(const std::vector<ThreadEntry>& tvec,
                     int* out_total_extent) {
    int& total_extent = *out_total_extent;
    total_extent = 1;
    if (tvec.size() == 0) {
      return make_zero(Int(32));
    }

    Expr ret;
    for (const ThreadEntry& e : tvec) {
      if (ret.defined()) {
        ret = ret + e.iv->var * total_extent;
      } else {
        CHECK_EQ(total_extent, 1);
        ret = e.iv->var;
      }
      total_extent *= e.extent;
    }
    return ret;
  }
  // sync thread op.
291
  static Stmt SyncThread(const std::string& sync) {
292 293
    return Evaluate::make(
        Call::make(Int(32), intrinsic::tvm_storage_sync,
294
                   {StringImm::make(sync)},
295 296 297 298 299 300 301 302 303 304 305 306
                   Call::Intrinsic));
  }
  // The local buffer index.
  static Expr BufIndex(Expr reduce_index, Expr group_index, int reduce_extent) {
    if (!is_zero(group_index)) {
      return ir::Simplify(group_index * reduce_extent + reduce_index);
    } else {
      return reduce_index;
    }
  }
  // The warp size of the device.
  int warp_size_{1};
307

308 309
  // surrounding scope of thread extent.
  std::vector<const AttrStmt*> thread_extents_;
ziheng committed
310
  std::vector<const CommReducerNode*> reduce_combiner_;
311 312 313 314 315 316 317 318
  // The load remap
  std::unordered_map<const Variable *, Expr> load_remap_;
  // Allocate remap
  std::unordered_map<const Variable *, Stmt> alloc_remap_;
};

LoweredFunc
LowerThreadAllreduce(LoweredFunc f, int warp_size) {
319
  CHECK_NE(f->func_type, kHostFunc);
320 321 322 323 324 325
  auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
  n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body);
  return LoweredFunc(n);
}
}  // namespace ir
}  // namespace tvm