int_set.cc 16.5 KB
Newer Older
1
/*!
2 3
 *  Copyright (c) 2017 by Contributors
 * \file int_set.cc
4 5 6
 * \brief The integer set functions
 */
#include <tvm/ir.h>
7
#include <tvm/ir_pass.h>
8
#include <tvm/arithmetic.h>
9
#include <arithmetic/Interval.h>
10
#include <unordered_map>
11
#include "./compute_expr.h"
12
#include "./int_set_internal.h"
13 14

namespace tvm {
15
namespace arith {
16

17
using Halide::Internal::Interval;
18 19
using namespace ir;

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
inline IntSet IntSet::cover_interval() const {
  if ((*this).as<IntervalSet>()) return *this;
  const StrideSet* s =  (*this).as<StrideSet>();
  if (s) {
    CHECK_NE(s->extents.size(), 0U);
    Expr max = s->base.max;
    for (size_t i = 0; i < s->extents.size(); ++i) {
      max = max + s->extents[i] * s->strides[i] - s->strides[i];
    }
    return IntervalSet::make(s->base.min, max);
  }
  LOG(FATAL) << "cannot convert set " << (*this)->type_key() << " to interval";
  return IntSet::everything();
}

Range IntSet::cover_range(Range max_range) const {
  IntSet temp;
  const IntervalSet* s_int = (*this).as<IntervalSet>();
  if (s_int == nullptr) {
    temp = this->cover_interval();
    s_int = temp.as<IntervalSet>();
  }
  if (s_int->i.is_bounded()) {
43
    return Range::make_by_min_extent(
44 45 46 47
        s_int->i.min, Simplify(s_int->i.max + 1 - s_int->i.min));
  }
  return max_range;
}
48

49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
Expr IntSet::min() const {
  const IntervalSet* s_int = (*this).as<IntervalSet>();
  CHECK(s_int);
  return s_int->i.min;
}

Expr IntSet::max() const {
  const IntervalSet* s_int = (*this).as<IntervalSet>();
  CHECK(s_int);
  return s_int->i.max;
}

bool IntSet::is_nothing() const {
  const IntervalSet* s_int = (*this).as<IntervalSet>();
  return (s_int && s_int->i.is_empty());
}

66 67 68 69
bool IntSet::is_everything() const {
  const IntervalSet* s_int = (*this).as<IntervalSet>();
  return (s_int && s_int->i.is_everything());
}
70

71 72 73
bool IntSet::is_single_point() const {
  const IntervalSet* s_int = (*this).as<IntervalSet>();
  return (s_int && s_int->i.is_single_point());
74 75
}

76 77 78 79 80
bool IntSet::can_prove_positive() const {
  const IntervalSet* s_int = (*this).as<IntervalSet>();
  return (s_int && is_positive_const(ir::Simplify(s_int->i.min)));
}

81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
bool IntSet::can_prove_negative() const {
  const IntervalSet* s_int = (*this).as<IntervalSet>();
  return (s_int && is_negative_const(ir::Simplify(s_int->i.max)));
}

SignType IntSet::sign_type() const {
  if (can_prove_positive()) {
    return kPositive;
  } else if (can_prove_negative()) {
    return kNegative;
  } else if (is_single_point() && is_zero(point_value())) {
    return kZero;
  } else {
    return kUnknown;
  }
}
97 98 99 100 101 102
Expr IntSet::point_value() const {
  const IntervalSet* s_int = (*this).as<IntervalSet>();
  CHECK(s_int && s_int->i.is_single_point());
  return s_int->i.min;
}

103 104 105 106
IntSet IntSet::nothing() {
  return IntervalSet::make(Interval::nothing());
}

107 108
IntSet IntSet::everything() {
  return IntervalSet::make(Interval::everything());
109 110
}

111 112
IntSet IntSet::single_point(Expr x) {
  return IntervalSet::make(Interval::single_point(x));
113 114
}

115 116 117 118 119 120 121 122
IntSet IntSet::range(Range r) {
  // must make sure it can be matched back by MatchRange.
  if (is_one(r->extent)) {
    return IntSet::single_point(r->min);
  }
  if (is_positive_const(r->extent) && is_const(r->min)) {
    return IntervalSet::make(
        r->min, ComputeExpr<Sub>(ComputeExpr<Add>(r->extent, r->min), 1));
123
  }
124
  return IntervalSet::make(r->min, (r->extent + r->min) - 1);
125 126
}

127 128 129 130 131 132 133
IntSet IntSet::interval(Expr min, Expr max) {
  if (min.same_as(max)) {
    return IntSet::single_point(min);
  }
  return IntervalSet::make(min, max);
}

134
// Check if a is created from b.
135 136
bool IntSet::match_range(const Range& b) const {
  const IntSet& a = *this;
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
  const IntervalSet* a_int = a.as<IntervalSet>();
  if (!a_int) return false;
  const Interval& i = a_int->i;
  if (!i.min.same_as(b)) return false;
  if (is_one(b->extent)) return i.is_single_point();
  if (is_positive_const(b->extent) && is_const(b->min)) {
    // deep equality
    return Equal(
        ComputeExpr<Sub>(ComputeExpr<Add>(b->extent, b->min), 1),
        a_int->i.max);
  }
  const Sub* sub = i.max.as<Sub>();
  if (!sub) return false;
  if (is_one(sub->b)) return false;
  const Add* add = sub->a.as<Add>();
  return add &&
      add->a.same_as(b->min) &&
      add->b.same_as(b->extent);
155 156
}

157 158 159 160 161 162
inline bool MatchPoint(const IntSet& a,
                       const Expr& b) {
  const IntervalSet* a_int = a.as<IntervalSet>();
  if (!a_int) return false;
  const Interval& i = a_int->i;
  return i.is_single_point() && i.min.same_as(b);
163 164
}

165
IntSet Union(const Array<IntSet>& sets) {
166
  if (sets.size() == 0) return IntSet::nothing();
167 168 169 170
  if (sets.size() == 1) return sets[0];
  Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
  for (size_t i = 1; i < sets.size(); ++i) {
    IntSet s = sets[i].cover_interval();
171
    const Interval& y = s.as<IntervalSet>()->i;
172
    x.include(y);
173
  }
174 175
  x.max = ir::Simplify(x.max);
  x.min = ir::Simplify(x.min);
176
  return IntervalSet::make(x);
177 178
}

179 180 181 182 183 184 185 186 187
IntSet Intersect(const Array<IntSet>& sets) {
  Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
  for (size_t i = 1; i < sets.size(); ++i) {
    Interval y = sets[i].cover_interval().as<IntervalSet>()->i;
    x = Interval::make_intersection(x, y);
  }
  return IntervalSet::make(x);
}

188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
// type traits
template<typename OP>
struct is_logical_op {
  static const bool value = false;
};

#define TVM_DECLARE_LOGICAL_OP(OP)              \
  template<>                                    \
  struct is_logical_op<ir::OP> {                \
    static const bool value = true;             \
  };

// interval related.
template<typename OP>
inline IntSet CombineInterval(Interval a, Interval b) {
  if (a.is_single_point() && b.is_single_point()) {
    return IntSet::single_point(ComputeExpr<OP>(a.min, b.min));
205
  }
206 207
  LOG(WARNING) << "Return Everything in CombineInterval " << OP::_type_key;
  return IntSet::everything();
208 209
}

210 211 212 213 214 215 216 217
template<>
inline IntSet CombineInterval<Add>(Interval a, Interval b) {
  if (a.is_single_point() && b.is_single_point()) {
    return IntSet::single_point(ComputeExpr<Add>(a.min, b.min));
  }
  Interval r = Interval::everything();
  if (a.has_lower_bound() && b.has_lower_bound()) {
    r.min = ComputeExpr<Add>(a.min, b.min);
218
  }
219 220 221 222
  if (a.has_upper_bound() && b.has_upper_bound()) {
    r.max = ComputeExpr<Add>(a.max, b.max);
  }
  return IntervalSet::make(r);
223 224 225
}

template<>
226 227 228 229 230 231 232 233 234 235 236 237
inline IntSet CombineInterval<Sub>(Interval a, Interval b) {
  if (a.is_single_point() && b.is_single_point()) {
    return IntSet::single_point(ComputeExpr<Sub>(a.min, b.min));
  }
  Interval r = Interval::everything();
  if (a.has_lower_bound() && b.has_upper_bound()) {
    r.min = ComputeExpr<Sub>(a.min, b.max);
  }
  if (a.has_upper_bound() && b.has_lower_bound()) {
    r.max = ComputeExpr<Sub>(a.max, b.min);
  }
  return IntervalSet::make(r);
238 239
}

240 241 242 243 244 245 246 247 248 249 250 251 252
template<>
inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
  if (a.is_single_point() && b.is_single_point()) {
    return IntSet::single_point(ComputeExpr<Mul>(a.min, b.min));
  }
  if (a.is_single_point() && !b.is_single_point()) {
    std::swap(a, b);
  }
  if (b.is_single_point()) {
    if (is_zero(b.min)) return IntSet::single_point(0);
    if (is_one(b.min)) return IntervalSet::make(a);
    Expr e1 = a.has_lower_bound() ? ComputeExpr<Mul>(a.min, b.min) : a.min;
    Expr e2 = a.has_upper_bound() ? ComputeExpr<Mul>(a.max, b.min) : a.max;
253
    // no relaxation is needed in here due to set is inclusive
254 255 256 257 258 259 260 261 262 263 264 265
    // TODO(tqchen): consider convert to StrideSet.
    if (is_positive_const(b.min)) {
      return IntervalSet::make(e1, e2);
    } else if (is_negative_const(b.min)) {
      return IntervalSet::make(e2, e1);
    } else if (a.is_bounded()) {
      Expr cmp = b.min >= make_zero(b.min.type().element_of());
      return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1));
    }
  }
  LOG(WARNING) << "Return Everything in CombineInterval Mul";
  return IntSet::everything();
266 267 268
}

template<>
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
inline IntSet CombineInterval<Div>(Interval a, Interval b) {
  if (a.is_single_point() && b.is_single_point()) {
    return IntSet::single_point(ComputeExpr<Div>(a.min, b.min));
  }
  if (b.is_single_point()) {
    if (is_zero(b.min)) {
      LOG(FATAL) << "Divide by zero in CombineInterval Div";
    }
    if (is_one(b.min)) return IntervalSet::make(a);
    Expr e1 = a.has_lower_bound() ? ComputeExpr<Div>(a.min, b.min) : a.min;
    Expr e2 = a.has_upper_bound() ? ComputeExpr<Div>(a.max, b.min) : a.max;
    // no relaxation is needed in here due to set is inclusive
    if (is_positive_const(b.min)) {
      return IntervalSet::make(e1, e2);
    } else if (is_negative_const(b.min)) {
      return IntervalSet::make(e2, e1);
    } else if (a.is_bounded()) {
      Expr cmp = b.min >= make_zero(b.min.type().element_of());
      return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1));
    }
  }
  LOG(WARNING) << "Return Everything in CombineInterval Div";
  return IntSet::everything();
}

template<>
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
inline IntSet CombineInterval<Mod>(Interval a, Interval b) {
  if (a.is_single_point() && b.is_single_point()) {
    return IntSet::single_point(ComputeExpr<Mod>(a.min, b.min));
  }
  if (b.is_single_point()) {
    Expr divisor = b.min;
    if (is_zero(divisor)) {
      LOG(FATAL) << "Modular by zero in CombineInterval Mod";
    }
    return IntervalSet::make(make_zero(divisor.type()), divisor - 1);
  }

  LOG(WARNING) << "Return Everything in CombineInterval Mod";
  return IntSet::everything();
}

template<>
312 313 314
inline IntSet CombineInterval<Max>(Interval a, Interval b) {
  if (a.is_single_point() && b.is_single_point()) {
    return IntSet::single_point(ComputeExpr<Max>(a.min, b.min));
315
  }
316 317
  return IntervalSet::make(Interval::make_max(a.min, b.min),
                           Interval::make_max(a.max, b.max));
318 319
}

320 321 322 323 324 325 326
template<>
inline IntSet CombineInterval<Min>(Interval a, Interval b) {
  if (a.is_single_point() && b.is_single_point()) {
    return IntSet::single_point(ComputeExpr<Min>(a.min, b.min));
  }
  return IntervalSet::make(Interval::make_min(a.min, b.min),
                           Interval::make_min(a.max, b.max));
327 328
}

329 330 331 332 333
template<typename OP>
inline IntSet CombineInterval_(IntSet a, IntSet b) {
  return CombineInterval<OP>(
      a.as<IntervalSet>()->i, b.as<IntervalSet>()->i);
}
334

335 336 337 338 339 340 341
// stride related
inline IntSet AsStrideSet(IntSet a) {
  if (a.as<StrideSet>()) return a;
  const IntervalSet* s = a.as<IntervalSet>();
  CHECK(s->i.is_bounded());
  std::shared_ptr<StrideSet> n = std::make_shared<StrideSet>();
  n->base = s->i;
342 343
  return IntSet(n);
}
344 345 346 347
template<typename OP>
inline IntSet CombineSets(IntSet a, IntSet b) {
  return CombineInterval_<OP>(a.cover_interval(), b.cover_interval());
}
348

349 350 351 352 353 354 355 356 357 358 359 360 361 362
template<>
inline IntSet CombineSets<Add>(IntSet a, IntSet b) {
  const IntervalSet* a_int = a.as<IntervalSet>();
  const IntervalSet* b_int = b.as<IntervalSet>();
  if (a_int && is_zero(a_int->i.min)) return b;
  if (b_int && is_zero(b_int->i.min)) return a;
  a = AsStrideSet(a);
  b = AsStrideSet(b);
  const StrideSet* a_stride = a.as<StrideSet>();
  const StrideSet* b_stride = b.as<StrideSet>();
  auto n = std::make_shared<StrideSet>(*a_stride);
  for (size_t i = 0; i < b_stride->extents.size(); ++i) {
    n->extents.push_back(b_stride->extents[i]);
    n->strides.push_back(b_stride->strides[i]);
363
  }
364 365 366
  n->base = CombineInterval<Add>(
      a_stride->base, b_stride->base).as<IntervalSet>()->i;
  return IntSet(n);
367 368
}

369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
inline IntSet NegateSet(IntSet a) {
  const IntervalSet* a_int = a.as<IntervalSet>();
  if (a_int) {
    if (a_int->i.is_single_point()) {
      return IntSet::single_point(-a_int->i.min);
    } else {
      Interval r = Interval::everything();
      if (a_int->i.has_upper_bound()) {
        r.min = -(a_int->i.max);
      }
      if (a_int->i.has_lower_bound()) {
        r.max = -(a_int->i.min);
      }
      return IntervalSet::make(r);
    }
  } else {
    return NegateSet(a.cover_interval());
  }
387 388
}

389 390 391
template<>
inline IntSet CombineSets<Sub>(IntSet a, IntSet b) {
  return CombineSets<Add>(a, NegateSet(b));
tqchen committed
392 393
}

394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
TVM_DECLARE_LOGICAL_OP(And);
TVM_DECLARE_LOGICAL_OP(Or);
TVM_DECLARE_LOGICAL_OP(EQ);
TVM_DECLARE_LOGICAL_OP(NE);
TVM_DECLARE_LOGICAL_OP(GE);
TVM_DECLARE_LOGICAL_OP(GT);
TVM_DECLARE_LOGICAL_OP(LE);
TVM_DECLARE_LOGICAL_OP(LT);
TVM_DECLARE_LOGICAL_OP(Not);

// generic combine operations of two sets
template<typename OP>
inline IntSet Combine(const IntSet& a, const IntSet &b) {
  if (is_logical_op<OP>::value) {
    return IntervalSet::make(0, 1);
  }
  const IntervalSet* a_int = a.as<IntervalSet>();
  const IntervalSet* b_int = b.as<IntervalSet>();
  if (a_int && a_int->i.is_everything()) return a;
  if (b_int && b_int->i.is_everything()) return b;
  if (a_int && b_int) {
    return CombineInterval<OP>(a_int->i, b_int->i);
  }
  if (a_int && !(a_int->i.is_bounded())) {
    return CombineInterval_<OP>(a, b.cover_interval());
  }
  if (b_int && !(b_int->i.is_bounded())) {
    return CombineInterval_<OP>(a.cover_interval(), b);
  }
  return CombineSets<OP>(a, b);
424 425
}

426 427
// Evaluator to evalute the epxression.
class IntSetEvaluator {
428
 public:
429 430 431
  explicit IntSetEvaluator(const std::unordered_map<const Variable*, IntSet>& dom_map)
      : dom_map(dom_map) {}

432
  inline virtual IntSet Eval(Expr expr) {
433 434 435 436 437
    static const FType& f = vtable();
    if (f.can_dispatch(expr)) {
      return f(expr, expr, this);
    } else {
      LOG(WARNING) << "cannot evaluate set type " << expr->type_key();
438
      return IntSet::nothing();
439 440 441
    }
  }

442
  using FType = tvm::IRFunctor<IntSet (const NodeRef&, const Expr&, IntSetEvaluator *)>;
443 444 445 446
  static FType& vtable() {  // NOLINT(*)
    static FType inst; return inst;
  }

447
  const std::unordered_map<const Variable*, IntSet>& dom_map;
448 449
};

450
inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator* m) {
451
  return IntSet::single_point(e);
452 453
}

454
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
455 456 457 458
.set_dispatch<IntImm>(ConstOp)
.set_dispatch<UIntImm>(ConstOp)
.set_dispatch<FloatImm>(ConstOp);

459 460
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
.set_dispatch<Variable>([](const Variable* op, const Expr& e, IntSetEvaluator* m) {
461 462 463 464
    auto it = m->dom_map.find(op);
    if (it != m->dom_map.end()) {
      return it->second;
    } else {
465
      return IntSet::single_point(e);
466 467 468 469 470
    }
  });

// binary operator
template<typename T>
471
inline IntSet Binary(const T* op, const Expr& e, IntSetEvaluator* m) {
472 473
  IntSet a = m->Eval(op->a);
  IntSet b = m->Eval(op->b);
474 475
  if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
    return IntSet::single_point(e);
476
  }
477
  return Combine<T>(a, b);
478 479
}

480
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
481 482 483 484 485 486
.set_dispatch<Add>(Binary<Add>)
.set_dispatch<Sub>(Binary<Sub>)
.set_dispatch<Mul>(Binary<Mul>)
.set_dispatch<Div>(Binary<Div>)
.set_dispatch<Mod>(Binary<Mod>)
.set_dispatch<Min>(Binary<Min>)
487 488 489 490 491 492 493 494 495
.set_dispatch<Max>(Binary<Max>)
.set_dispatch<EQ>(Binary<EQ>)
.set_dispatch<NE>(Binary<NE>)
.set_dispatch<LT>(Binary<LT>)
.set_dispatch<LE>(Binary<LE>)
.set_dispatch<GT>(Binary<GT>)
.set_dispatch<GE>(Binary<GE>)
.set_dispatch<And>(Binary<And>)
.set_dispatch<Or>(Binary<Or>);
496

497 498 499 500 501
IntSet EvalSet(Expr e,
               const std::unordered_map<const Variable*, IntSet>& dom_map) {
  return IntSetEvaluator(dom_map).Eval(e);
}

502 503
IntSet EvalSet(Expr e,
               const Map<IterVar, IntSet>& dom_map) {
504
  std::unordered_map<const Variable*, IntSet> dmap;
505
  for (auto kv : dom_map) {
506
    dmap[kv.first->var.as<Variable>()] = kv.second;
507
  }
508
  return EvalSet(e, dmap);
509 510
}

511
IntSet EvalSet(Range r,
512 513
               const std::unordered_map<const Variable*, IntSet>& dom_map) {
  IntSetEvaluator m(dom_map);
514 515 516 517 518 519 520 521
  IntSet min_set = m.Eval(r->min);
  IntSet ext_set = m.Eval(r->extent).cover_interval();
  const Interval& ei = ext_set.as<IntervalSet>()->i;
  if (!ei.has_upper_bound()) return IntSet::everything();
  ext_set = IntervalSet::make(0, ComputeExpr<Sub>(ei.max, 1));
  return Combine<Add>(min_set, ext_set);
}

522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
class SubExprIntSetEvaluator : public IntSetEvaluator {
 public:
  explicit SubExprIntSetEvaluator(const std::unordered_map<const Variable*, IntSet>& dom_map)
      : IntSetEvaluator(dom_map) {}

  inline IntSet Eval(Expr expr) override {
    IntSet ret = IntSetEvaluator::Eval(expr);
    expr_map[expr] = ret;
    return ret;
  }

  ExprIntSetMap expr_map;
};

ExprIntSetMap EvalSetForEachSubExpr(Expr e,
    const std::unordered_map<const Variable*, IntSet>& dom_map) {
  SubExprIntSetEvaluator m(dom_map);
  m.Eval(e);
  return m.expr_map;
}

543 544 545 546 547 548 549 550 551
IntSet EvalSet(Range r,
               const Map<IterVar, IntSet>& dom_map) {
  std::unordered_map<const Variable*, IntSet> dmap;
  for (auto kv : dom_map) {
    dmap[kv.first->var.as<Variable>()] = kv.second;
  }
  return EvalSet(r, dmap);
}

552 553
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntervalSet>([](const IntervalSet *op, IRPrinter *p) {
554
    p->stream << "interval-set"
555 556 557 558
              << "[" << op->i.min << ", "
              << op->i.max << ']';
  });

559
}  // namespace arith
560
}  // namespace tvm