storage_flatten.cc 15.4 KB
Newer Older
1 2 3 4
/*!
 *  Copyright (c) 2016 by Contributors
 * \file storage_flatten.cc
 */
5 6
// Flattens storage from multi-dimensional array to 1D
// buffer access as in Halide pipeline.
7
#include <tvm/ir.h>
8
#include <tvm/expr.h>
9
#include <tvm/operation.h>
10
#include <tvm/ir_mutator.h>
11
#include <tvm/ir_operator.h>
12
#include <tvm/ir_pass.h>
13
#include <tvm/buffer.h>
14
#include <tvm/target_info.h>
15
#include <tvm/runtime/device_api.h>
16
#include <unordered_map>
17 18
#include "./ir_util.h"
#include "./arg_binder.h"
19
#include "../arithmetic/compute_expr.h"
20
#include "../runtime/thread_storage_scope.h"
21 22 23 24

namespace tvm {
namespace ir {

25
using HalideIR::Internal::Region;
26
using runtime::StorageRank;
27 28
using runtime::StorageScope;
using runtime::ThreadScope;
29
using intrinsic::tvm_address_of;
30 31 32

class StorageFlattener : public IRMutator {
 public:
33 34
  explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer,
                            int cache_line_size) {
35 36 37 38 39 40
    for (auto kv : extern_buffer) {
      BufferEntry e;
      e.buffer = kv.second;
      e.external = true;
      buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e;
    }
41
    cache_line_size_ = cache_line_size;
42
  }
43

44 45 46
  Stmt Mutate_(const Store* op, const Stmt& s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Store>();
47 48 49 50 51 52
    auto it = var_remap_.find(op->buffer_var.get());
    if (it != var_remap_.end() &&
        !it->second.same_as(op->buffer_var)) {
      CHECK(it->second.as<Variable>());
      VarExpr buf_var(it->second.node_);
      return Store::make(buf_var, op->value, op->index, op->predicate);
53 54 55 56
    } else {
      return stmt;
    }
  }
57 58

  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
59
    if (op->attr_key == attr::realize_scope) {
60 61
      storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
      return this->Mutate(op->body);
62 63 64 65 66 67 68 69 70 71 72 73
    } else if (op->attr_key == attr::double_buffer_scope) {
      Operation func(op->node.node_);
      Stmt body = Mutate(op->body);
      for (int i = 0; i < func->num_outputs(); ++i) {
        TensorKey key{func, i};
        auto it = buf_map_.find(key);
        CHECK(it != buf_map_.end())
            << "Cannot find allocated buffer for " << key.f;
        body = AttrStmt::make(
            it->second.buffer->data, op->attr_key, op->value, body);
      }
      return body;
74
    } else if (op->attr_key == attr::thread_extent) {
75
      IterVar iv(op->node.node_);
76 77 78 79 80
      ThreadScope ts = ThreadScope::make(iv->thread_tag);
      curr_thread_scope_.push_back(ts);
      Stmt stmt = IRMutator::Mutate_(op, s);
      curr_thread_scope_.pop_back();
      return stmt;
81 82
    } else if (op->attr_key == attr::buffer_bind_scope) {
      return HandleBufferBindScope(op);
83 84 85 86 87 88 89 90 91 92 93 94 95
    } else if (op->attr_key == attr::buffer_dim_align) {
      Tensor tensor(op->node.node_);
      const Call* tuple = op->value.as<Call>();
      CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
      TensorKey key{tensor->op, tensor->value_index};
      auto& vinfo = dim_align_[key];
      int dim = tuple->args[0].as<IntImm>()->value;
      if (static_cast<size_t>(dim) >= vinfo.size()) {
        vinfo.resize(dim + 1);
      }
      vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
      vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
      return this->Mutate(op->body);
96 97
    } else if (op->attr_key == attr::opengl_stage_scope) {
      is_opengl_ = true;
98 99 100 101 102 103 104 105 106 107 108 109 110 111
    }
    return IRMutator::Mutate_(op, s);
  }

  Stmt Mutate_(const Provide* op, const Stmt& s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Provide>();
    TensorKey key{op->func, op->value_index};
    auto it = buf_map_.find(key);
    CHECK(it != buf_map_.end())
        << "Cannot find allocated buffer for " << key.f;
    const BufferEntry& e = it->second;
    CHECK(!e.released)
        << "Read a buffer that is already out of scope";
112 113 114 115 116 117 118 119 120
    if (is_opengl_) {
      return Evaluate::make(Call::make(
          Type(),
          Call::glsl_texture_store,
          {e.buffer->data, op->value},
          Call::Intrinsic));
    } else {
      return e.buffer.vstore(e.RelIndex(op->args), op->value);
    }
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
  }

  Stmt Mutate_(const Realize* op, const Stmt& s) final {
    TensorKey key{op->func, op->value_index};
    if (buf_map_.count(key)) {
      CHECK(buf_map_.at(key).external);
      return this->Mutate(op->body);
    } else {
      // create a buffer entry
      BufferEntry e;
      e.bounds = op->bounds;
      Array<Expr> shape;
      for (auto r : e.bounds) {
        shape.push_back(r->extent);
      }
      // deduce current storage scope.
      auto it = storage_scope_.find(op->func.get());
138 139 140
      CHECK(it != storage_scope_.end())
          << "Cannot find storage scope of " << op->func
          << " value_index=" << op->value_index;
141 142 143
      StorageScope skey;
      const std::string& strkey = it->second;
      if (strkey.length() == 0) {
144
        if (curr_thread_scope_.size() != 0) {
145 146
          skey.rank = runtime::DefaultStorageRank(
              curr_thread_scope_.back().rank);
147 148
        }
      } else {
149
        skey = StorageScope::make(strkey);
150
      }
151

152 153 154
      // use small alignment for small arrays
      int32_t const_size = Allocate::constant_allocation_size(shape, key.GetName());
      int align = GetTempAllocaAlignment(op->type, const_size);
155 156 157 158
      if (skey.tag.length() != 0) {
        MemoryInfo info = GetMemoryInfo(skey.to_string());
        if (info.defined()) {
          align = (info->max_simd_bits + op->type.bits() - 1) / op->type.bits();
159 160
          CHECK_LE(const_size * op->type.bits(), info->max_num_bits)
              << "Allocation exceed bound of memory tag " << skey.to_string();
161 162
        }
      }
163
      Array<Expr> strides;
164
      if (dim_align_.count(key) != 0 && shape.size() != 0) {
165 166
        std::vector<Expr> rstrides;
        const std::vector<DimAlignInfo>& avec = dim_align_[key];
167 168
        int first_dim = 0;
        Expr stride = make_const(shape[first_dim].type(), 1);
169 170 171 172 173 174 175 176 177 178 179 180 181
        for (size_t i = shape.size(); i != 0; --i) {
          size_t dim = i - 1;
          if (dim < avec.size() && avec[dim].align_factor != 0) {
            Expr factor = make_const(stride.type(), avec[dim].align_factor);
            Expr offset = make_const(stride.type(), avec[dim].align_offset);
            stride = stride + (factor + offset - stride % factor) % factor;
            stride = ir::Simplify(stride);
          }
          rstrides.push_back(stride);
          stride = arith::ComputeExpr<Mul>(stride, shape[dim]);
        }
        strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
      }
182

183 184
      e.buffer = BufferNode::make(
          Var(key.GetName(), Handle()),
185
          op->type, shape, strides, Expr(),
186 187
          key.GetName(), skey.to_string(),
          align, 0);
188

189 190 191
      buf_map_[key] = e;
      Stmt body = this->Mutate(op->body);
      buf_map_[key].released = true;
192
      Stmt ret;
193

194
      if (strides.size() != 0) {
195
        int first_dim = 0;
196 197
        ret = Allocate::make(
            e.buffer->data, e.buffer->dtype,
198
            {arith::ComputeExpr<Mul>(e.buffer->strides[first_dim], e.buffer->shape[first_dim])},
199 200
            make_const(Bool(e.buffer->dtype.lanes()), true), body);
      } else {
201 202 203 204
        shape = e.buffer->shape;
        if (shape.size() == 0) {
          shape.push_back(make_const(Int(32), 1));
        }
205
        ret = Allocate::make(
206
            e.buffer->data, e.buffer->dtype, shape,
207 208
            make_const(Bool(e.buffer->dtype.lanes()), true), body);
      }
209 210
      ret = AttrStmt::make(
          e.buffer->data, attr::storage_scope,
211
          StringImm::make(e.buffer->scope), ret);
212
      return ret;
213 214 215
    }
  }

216 217 218
  Expr Mutate_(const Load* op, const Expr& e) final {
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Load>();
219 220 221 222 223 224
    auto it = var_remap_.find(op->buffer_var.get());
    if (it != var_remap_.end() &&
        !it->second.same_as(op->buffer_var)) {
      CHECK(it->second.as<Variable>());
      VarExpr buf_var(it->second.node_);
      return Load::make(op->type, buf_var, op->index, op->predicate);
225 226 227 228 229 230
    } else {
      return expr;
    }
  }

  Expr Mutate_(const Variable* op, const Expr& e) final {
231 232
    auto it = var_remap_.find(op);
    if (it != var_remap_.end()) {
233 234 235 236 237 238
      return it->second;
    } else {
      return e;
    }
  }

239 240 241
  Expr Mutate_(const Call* op, const Expr& olde) final {
    Expr expr = IRMutator::Mutate_(op, olde);
    op = expr.as<Call>();
242 243 244 245 246 247 248 249
    if (op != nullptr && op->call_type == Call::Halide) {
      TensorKey key{op->func, op->value_index};
      auto it = buf_map_.find(key);
      CHECK(it != buf_map_.end())
          << "Cannot find allocated buffer for " << key.f;
      const BufferEntry& e = it->second;
      CHECK(!e.released)
          << "Read a buffer that is already out of scope";
250
      return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype);
251 252 253 254 255
    } else {
      return expr;
    }
  }

256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
  Stmt Mutate_(const Prefetch *op, const Stmt &s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Prefetch>();
    CHECK(op != nullptr);
    TensorKey key{op->func, op->value_index};
    auto it = buf_map_.find(key);
    CHECK(it != buf_map_.end())
        << "Cannot find allocated buffer for " << key.f;
    const BufferEntry& e = it->second;

    CHECK(!e.released)
        << "Read a buffer that is already out of scope";
    CHECK_EQ(e.buffer->shape.size(), op->bounds.size())
      << "Prefetch dim should be the same as buffer dim";

    int block_size = 1,
        elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(),
        shape = 0;

    int starts = op->bounds.size() - 1;
    while (starts > 0 && arith::GetConstInt(e.buffer->shape[starts], &shape)
        && elem_cnt >= block_size * shape) {
      block_size *= shape;
      starts--;
    }
    Expr stride(elem_cnt / block_size);

    Array<Expr> args;
    std::vector<VarExpr> vars;

    for (int i = op->bounds.size() - 1; i > starts; --i) {
      args.push_back(op->bounds[i]->min);
    }
    auto &func_name = op->func->func_name();
    vars.push_back(VarExpr("prefetch." + func_name + "." + std::to_string(starts), Int(32)));
    args.push_back(op->bounds[starts]->min + stride * vars.back());
    for (int i = starts - 1; i >= 0; --i) {
      vars.push_back(VarExpr("prefetch." + func_name + "." + std::to_string(i), Int(32)));
      args.push_back(vars.back() + op->bounds[i]->min);
    }
    for (int i = starts; i >= 0; --i) {
      if (i < starts) {
        stmt = For::make(
            vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::Host, stmt);
      } else {
301
        Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
302 303 304 305 306 307 308 309 310 311
        Expr address = Call::make(Handle(), tvm_address_of, {load}, Call::PureIntrinsic);
        Expr prefetch = Call::make(op->type, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic);
        stmt = Evaluate::make(prefetch);
        Expr extent = (op->bounds[i]->extent - 1) / stride + 1;
        stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::Host, stmt);
      }
    }
    return stmt;
  }

312
 private:
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
  // The specific tensor data layout is not determined before
  // StorageFlatten pass. We use buffer_bind_scope
  // to specify before hand we want to bind a subregion
  // of tensor to a symbolic buffer, which get used in extern.
  //
  // Example:
  //
  // realize A in range [i*4, extent=10) {
  //   bind Ab to A in [i*4+1, extent=4) {
  //     call_func(Ab.ptr, Ab.shape[0])
  //   }
  // }
  //
  // After StorageFlatten
  //
  // alloc A[10]
  //   call(A + 1,  4)
  //
  // Buffer is a protocol to declare specific
  // data layout and shape we expect.
  // So this function need to check:
  // - If the bind range is within the realize range
  // - If we can match the requirement of buffer
  // - Remap variables such as Ab.ptr to the actual value.
  //
  // Here are a few possible failure cases:
  // - Buffer is declared to have constant shape,
  //   but we try to bind it to a different one.
  // - Buffer is declared to be compact(no strides)
  //   but this binded region is a subregion of
  //   a matrix(tensor), which means it requires strides.
  //
  // We do support a few relaxed case, such as bindingx
  // region with shape [1, 1, n, m] to buffer with shape [n, m]
347 348 349 350 351 352 353 354 355
  Stmt HandleBufferBindScope(const AttrStmt* op) {
    Array<NodeRef> arr(op->node.node_);
    CHECK_EQ(arr.size(), 2U);
    const BufferNode* buffer = arr[0].as<BufferNode>();
    const TensorNode* tensor = arr[1].as<TensorNode>();
    const Call* tuple = op->value.as<Call>();
    CHECK(buffer && tensor);
    CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
    TensorKey key{tensor->op, tensor->value_index};
356 357
    CHECK(buf_map_.count(key))
        << "Cannot find buffer of " << tensor->op << " value=" << tensor->value_index;
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
    const BufferEntry& be = buf_map_.at(key);
    CHECK(!be.released);
    CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
    Array<Expr> begins, extents;
    if (be.bounds.size() != 0) {
      CHECK_EQ(tuple->args.size(), be.bounds.size() * 2);
      for (size_t i = 0; i < be.buffer->shape.size(); ++i) {
        begins.push_back(
            arith::ComputeExpr<Sub>(tuple->args[2 * i], be.bounds[i]->min));
        extents.push_back(tuple->args[2 * i + 1]);
      }
    } else {
      for (size_t i = 0; i < tuple->args.size(); i += 2) {
        begins.push_back(tuple->args[i]);
        extents.push_back(tuple->args[i + 1]);
      }
    }
    Buffer slice = be.buffer.MakeSlice(begins, extents);
    if (buffer->strides.size() == 0) {
      CHECK_EQ(slice->strides.size(), 0U)
378 379
          << "Trying to bind compact buffer to strided one strides="
          << slice->strides;
380 381 382 383
    } else {
      slice = slice.MakeStrideView();
    }
    // start binding
384
    ArgBinder binder(&var_remap_);
385
    binder.BindBuffer(Buffer(arr[0].node_), slice, buffer->name, true);
386
    // Apply the remaps
387 388 389
    Stmt body = MergeNest(binder.asserts(), op->body);
    body = MergeNest(binder.init_nest(), body);
    body = this->Mutate(body);
390
    // remove the binds
391 392
    for (const Var& v : binder.defs()) {
      var_remap_.erase(v.get());
393
    }
394
    return body;
395
  }
396 397 398 399 400
  // The buffer entry in the flatten map
  struct DimAlignInfo {
    int align_factor{0};
    int align_offset{0};
  };
401 402 403 404
  // The buffer entry in the flatten map
  struct BufferEntry {
    // the buffer of storage
    Buffer buffer;
405
    // the bounds of realization, can be null, means everything
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
    Region bounds;
    // Whether the buffer is external
    bool external{false};
    // Whether we are out of allocation bounds and buffer get released.
    bool released{false};
    // relative index
    inline Array<Expr> RelIndex(Array<Expr> args) const {
      if (bounds.size() != 0) {
        Array<Expr> index;
        CHECK_EQ(bounds.size(), args.size());
        for (size_t i = 0; i < bounds.size(); ++i) {
          index.push_back(args[i] - bounds[i]->min);
        }
        return index;
      } else {
        return args;
      }
    }
  };
  // The buffer assignment map
426 427 428
  // Variable remap
  std::unordered_map<const Variable*, Expr> var_remap_;
  // Buffer map
429
  std::unordered_map<TensorKey, BufferEntry> buf_map_;
430 431 432
  // Dimension alignment
  std::unordered_map<TensorKey, std::vector<DimAlignInfo> > dim_align_;
  // Storage scope
433 434 435
  std::unordered_map<const Node*, std::string> storage_scope_;
  // The current thread scope.
  std::vector<ThreadScope> curr_thread_scope_;
436 437
  // The size of cacheline
  int cache_line_size_;
438 439
  // The current stage is an OpenGL shader.
  bool is_opengl_{false};
440 441 442
};

Stmt StorageFlatten(Stmt stmt,
443 444 445
                    Map<Tensor, Buffer> extern_buffer,
                    int cache_line_size) {
  stmt = StorageFlattener(extern_buffer, cache_line_size).Mutate(stmt);
446 447 448 449 450
  return stmt;
}

}  // namespace ir
}  // namespace tvm