buffer.cc 14.5 KB
Newer Older
1 2 3 4 5
/*!
 *  Copyright (c) 2016 by Contributors
 * \file buffer.cc
 */
#include <tvm/buffer.h>
6
#include <tvm/runtime/device_api.h>
7
#include <tvm/ir.h>
8
#include <tvm/ir_pass.h>
9
#include <iterator>
10
#include "../arithmetic/compute_expr.h"
11 12 13

namespace tvm {

14 15 16 17 18 19 20
Array<Expr> SimplifyArray(Array<Expr> array) {
  for (size_t i = 0; i < array.size(); ++i) {
    array.Set(i, ir::Simplify(array[i]));
  }
  return array;
}

21 22 23 24 25 26 27 28 29
Buffer decl_buffer(Array<Expr> shape,
                   Type dtype,
                   std::string name) {
  return BufferNode::make(
      Var(name, Handle()),
      dtype,
      shape,
      Array<Expr>(),
      Expr(),
30 31 32
      name,
      "",
      0, 0);
33 34
}

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
// Split the given expression w.r.t the add operator
inline std::vector<const Expr*> ExprSplitAddition(const Expr &expr) {
  using namespace ir;
  std::vector<const Expr*> ret;
  std::stack<const Expr*> split_buffer;
  split_buffer.push(&expr);
  while (!split_buffer.empty()) {
    const Expr* top_ele = split_buffer.top();
    split_buffer.pop();
    auto expr_add_match = top_ele->as<Add>();
    if (expr_add_match) {
      split_buffer.push(&expr_add_match->b);
      split_buffer.push(&expr_add_match->a);
    } else {
      ret.emplace_back(top_ele);
    }
  }
  return ret;
}


// Searches for the following types of expr:
//   mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
//   mod_l_expr = c
//   mod_r_expr = k1 * k2 * ... * ki
// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c)
// Currently the we will not search the add/mult combinations exhaustively
//   as it will take too much computation.
inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
                                              const Expr &mod_l_expr,
                                              const Expr &mod_r_expr) {
  using namespace ir;
  const Mul* mult_ptr = mult_expr.as<Mul>();
  if (!mult_ptr) return std::make_pair(false, Expr());
  Expr mult_outer = mult_ptr->b;
  const Expr* inner = &(mult_ptr->a);
  // 1. Calculate the outer multiplier
  while (true) {
    mult_ptr = inner->as<Mul>();
    if (mult_ptr) {
      inner = &(mult_ptr->a);
      mult_outer = mult_ptr->b * mult_outer;
    } else {
      break;
    }
  }
  // 2. Search for the pattern c / (...) * (...) + c % (...)
  // We match the search element with Add, Mul and Div.
  //   If Add is found, we need to continue our search for the rhs
  //   If Mult is found, we will expand the inner multiplication factor
  //   If Div is found, we will go on testing whether lhs matches the lhs of mod expr
  //      and returns the optimization result.
  const Expr* search_ptr = inner;
  Expr mult_inner;  // The inner multiplication factor
  Expr no_opt_sum;  // Sum of the exprs that cannot be optimized
  while (true) {
    auto inner_div_ptr = search_ptr->as<Div>();
    auto inner_mult_ptr = search_ptr->as<Mul>();
    auto inner_add_ptr = search_ptr->as<Add>();
    if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
      return std::make_pair(false, Expr());
    } else if (inner_div_ptr) {
      Expr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer;
      if (Equal(overall_mult, inner_div_ptr->b)
          && Equal(overall_mult, mod_r_expr)
          && Equal(inner_div_ptr->a, mod_l_expr)) {
        // Found!
        Expr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr;
        return std::make_pair(true, ret);
      } else {
        return std::make_pair(false, Expr());
      }
    } else if (inner_mult_ptr) {
      mult_inner = mult_inner.get() ? inner_mult_ptr->b * mult_inner : inner_mult_ptr->b;
      search_ptr = &(inner_mult_ptr->a);
    } else if (inner_add_ptr) {
      if (mult_inner.get()) {
        return std::make_pair(false, Expr());
      }
      no_opt_sum = no_opt_sum.get() ? no_opt_sum + inner_add_ptr->a : inner_add_ptr->a;
      search_ptr = &(inner_add_ptr->b);
    } else {
      LOG(FATAL) << "Unexpected search result!";
      break;
    }
  }
  return std::make_pair(false, Expr());
}

// Insert the elements into the corresponding mult_exprs and mod_exprs.
// If the element is found to match Mul, it will be pushed to the mult_exprs.
// If the element it found to match Mod, it will be pused to the mod_exprs.
// Otherwise, the elements will be added to the no_opt_sum variable
inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles,
                                      std::list<Expr>* mult_exprs,
                                      std::list<std::pair<Expr, Expr> >* mod_exprs,
                                      Expr* no_opt_sum,
                                      bool* has_mult,
                                      bool* has_mod) {
  using namespace ir;
  *has_mult = false;
  *has_mod = false;
  for (const Expr* ele : eles) {
    auto mod_ptr = ele->as<Mod>();
    auto mult_ptr = ele->as<Mul>();
    if (mod_ptr) {
      *has_mod = true;
      mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b)));
    } else if (mult_ptr) {
      *has_mult = true;
      mult_exprs->emplace_back(*ele);
    } else {
      *no_opt_sum = no_opt_sum->get() ? *no_opt_sum + *ele : *ele;
    }
  }
}

// Searches for this types of expr:
//   (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
//   + c % (k1 * k2 * ... * ki)
// and simplifies to (a1 + a2 + ... + aj) * kt * ... * ki + c
// The search will be performed repeatively until no pattern is found.
// Return: a pair with (false, Expr()) if cannot be optimized.
//         a pair with (true, optimized_expr) if can be optimized
inline Expr MergeMulMod(const Expr &base) {
  using namespace ir;
  // 1. Prepare the lists.
  // We store two lists, a list that contain all the elements that match Mul and
  //                     a list that contain all the elements that match Mod.
  // The elements in the Mod will be used to match against the elements in Mul.
  // The result will then be split and pushed back to these two lists.
  Expr simplified_base = Simplify(base);
  std::vector<const Expr*> eles = ExprSplitAddition(simplified_base);
  std::list<Expr> mult_exprs;
  std::list<std::pair<Expr, Expr> > mod_exprs;
  Expr no_opt_sum;
  bool has_mult;
  bool has_mod;
  MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs,
                            &no_opt_sum, &has_mult, &has_mod);
  bool find_opt = false;
  std::list<std::pair<Expr, Expr> >::iterator search_mod_it = mod_exprs.begin();
  // 2. Exhaustive Search
  while (search_mod_it != mod_exprs.end()) {
    std::list<Expr>::iterator mult_it = mult_exprs.begin();
    bool inner_find_opt = false;
    while (mult_it != mult_exprs.end()) {
      std::pair<bool, Expr> ret = MergeMulModInner(*mult_it,
                                                   search_mod_it->first,
                                                   search_mod_it->second);
      if (ret.first) {
        inner_find_opt = true;
        auto temp_mod_it = search_mod_it;
        ++search_mod_it;
        mod_exprs.erase(temp_mod_it);
        mult_exprs.erase(mult_it);
        std::vector<const Expr*> ret_eles = ExprSplitAddition(ret.second);
        MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs,
                                  &no_opt_sum, &has_mult, &has_mod);
        if (has_mult) {
          search_mod_it = mod_exprs.begin();
        } else if (has_mod && search_mod_it == mod_exprs.end()) {
          search_mod_it--;
        }
        break;
      } else {
        ++mult_it;
      }
    }
    find_opt = find_opt || inner_find_opt;
    if (!inner_find_opt) {
      ++search_mod_it;
    }
  }
  if (!find_opt) {
    return simplified_base;
  }
  for (std::list<Expr>::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) {
    no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it;
  }
  for (std::list<std::pair<Expr, Expr> >::iterator it = mod_exprs.begin();
                                                   it != mod_exprs.end(); ++it) {
    no_opt_sum = no_opt_sum.get() ? no_opt_sum + it->first % it->second : it->first % it->second;
  }
  return no_opt_sum;
}

222 223
// The buffer offset in convention of number of elements of
// original data ignoring number of lanes.
224
// We also perform optimization to simplify the indexing expression.
225 226
inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
  Expr base = n->elem_offset;
227 228
  if (n->strides.size() == 0) {
    CHECK_EQ(n->shape.size(), index.size());
229 230 231 232
    if (index.size() > 0) {
      Expr offset = index[0];
      for (size_t i = 1; i < index.size(); ++i) {
        offset = MergeMulMod(offset * n->shape[i] + index[i]);
233
      }
234
      base = base + offset;
235 236 237
    }
  } else {
    CHECK_EQ(n->strides.size(), index.size());
238
    if (is_zero(base)) {
239
      base = MergeMulMod(index[0] * n->strides[0]);
240
    } else {
241
      base = MergeMulMod(base + index[0] * n->strides[0]);
242
    }
243
    for (size_t i = 1; i < index.size(); ++i) {
244
      base = MergeMulMod(base + index[i] * n->strides[i]);
245 246 247 248 249
    }
  }
  return base;
}

250
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, Type dtype) {
251 252
  Expr offset = ElemOffset(n, index);
  if (n->dtype.lanes() != 1) {
253 254 255 256 257 258
    offset = offset * make_const(offset.type(), dtype.lanes());
  }
  if (dtype.lanes() != 1) {
    return ir::Ramp::make(offset, make_const(offset.type(), 1), dtype.lanes());
  } else {
    return offset;
259 260 261
  }
}

262
Expr Buffer::vload(Array<Expr> begin, Type dtype) const {
263
  // specially handle bool, stored as Int(8)
264
  const BufferNode* n = operator->();
265 266 267 268
  CHECK(dtype.element_of() == n->dtype.element_of() &&
        dtype.lanes() % n->dtype.lanes() == 0)
      << "Cannot load " << dtype
      << " from buffer of " << n->dtype;
269 270 271 272 273 274 275 276 277 278 279
  if (dtype == Bool()) {
    return ir::Cast::make(
        Bool(),
        ir::Load::make(
            Int(8), n->data, BufferOffset(n, begin, Int(8)),
            const_true()));
  } else {
    return ir::Load::make(
        dtype, n->data, BufferOffset(n, begin, dtype),
        const_true(dtype.lanes()));
  }
280 281
}

282
Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
283
  // specially handle bool, stored as Int(8)
284
  const BufferNode* n = operator->();
285 286 287 288 289
  Type dtype = value.type();
  CHECK(dtype.element_of() == n->dtype.element_of() &&
        dtype.lanes() % n->dtype.lanes() == 0)
      << "Cannot load " << dtype
      << " from buffer of " << n->dtype;
290 291 292 293 294 295 296 297 298
  if (value.type() == Bool()) {
    return ir::Store::make(n->data,
                           ir::Cast::make(Int(8), value),
                           BufferOffset(n, begin, Int(8)),
                           const_true());
  } else {
    return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype),
                           const_true(dtype.lanes()));
  }
299 300
}

301 302
Buffer Buffer::MakeStrideView() const {
  if ((*this)->strides.size() != 0) return *this;
303
  if ((*this)->shape.size() == 0) return *this;
304
  std::vector<Expr> temp;
305
  auto n = make_node<BufferNode>(*operator->());
306
  Expr acc = make_const(n->DefaultIndexType(), 1);
307 308 309 310 311 312 313 314 315 316 317 318
  for (size_t i = n->shape.size(); i != 0 ; --i) {
    temp.push_back(acc);
    acc = acc * n->shape[i - 1];
  }
  for (size_t i = temp.size(); i != 0; --i) {
    n->strides.push_back(temp[i - 1]);
  }
  return Buffer(n);
}

Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
  const BufferNode* n = operator->();
319
  begins = SimplifyArray(begins);
320
  Expr elem_offset = ir::Simplify(ElemOffset(n, begins));
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
  Array<Expr> strides = n->strides;
  if (strides.size() == 0) {
    bool can_relax = true;
    bool need_stride = false;
    // check if stride is needed.
    for (size_t i = 0; i < extents.size(); ++i) {
      if (!can_relax) {
        if (!is_zero(begins[i]) ||
            !is_zero(ir::Simplify(extents[i] - n->shape[i]))) {
          need_stride = true;
        }
      }
      if (!is_one(extents[i])) can_relax = false;
    }
    // make stride.
    if (need_stride) {
      return MakeStrideView().MakeSlice(begins, extents);
    }
  }
  return BufferNode::make(n->data,
                          n->dtype,
                          extents,
                          strides,
                          elem_offset,
                          n->name + "_slice",
                          n->scope,
347
                          n->data_alignment,
348 349 350
                          0);
}

351
Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const {
352
  const BufferNode* self = operator->();
353
  Expr e_dtype;
354 355 356 357 358 359
  Expr extent;
  if (self->shape.size() == 0) {
    extent = make_const(self->DefaultIndexType(), 1);
  } else if (self->strides.size() == self->shape.size()) {
    int highest_dim = 0;
    extent = arith::ComputeExpr<ir::Mul>(
360
        self->strides[highest_dim], self->shape[highest_dim]) - offset;
361
  } else {
362
    extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()) - offset;
363
  }
364
  Expr elem_offset = self->elem_offset + offset;
365
  if (content_lanes > 1) {
366
    e_dtype = ir::TypeAnnotation(self->dtype.with_lanes(content_lanes));
367 368 369 370
    extent = extent / make_const(self->elem_offset.type(), content_lanes);
    elem_offset = self->elem_offset / make_const(self->elem_offset.type(),
                                                 content_lanes);
  } else {
371
    e_dtype = ir::TypeAnnotation(self->dtype);
372
  }
373
  Array<Expr> acc_args{
374
    e_dtype, self->data, elem_offset,
375 376 377 378 379
        extent, make_const(Int(32), access_mask)};
  return ir::Call::make(
      ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic);
}

380 381
Buffer BufferNode::make(Var data,
                        Type dtype,
382 383
                        Array<Expr> shape,
                        Array<Expr> strides,
384
                        Expr elem_offset,
385 386
                        std::string name,
                        std::string scope,
387 388
                        int data_alignment,
                        int offset_factor) {
389
  auto n = make_node<BufferNode>();
390
  n->data = std::move(data);
391
  n->dtype = dtype;
392 393 394
  n->shape = std::move(shape);
  n->strides = std::move(strides);
  n->name = std::move(name);
395 396 397
  if (scope.length() == 0) {
    scope = "global";
  }
398
  n->scope = std::move(scope);
399
  if (!elem_offset.defined()) {
400
    elem_offset = make_const(n->DefaultIndexType(), 0);
401
  }
402
  if (data_alignment <= 0) {
403 404 405 406
    data_alignment = runtime::kAllocAlignment;
  }
  if (offset_factor == 0) {
    offset_factor = 1;
407
  }
408
  n->elem_offset = std::move(elem_offset);
409 410
  n->data_alignment = data_alignment;
  n->offset_factor = offset_factor;
411 412 413 414 415 416 417 418 419 420 421
  return Buffer(n);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BufferNode>([](const BufferNode *op, IRPrinter *p) {
    p->stream << "buffer(" << op->name << ", " << op << ")";
});

TVM_REGISTER_NODE_TYPE(BufferNode);

}  // namespace tvm