storage_flatten.cc 17.8 KB
Newer Older
1 2 3 4
/*!
 *  Copyright (c) 2016 by Contributors
 * \file storage_flatten.cc
 */
5 6
// Flattens storage from multi-dimensional array to 1D
// buffer access as in Halide pipeline.
7
#include <tvm/ir.h>
8
#include <tvm/expr.h>
9
#include <tvm/operation.h>
10
#include <tvm/ir_mutator.h>
11
#include <tvm/expr_operator.h>
12
#include <tvm/ir_pass.h>
13
#include <tvm/buffer.h>
14
#include <tvm/target_info.h>
15
#include <tvm/runtime/device_api.h>
16
#include <unordered_map>
17 18
#include "ir_util.h"
#include "arg_binder.h"
19
#include "../arithmetic/compute_expr.h"
20
#include "../runtime/thread_storage_scope.h"
21 22 23 24

namespace tvm {
namespace ir {

25
using HalideIR::Internal::Region;
26
using runtime::StorageRank;
27 28
using runtime::StorageScope;
using runtime::ThreadScope;
29
using intrinsic::tvm_address_of;
30 31 32

class StorageFlattener : public IRMutator {
 public:
33
  explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer,
34 35
                            int cache_line_size, bool create_bound_attributes)
      : create_bound_attributes_(create_bound_attributes) {
36 37 38 39 40 41
    for (auto kv : extern_buffer) {
      BufferEntry e;
      e.buffer = kv.second;
      e.external = true;
      buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e;
    }
42
    cache_line_size_ = cache_line_size;
43
  }
44

45 46 47
  Stmt Mutate_(const Store* op, const Stmt& s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Store>();
48 49 50 51 52 53
    auto it = var_remap_.find(op->buffer_var.get());
    if (it != var_remap_.end() &&
        !it->second.same_as(op->buffer_var)) {
      CHECK(it->second.as<Variable>());
      VarExpr buf_var(it->second.node_);
      return Store::make(buf_var, op->value, op->index, op->predicate);
54 55 56 57
    } else {
      return stmt;
    }
  }
58 59

  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
60
    if (op->attr_key == attr::realize_scope) {
61 62
      storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
      return this->Mutate(op->body);
63 64
    } else if (op->attr_key == attr::double_buffer_scope &&
               op->node.node_->derived_from<OperationNode>()) {
65 66 67 68 69 70 71 72 73 74 75
      Operation func(op->node.node_);
      Stmt body = Mutate(op->body);
      for (int i = 0; i < func->num_outputs(); ++i) {
        TensorKey key{func, i};
        auto it = buf_map_.find(key);
        CHECK(it != buf_map_.end())
            << "Cannot find allocated buffer for " << key.f;
        body = AttrStmt::make(
            it->second.buffer->data, op->attr_key, op->value, body);
      }
      return body;
76
    } else if (op->attr_key == attr::thread_extent) {
77
      IterVar iv(op->node.node_);
78 79 80 81 82
      ThreadScope ts = ThreadScope::make(iv->thread_tag);
      curr_thread_scope_.push_back(ts);
      Stmt stmt = IRMutator::Mutate_(op, s);
      curr_thread_scope_.pop_back();
      return stmt;
83 84
    } else if (op->attr_key == attr::buffer_bind_scope) {
      return HandleBufferBindScope(op);
85 86 87 88 89 90 91 92 93 94 95 96 97
    } else if (op->attr_key == attr::buffer_dim_align) {
      Tensor tensor(op->node.node_);
      const Call* tuple = op->value.as<Call>();
      CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
      TensorKey key{tensor->op, tensor->value_index};
      auto& vinfo = dim_align_[key];
      int dim = tuple->args[0].as<IntImm>()->value;
      if (static_cast<size_t>(dim) >= vinfo.size()) {
        vinfo.resize(dim + 1);
      }
      vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
      vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
      return this->Mutate(op->body);
98 99
    } else if (op->attr_key == attr::opengl_stage_scope) {
      is_opengl_ = true;
100 101 102 103 104
    }
    return IRMutator::Mutate_(op, s);
  }

  Stmt Mutate_(const Provide* op, const Stmt& s) final {
105 106
    if (create_bound_attributes_)
      shape_collector_.clear();
107 108 109 110 111 112 113 114 115
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Provide>();
    TensorKey key{op->func, op->value_index};
    auto it = buf_map_.find(key);
    CHECK(it != buf_map_.end())
        << "Cannot find allocated buffer for " << key.f;
    const BufferEntry& e = it->second;
    CHECK(!e.released)
        << "Read a buffer that is already out of scope";
116 117 118 119 120 121 122
    if (is_opengl_) {
      return Evaluate::make(Call::make(
          Type(),
          Call::glsl_texture_store,
          {e.buffer->data, op->value},
          Call::Intrinsic));
    } else {
123 124 125 126 127 128 129 130 131 132 133 134 135 136
      Stmt body = e.buffer.vstore(e.RelIndex(op->args), op->value);
      if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
        shape_collector_.push_back(
            std::make_pair(e.buffer->data, e.buffer->shape));
      }
      // To create bound attribute collector should has at least one item.
      if (create_bound_attributes_ && shape_collector_.size()) {
        for (size_t i = 0; i < shape_collector_.size(); ++i) {
          body = AttrStmt::make(
              shape_collector_[i].first, ir::attr::buffer_bound,
              MakeBound(e.buffer->dtype, shape_collector_[i].second), body);
        }
      }
      return body;
137
    }
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
  }

  Stmt Mutate_(const Realize* op, const Stmt& s) final {
    TensorKey key{op->func, op->value_index};
    if (buf_map_.count(key)) {
      CHECK(buf_map_.at(key).external);
      return this->Mutate(op->body);
    } else {
      // create a buffer entry
      BufferEntry e;
      e.bounds = op->bounds;
      Array<Expr> shape;
      for (auto r : e.bounds) {
        shape.push_back(r->extent);
      }
      // deduce current storage scope.
      auto it = storage_scope_.find(op->func.get());
155 156 157
      CHECK(it != storage_scope_.end())
          << "Cannot find storage scope of " << op->func
          << " value_index=" << op->value_index;
158 159 160
      StorageScope skey;
      const std::string& strkey = it->second;
      if (strkey.length() == 0) {
161
        if (curr_thread_scope_.size() != 0) {
162 163
          skey.rank = runtime::DefaultStorageRank(
              curr_thread_scope_.back().rank);
164 165
        }
      } else {
166
        skey = StorageScope::make(strkey);
167
      }
168

169 170 171
      // use small alignment for small arrays
      int32_t const_size = Allocate::constant_allocation_size(shape, key.GetName());
      int align = GetTempAllocaAlignment(op->type, const_size);
172 173 174 175
      if (skey.tag.length() != 0) {
        MemoryInfo info = GetMemoryInfo(skey.to_string());
        if (info.defined()) {
          align = (info->max_simd_bits + op->type.bits() - 1) / op->type.bits();
176 177
          CHECK_LE(const_size * op->type.bits(), info->max_num_bits)
              << "Allocation exceed bound of memory tag " << skey.to_string();
178 179
        }
      }
180
      Array<Expr> strides;
181
      if (dim_align_.count(key) != 0 && shape.size() != 0) {
182 183
        std::vector<Expr> rstrides;
        const std::vector<DimAlignInfo>& avec = dim_align_[key];
184 185
        int first_dim = 0;
        Expr stride = make_const(shape[first_dim].type(), 1);
186 187 188 189 190 191 192 193 194 195 196 197 198
        for (size_t i = shape.size(); i != 0; --i) {
          size_t dim = i - 1;
          if (dim < avec.size() && avec[dim].align_factor != 0) {
            Expr factor = make_const(stride.type(), avec[dim].align_factor);
            Expr offset = make_const(stride.type(), avec[dim].align_offset);
            stride = stride + (factor + offset - stride % factor) % factor;
            stride = ir::Simplify(stride);
          }
          rstrides.push_back(stride);
          stride = arith::ComputeExpr<Mul>(stride, shape[dim]);
        }
        strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
      }
199

200 201
      e.buffer = BufferNode::make(
          Var(key.GetName(), Handle()),
202
          op->type, shape, strides, Expr(),
203 204
          key.GetName(), skey.to_string(),
          align, 0);
205

206 207 208
      buf_map_[key] = e;
      Stmt body = this->Mutate(op->body);
      buf_map_[key].released = true;
209
      Stmt ret;
210

211 212 213 214 215 216
      Type storage_type = e.buffer->dtype;
      // specially handle bool, lower its storage
      // type to be Int(8)(byte)
      if (storage_type == Bool()) {
        storage_type = Int(8);
      }
217
      if (strides.size() != 0) {
218
        int first_dim = 0;
219
        ret = Allocate::make(
220
            e.buffer->data, storage_type,
221
            {arith::ComputeExpr<Mul>(e.buffer->strides[first_dim], e.buffer->shape[first_dim])},
222 223
            make_const(Bool(e.buffer->dtype.lanes()), true), body);
      } else {
224 225 226 227
        shape = e.buffer->shape;
        if (shape.size() == 0) {
          shape.push_back(make_const(Int(32), 1));
        }
228
        ret = Allocate::make(
229
            e.buffer->data, storage_type, shape,
230 231
            make_const(Bool(e.buffer->dtype.lanes()), true), body);
      }
232 233
      ret = AttrStmt::make(
          e.buffer->data, attr::storage_scope,
234
          StringImm::make(e.buffer->scope), ret);
235 236 237 238 239

      if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
        ret = AttrStmt::make(e.buffer->data, ir::attr::buffer_bound,
                             MakeBound(e.buffer->dtype, e.buffer->shape), ret);
      }
240
      return ret;
241 242 243
    }
  }

244 245 246
  Expr Mutate_(const Load* op, const Expr& e) final {
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Load>();
247 248 249 250 251 252
    auto it = var_remap_.find(op->buffer_var.get());
    if (it != var_remap_.end() &&
        !it->second.same_as(op->buffer_var)) {
      CHECK(it->second.as<Variable>());
      VarExpr buf_var(it->second.node_);
      return Load::make(op->type, buf_var, op->index, op->predicate);
253 254 255 256 257 258
    } else {
      return expr;
    }
  }

  Expr Mutate_(const Variable* op, const Expr& e) final {
259 260
    auto it = var_remap_.find(op);
    if (it != var_remap_.end()) {
261 262 263 264 265 266
      return it->second;
    } else {
      return e;
    }
  }

267 268 269
  Expr Mutate_(const Call* op, const Expr& olde) final {
    Expr expr = IRMutator::Mutate_(op, olde);
    op = expr.as<Call>();
270 271 272 273 274 275 276 277
    if (op != nullptr && op->call_type == Call::Halide) {
      TensorKey key{op->func, op->value_index};
      auto it = buf_map_.find(key);
      CHECK(it != buf_map_.end())
          << "Cannot find allocated buffer for " << key.f;
      const BufferEntry& e = it->second;
      CHECK(!e.released)
          << "Read a buffer that is already out of scope";
278 279 280 281 282

      if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
        shape_collector_.push_back(
            std::make_pair(e.buffer->data, e.buffer->shape));
      }
283
      return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype);
284 285 286 287 288
    } else {
      return expr;
    }
  }

289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
  Stmt Mutate_(const Prefetch *op, const Stmt &s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Prefetch>();
    CHECK(op != nullptr);
    TensorKey key{op->func, op->value_index};
    auto it = buf_map_.find(key);
    CHECK(it != buf_map_.end())
        << "Cannot find allocated buffer for " << key.f;
    const BufferEntry& e = it->second;

    CHECK(!e.released)
        << "Read a buffer that is already out of scope";
    CHECK_EQ(e.buffer->shape.size(), op->bounds.size())
      << "Prefetch dim should be the same as buffer dim";

    int block_size = 1,
        elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(),
        shape = 0;

    int starts = op->bounds.size() - 1;
    while (starts > 0 && arith::GetConstInt(e.buffer->shape[starts], &shape)
        && elem_cnt >= block_size * shape) {
      block_size *= shape;
      starts--;
    }
    Expr stride(elem_cnt / block_size);

    Array<Expr> args;
    std::vector<VarExpr> vars;

    for (int i = op->bounds.size() - 1; i > starts; --i) {
      args.push_back(op->bounds[i]->min);
    }
    auto &func_name = op->func->func_name();
    vars.push_back(VarExpr("prefetch." + func_name + "." + std::to_string(starts), Int(32)));
    args.push_back(op->bounds[starts]->min + stride * vars.back());
    for (int i = starts - 1; i >= 0; --i) {
      vars.push_back(VarExpr("prefetch." + func_name + "." + std::to_string(i), Int(32)));
      args.push_back(vars.back() + op->bounds[i]->min);
    }
    for (int i = starts; i >= 0; --i) {
      if (i < starts) {
        stmt = For::make(
            vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::Host, stmt);
      } else {
334
        Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
335 336 337 338 339 340 341 342 343 344
        Expr address = Call::make(Handle(), tvm_address_of, {load}, Call::PureIntrinsic);
        Expr prefetch = Call::make(op->type, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic);
        stmt = Evaluate::make(prefetch);
        Expr extent = (op->bounds[i]->extent - 1) / stride + 1;
        stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::Host, stmt);
      }
    }
    return stmt;
  }

345
 private:
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
  // The specific tensor data layout is not determined before
  // StorageFlatten pass. We use buffer_bind_scope
  // to specify before hand we want to bind a subregion
  // of tensor to a symbolic buffer, which get used in extern.
  //
  // Example:
  //
  // realize A in range [i*4, extent=10) {
  //   bind Ab to A in [i*4+1, extent=4) {
  //     call_func(Ab.ptr, Ab.shape[0])
  //   }
  // }
  //
  // After StorageFlatten
  //
  // alloc A[10]
  //   call(A + 1,  4)
  //
  // Buffer is a protocol to declare specific
  // data layout and shape we expect.
  // So this function need to check:
  // - If the bind range is within the realize range
  // - If we can match the requirement of buffer
  // - Remap variables such as Ab.ptr to the actual value.
  //
  // Here are a few possible failure cases:
  // - Buffer is declared to have constant shape,
  //   but we try to bind it to a different one.
  // - Buffer is declared to be compact(no strides)
  //   but this binded region is a subregion of
  //   a matrix(tensor), which means it requires strides.
  //
  // We do support a few relaxed case, such as bindingx
  // region with shape [1, 1, n, m] to buffer with shape [n, m]
380 381 382 383 384 385 386 387 388
  Stmt HandleBufferBindScope(const AttrStmt* op) {
    Array<NodeRef> arr(op->node.node_);
    CHECK_EQ(arr.size(), 2U);
    const BufferNode* buffer = arr[0].as<BufferNode>();
    const TensorNode* tensor = arr[1].as<TensorNode>();
    const Call* tuple = op->value.as<Call>();
    CHECK(buffer && tensor);
    CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
    TensorKey key{tensor->op, tensor->value_index};
389 390
    CHECK(buf_map_.count(key))
        << "Cannot find buffer of " << tensor->op << " value=" << tensor->value_index;
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
    const BufferEntry& be = buf_map_.at(key);
    CHECK(!be.released);
    CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
    Array<Expr> begins, extents;
    if (be.bounds.size() != 0) {
      CHECK_EQ(tuple->args.size(), be.bounds.size() * 2);
      for (size_t i = 0; i < be.buffer->shape.size(); ++i) {
        begins.push_back(
            arith::ComputeExpr<Sub>(tuple->args[2 * i], be.bounds[i]->min));
        extents.push_back(tuple->args[2 * i + 1]);
      }
    } else {
      for (size_t i = 0; i < tuple->args.size(); i += 2) {
        begins.push_back(tuple->args[i]);
        extents.push_back(tuple->args[i + 1]);
      }
    }
    Buffer slice = be.buffer.MakeSlice(begins, extents);
    if (buffer->strides.size() == 0) {
      CHECK_EQ(slice->strides.size(), 0U)
411 412
          << "Trying to bind compact buffer to strided one strides="
          << slice->strides;
413 414 415 416
    } else {
      slice = slice.MakeStrideView();
    }
    // start binding
417
    ArgBinder binder(&var_remap_);
418
    binder.BindBuffer(Buffer(arr[0].node_), slice, buffer->name, true);
419
    // Apply the remaps
420 421 422
    Stmt body = MergeNest(binder.asserts(), op->body);
    body = MergeNest(binder.init_nest(), body);
    body = this->Mutate(body);
423
    // remove the binds
424 425
    for (const Var& v : binder.defs()) {
      var_remap_.erase(v.get());
426
    }
427
    return body;
428
  }
429 430 431 432 433
  // The buffer entry in the flatten map
  struct DimAlignInfo {
    int align_factor{0};
    int align_offset{0};
  };
434 435 436 437
  // The buffer entry in the flatten map
  struct BufferEntry {
    // the buffer of storage
    Buffer buffer;
438
    // the bounds of realization, can be null, means everything
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
    Region bounds;
    // Whether the buffer is external
    bool external{false};
    // Whether we are out of allocation bounds and buffer get released.
    bool released{false};
    // relative index
    inline Array<Expr> RelIndex(Array<Expr> args) const {
      if (bounds.size() != 0) {
        Array<Expr> index;
        CHECK_EQ(bounds.size(), args.size());
        for (size_t i = 0; i < bounds.size(); ++i) {
          index.push_back(args[i] - bounds[i]->min);
        }
        return index;
      } else {
        return args;
      }
    }
  };
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482

  bool ShapeIsValid(const Array<Expr> &shape) {
    // Zero-dimensional tensor does not need boundary check.
    if (!shape.size())
      return false;

    for (size_t i = 0; i < shape.size(); ++i) {
      if (!shape[i].defined() || !shape[i].type().is_scalar() ||
          is_negative_const(shape[i])) {
        return false;
      }
    }
    return true;
  }

  Expr MakeBound(const Type &type, const Array<Expr> &shape) {
    // We have already checked the shape size to be greater then 0.
    Expr bound = Mul::make(make_const(shape[0].type(), type.lanes()), shape[0]);
    for (size_t i = 1; i < shape.size(); ++i) {
      bound = Mul::make(
          bound, Mul::make(make_const(bound.type(), type.lanes()), shape[i]));
    }
    return bound;
  }

483
  // The buffer assignment map
484 485 486
  // Variable remap
  std::unordered_map<const Variable*, Expr> var_remap_;
  // Buffer map
487
  std::unordered_map<TensorKey, BufferEntry> buf_map_;
488 489 490
  // Dimension alignment
  std::unordered_map<TensorKey, std::vector<DimAlignInfo> > dim_align_;
  // Storage scope
491 492 493
  std::unordered_map<const Node*, std::string> storage_scope_;
  // The current thread scope.
  std::vector<ThreadScope> curr_thread_scope_;
494 495
  // Collects shapes.
  std::vector<std::pair<VarExpr, Array<Expr>>> shape_collector_;
496 497
  // The size of cacheline
  int cache_line_size_;
498 499
  // The current stage is an OpenGL shader.
  bool is_opengl_{false};
500 501
  // Whether to mark load/store with theirs bounds.
  bool create_bound_attributes_{false};
502 503
};

504 505 506 507 508
Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer,
                    int cache_line_size, bool create_bound_attributes) {
  stmt =
      StorageFlattener(extern_buffer, cache_line_size, create_bound_attributes)
          .Mutate(stmt);
509 510 511 512 513
  return stmt;
}

}  // namespace ir
}  // namespace tvm