buffer.cc 15.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 *  Copyright (c) 2016 by Contributors
 * \file buffer.cc
 */
#include <tvm/buffer.h>
25
#include <tvm/runtime/device_api.h>
26
#include <tvm/ir.h>
27
#include <tvm/ir_pass.h>
28
#include <iterator>
29
#include "../arithmetic/compute_expr.h"
30 31 32

namespace tvm {

33 34 35 36 37 38 39
Array<Expr> SimplifyArray(Array<Expr> array) {
  for (size_t i = 0; i < array.size(); ++i) {
    array.Set(i, ir::Simplify(array[i]));
  }
  return array;
}

40 41 42 43 44 45 46 47 48
Buffer decl_buffer(Array<Expr> shape,
                   Type dtype,
                   std::string name) {
  return BufferNode::make(
      Var(name, Handle()),
      dtype,
      shape,
      Array<Expr>(),
      Expr(),
49 50 51
      name,
      "",
      0, 0);
52 53
}

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
// Split the given expression w.r.t the add operator
inline std::vector<const Expr*> ExprSplitAddition(const Expr &expr) {
  using namespace ir;
  std::vector<const Expr*> ret;
  std::stack<const Expr*> split_buffer;
  split_buffer.push(&expr);
  while (!split_buffer.empty()) {
    const Expr* top_ele = split_buffer.top();
    split_buffer.pop();
    auto expr_add_match = top_ele->as<Add>();
    if (expr_add_match) {
      split_buffer.push(&expr_add_match->b);
      split_buffer.push(&expr_add_match->a);
    } else {
      ret.emplace_back(top_ele);
    }
  }
  return ret;
}


// Searches for the following types of expr:
//   mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
//   mod_l_expr = c
//   mod_r_expr = k1 * k2 * ... * ki
// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c)
// Currently the we will not search the add/mult combinations exhaustively
//   as it will take too much computation.
inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
                                              const Expr &mod_l_expr,
                                              const Expr &mod_r_expr) {
  using namespace ir;
  const Mul* mult_ptr = mult_expr.as<Mul>();
  if (!mult_ptr) return std::make_pair(false, Expr());
  Expr mult_outer = mult_ptr->b;
  const Expr* inner = &(mult_ptr->a);
  // 1. Calculate the outer multiplier
  while (true) {
    mult_ptr = inner->as<Mul>();
    if (mult_ptr) {
      inner = &(mult_ptr->a);
      mult_outer = mult_ptr->b * mult_outer;
    } else {
      break;
    }
  }
  // 2. Search for the pattern c / (...) * (...) + c % (...)
  // We match the search element with Add, Mul and Div.
  //   If Add is found, we need to continue our search for the rhs
  //   If Mult is found, we will expand the inner multiplication factor
  //   If Div is found, we will go on testing whether lhs matches the lhs of mod expr
  //      and returns the optimization result.
  const Expr* search_ptr = inner;
  Expr mult_inner;  // The inner multiplication factor
  Expr no_opt_sum;  // Sum of the exprs that cannot be optimized
  while (true) {
    auto inner_div_ptr = search_ptr->as<Div>();
    auto inner_mult_ptr = search_ptr->as<Mul>();
    auto inner_add_ptr = search_ptr->as<Add>();
    if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
      return std::make_pair(false, Expr());
    } else if (inner_div_ptr) {
      Expr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer;
      if (Equal(overall_mult, inner_div_ptr->b)
          && Equal(overall_mult, mod_r_expr)
          && Equal(inner_div_ptr->a, mod_l_expr)) {
        // Found!
        Expr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr;
        return std::make_pair(true, ret);
      } else {
        return std::make_pair(false, Expr());
      }
    } else if (inner_mult_ptr) {
      mult_inner = mult_inner.get() ? inner_mult_ptr->b * mult_inner : inner_mult_ptr->b;
      search_ptr = &(inner_mult_ptr->a);
    } else if (inner_add_ptr) {
      if (mult_inner.get()) {
        return std::make_pair(false, Expr());
      }
      no_opt_sum = no_opt_sum.get() ? no_opt_sum + inner_add_ptr->a : inner_add_ptr->a;
      search_ptr = &(inner_add_ptr->b);
    } else {
      LOG(FATAL) << "Unexpected search result!";
      break;
    }
  }
  return std::make_pair(false, Expr());
}

// Insert the elements into the corresponding mult_exprs and mod_exprs.
// If the element is found to match Mul, it will be pushed to the mult_exprs.
// If the element it found to match Mod, it will be pused to the mod_exprs.
// Otherwise, the elements will be added to the no_opt_sum variable
inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles,
                                      std::list<Expr>* mult_exprs,
                                      std::list<std::pair<Expr, Expr> >* mod_exprs,
                                      Expr* no_opt_sum,
                                      bool* has_mult,
                                      bool* has_mod) {
  using namespace ir;
  *has_mult = false;
  *has_mod = false;
  for (const Expr* ele : eles) {
    auto mod_ptr = ele->as<Mod>();
    auto mult_ptr = ele->as<Mul>();
    if (mod_ptr) {
      *has_mod = true;
      mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b)));
    } else if (mult_ptr) {
      *has_mult = true;
      mult_exprs->emplace_back(*ele);
    } else {
      *no_opt_sum = no_opt_sum->get() ? *no_opt_sum + *ele : *ele;
    }
  }
}

// Searches for this types of expr:
//   (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
//   + c % (k1 * k2 * ... * ki)
// and simplifies to (a1 + a2 + ... + aj) * kt * ... * ki + c
// The search will be performed repeatively until no pattern is found.
// Return: a pair with (false, Expr()) if cannot be optimized.
//         a pair with (true, optimized_expr) if can be optimized
inline Expr MergeMulMod(const Expr &base) {
  using namespace ir;
  // 1. Prepare the lists.
  // We store two lists, a list that contain all the elements that match Mul and
  //                     a list that contain all the elements that match Mod.
  // The elements in the Mod will be used to match against the elements in Mul.
  // The result will then be split and pushed back to these two lists.
  Expr simplified_base = Simplify(base);
  std::vector<const Expr*> eles = ExprSplitAddition(simplified_base);
  std::list<Expr> mult_exprs;
  std::list<std::pair<Expr, Expr> > mod_exprs;
  Expr no_opt_sum;
  bool has_mult;
  bool has_mod;
  MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs,
                            &no_opt_sum, &has_mult, &has_mod);
  bool find_opt = false;
  std::list<std::pair<Expr, Expr> >::iterator search_mod_it = mod_exprs.begin();
  // 2. Exhaustive Search
  while (search_mod_it != mod_exprs.end()) {
    std::list<Expr>::iterator mult_it = mult_exprs.begin();
    bool inner_find_opt = false;
    while (mult_it != mult_exprs.end()) {
      std::pair<bool, Expr> ret = MergeMulModInner(*mult_it,
                                                   search_mod_it->first,
                                                   search_mod_it->second);
      if (ret.first) {
        inner_find_opt = true;
        auto temp_mod_it = search_mod_it;
        ++search_mod_it;
        mod_exprs.erase(temp_mod_it);
        mult_exprs.erase(mult_it);
        std::vector<const Expr*> ret_eles = ExprSplitAddition(ret.second);
        MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs,
                                  &no_opt_sum, &has_mult, &has_mod);
        if (has_mult) {
          search_mod_it = mod_exprs.begin();
        } else if (has_mod && search_mod_it == mod_exprs.end()) {
          search_mod_it--;
        }
        break;
      } else {
        ++mult_it;
      }
    }
    find_opt = find_opt || inner_find_opt;
    if (!inner_find_opt) {
      ++search_mod_it;
    }
  }
  if (!find_opt) {
    return simplified_base;
  }
  for (std::list<Expr>::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) {
    no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it;
  }
  for (std::list<std::pair<Expr, Expr> >::iterator it = mod_exprs.begin();
                                                   it != mod_exprs.end(); ++it) {
    no_opt_sum = no_opt_sum.get() ? no_opt_sum + it->first % it->second : it->first % it->second;
  }
  return no_opt_sum;
}

241 242
// The buffer offset in convention of number of elements of
// original data ignoring number of lanes.
243
// We also perform optimization to simplify the indexing expression.
244 245
inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
  Expr base = n->elem_offset;
246 247
  if (n->strides.size() == 0) {
    CHECK_EQ(n->shape.size(), index.size());
248 249 250 251
    if (index.size() > 0) {
      Expr offset = index[0];
      for (size_t i = 1; i < index.size(); ++i) {
        offset = MergeMulMod(offset * n->shape[i] + index[i]);
252
      }
253
      base = base + offset;
254 255 256
    }
  } else {
    CHECK_EQ(n->strides.size(), index.size());
257
    if (is_zero(base)) {
258
      base = MergeMulMod(index[0] * n->strides[0]);
259
    } else {
260
      base = MergeMulMod(base + index[0] * n->strides[0]);
261
    }
262
    for (size_t i = 1; i < index.size(); ++i) {
263
      base = MergeMulMod(base + index[i] * n->strides[i]);
264 265 266 267 268
    }
  }
  return base;
}

269
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, Type dtype) {
270 271
  Expr offset = ElemOffset(n, index);
  if (n->dtype.lanes() != 1) {
272 273 274 275 276 277
    offset = offset * make_const(offset.type(), dtype.lanes());
  }
  if (dtype.lanes() != 1) {
    return ir::Ramp::make(offset, make_const(offset.type(), 1), dtype.lanes());
  } else {
    return offset;
278 279 280
  }
}

281
Expr Buffer::vload(Array<Expr> begin, Type dtype) const {
282
  // specially handle bool, stored as Int(8)
283
  const BufferNode* n = operator->();
284 285 286 287
  CHECK(dtype.element_of() == n->dtype.element_of() &&
        dtype.lanes() % n->dtype.lanes() == 0)
      << "Cannot load " << dtype
      << " from buffer of " << n->dtype;
288 289 290 291 292 293 294 295 296 297 298
  if (dtype == Bool()) {
    return ir::Cast::make(
        Bool(),
        ir::Load::make(
            Int(8), n->data, BufferOffset(n, begin, Int(8)),
            const_true()));
  } else {
    return ir::Load::make(
        dtype, n->data, BufferOffset(n, begin, dtype),
        const_true(dtype.lanes()));
  }
299 300
}

301
Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
302
  // specially handle bool, stored as Int(8)
303
  const BufferNode* n = operator->();
304 305 306 307 308
  Type dtype = value.type();
  CHECK(dtype.element_of() == n->dtype.element_of() &&
        dtype.lanes() % n->dtype.lanes() == 0)
      << "Cannot load " << dtype
      << " from buffer of " << n->dtype;
309 310 311 312 313 314 315 316 317
  if (value.type() == Bool()) {
    return ir::Store::make(n->data,
                           ir::Cast::make(Int(8), value),
                           BufferOffset(n, begin, Int(8)),
                           const_true());
  } else {
    return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype),
                           const_true(dtype.lanes()));
  }
318 319
}

320 321
Buffer Buffer::MakeStrideView() const {
  if ((*this)->strides.size() != 0) return *this;
322
  if ((*this)->shape.size() == 0) return *this;
323
  std::vector<Expr> temp;
324
  auto n = make_node<BufferNode>(*operator->());
325
  Expr acc = make_const(n->DefaultIndexType(), 1);
326 327 328 329 330 331 332 333 334 335 336 337
  for (size_t i = n->shape.size(); i != 0 ; --i) {
    temp.push_back(acc);
    acc = acc * n->shape[i - 1];
  }
  for (size_t i = temp.size(); i != 0; --i) {
    n->strides.push_back(temp[i - 1]);
  }
  return Buffer(n);
}

Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
  const BufferNode* n = operator->();
338
  begins = SimplifyArray(begins);
339
  Expr elem_offset = ir::Simplify(ElemOffset(n, begins));
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
  Array<Expr> strides = n->strides;
  if (strides.size() == 0) {
    bool can_relax = true;
    bool need_stride = false;
    // check if stride is needed.
    for (size_t i = 0; i < extents.size(); ++i) {
      if (!can_relax) {
        if (!is_zero(begins[i]) ||
            !is_zero(ir::Simplify(extents[i] - n->shape[i]))) {
          need_stride = true;
        }
      }
      if (!is_one(extents[i])) can_relax = false;
    }
    // make stride.
    if (need_stride) {
      return MakeStrideView().MakeSlice(begins, extents);
    }
  }
  return BufferNode::make(n->data,
                          n->dtype,
                          extents,
                          strides,
                          elem_offset,
                          n->name + "_slice",
                          n->scope,
366
                          n->data_alignment,
367 368 369
                          0);
}

370
Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const {
371
  const BufferNode* self = operator->();
372
  Expr e_dtype;
373 374 375 376 377 378
  Expr extent;
  if (self->shape.size() == 0) {
    extent = make_const(self->DefaultIndexType(), 1);
  } else if (self->strides.size() == self->shape.size()) {
    int highest_dim = 0;
    extent = arith::ComputeExpr<ir::Mul>(
379
        self->strides[highest_dim], self->shape[highest_dim]) - offset;
380
  } else {
381
    extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()) - offset;
382
  }
383
  Expr elem_offset = self->elem_offset + offset;
384
  if (content_lanes > 1) {
385
    e_dtype = ir::TypeAnnotation(self->dtype.with_lanes(content_lanes));
386 387 388 389
    extent = extent / make_const(self->elem_offset.type(), content_lanes);
    elem_offset = self->elem_offset / make_const(self->elem_offset.type(),
                                                 content_lanes);
  } else {
390
    e_dtype = ir::TypeAnnotation(self->dtype);
391
  }
392
  Array<Expr> acc_args{
393
    e_dtype, self->data, elem_offset,
394 395 396 397 398
        extent, make_const(Int(32), access_mask)};
  return ir::Call::make(
      ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic);
}

399 400
Buffer BufferNode::make(Var data,
                        Type dtype,
401 402
                        Array<Expr> shape,
                        Array<Expr> strides,
403
                        Expr elem_offset,
404 405
                        std::string name,
                        std::string scope,
406 407
                        int data_alignment,
                        int offset_factor) {
408
  auto n = make_node<BufferNode>();
409
  n->data = std::move(data);
410
  n->dtype = dtype;
411 412 413
  n->shape = std::move(shape);
  n->strides = std::move(strides);
  n->name = std::move(name);
414 415 416
  if (scope.length() == 0) {
    scope = "global";
  }
417
  n->scope = std::move(scope);
418
  if (!elem_offset.defined()) {
419
    elem_offset = make_const(n->DefaultIndexType(), 0);
420
  }
421
  if (data_alignment <= 0) {
422 423 424 425
    data_alignment = runtime::kAllocAlignment;
  }
  if (offset_factor == 0) {
    offset_factor = 1;
426
  }
427
  n->elem_offset = std::move(elem_offset);
428 429
  n->data_alignment = data_alignment;
  n->offset_factor = offset_factor;
430 431 432 433 434 435 436 437 438 439 440
  return Buffer(n);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BufferNode>([](const BufferNode *op, IRPrinter *p) {
    p->stream << "buffer(" << op->name << ", " << op << ")";
});

TVM_REGISTER_NODE_TYPE(BufferNode);

}  // namespace tvm