storage_rewrite.cc 36 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 * \file storage_rewrite.cc
 * \brief Memory access pattern analysis and optimization.
 *  Re-write data access to enable memory sharing when possible.
 */
25
#include <tvm/arith/analyzer.h>
26 27
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
28
#include <tvm/ir_functor_ext.h>
29
#include <tvm/target/target_info.h>
30 31 32
#include <map>
#include <unordered_set>
#include <unordered_map>
33
#include "ir_util.h"
34
#include "../arith/compute_expr.h"
35
#include "../runtime/thread_storage_scope.h"
36 37 38 39

namespace tvm {
namespace ir {

40
using runtime::StorageRank;
41 42
using runtime::StorageScope;

43
// Find a linear pattern of storage access
44
// Used for liveness analysis.
45 46 47 48 49 50 51 52 53 54 55 56
// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
// before_scope -> scope_body -> after_scope
//
// The linear_seq_ stores before_scope and after_scope.
// The access to the arrays are stored at the after_scope point.
//
// Define "scope" as the body of For/thread_launch/IfThenElse
// This pass tries to detect last point that we need to keep memory
// alive under the same scope as allocate.
// The storage need to be kept alive between allocate and last access.
// The free point is only inserted at the same scope of allocate.
//
57
class LinearAccessPatternFinder final : public StmtExprVisitor {
58
 public:
59 60 61
  /*! \brief record the touch hist of statment. */
  struct StmtEntry {
    // The statment
62
    const Object* stmt;
63 64 65 66 67
    // The index in the linear_seq_ to point to end of the nested scope.
    // This is only set to non-zero if stmt is a nested scope.
    // if offset > 0, means this is the begin, the end entry is current_index + offset
    // if offset < 0, means this is the end, the begin entry is current_index + offset
    int64_t scope_pair_offset{0};
68
    // The buffer variables this statment touched.
69
    std::vector<const VarNode*> touched;
70
  };
71 72 73 74 75 76 77
  // The scope of each allocation
  struct AllocEntry {
    // Scope used for allocation.
    StorageScope storage_scope;
    // scope level
    size_t level{0};
    // allocation stmt
78
    const AllocateNode* alloc{nullptr};
79
  };
80

81
  void VisitStmt_(const AllocateNode* op) final {
82
    size_t level = scope_.size();
83
    const VarNode* buf = op->buffer_var.get();
84 85 86 87 88
    auto it = alloc_info_.find(buf);
    CHECK(it != alloc_info_.end());
    CHECK(it->second.alloc == nullptr);
    it->second.alloc = op;
    it->second.level = level;
89
    StmtExprVisitor::VisitStmt_(op);
90
  }
91
  void VisitStmt_(const StoreNode* op) final {
92 93
    scope_.push_back(StmtEntry());
    // visit subexpr
94
    StmtExprVisitor::VisitStmt_(op);
95
    // Add write access.
96
    const VarNode* buf = op->buffer_var.get();
97 98 99 100
    auto it = alloc_info_.find(buf);
    if (it != alloc_info_.end() && it->second.alloc) {
      CHECK_LT(it->second.level, scope_.size());
      scope_[it->second.level].touched.push_back(buf);
101 102 103
    }
    StmtEntry e = scope_.back();
    scope_.pop_back();
104
    if (e.touched.size() != 0) {
105 106 107 108
      e.stmt = op;
      linear_seq_.push_back(e);
    }
  }
109
  void VisitStmt_(const EvaluateNode* op) final {
110 111
    scope_.push_back(StmtEntry());
    // visit subexpr
112
    StmtExprVisitor::VisitStmt_(op);
113 114
    StmtEntry e = scope_.back();
    scope_.pop_back();
115
    if (e.touched.size() != 0) {
116 117 118 119
      e.stmt = op;
      linear_seq_.push_back(e);
    }
  }
120
  void VisitExpr_(const LoadNode* op) final {
121
    // Add write access.
122
    StmtExprVisitor::VisitExpr_(op);
123
    const VarNode* buf = op->buffer_var.get();
124 125 126
    auto it = alloc_info_.find(buf);
    if (it != alloc_info_.end() && it->second.alloc) {
      CHECK_LT(it->second.level, scope_.size())
127
          << "Load memory in places other than store.";
128
      scope_[it->second.level].touched.push_back(buf);
129 130
    }
  }
131
  void VisitExpr_(const CallNode* op) final {
132
    if (op->is_intrinsic(intrinsic::tvm_address_of)) {
133
      const LoadNode* l = op->args[0].as<LoadNode>();
134
      this->VisitExpr(l->index);
135
    } else {
136
      StmtExprVisitor::VisitExpr_(op);
137 138
    }
  }
139
  void VisitExpr_(const VarNode* buf) final {
140
    // Directly reference to the variable count as a read.
141 142 143 144 145
    auto it = alloc_info_.find(buf);
    if (it != alloc_info_.end() && it->second.alloc) {
      CHECK_LT(it->second.level, scope_.size())
          << " buf=" << buf->name_hint;
      scope_[it->second.level].touched.push_back(buf);
146 147 148 149 150 151 152
    }
  }
  template<typename T>
  void VisitNewScope(const T* op) {
    scope_.push_back(StmtEntry());
    StmtEntry e;
    e.stmt = op;
153
    int64_t begin_index =  static_cast<int64_t>(linear_seq_.size());
154 155
    // before scope.
    linear_seq_.push_back(e);
156
    StmtExprVisitor::VisitStmt_(op);
157
    // after scope.
158
    e.touched = std::move(scope_.back().touched);
159
    scope_.pop_back();
160 161 162
    int64_t end_index =  static_cast<int64_t>(linear_seq_.size());
    CHECK_GT(end_index, begin_index);
    e.scope_pair_offset = begin_index - end_index;
163
    linear_seq_.push_back(e);
164 165 166
    // record the pointer to end index.
    CHECK_NE(end_index, 0U);
    linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
167
  }
168
  void VisitStmt_(const AttrStmtNode* op) final {
169 170 171 172 173
    // Only record the outer most thread extent.
    if (op->attr_key == attr::thread_extent && !in_thread_env_) {
      in_thread_env_ = true;
      VisitNewScope(op);
      in_thread_env_ = false;
174 175
    } else if (op->attr_key == attr::extern_scope) {
      VisitNewScope(op);
176 177
    } else if (op->attr_key == attr::virtual_thread) {
      VisitNewScope(op);
178
    } else if (op->attr_key == attr::storage_scope) {
179
      const VarNode* buf = op->node.as<VarNode>();
180
      alloc_info_[buf].storage_scope =
181
          StorageScope::make(op->value.as<StringImmNode>()->value);
182
      StmtExprVisitor::VisitStmt_(op);
183
    } else {
184
      StmtExprVisitor::VisitStmt_(op);
185 186
    }
  }
187
  void VisitStmt_(const IfThenElseNode* op) final {
188 189 190
    VisitNewScope(op);
  }

191
  void VisitStmt_(const ForNode* op) final {
192 193 194
    VisitNewScope(op);
  }

195
  void VisitStmt_(const AssertStmtNode* op) final {
196 197 198
    VisitNewScope(op);
  }

199 200 201
  // linearized access sequence.
  std::vector<StmtEntry> linear_seq_;
  // The storage scope of each buffer
202
  std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
203

204 205 206 207 208
 private:
  // Whether already in thread env.
  bool in_thread_env_{false};
  // The scope stack.
  std::vector<StmtEntry> scope_;
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
};

// Verify if the statement can be run safely via inplace fashion
//
// Detect pattern: dst[index] = f(src[index])
//
// WARNING: the current detection algorithm cannot handle the case
// when a location in an array is written multiple times
//
// For example, the following program will pass the check,
// but we cannot make A and B to be the same array.
//
//  A[0] = B[0] + 1
//  A[0] = B[0] + 1
//
// The high level code generator needs to ensure that the generated
// code only write each location of the target array once.
//
// This is the case with IR generated by the current compute schedule.
// We explicitly return false if we find there is an extern block
// which can be arbitrary IR.
//
// Neve-the-less, inplace detector should be used with care in mind.
// We may also consider introduce a condition checker that checks
// if every index only visited once for an absolute sufficient condition.
//
// The code after inplace transformation is no longer idempotent.
//
237
class InplaceOpVerifier : public StmtExprVisitor {
238
 public:
239
  bool Check(const Object* stmt,
240 241
             const VarNode* dst,
             const VarNode* src) {
242 243 244
    dst_ = dst;
    src_ = src;
    result_ = true;
245 246 247 248 249 250 251 252
    if (stmt->IsInstance<AttrStmtNode>()) {
      VisitStmt_(static_cast<const AttrStmtNode*>(stmt));
    } else if (stmt->IsInstance<ForNode>()) {
      VisitStmt_(static_cast<const ForNode*>(stmt));
    } else if (stmt->IsInstance<IfThenElseNode>()) {
      VisitStmt_(static_cast<const IfThenElseNode*>(stmt));
    } else if (stmt->IsInstance<StoreNode>()) {
      VisitStmt_(static_cast<const StoreNode*>(stmt));
253 254 255 256 257 258
    } else {
      return false;
    }
    return result_;
  }

259
  using StmtExprVisitor::VisitStmt_;
260

261
  void VisitStmt(const Stmt& n) final {
262
    if (!result_) return;
263 264
    StmtExprVisitor::VisitStmt(n);
  }
265
  void VisitExpr(const PrimExpr& n) final {
266 267
    if (!result_) return;
    StmtExprVisitor::VisitExpr(n);
268 269
  }

270
  void VisitExpr_(const VarNode* op) final {
271 272 273 274 275 276
    // assume all opaque access is unsafe
    if (op == dst_ || op == src_) {
      result_ = false; return;
    }
  }

277
  void VisitStmt_(const StoreNode* op) final {
278
    ++mem_nest_;
279
    this->VisitExpr(op->index);
280 281 282
    --mem_nest_;
    if (op->buffer_var.get() == dst_) {
      store_ = op;
283 284
      this->VisitExpr(op->value);
      this->VisitExpr(op->predicate);
285 286
      store_ = nullptr;
    } else {
287 288
      this->VisitExpr(op->value);
      this->VisitExpr(op->predicate);
289 290 291
    }
  }

292
  void VisitStmt_(const AttrStmtNode* op) final {
293 294 295 296 297
    // always reject extern code
    if (op->attr_key == attr::extern_scope ||
        op->attr_key == attr::volatile_scope) {
      result_ = false; return;
    }
298
    StmtExprVisitor::VisitStmt_(op);
299 300
  }

301 302
  void VisitExpr_(const LoadNode* op) final {
    const VarNode* buf = op->buffer_var.get();
303 304 305 306 307 308 309 310 311 312
    // cannot read from dst_ (no reduction)
    if (buf == dst_) {
      result_ = false; return;
    }
    // do not allow indirect memory load
    if (mem_nest_ != 0) {
      result_ = false; return;
    }
    if (src_ == buf) {
      if (store_ == nullptr ||
313
          store_->value.dtype() != op->dtype ||
314 315 316 317 318
          !ir::Equal(store_->index, op->index)) {
        result_ = false; return;
      }
    }
    ++mem_nest_;
319
    StmtExprVisitor::VisitExpr_(op);
320 321 322 323 324 325 326 327
    --mem_nest_;
  }


 private:
  // result of the check
  bool result_{true};
  // destination memory
328
  const VarNode* dst_;
329
  // source variable
330
  const VarNode* src_;
331 332 333 334
  // counter of load,
  // it is not safe to inplace when there is nested load like A[B[i]]
  int mem_nest_{0};
  // The current store to be inspected
335
  const StoreNode* store_{nullptr};
336 337 338
};

// Planner to plan and rewrite memory allocation.
339
class StoragePlanRewriter : public StmtExprMutator {
340
 public:
341
  using StmtEntry = LinearAccessPatternFinder::StmtEntry;
342
  using AllocEntry = LinearAccessPatternFinder::AllocEntry;
343

344 345 346 347
  Stmt Rewrite(Stmt stmt, bool detect_inplace) {
    detect_inplace_ = detect_inplace;
    // plan the rewrite
    LinearAccessPatternFinder finder;
348
    finder(stmt);
349 350
    this->LivenessAnalysis(finder.linear_seq_);
    this->PlanMemory(finder.linear_seq_, finder.alloc_info_);
351
    this->PrepareNewAlloc();
352
    // start rewrite
353
    stmt = operator()(std::move(stmt));
354 355 356
    if (attach_map_.count(nullptr)) {
      std::vector<Stmt> nest;
      for (StorageEntry* e : attach_map_.at(nullptr)) {
357
        // CHECK_EQ(e->scope.rank, 0);
358
        if (e->new_alloc.defined()) {
359
          nest.emplace_back(AttrStmtNode::make(
360
              e->alloc_var, attr::storage_scope,
361 362
              StringImmNode::make(e->scope.to_string()),
              EvaluateNode::make(0)));
363 364
          nest.push_back(e->new_alloc);
        }
365 366 367 368 369
      }
      stmt = MergeNest(nest, stmt);
    }
    return stmt;
  }
370
  Stmt VisitStmt_(const StoreNode* op) final {
371
    Stmt stmt = StmtExprMutator::VisitStmt_(op);
372
    op = stmt.as<StoreNode>();
373 374
    auto it = alloc_map_.find(op->buffer_var.get());
    if (it == alloc_map_.end()) return stmt;
375
    return StoreNode::make(it->second->alloc_var,
376
                       op->value,
377
                       RemapIndex(op->value.dtype(), op->index, it->second),
378
                       op->predicate);
379
  }
380 381
  PrimExpr VisitExpr_(const LoadNode* op) final {
    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
382
    op = expr.as<LoadNode>();
383 384
    auto it = alloc_map_.find(op->buffer_var.get());
    if (it == alloc_map_.end()) return expr;
385
    return LoadNode::make(op->dtype,
386
                      it->second->alloc_var,
387
                      RemapIndex(op->dtype, op->index, it->second),
388
                      op->predicate);
389
  }
390
  PrimExpr VisitExpr_(const VarNode* op) final {
391 392
    auto it = alloc_map_.find(op);
    if (it != alloc_map_.end()) {
393
      if (it->second->bits_offset != 0) {
394 395
        LOG(WARNING) << "Use a merged buffer variable address, could cause error";
      }
396 397
      return it->second->alloc_var;
    } else {
398
      return GetRef<PrimExpr>(op);
399 400
    }
  }
401
  PrimExpr VisitExpr_(const CallNode* op) final {
402 403
    if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
      CHECK_EQ(op->args.size(), 5U);
404
      DataType dtype = op->args[0].dtype();
405
      const VarNode* buffer = op->args[1].as<VarNode>();
406
      auto it = alloc_map_.find(buffer);
407 408 409 410
      if (it == alloc_map_.end()) {
        return StmtExprMutator::VisitExpr_(op);
      }
      const StorageEntry* se = it->second;
411 412
      PrimExpr offset = this->VisitExpr(op->args[2]);
      PrimExpr extent = this->VisitExpr(op->args[3]);
413 414 415 416 417
      uint64_t elem_bits = dtype.bits() * dtype.lanes();
      CHECK_EQ(se->bits_offset % elem_bits, 0U);
      if (se->bits_offset != 0) {
        offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset;
      }
418
      return CallNode::make(
419 420 421
          op->dtype, op->name,
          {op->args[0], se->alloc_var, offset, extent, op->args[4]},
          op->call_type);
422
    } else {
423
      return StmtExprMutator::VisitExpr_(op);
424 425
    }
  }
426

427
  Stmt VisitStmt_(const AttrStmtNode* op) final {
428
    if (op->attr_key == attr::storage_scope) {
429
      return this->VisitStmt(op->body);
430
    } else if (op->attr_key == attr::thread_extent ||
431
               op->attr_key == attr::virtual_thread ||
432
               attr::IsPragmaKey(op->attr_key)) {
433
      // remake all the allocation at the attach scope.
434
      if (attach_map_.count(op)) {
435
        auto& svec = attach_map_[op];
436
        Stmt stmt = StmtExprMutator::VisitStmt_(op);
437 438
        op = stmt.as<AttrStmtNode>();
        return AttrStmtNode::make(
439 440
            op->node, op->attr_key, op->value,
            MakeAttach(svec, op->body));
441
      } else {
442
        return StmtExprMutator::VisitStmt_(op);
443 444
      }
    } else if (op->attr_key == attr::volatile_scope) {
445
      Stmt stmt = StmtExprMutator::VisitStmt_(op);
446 447
      op = stmt.as<AttrStmtNode>();
      auto it = alloc_map_.find(op->node.as<VarNode>());
448
      if (it == alloc_map_.end()) return stmt;
449
      return AttrStmtNode::make(
450 451
          it->second->alloc_var, op->attr_key, op->value, op->body);
    } else {
452
      return StmtExprMutator::VisitStmt_(op);
453 454
    }
  }
455
  Stmt VisitStmt_(const ForNode* op) final {
456 457
    CHECK(op->for_type != ForType::Vectorized)
        << "VectorizeLoop before LiftStorageAlloc";
458 459 460
    // remake all the allocation at the attach scope.
    if (attach_map_.count(op)) {
      auto& svec = attach_map_[op];
461
      Stmt stmt = StmtExprMutator::VisitStmt_(op);
462 463
      op = stmt.as<ForNode>();
      return ForNode::make(
464 465 466
          op->loop_var, op->min, op->extent, op->for_type, op->device_api,
          MakeAttach(svec, op->body));
    } else {
467
      return StmtExprMutator::VisitStmt_(op);
468
    }
469
  }
470

471
  Stmt VisitStmt_(const AllocateNode* op) final {
472
    return this->VisitStmt(op->body);
473 474 475 476 477 478 479
  }

 private:
  struct StorageEntry {
    // The scope that this alloc attaches after
    // For shared/local memory it is beginning of the thread extent.
    // for global memory it is nullptr, means beginning of everything.
480
    const Object* attach_scope_{nullptr};
481
    // The constant size of the buffer in bits, only used if it is constant
482
    uint64_t const_nbits{0};
483 484 485
    // The storage scope.
    StorageScope scope;
    // Allocs that shares this entry.
486
    std::vector<const AllocateNode*> allocs;
487 488 489 490
    // The children of this entry, not including itself.
    std::vector<StorageEntry*> merged_children;
    // The replacement allocation, if any.
    Stmt new_alloc;
491
    // The var expr of new allocation.
492
    Var alloc_var;
493
    // The allocation element type.
494
    DataType elem_type;
495
    // This is non-zero if this allocate is folded into another one
496 497 498 499 500 501 502 503 504 505 506
    // the address(in bits) becomes alloc_var + bits_offset;
    // can be effectively converted to the element type.
    // We need to convert bit_offset to offset of specific element type later.
    //
    // We use bits(instead of bytes) to support non-conventional indexing in hardware.
    // When we are merging buffer together, the bits_offset are set to be aligned
    // to certain value given by the max_simd_bits property of the special memory.
    //
    // This allows effective sharing among different types as long as their alignment
    // requirement fits into the max_simd_bits.
    uint64_t bits_offset{0};
507
  };
508 509 510 511 512

  // Alllocate entry of node.
  // Event entry in liveness analysis
  struct EventEntry {
    // variables we generate
513
    std::vector<const VarNode*> gen;
514
    // variables we kill
515
    std::vector<const VarNode*> kill;
516 517
  };

518 519 520 521
  Stmt MakeAttach(const std::vector<StorageEntry*>& svec,
                  Stmt body) {
    std::vector<Stmt> nest;
    for (StorageEntry* e : svec) {
522
      if (e->new_alloc.defined()) {
523
        nest.emplace_back(AttrStmtNode::make(
524
            e->alloc_var, attr::storage_scope,
525 526
            StringImmNode::make(e->scope.to_string()),
            EvaluateNode::make(0)));
527 528
        nest.push_back(e->new_alloc);
      }
529 530 531
    }
    return MergeNest(nest, body);
  }
532
  // Remap the index
533
  PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) {
534 535 536
    if (e->bits_offset == 0) return index;
    uint64_t elem_bits = dtype.bits() * dtype.lanes();
    CHECK_EQ(e->bits_offset % elem_bits, 0U);
537
    return make_const(index.dtype(), e->bits_offset / elem_bits) + index;
538
  }
539 540 541 542
  // Prepare the new allocations
  void PrepareNewAlloc() {
    for (size_t i = 0; i < alloc_vec_.size(); ++i) {
      StorageEntry* e = alloc_vec_[i].get();
543 544 545 546
      attach_map_[e->attach_scope_].push_back(e);
    }
    // find allocation via attach map.
    for (auto &kv : attach_map_) {
547
      // find the element with the most amount of bytes.
548 549 550 551 552 553 554 555 556 557 558 559 560
      std::vector<StorageEntry*>& vec = kv.second;
      // try to find merge, for tagged memory
      for (size_t i = 0; i < vec.size(); ++i) {
        StorageEntry* e = vec[i];
        if (e->scope.tag.length() != 0) {
          CHECK_NE(e->const_nbits, 0U)
              << "Special tagged memory must be const size";
          for (size_t j = 0; j < i; ++j) {
            if (e->scope == vec[j]->scope) {
              vec[j]->merged_children.push_back(e);
              break;
            }
          }
561 562
        }
      }
563 564 565 566
      // Start allocation
      for (size_t i = 0; i < vec.size(); ++i) {
        StorageEntry* e = vec[i];
        // already merged
567
        if (e->bits_offset != 0) continue;
568 569 570 571 572
        if (e->merged_children.size() != 0) {
          NewAllocTagMerged(e); continue;
        }
        // Get the allocation size;
        e->alloc_var = e->allocs[0]->buffer_var;
573
        DataType alloc_type = e->allocs[0]->dtype;
574
        for (const AllocateNode* op : e->allocs) {
575 576
          if (op->dtype.lanes() > alloc_type.lanes()) {
            alloc_type = op->dtype;
577
          }
578 579 580
        }
        if (e->allocs.size() == 1) {
          // simply use the original allocation.
581
          PrimExpr sz = arith::ComputeReduce<MulNode>(e->allocs[0]->extents,
582
                                              make_const(DataType::Int(32), 1));
583
          e->new_alloc = AllocateNode::make(
584
              e->alloc_var, alloc_type, {sz},
585
              e->allocs[0]->condition, EvaluateNode::make(0));
586 587 588 589 590 591
          if (e->scope.tag.length() != 0) {
            MemoryInfo info = GetMemoryInfo(e->scope.to_string());
            uint64_t total_elem = e->const_nbits / e->elem_type.bits();
            CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
                << "Allocation exceed bound of memory tag " << e->scope.to_string();
          }
592 593
        } else {
          // Build a merged allocation
594
          PrimExpr combo_size;
595
          for (const AllocateNode* op : e->allocs) {
596 597
            PrimExpr sz = arith::ComputeReduce<MulNode>(
                op->extents, make_const(DataType::Int(32), 1));
598
            auto nbits = op->dtype.bits() * op->dtype.lanes();
599
            if (const auto* imm = sz.as<IntImmNode>()) {
600 601 602 603 604 605
              if (imm->value > std::numeric_limits<int>::max() / nbits) {
                LOG(WARNING) << "The allocation requires : " << imm->value
                             << " * " << nbits
                             << " bits, which is greater than the maximum of"
                                " int32. The size is cast to int64."
                             << "\n";
606
                sz = make_const(DataType::Int(64), imm->value);
607 608
              }
            }
609
            // transform to bits
610
            auto sz_nbits = sz * nbits;
611
            if (combo_size.defined()) {
612
              combo_size = max(combo_size, sz_nbits);
613
            } else {
614
              combo_size = sz_nbits;
615
            }
616
          }
617 618
          // transform to alloc bytes
          auto type_bits = alloc_type.bits() * alloc_type.lanes();
619 620
          bool divided = analyzer_.CanProve(indexmod(combo_size, type_bits) == 0);
          combo_size = indexdiv(combo_size, type_bits);
621 622
          // round up for can not divided
          if (!divided) {
623
            combo_size = combo_size + make_const(DataType::Int(32), 1);
624
          }
625
          combo_size = ir::Simplify(combo_size);
626
          e->new_alloc = AllocateNode::make(
627
              e->alloc_var, alloc_type, {combo_size}, const_true(),
628
              EvaluateNode::make(0));
629 630 631 632 633 634
          if (e->scope.tag.length() != 0) {
            MemoryInfo info = GetMemoryInfo(e->scope.to_string());
            uint64_t total_elem = e->const_nbits / e->elem_type.bits();
            CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
                << "Allocation exceed bound of memory tag " << e->scope.to_string();
          }
635 636
        }
      }
637 638 639 640 641 642 643 644
    }
  }
  // New allocation for merged data
  void NewAllocTagMerged(StorageEntry* e) {
    CHECK_NE(e->scope.tag.length(), 0U);
    // allocate with element type.
    CHECK_NE(e->const_nbits, 0U);
    MemoryInfo info = GetMemoryInfo(e->scope.to_string());
645
    uint64_t total_bits = e->const_nbits;
646 647
    // By default, align to 32 bits.
    size_t align = 32;
648
    if (info.defined()) {
649
      align = info->max_simd_bits;
650
    }
651 652
    // Always align to max_simd_bits
    // so we can remap types by keeping this property
653 654
    if (total_bits % align != 0) {
      total_bits += align  - (total_bits % align);
655 656 657
    }
    e->alloc_var = e->allocs[0]->buffer_var;
    for (StorageEntry* child : e->merged_children) {
658 659
      CHECK_NE(child->const_nbits, 0U);
      CHECK_NE(total_bits, 0U);
660
      child->bits_offset = total_bits;
661
      child->alloc_var = e->alloc_var;
662 663 664
      total_bits += child->const_nbits;
      if (total_bits % align != 0) {
        total_bits += align  - (total_bits % align);
665 666
      }
    }
667
    uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
668
    PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(),
669
                                 (total_bits + type_bits - 1) / type_bits);
670
    e->new_alloc = AllocateNode::make(
671
        e->alloc_var, e->elem_type, {alloc_size}, const_true(),
672
        EvaluateNode::make(0));
673
    if (info.defined()) {
674
      CHECK_LE(total_bits, info->max_num_bits)
675
          << "Allocation exceed bound of memory tag " << e->scope.to_string();
676 677
    }
  }
678 679 680
  // Liveness analysis to find gen and kill point of each variable.
  void LivenessAnalysis(const std::vector<StmtEntry>& seq) {
    // find kill point, do a reverse linear scan.
681
    std::unordered_set<const VarNode*> touched;
682 683
    for (size_t i = seq.size(); i != 0; --i) {
      const StmtEntry& s = seq[i - 1];
684
      for (const VarNode* buffer : s.touched) {
685 686
        if (!touched.count(buffer)) {
          touched.insert(buffer);
687 688 689 690 691 692 693 694 695 696
          event_map_[s.stmt].kill.push_back(buffer);
        }
      }
    }
    // find gen point, do forward scan
    touched.clear();
    for (size_t i = 0; i < seq.size(); ++i) {
      int64_t offset = seq[i].scope_pair_offset;
      if (offset < 0) continue;
      const StmtEntry& s = seq[i + offset];
697
      for (const VarNode* buffer : s.touched) {
698 699 700
        if (!touched.count(buffer)) {
          touched.insert(buffer);
          event_map_[s.stmt].gen.push_back(buffer);
701 702 703 704
        }
      }
    }
  }
705
  void PlanNewScope(const Object* op) {
706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728
    if (thread_scope_ != nullptr) {
      CHECK(thread_scope_ == op);
      // erase all memory atatched to this scope.
      for (auto it = const_free_map_.begin(); it != const_free_map_.end();) {
        if (it->second->attach_scope_ == op) {
          it = const_free_map_.erase(it);
        } else {
          ++it;
        }
      }
      for (auto it = sym_free_list_.begin(); it != sym_free_list_.end();) {
        if ((*it)->attach_scope_ == op) {
          it = sym_free_list_.erase(it);
        } else {
          ++it;
        }
      }
      thread_scope_ = nullptr;
    } else {
      thread_scope_ = op;
    }
  }

729
  // Memory plan algorithm
730
  void PlanMemory(const std::vector<StmtEntry>& seq,
731 732
                  const std::unordered_map<const VarNode*, AllocEntry>& alloc_info) {
    std::unordered_set<const VarNode*> inplace_flag;
733

734 735
    for (size_t i = 0; i < seq.size(); ++i) {
      const StmtEntry& s = seq[i];
736 737 738 739 740 741 742 743 744 745 746
      auto it = event_map_.find(seq[i].stmt);

      // scope_pair_offset >= 0 means it is either
      // - leaf stmt(offset = 0)
      // - beginning of scope(offset < 0)
      // In both cases, we need to handle the gen event correctly
      if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
        // Inplace operation detection
        // specially handle this
        bool detect_inplace = detect_inplace_ && (it->second.gen.size() <= 2);

747
        for (const VarNode* var : it->second.gen) {
748 749 750 751 752
          CHECK(alloc_info.count(var));
          const AllocEntry& ae = alloc_info.at(var);
          StorageEntry* dst_entry = nullptr;
          // inplace detection
          if (detect_inplace) {
753 754
            // only one inplace var for s.stmt
            bool inplace_found = false;
755
            for (const VarNode* src : it->second.kill) {
756 757 758 759 760
              if (!inplace_flag.count(src) && alloc_map_.count(src)) {
                InplaceOpVerifier visitor;
                StorageEntry* src_entry = alloc_map_.at(src);
                if (src_entry->scope == ae.storage_scope &&
                    src_entry->attach_scope_ == thread_scope_ &&
761
                    src_entry->elem_type == ae.alloc->dtype.element_of() &&
762
                    visitor.Check(s.stmt, var, src)) {
763 764
                  uint64_t const_nbits =
                      static_cast<uint64_t>(ae.alloc->constant_allocation_size()) *
765 766
                      ae.alloc->dtype.bits() *
                      ae.alloc->dtype.lanes();
767
                  if (src_entry->const_nbits == const_nbits && !inplace_found) {
768 769 770
                    // successfully inplace
                    dst_entry = src_entry;
                    inplace_flag.insert(src);
771
                    inplace_found = true;
772 773 774 775 776 777 778 779 780 781 782 783 784
                  }
                }
              }
            }
          }
          if (dst_entry == nullptr) {
            dst_entry = FindAlloc(ae.alloc, thread_scope_, ae.storage_scope);
          }
          dst_entry->allocs.emplace_back(ae.alloc);
          alloc_map_[var] = dst_entry;
        }
      }
      // enter/exit new scope
785 786
      if (s.stmt->IsInstance<AttrStmtNode>()) {
        const auto* op = static_cast<const AttrStmtNode*>(s.stmt);
787
        if (op->attr_key == attr::thread_extent ||
788 789
            op->attr_key == attr::virtual_thread ||
            attr::IsPragmaKey(op->attr_key)) {
790 791 792 793
          PlanNewScope(op);
        } else {
          CHECK(op->attr_key == attr::extern_scope);
        }
794 795
      } else if (s.stmt->IsInstance<ForNode>()) {
        const auto* op = static_cast<const ForNode*>(s.stmt);
796 797 798
        if (op->for_type == ForType::Parallel) {
          if (thread_scope_ == nullptr || thread_scope_ == op) {
            PlanNewScope(op);
799 800 801
          }
        }
      }
802 803 804 805 806
      // scope_pair_offset <= 0 means it is either
      // - leaf stmt(offset = 0)
      // - end of scope(offset < 0)
      // In both cases, we need to handle the kill event correctly
      if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
807
        for (const VarNode* var : it->second.kill) {
808 809 810 811
          // skip space which are already replaced by inplace
          if (!inplace_flag.count(var)) {
            this->Free(var);
          }
812 813 814 815 816
        }
      }
    }
  }
  // Allocate new storage entry.
817
  StorageEntry* NewAlloc(const AllocateNode* op,
818
                         const Object* attach_scope,
819
                         const StorageScope& scope,
820
                         size_t const_nbits) {
821
    CHECK(op != nullptr);
822 823
    // Re-use not successful, allocate a new buffer.
    std::unique_ptr<StorageEntry> entry(new StorageEntry());
824
    entry->attach_scope_ = attach_scope;
825
    entry->scope = scope;
826
    entry->elem_type = op->dtype.element_of();
827
    entry->const_nbits = const_nbits;
828 829 830 831
    StorageEntry* e = entry.get();
    alloc_vec_.emplace_back(std::move(entry));
    return e;
  }
832

833
  StorageEntry* FindAlloc(const AllocateNode* op,
834
                          const Object* attach_scope,
835
                          const StorageScope& scope) {
836
    CHECK(op != nullptr);
837 838
    // skip plan for local variable,
    // compiler can do a better job with register allocation.
839
    const uint64_t match_range = 16;
840
    uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
841
    uint64_t const_nbits = static_cast<uint64_t>(
842
        op->constant_allocation_size() * op_elem_bits);
843
    // disable reuse of small arrays, they will be lowered to registers in LLVM
844 845
    // This rules only apply if we are using non special memory
    if (scope.tag.length() == 0) {
846
      if (scope.rank >= StorageRank::kWarp || op->dtype.is_handle()) {
847 848 849 850 851
        return NewAlloc(op, attach_scope, scope, const_nbits);
      }
      if (const_nbits > 0  &&  const_nbits <= 32) {
        return NewAlloc(op, attach_scope, scope, const_nbits);
      }
852
    }
853
    if (const_nbits != 0) {
854
      // constant allocation.
855 856 857
      auto begin = const_free_map_.lower_bound(const_nbits / match_range);
      auto mid = const_free_map_.lower_bound(const_nbits);
      auto end = const_free_map_.upper_bound(const_nbits * match_range);
858
      // start looking at the buffer that is bigger than the required size first
859 860
      for (auto it = mid; it != end; ++it) {
        StorageEntry *e = it->second;
861
        if (e->attach_scope_ != attach_scope) continue;
862
        if (e->scope != scope) continue;
863 864
        // when not divided, no reuse, eg, float4 vs float3
        if (e->bits_offset % op_elem_bits != 0) continue;
865
        e->const_nbits = std::max(const_nbits, e->const_nbits);
866 867 868
        const_free_map_.erase(it);
        return e;
      }
869
      // then start looking at smaller buffers.
870 871 872
      for (auto it = mid; it != begin;) {
        --it;
        StorageEntry *e = it->second;
873
        if (e->attach_scope_ != attach_scope) continue;
874
        if (e->scope != scope) continue;
875
        if (e->elem_type != op->dtype.element_of()) continue;
876
        e->const_nbits = std::max(const_nbits, e->const_nbits);
877 878 879 880 881 882 883 884
        const_free_map_.erase(it);
        return e;
      }
    } else {
      // Simple strategy: round roubin.
      for (auto it = sym_free_list_.begin();
           it != sym_free_list_.end(); ++it) {
        StorageEntry* e = *it;
885
        if (e->attach_scope_ != attach_scope) continue;
886
        if (e->scope != scope) continue;
887
        if (e->elem_type != op->dtype.element_of()) continue;
888 889 890 891
        sym_free_list_.erase(it);
        return e;
      }
    }
892
    return NewAlloc(op, attach_scope, scope, const_nbits);
893 894
  }
  // simulated free.
895
  void Free(const VarNode* var) {
896 897 898
    auto it = alloc_map_.find(var);
    CHECK(it != alloc_map_.end());
    StorageEntry* e = it->second;
899
    CHECK_NE(e->allocs.size(), 0U);
900 901 902 903 904

    // disable reuse of small arrays, they will be lowered to registers in LLVM
    // This rules only apply if we are using non special memory
    if (e->scope.tag.length() == 0) {
      // Disable sharing of local memory.
905
      if (e->scope.rank >= StorageRank::kWarp ||
906
          e->allocs[0]->dtype.is_handle()) return;
907 908 909
      // disable reuse of small arrays
      if (e->const_nbits > 0 && e->const_nbits <= 32) return;
    }
910
    // normal free.
911 912
    if (e->const_nbits != 0) {
      const_free_map_.insert({e->const_nbits, e});
913 914 915 916 917
    } else {
      sym_free_list_.push_back(e);
    }
  }
  // thread scope.
918
  const Object* thread_scope_{nullptr};
919 920
  // whether enable inplace detection.
  bool detect_inplace_{false};
921
  // Locations of free ops.
922
  std::unordered_map<const Object*, EventEntry> event_map_;
923
  // constant size free map.
924
  std::multimap<uint64_t, StorageEntry*> const_free_map_;
925 926
  // symbolic free list, for non constant items.
  std::list<StorageEntry*> sym_free_list_;
927
  // The allocation attach map
928
  std::unordered_map<const Object*, std::vector<StorageEntry*> > attach_map_;
929
  // The allocation assign map
930
  std::unordered_map<const VarNode*, StorageEntry*> alloc_map_;
931 932
  // The allocations
  std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
933 934
  // analyzer
  arith::Analyzer analyzer_;
935 936
};

937 938
// Turn alloc into vector alloc
// if all its access is the same vector type.
939
class VectorAllocRewriter : public StmtExprMutator {
940
 public:
941
  PrimExpr VisitExpr_(const LoadNode* op) final {
942
    UpdateTypeMap(op->buffer_var.get(), op->dtype);
943
    return StmtExprMutator::VisitExpr_(op);
944 945
  }

946
  Stmt VisitStmt_(const StoreNode* op) final {
947
    UpdateTypeMap(op->buffer_var.get(), op->value.dtype());
948
    return StmtExprMutator::VisitStmt_(op);
949
  }
950
  PrimExpr VisitExpr_(const CallNode* op) final {
951
    if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
952
      DataType dtype = op->args[0].dtype();
953
      const VarNode* buffer = op->args[1].as<VarNode>();
954 955
      UpdateTypeMap(buffer, dtype);
    }
956
    return StmtExprMutator::VisitExpr_(op);
957 958
  }

959
  Stmt VisitStmt_(const AllocateNode* op) final {
960
    Stmt stmt = StmtExprMutator::VisitStmt_(op);
961
    op = stmt.as<AllocateNode>();
962 963 964
    const auto& tvec = acc_map_[op->buffer_var.get()];

    if (tvec.size() == 1 &&
965 966 967 968
        tvec[0].element_of() == op->dtype.element_of() &&
        tvec[0].lanes() % op->dtype.lanes() == 0 &&
        tvec[0].lanes() != op->dtype.lanes()) {
      int factor = tvec[0].lanes() / op->dtype.lanes();
969
      Array<PrimExpr> extents = op->extents;
970 971
      arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]);
      if (me->base % factor == 0 && me->coeff % factor == 0) {
972
        extents.Set(extents.size() - 1,
973
                    extents[extents.size() - 1] / make_const(extents[0].dtype(), factor));
974
        return AllocateNode::make(
975 976 977 978 979 980 981
            op->buffer_var, tvec[0], extents,
            op->condition, op->body);
      }
    }
    return stmt;
  }

982
  void UpdateTypeMap(const VarNode* buffer, DataType t) {
983 984 985 986 987
    auto& tvec = acc_map_[buffer];
    if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) {
      tvec.push_back(t);
    }
  }
988

989
  // Internal access map
990
  std::unordered_map<const VarNode*, std::vector<DataType> > acc_map_;
991 992
  // internal analyzer
  arith::Analyzer analyzer_;
993 994 995
};


996
LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
997
  auto n = make_object<LoweredFuncNode>(*f.operator->());
998
  VectorAllocRewriter rewriter;
999
  n->body = rewriter(n->body);
1000
  for (Var arg : f->args) {
1001
    if (arg.dtype().is_handle()) {
1002 1003
      const auto& tvec = rewriter.acc_map_[arg.get()];
      if (tvec.size() == 1) {
1004
        PrimExpr dtype = make_const(tvec[0], 0);
1005 1006 1007 1008 1009
        n->handle_data_type.Set(arg, dtype);
      } else {
        // always set data type to be non vectorized so
        // load/store can still work via scalarization
        if (tvec.size() != 0 && !n->handle_data_type.count(arg)) {
1010
          PrimExpr dtype = make_const(tvec[0].with_lanes(1), 0);
1011 1012 1013 1014 1015 1016 1017 1018
          n->handle_data_type.Set(arg, dtype);
        }
      }
    }
  }
  return LoweredFunc(n);
}

1019
Stmt StorageRewrite(Stmt stmt) {
1020 1021
  stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true);
  return VectorAllocRewriter()(std::move(stmt));
1022 1023 1024
}
}  // namespace ir
}  // namespace tvm