storage_flatten.cc 19 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 *  Copyright (c) 2016 by Contributors
 * \file storage_flatten.cc
 */
24 25
// Flattens storage from multi-dimensional array to 1D
// buffer access as in Halide pipeline.
26
#include <tvm/arithmetic.h>
27
#include <tvm/ir.h>
28
#include <tvm/expr.h>
29
#include <tvm/operation.h>
30
#include <tvm/ir_mutator.h>
31
#include <tvm/ir_visitor.h>
32
#include <tvm/expr_operator.h>
33
#include <tvm/ir_pass.h>
34
#include <tvm/buffer.h>
35
#include <tvm/target_info.h>
36
#include <tvm/runtime/device_api.h>
37
#include <unordered_map>
38 39
#include "ir_util.h"
#include "arg_binder.h"
40
#include "../arithmetic/compute_expr.h"
41
#include "../arithmetic/ir_visitor_with_analyzer.h"
42
#include "../runtime/thread_storage_scope.h"
43 44 45 46

namespace tvm {
namespace ir {

47
using runtime::StorageRank;
48 49
using runtime::StorageScope;
using runtime::ThreadScope;
50
using intrinsic::tvm_address_of;
51 52 53

class StorageFlattener : public IRMutator {
 public:
54
  explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer,
55 56 57 58
                            int cache_line_size, bool create_bound_attributes,
                            IRVisitorWithAnalyzer* bounded_analyzer)
      : bounded_analyzer_(bounded_analyzer),
        create_bound_attributes_(create_bound_attributes) {
59 60 61 62 63 64
    for (auto kv : extern_buffer) {
      BufferEntry e;
      e.buffer = kv.second;
      e.external = true;
      buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e;
    }
65
    cache_line_size_ = cache_line_size;
66
  }
67

68 69 70
  Stmt Mutate_(const Store* op, const Stmt& s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Store>();
71 72 73 74 75 76
    auto it = var_remap_.find(op->buffer_var.get());
    if (it != var_remap_.end() &&
        !it->second.same_as(op->buffer_var)) {
      CHECK(it->second.as<Variable>());
      VarExpr buf_var(it->second.node_);
      return Store::make(buf_var, op->value, op->index, op->predicate);
77 78 79 80
    } else {
      return stmt;
    }
  }
81 82

  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
83
    if (op->attr_key == attr::realize_scope) {
84 85
      storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
      return this->Mutate(op->body);
86 87
    } else if (op->attr_key == attr::double_buffer_scope &&
               op->node.node_->derived_from<OperationNode>()) {
88 89 90 91 92 93 94 95 96 97 98
      Operation func(op->node.node_);
      Stmt body = Mutate(op->body);
      for (int i = 0; i < func->num_outputs(); ++i) {
        TensorKey key{func, i};
        auto it = buf_map_.find(key);
        CHECK(it != buf_map_.end())
            << "Cannot find allocated buffer for " << key.f;
        body = AttrStmt::make(
            it->second.buffer->data, op->attr_key, op->value, body);
      }
      return body;
99
    } else if (op->attr_key == attr::thread_extent) {
100
      IterVar iv(op->node.node_);
101 102 103 104 105
      ThreadScope ts = ThreadScope::make(iv->thread_tag);
      curr_thread_scope_.push_back(ts);
      Stmt stmt = IRMutator::Mutate_(op, s);
      curr_thread_scope_.pop_back();
      return stmt;
106 107
    } else if (op->attr_key == attr::buffer_bind_scope) {
      return HandleBufferBindScope(op);
108 109 110 111 112 113 114 115 116 117 118 119 120
    } else if (op->attr_key == attr::buffer_dim_align) {
      Tensor tensor(op->node.node_);
      const Call* tuple = op->value.as<Call>();
      CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
      TensorKey key{tensor->op, tensor->value_index};
      auto& vinfo = dim_align_[key];
      int dim = tuple->args[0].as<IntImm>()->value;
      if (static_cast<size_t>(dim) >= vinfo.size()) {
        vinfo.resize(dim + 1);
      }
      vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
      vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
      return this->Mutate(op->body);
121 122
    } else if (op->attr_key == attr::opengl_stage_scope) {
      is_opengl_ = true;
123 124 125 126 127
    }
    return IRMutator::Mutate_(op, s);
  }

  Stmt Mutate_(const Provide* op, const Stmt& s) final {
128 129
    if (create_bound_attributes_)
      shape_collector_.clear();
130 131 132 133 134 135 136 137 138
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Provide>();
    TensorKey key{op->func, op->value_index};
    auto it = buf_map_.find(key);
    CHECK(it != buf_map_.end())
        << "Cannot find allocated buffer for " << key.f;
    const BufferEntry& e = it->second;
    CHECK(!e.released)
        << "Read a buffer that is already out of scope";
139 140 141 142 143 144 145
    if (is_opengl_) {
      return Evaluate::make(Call::make(
          Type(),
          Call::glsl_texture_store,
          {e.buffer->data, op->value},
          Call::Intrinsic));
    } else {
146 147 148 149 150 151 152 153 154 155 156 157 158 159
      Stmt body = e.buffer.vstore(e.RelIndex(op->args), op->value);
      if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
        shape_collector_.push_back(
            std::make_pair(e.buffer->data, e.buffer->shape));
      }
      // To create bound attribute collector should has at least one item.
      if (create_bound_attributes_ && shape_collector_.size()) {
        for (size_t i = 0; i < shape_collector_.size(); ++i) {
          body = AttrStmt::make(
              shape_collector_[i].first, ir::attr::buffer_bound,
              MakeBound(e.buffer->dtype, shape_collector_[i].second), body);
        }
      }
      return body;
160
    }
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
  }

  Stmt Mutate_(const Realize* op, const Stmt& s) final {
    TensorKey key{op->func, op->value_index};
    if (buf_map_.count(key)) {
      CHECK(buf_map_.at(key).external);
      return this->Mutate(op->body);
    } else {
      // create a buffer entry
      BufferEntry e;
      e.bounds = op->bounds;
      Array<Expr> shape;
      for (auto r : e.bounds) {
        shape.push_back(r->extent);
      }
      // deduce current storage scope.
      auto it = storage_scope_.find(op->func.get());
178 179 180
      CHECK(it != storage_scope_.end())
          << "Cannot find storage scope of " << op->func
          << " value_index=" << op->value_index;
181 182 183
      StorageScope skey;
      const std::string& strkey = it->second;
      if (strkey.length() == 0) {
184
        if (curr_thread_scope_.size() != 0) {
185 186
          skey.rank = runtime::DefaultStorageRank(
              curr_thread_scope_.back().rank);
187 188
        }
      } else {
189
        skey = StorageScope::make(strkey);
190
      }
191

192
      // use small alignment for small arrays
193
      int32_t const_size = Allocate::constant_allocation_size(shape);
194
      int align = GetTempAllocaAlignment(op->type, const_size);
195 196 197 198
      if (skey.tag.length() != 0) {
        MemoryInfo info = GetMemoryInfo(skey.to_string());
        if (info.defined()) {
          align = (info->max_simd_bits + op->type.bits() - 1) / op->type.bits();
199 200
          CHECK_LE(const_size * op->type.bits(), info->max_num_bits)
              << "Allocation exceed bound of memory tag " << skey.to_string();
201 202
        }
      }
203
      Array<Expr> strides;
204
      if (dim_align_.count(key) != 0 && shape.size() != 0) {
205 206
        std::vector<Expr> rstrides;
        const std::vector<DimAlignInfo>& avec = dim_align_[key];
207 208
        int first_dim = 0;
        Expr stride = make_const(shape[first_dim].type(), 1);
209 210 211 212 213
        for (size_t i = shape.size(); i != 0; --i) {
          size_t dim = i - 1;
          if (dim < avec.size() && avec[dim].align_factor != 0) {
            Expr factor = make_const(stride.type(), avec[dim].align_factor);
            Expr offset = make_const(stride.type(), avec[dim].align_offset);
214
            stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
215 216 217
            stride = ir::Simplify(stride);
          }
          rstrides.push_back(stride);
218
          stride = stride * shape[dim];
219 220 221
        }
        strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
      }
222

223 224
      e.buffer = BufferNode::make(
          Var(key.GetName(), Handle()),
225
          op->type, shape, strides, Expr(),
226
          key.GetName(), skey.to_string(),
227
          align, 0, kDefault);
228

229 230 231
      buf_map_[key] = e;
      Stmt body = this->Mutate(op->body);
      buf_map_[key].released = true;
232
      Stmt ret;
233

234 235 236 237 238 239
      Type storage_type = e.buffer->dtype;
      // specially handle bool, lower its storage
      // type to be Int(8)(byte)
      if (storage_type == Bool()) {
        storage_type = Int(8);
      }
240
      if (strides.size() != 0) {
241
        int first_dim = 0;
242
        ret = Allocate::make(
243
            e.buffer->data, storage_type,
244
            {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
245 246
            make_const(Bool(e.buffer->dtype.lanes()), true), body);
      } else {
247 248 249 250
        shape = e.buffer->shape;
        if (shape.size() == 0) {
          shape.push_back(make_const(Int(32), 1));
        }
251
        ret = Allocate::make(
252
            e.buffer->data, storage_type, shape,
253 254
            make_const(Bool(e.buffer->dtype.lanes()), true), body);
      }
255 256
      ret = AttrStmt::make(
          e.buffer->data, attr::storage_scope,
257
          StringImm::make(e.buffer->scope), ret);
258 259 260 261 262

      if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
        ret = AttrStmt::make(e.buffer->data, ir::attr::buffer_bound,
                             MakeBound(e.buffer->dtype, e.buffer->shape), ret);
      }
263
      return ret;
264 265 266
    }
  }

267 268 269
  Expr Mutate_(const Load* op, const Expr& e) final {
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Load>();
270 271 272 273 274 275
    auto it = var_remap_.find(op->buffer_var.get());
    if (it != var_remap_.end() &&
        !it->second.same_as(op->buffer_var)) {
      CHECK(it->second.as<Variable>());
      VarExpr buf_var(it->second.node_);
      return Load::make(op->type, buf_var, op->index, op->predicate);
276 277 278 279 280 281
    } else {
      return expr;
    }
  }

  Expr Mutate_(const Variable* op, const Expr& e) final {
282 283
    auto it = var_remap_.find(op);
    if (it != var_remap_.end()) {
284 285 286 287 288 289
      return it->second;
    } else {
      return e;
    }
  }

290 291 292
  Expr Mutate_(const Call* op, const Expr& olde) final {
    Expr expr = IRMutator::Mutate_(op, olde);
    op = expr.as<Call>();
293 294 295 296 297 298 299 300
    if (op != nullptr && op->call_type == Call::Halide) {
      TensorKey key{op->func, op->value_index};
      auto it = buf_map_.find(key);
      CHECK(it != buf_map_.end())
          << "Cannot find allocated buffer for " << key.f;
      const BufferEntry& e = it->second;
      CHECK(!e.released)
          << "Read a buffer that is already out of scope";
301 302 303 304 305

      if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
        shape_collector_.push_back(
            std::make_pair(e.buffer->data, e.buffer->shape));
      }
306
      return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype);
307 308 309 310 311
    } else {
      return expr;
    }
  }

312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
  Stmt Mutate_(const Prefetch *op, const Stmt &s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Prefetch>();
    CHECK(op != nullptr);
    TensorKey key{op->func, op->value_index};
    auto it = buf_map_.find(key);
    CHECK(it != buf_map_.end())
        << "Cannot find allocated buffer for " << key.f;
    const BufferEntry& e = it->second;

    CHECK(!e.released)
        << "Read a buffer that is already out of scope";
    CHECK_EQ(e.buffer->shape.size(), op->bounds.size())
      << "Prefetch dim should be the same as buffer dim";

    int block_size = 1,
        elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(),
        shape = 0;

    int starts = op->bounds.size() - 1;
    while (starts > 0 && arith::GetConstInt(e.buffer->shape[starts], &shape)
        && elem_cnt >= block_size * shape) {
      block_size *= shape;
      starts--;
    }
    Expr stride(elem_cnt / block_size);

    Array<Expr> args;
    std::vector<VarExpr> vars;

    for (int i = op->bounds.size() - 1; i > starts; --i) {
      args.push_back(op->bounds[i]->min);
    }
    auto &func_name = op->func->func_name();
    vars.push_back(VarExpr("prefetch." + func_name + "." + std::to_string(starts), Int(32)));
    args.push_back(op->bounds[starts]->min + stride * vars.back());
    for (int i = starts - 1; i >= 0; --i) {
      vars.push_back(VarExpr("prefetch." + func_name + "." + std::to_string(i), Int(32)));
      args.push_back(vars.back() + op->bounds[i]->min);
    }
    for (int i = starts; i >= 0; --i) {
      if (i < starts) {
        stmt = For::make(
355
            vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
356
      } else {
357
        Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
358 359 360 361
        Expr address = Call::make(Handle(), tvm_address_of, {load}, Call::PureIntrinsic);
        Expr prefetch = Call::make(op->type, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic);
        stmt = Evaluate::make(prefetch);
        Expr extent = (op->bounds[i]->extent - 1) / stride + 1;
362
        stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
363 364 365 366 367
      }
    }
    return stmt;
  }

368
 private:
369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402
  // The specific tensor data layout is not determined before
  // StorageFlatten pass. We use buffer_bind_scope
  // to specify before hand we want to bind a subregion
  // of tensor to a symbolic buffer, which get used in extern.
  //
  // Example:
  //
  // realize A in range [i*4, extent=10) {
  //   bind Ab to A in [i*4+1, extent=4) {
  //     call_func(Ab.ptr, Ab.shape[0])
  //   }
  // }
  //
  // After StorageFlatten
  //
  // alloc A[10]
  //   call(A + 1,  4)
  //
  // Buffer is a protocol to declare specific
  // data layout and shape we expect.
  // So this function need to check:
  // - If the bind range is within the realize range
  // - If we can match the requirement of buffer
  // - Remap variables such as Ab.ptr to the actual value.
  //
  // Here are a few possible failure cases:
  // - Buffer is declared to have constant shape,
  //   but we try to bind it to a different one.
  // - Buffer is declared to be compact(no strides)
  //   but this binded region is a subregion of
  //   a matrix(tensor), which means it requires strides.
  //
  // We do support a few relaxed case, such as bindingx
  // region with shape [1, 1, n, m] to buffer with shape [n, m]
403 404 405 406 407 408 409 410 411
  Stmt HandleBufferBindScope(const AttrStmt* op) {
    Array<NodeRef> arr(op->node.node_);
    CHECK_EQ(arr.size(), 2U);
    const BufferNode* buffer = arr[0].as<BufferNode>();
    const TensorNode* tensor = arr[1].as<TensorNode>();
    const Call* tuple = op->value.as<Call>();
    CHECK(buffer && tensor);
    CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
    TensorKey key{tensor->op, tensor->value_index};
412 413
    CHECK(buf_map_.count(key))
        << "Cannot find buffer of " << tensor->op << " value=" << tensor->value_index;
414 415 416 417 418 419 420
    const BufferEntry& be = buf_map_.at(key);
    CHECK(!be.released);
    CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
    Array<Expr> begins, extents;
    if (be.bounds.size() != 0) {
      CHECK_EQ(tuple->args.size(), be.bounds.size() * 2);
      for (size_t i = 0; i < be.buffer->shape.size(); ++i) {
421
        begins.push_back(tuple->args[2 * i] - be.bounds[i]->min);
422 423 424 425 426
        extents.push_back(tuple->args[2 * i + 1]);
      }
    } else {
      for (size_t i = 0; i < tuple->args.size(); i += 2) {
        begins.push_back(tuple->args[i]);
427 428
        auto new_extent = bounded_analyzer_->Simplify(tuple->args[i+1]);
        extents.push_back(new_extent);
429 430 431 432 433
      }
    }
    Buffer slice = be.buffer.MakeSlice(begins, extents);
    if (buffer->strides.size() == 0) {
      CHECK_EQ(slice->strides.size(), 0U)
434 435
          << "Trying to bind compact buffer to strided one strides="
          << slice->strides;
436 437 438 439
    } else {
      slice = slice.MakeStrideView();
    }
    // start binding
440
    ArgBinder binder(&var_remap_);
441
    binder.BindBuffer(Buffer(arr[0].node_), slice, buffer->name, true);
442
    // Apply the remaps
443 444 445
    Stmt body = MergeNest(binder.asserts(), op->body);
    body = MergeNest(binder.init_nest(), body);
    body = this->Mutate(body);
446
    // remove the binds
447 448
    for (const Var& v : binder.defs()) {
      var_remap_.erase(v.get());
449
    }
450
    return body;
451
  }
452 453 454 455 456
  // The buffer entry in the flatten map
  struct DimAlignInfo {
    int align_factor{0};
    int align_offset{0};
  };
457 458 459 460
  // The buffer entry in the flatten map
  struct BufferEntry {
    // the buffer of storage
    Buffer buffer;
461
    // the bounds of realization, can be null, means everything
462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
    Region bounds;
    // Whether the buffer is external
    bool external{false};
    // Whether we are out of allocation bounds and buffer get released.
    bool released{false};
    // relative index
    inline Array<Expr> RelIndex(Array<Expr> args) const {
      if (bounds.size() != 0) {
        Array<Expr> index;
        CHECK_EQ(bounds.size(), args.size());
        for (size_t i = 0; i < bounds.size(); ++i) {
          index.push_back(args[i] - bounds[i]->min);
        }
        return index;
      } else {
        return args;
      }
    }
  };
481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505

  bool ShapeIsValid(const Array<Expr> &shape) {
    // Zero-dimensional tensor does not need boundary check.
    if (!shape.size())
      return false;

    for (size_t i = 0; i < shape.size(); ++i) {
      if (!shape[i].defined() || !shape[i].type().is_scalar() ||
          is_negative_const(shape[i])) {
        return false;
      }
    }
    return true;
  }

  Expr MakeBound(const Type &type, const Array<Expr> &shape) {
    // We have already checked the shape size to be greater then 0.
    Expr bound = Mul::make(make_const(shape[0].type(), type.lanes()), shape[0]);
    for (size_t i = 1; i < shape.size(); ++i) {
      bound = Mul::make(
          bound, Mul::make(make_const(bound.type(), type.lanes()), shape[i]));
    }
    return bound;
  }

506
  // The buffer assignment map
507 508 509
  // Variable remap
  std::unordered_map<const Variable*, Expr> var_remap_;
  // Buffer map
510
  std::unordered_map<TensorKey, BufferEntry> buf_map_;
511 512 513
  // Dimension alignment
  std::unordered_map<TensorKey, std::vector<DimAlignInfo> > dim_align_;
  // Storage scope
514 515 516
  std::unordered_map<const Node*, std::string> storage_scope_;
  // The current thread scope.
  std::vector<ThreadScope> curr_thread_scope_;
517 518
  // Collects shapes.
  std::vector<std::pair<VarExpr, Array<Expr>>> shape_collector_;
519 520 521
  // bounds populator. We really need the analyzer from it.
  // However
  IRVisitorWithAnalyzer* bounded_analyzer_;
522 523
  // The size of cacheline
  int cache_line_size_;
524 525
  // The current stage is an OpenGL shader.
  bool is_opengl_{false};
526 527
  // Whether to mark load/store with theirs bounds.
  bool create_bound_attributes_{false};
528 529
};

530 531
Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer,
                    int cache_line_size, bool create_bound_attributes) {
532 533
  IRVisitorWithAnalyzer bounded_analyzer;
  bounded_analyzer.Visit(stmt);
534
  stmt =
535 536
      StorageFlattener(extern_buffer, cache_line_size,
          create_bound_attributes, &bounded_analyzer).Mutate(stmt);
537 538 539 540 541
  return stmt;
}

}  // namespace ir
}  // namespace tvm