storage_flatten.cc 19.5 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22
/*!
 * \file storage_flatten.cc
 */
23 24
// Flattens storage from multi-dimensional array to 1D
// buffer access as in Halide pipeline.
25
#include <tvm/arith/analyzer.h>
26 27
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
28
#include <tvm/te/operation.h>
29 30 31 32
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/op.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/buffer.h>
33
#include <tvm/target/target_info.h>
34
#include <tvm/runtime/device_api.h>
35
#include <unordered_map>
36 37
#include "ir_util.h"
#include "arg_binder.h"
38 39 40
#include "../../arith/compute_expr.h"
#include "../../arith/ir_visitor_with_analyzer.h"
#include "../../runtime/thread_storage_scope.h"
41 42

namespace tvm {
43
namespace tir {
44

45
using runtime::StorageRank;
46 47
using runtime::StorageScope;
using runtime::ThreadScope;
48
using intrinsic::tvm_address_of;
49

50
class StorageFlattener : public StmtExprMutator {
51
 public:
52
  explicit StorageFlattener(Map<te::Tensor, Buffer> extern_buffer,
53 54 55 56
                            int cache_line_size, bool create_bound_attributes,
                            IRVisitorWithAnalyzer* bounded_analyzer)
      : bounded_analyzer_(bounded_analyzer),
        create_bound_attributes_(create_bound_attributes) {
57 58 59 60 61 62
    for (auto kv : extern_buffer) {
      BufferEntry e;
      e.buffer = kv.second;
      e.external = true;
      buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e;
    }
63
    cache_line_size_ = cache_line_size;
64
  }
65

66
  Stmt VisitStmt_(const StoreNode* op) final {
67
    Stmt stmt = StmtExprMutator::VisitStmt_(op);
68
    op = stmt.as<StoreNode>();
69 70 71
    auto it = var_remap_.find(op->buffer_var.get());
    if (it != var_remap_.end() &&
        !it->second.same_as(op->buffer_var)) {
72
      CHECK(it->second.as<VarNode>());
73
      Var buf_var = Downcast<Var>(it->second);
74
      return StoreNode::make(buf_var, op->value, op->index, op->predicate);
75 76 77 78
    } else {
      return stmt;
    }
  }
79

80
  Stmt VisitStmt_(const AttrStmtNode* op) final {
81
    if (op->attr_key == attr::realize_scope) {
82
      storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
83
      return this->VisitStmt(op->body);
84
    } else if (op->attr_key == attr::double_buffer_scope &&
85 86
               op->node->IsInstance<te::OperationNode>()) {
      auto func = Downcast<te::Operation>(op->node);
87
      Stmt body = this->VisitStmt(op->body);
88 89 90 91 92
      for (int i = 0; i < func->num_outputs(); ++i) {
        TensorKey key{func, i};
        auto it = buf_map_.find(key);
        CHECK(it != buf_map_.end())
            << "Cannot find allocated buffer for " << key.f;
93
        body = AttrStmtNode::make(
94 95 96
            it->second.buffer->data, op->attr_key, op->value, body);
      }
      return body;
97
    } else if (op->attr_key == attr::thread_extent) {
98
      IterVar iv = Downcast<IterVar>(op->node);
99 100
      ThreadScope ts = ThreadScope::make(iv->thread_tag);
      curr_thread_scope_.push_back(ts);
101
      Stmt stmt = StmtExprMutator::VisitStmt_(op);
102 103
      curr_thread_scope_.pop_back();
      return stmt;
104 105
    } else if (op->attr_key == attr::buffer_bind_scope) {
      return HandleBufferBindScope(op);
106
    } else if (op->attr_key == attr::buffer_dim_align) {
107
      auto tensor = Downcast<te::Tensor>(op->node);
108
      const CallNode* tuple = op->value.as<CallNode>();
109 110 111
      CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
      TensorKey key{tensor->op, tensor->value_index};
      auto& vinfo = dim_align_[key];
112
      int dim = tuple->args[0].as<IntImmNode>()->value;
113 114 115
      if (static_cast<size_t>(dim) >= vinfo.size()) {
        vinfo.resize(dim + 1);
      }
116 117
      vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
      vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
118
      return this->VisitStmt(op->body);
119 120
    } else if (op->attr_key == attr::opengl_stage_scope) {
      is_opengl_ = true;
121
    }
122
    return StmtExprMutator::VisitStmt_(op);
123 124
  }

125
  Stmt VisitStmt_(const ProvideNode* op) final {
126 127
    if (create_bound_attributes_)
      shape_collector_.clear();
128
    Stmt stmt = StmtExprMutator::VisitStmt_(op);
129
    op = stmt.as<ProvideNode>();
130 131 132 133 134 135 136
    TensorKey key{op->func, op->value_index};
    auto it = buf_map_.find(key);
    CHECK(it != buf_map_.end())
        << "Cannot find allocated buffer for " << key.f;
    const BufferEntry& e = it->second;
    CHECK(!e.released)
        << "Read a buffer that is already out of scope";
137
    if (is_opengl_) {
138
      return EvaluateNode::make(CallNode::make(
139
          DataType(),
140
          CallNode::glsl_texture_store,
141
          {e.buffer->data, op->value},
142
          CallNode::Intrinsic));
143
    } else {
144 145 146 147 148 149 150 151
      Stmt body = e.buffer.vstore(e.RelIndex(op->args), op->value);
      if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
        shape_collector_.push_back(
            std::make_pair(e.buffer->data, e.buffer->shape));
      }
      // To create bound attribute collector should has at least one item.
      if (create_bound_attributes_ && shape_collector_.size()) {
        for (size_t i = 0; i < shape_collector_.size(); ++i) {
152
          body = AttrStmtNode::make(
153
              shape_collector_[i].first, tir::attr::buffer_bound,
154 155 156 157
              MakeBound(e.buffer->dtype, shape_collector_[i].second), body);
        }
      }
      return body;
158
    }
159 160
  }

161
  Stmt VisitStmt_(const RealizeNode* op) final {
162 163 164
    TensorKey key{op->func, op->value_index};
    if (buf_map_.count(key)) {
      CHECK(buf_map_.at(key).external);
165
      return this->VisitStmt(op->body);
166 167 168 169
    } else {
      // create a buffer entry
      BufferEntry e;
      e.bounds = op->bounds;
170
      Array<PrimExpr> shape;
171 172 173 174 175
      for (auto r : e.bounds) {
        shape.push_back(r->extent);
      }
      // deduce current storage scope.
      auto it = storage_scope_.find(op->func.get());
176 177 178
      CHECK(it != storage_scope_.end())
          << "Cannot find storage scope of " << op->func
          << " value_index=" << op->value_index;
179 180 181
      StorageScope skey;
      const std::string& strkey = it->second;
      if (strkey.length() == 0) {
182
        if (curr_thread_scope_.size() != 0) {
183 184
          skey.rank = runtime::DefaultStorageRank(
              curr_thread_scope_.back().rank);
185 186
        }
      } else {
187
        skey = StorageScope::make(strkey);
188
      }
189

190
      // use small alignment for small arrays
191
      int32_t const_size = AllocateNode::constant_allocation_size(shape);
192
      int align = GetTempAllocaAlignment(op->dtype, const_size);
193 194 195
      if (skey.tag.length() != 0) {
        MemoryInfo info = GetMemoryInfo(skey.to_string());
        if (info.defined()) {
196 197
          align = (info->max_simd_bits + op->dtype.bits() - 1) / op->dtype.bits();
          CHECK_LE(const_size * op->dtype.bits(), info->max_num_bits)
198
              << "Allocation exceed bound of memory tag " << skey.to_string();
199 200
        }
      }
201
      Array<PrimExpr> strides;
202
      if (dim_align_.count(key) != 0 && shape.size() != 0) {
203
        std::vector<PrimExpr> rstrides;
204
        const std::vector<DimAlignInfo>& avec = dim_align_[key];
205
        int first_dim = 0;
206
        PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
207 208 209
        for (size_t i = shape.size(); i != 0; --i) {
          size_t dim = i - 1;
          if (dim < avec.size() && avec[dim].align_factor != 0) {
210 211
            PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
            PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
212
            stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
213
            stride = tir::Simplify(stride);
214 215
          }
          rstrides.push_back(stride);
216
          stride = stride * shape[dim];
217
        }
218
        strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
219
      }
220

221
      e.buffer = BufferNode::make(
222
          Var(key.GetName(), DataType::Handle()),
223
          op->dtype, shape, strides, PrimExpr(),
224
          key.GetName(), skey.to_string(),
225
          align, 0, kDefault);
226

227
      buf_map_[key] = e;
228
      Stmt body = this->VisitStmt(op->body);
229
      buf_map_[key].released = true;
230
      Stmt ret;
231

232
      DataType storage_type = e.buffer->dtype;
233
      // specially handle bool, lower its storage
234 235 236
      // type to beDataType::Int(8)(byte)
      if (storage_type == DataType::Bool()) {
        storage_type = DataType::Int(8);
237
      }
238
      if (strides.size() != 0) {
239
        int first_dim = 0;
240
        ret = AllocateNode::make(
241
            e.buffer->data, storage_type,
242
            {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
243
            make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
244
      } else {
245 246
        shape = e.buffer->shape;
        if (shape.size() == 0) {
247
          shape.push_back(make_const(DataType::Int(32), 1));
248
        }
249
        ret = AllocateNode::make(
250
            e.buffer->data, storage_type, shape,
251
            make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
252
      }
253
      ret = AttrStmtNode::make(
254
          e.buffer->data, attr::storage_scope,
255
          StringImmNode::make(e.buffer->scope), ret);
256 257

      if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
258
        ret = AttrStmtNode::make(e.buffer->data, tir::attr::buffer_bound,
259 260
                             MakeBound(e.buffer->dtype, e.buffer->shape), ret);
      }
261
      return ret;
262 263 264
    }
  }

265 266
  PrimExpr VisitExpr_(const LoadNode* op) final {
    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
267
    op = expr.as<LoadNode>();
268 269 270
    auto it = var_remap_.find(op->buffer_var.get());
    if (it != var_remap_.end() &&
        !it->second.same_as(op->buffer_var)) {
271
      CHECK(it->second.as<VarNode>());
272
      Var buf_var = Downcast<Var>(it->second);
273
      return LoadNode::make(op->dtype, buf_var, op->index, op->predicate);
274 275 276 277 278
    } else {
      return expr;
    }
  }

279
  PrimExpr VisitExpr_(const VarNode* op) final {
280 281
    auto it = var_remap_.find(op);
    if (it != var_remap_.end()) {
282 283
      return it->second;
    } else {
284
      return GetRef<PrimExpr>(op);
285 286 287
    }
  }

288 289
  PrimExpr VisitExpr_(const CallNode* op) final {
    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
290 291
    op = expr.as<CallNode>();
    if (op != nullptr && op->call_type == CallNode::Halide) {
292 293 294 295 296 297 298
      TensorKey key{op->func, op->value_index};
      auto it = buf_map_.find(key);
      CHECK(it != buf_map_.end())
          << "Cannot find allocated buffer for " << key.f;
      const BufferEntry& e = it->second;
      CHECK(!e.released)
          << "Read a buffer that is already out of scope";
299 300 301 302 303

      if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
        shape_collector_.push_back(
            std::make_pair(e.buffer->data, e.buffer->shape));
      }
304
      return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype);
305 306 307 308 309
    } else {
      return expr;
    }
  }

310
  Stmt VisitStmt_(const PrefetchNode *op) final {
311
    Stmt stmt = StmtExprMutator::VisitStmt_(op);
312
    op = stmt.as<PrefetchNode>();
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
    CHECK(op != nullptr);
    TensorKey key{op->func, op->value_index};
    auto it = buf_map_.find(key);
    CHECK(it != buf_map_.end())
        << "Cannot find allocated buffer for " << key.f;
    const BufferEntry& e = it->second;

    CHECK(!e.released)
        << "Read a buffer that is already out of scope";
    CHECK_EQ(e.buffer->shape.size(), op->bounds.size())
      << "Prefetch dim should be the same as buffer dim";

    int block_size = 1,
        elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(),
        shape = 0;

    int starts = op->bounds.size() - 1;
    while (starts > 0 && arith::GetConstInt(e.buffer->shape[starts], &shape)
        && elem_cnt >= block_size * shape) {
      block_size *= shape;
      starts--;
    }
335
    PrimExpr stride(elem_cnt / block_size);
336

337 338
    Array<PrimExpr> args;
    std::vector<Var> vars;
339 340 341 342 343

    for (int i = op->bounds.size() - 1; i > starts; --i) {
      args.push_back(op->bounds[i]->min);
    }
    auto &func_name = op->func->func_name();
344
    vars.push_back(Var(
345
        "prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32)));
346 347
    args.push_back(op->bounds[starts]->min + stride * vars.back());
    for (int i = starts - 1; i >= 0; --i) {
348
      vars.push_back(Var(
349
          "prefetch." + func_name + "." + std::to_string(i), DataType::Int(32)));
350 351 352 353
      args.push_back(vars.back() + op->bounds[i]->min);
    }
    for (int i = starts; i >= 0; --i) {
      if (i < starts) {
354
        stmt = ForNode::make(
355
            vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
356
      } else {
357 358
        PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
        PrimExpr address = CallNode::make(
359
            DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic);
360
        PrimExpr prefetch = CallNode::make(
361 362
            op->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic);
        stmt = EvaluateNode::make(prefetch);
363
        PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1;
364
        stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
365 366 367 368 369
      }
    }
    return stmt;
  }

370
 private:
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
  // The specific tensor data layout is not determined before
  // StorageFlatten pass. We use buffer_bind_scope
  // to specify before hand we want to bind a subregion
  // of tensor to a symbolic buffer, which get used in extern.
  //
  // Example:
  //
  // realize A in range [i*4, extent=10) {
  //   bind Ab to A in [i*4+1, extent=4) {
  //     call_func(Ab.ptr, Ab.shape[0])
  //   }
  // }
  //
  // After StorageFlatten
  //
  // alloc A[10]
  //   call(A + 1,  4)
  //
  // Buffer is a protocol to declare specific
  // data layout and shape we expect.
  // So this function need to check:
  // - If the bind range is within the realize range
  // - If we can match the requirement of buffer
  // - Remap variables such as Ab.ptr to the actual value.
  //
  // Here are a few possible failure cases:
  // - Buffer is declared to have constant shape,
  //   but we try to bind it to a different one.
  // - Buffer is declared to be compact(no strides)
  //   but this binded region is a subregion of
  //   a matrix(tensor), which means it requires strides.
  //
  // We do support a few relaxed case, such as bindingx
  // region with shape [1, 1, n, m] to buffer with shape [n, m]
405
  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
406
    Array<ObjectRef> arr = Downcast<Array<ObjectRef> > (op->node);
407 408
    CHECK_EQ(arr.size(), 2U);
    const BufferNode* buffer = arr[0].as<BufferNode>();
409
    const te::TensorNode* tensor = arr[1].as<te::TensorNode>();
410
    const CallNode* tuple = op->value.as<CallNode>();
411 412 413
    CHECK(buffer && tensor);
    CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
    TensorKey key{tensor->op, tensor->value_index};
414 415
    CHECK(buf_map_.count(key))
        << "Cannot find buffer of " << tensor->op << " value=" << tensor->value_index;
416 417 418
    const BufferEntry& be = buf_map_.at(key);
    CHECK(!be.released);
    CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
419
    Array<PrimExpr> begins, extents;
420 421 422
    if (be.bounds.size() != 0) {
      CHECK_EQ(tuple->args.size(), be.bounds.size() * 2);
      for (size_t i = 0; i < be.buffer->shape.size(); ++i) {
423
        begins.push_back(tuple->args[2 * i] - be.bounds[i]->min);
424 425 426 427 428
        extents.push_back(tuple->args[2 * i + 1]);
      }
    } else {
      for (size_t i = 0; i < tuple->args.size(); i += 2) {
        begins.push_back(tuple->args[i]);
429 430
        auto new_extent = bounded_analyzer_->Simplify(tuple->args[i+1]);
        extents.push_back(new_extent);
431 432 433 434 435
      }
    }
    Buffer slice = be.buffer.MakeSlice(begins, extents);
    if (buffer->strides.size() == 0) {
      CHECK_EQ(slice->strides.size(), 0U)
436 437
          << "Trying to bind compact buffer to strided one strides="
          << slice->strides;
438 439 440 441
    } else {
      slice = slice.MakeStrideView();
    }
    // start binding
442
    ArgBinder binder(&var_remap_);
443
    binder.BindBuffer(Downcast<Buffer>(arr[0]), slice, buffer->name, true);
444
    // Apply the remaps
445 446
    Stmt body = MergeNest(binder.asserts(), op->body);
    body = MergeNest(binder.init_nest(), body);
447
    body = this->VisitStmt(body);
448
    // remove the binds
449 450
    for (const Var& v : binder.defs()) {
      var_remap_.erase(v.get());
451
    }
452
    return body;
453
  }
454 455 456 457 458
  // The buffer entry in the flatten map
  struct DimAlignInfo {
    int align_factor{0};
    int align_offset{0};
  };
459 460 461 462
  // The buffer entry in the flatten map
  struct BufferEntry {
    // the buffer of storage
    Buffer buffer;
463
    // the bounds of realization, can be null, means everything
464 465 466 467 468 469
    Region bounds;
    // Whether the buffer is external
    bool external{false};
    // Whether we are out of allocation bounds and buffer get released.
    bool released{false};
    // relative index
470
    inline Array<PrimExpr> RelIndex(Array<PrimExpr> args) const {
471
      if (bounds.size() != 0) {
472
        Array<PrimExpr> index;
473 474 475 476 477 478 479 480 481 482
        CHECK_EQ(bounds.size(), args.size());
        for (size_t i = 0; i < bounds.size(); ++i) {
          index.push_back(args[i] - bounds[i]->min);
        }
        return index;
      } else {
        return args;
      }
    }
  };
483

484
  bool ShapeIsValid(const Array<PrimExpr> &shape) {
485 486 487 488 489
    // Zero-dimensional tensor does not need boundary check.
    if (!shape.size())
      return false;

    for (size_t i = 0; i < shape.size(); ++i) {
490
      if (!shape[i].defined() || !shape[i].dtype().is_scalar() ||
491 492 493 494 495 496 497
          is_negative_const(shape[i])) {
        return false;
      }
    }
    return true;
  }

498
  PrimExpr MakeBound(const DataType &type, const Array<PrimExpr> &shape) {
499
    // We have already checked the shape size to be greater then 0.
500
    PrimExpr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]);
501
    for (size_t i = 1; i < shape.size(); ++i) {
502 503
      bound = MulNode::make(
          bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i]));
504 505 506 507
    }
    return bound;
  }

508
  // The buffer assignment map
509
  // Variable remap
510
  std::unordered_map<const VarNode*, PrimExpr> var_remap_;
511
  // Buffer map
512
  std::unordered_map<TensorKey, BufferEntry> buf_map_;
513 514 515
  // Dimension alignment
  std::unordered_map<TensorKey, std::vector<DimAlignInfo> > dim_align_;
  // Storage scope
516
  std::unordered_map<const Object*, std::string> storage_scope_;
517 518
  // The current thread scope.
  std::vector<ThreadScope> curr_thread_scope_;
519
  // Collects shapes.
520
  std::vector<std::pair<Var, Array<PrimExpr>>> shape_collector_;
521 522 523
  // bounds populator. We really need the analyzer from it.
  // However
  IRVisitorWithAnalyzer* bounded_analyzer_;
524 525
  // The size of cacheline
  int cache_line_size_;
526 527
  // The current stage is an OpenGL shader.
  bool is_opengl_{false};
528 529
  // Whether to mark load/store with theirs bounds.
  bool create_bound_attributes_{false};
530 531
};

532
Stmt StorageFlatten(Stmt stmt, Map<te::Tensor, Buffer> extern_buffer,
533
                    int cache_line_size, bool create_bound_attributes) {
534
  IRVisitorWithAnalyzer bounded_analyzer;
535
  bounded_analyzer(stmt);
536
  stmt =
537
      StorageFlattener(extern_buffer, cache_line_size,
538
                       create_bound_attributes, &bounded_analyzer)(std::move(stmt));
539 540 541
  return stmt;
}

542
}  // namespace tir
543
}  // namespace tvm