int_set.cc 26.4 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 * \file int_set.cc
22 23
 * \brief The integer set functions
 */
24
#include <tvm/arith/int_set.h>
25 26
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
27 28
#include <tvm/runtime/registry.h>

29 30
#include <utility>
#include <algorithm>
31
#include <unordered_map>
32
#include "interval_set.h"
33
#include "pattern_match.h"
34 35

namespace tvm {
36
namespace arith {
37

38 39 40 41 42
using tir::make_const;
using tir::make_zero;
using tir::is_zero;
using tir::is_one;

43 44
PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle());
PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle());
45

46
IntervalSet::IntervalSet(PrimExpr min_value, PrimExpr max_value) {
47
  auto node = make_object<IntervalSetNode>();
48 49
  node->min_value = std::move(min_value);
  node->max_value = std::move(max_value);
50
  data_ = std::move(node);
51 52
}

53
IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) {
54
  return IntervalSet(min_value, max_value);
55 56
}

57
TVM_REGISTER_GLOBAL("arith.IntervalSet")
58
.set_body_typed(MakeIntervalSet);
59

60

61
IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
62 63
  PrimExpr max_value = min(a->max_value, b->max_value);
  PrimExpr min_value = max(a->min_value, b->min_value);
64 65
  if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) &&
      (min_value.dtype().is_int() || min_value.dtype().is_uint()) &&
66 67
      analyzer->CanProveGreaterEqual(min_value - max_value, 1)) {
    return IntervalSet::Empty();
68
  } else {
69
    return IntervalSet(min_value, max_value);
70 71 72
  }
}

73
IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
74 75
  PrimExpr max_value = max(a->max_value, b->max_value);
  PrimExpr min_value = min(a->min_value, b->min_value);
76
  return IntervalSet(min_value, max_value);
77 78
}

79 80 81 82 83 84 85 86
// type traits
template<typename OP>
struct is_logical_op {
  static const bool value = false;
};

#define TVM_DECLARE_LOGICAL_OP(OP)              \
  template<>                                    \
87
  struct is_logical_op<tir::OP> {                \
88 89 90
    static const bool value = true;             \
  };

91 92 93 94 95 96 97 98 99
TVM_DECLARE_LOGICAL_OP(AndNode);
TVM_DECLARE_LOGICAL_OP(OrNode);
TVM_DECLARE_LOGICAL_OP(EQNode);
TVM_DECLARE_LOGICAL_OP(NENode);
TVM_DECLARE_LOGICAL_OP(GENode);
TVM_DECLARE_LOGICAL_OP(GTNode);
TVM_DECLARE_LOGICAL_OP(LENode);
TVM_DECLARE_LOGICAL_OP(LTNode);
TVM_DECLARE_LOGICAL_OP(NotNode);
100 101 102 103 104 105 106 107 108 109

/*!
 * \brief Combine two interval set under arithmetic operations.
 * \note this can possibly relax the set.
 */
template<typename Op>
inline IntervalSet Combine(Analyzer* analyzer,
                           IntervalSet a,
                           IntervalSet b) {
  if (a->IsSinglePoint() && b->IsSinglePoint()) {
110
    PrimExpr res = TryConstFold<Op>(a->min_value, b->min_value);
111 112 113 114
    if (!res.defined()) res = Op::make(a->min_value, b->min_value);
    return IntervalSet::SinglePoint(res);
  }
  if (is_logical_op<Op>::value) {
115 116
    return IntervalSet(make_const(a->min_value.dtype(), 0),
                       make_const(a->min_value.dtype(), 1));
117 118 119 120 121 122
  }
  if (a->IsEmpty()) return a;
  if (b->IsEmpty()) return b;
  if (a->IsEverything()) return a;
  if (b->IsEverything()) return b;
  return IntervalSet::Everything();
123 124
}

125
template<>
126
inline IntervalSet Combine<tir::AddNode>(Analyzer* analyer,
127 128
                                        IntervalSet a,
                                        IntervalSet b) {
129 130 131 132 133
  if (a->IsSinglePoint() && b->IsSinglePoint()) {
    return IntervalSet::SinglePoint(a->min_value + b->min_value);
  }
  if (a->IsEmpty()) return a;
  if (b->IsEmpty()) return b;
134
  PrimExpr min_value =
135 136
      a->HasLowerBound() && b->HasLowerBound() ?
      a->min_value + b->min_value : neg_inf();
137
  PrimExpr max_value =
138 139 140
      a->HasUpperBound() && b->HasUpperBound() ?
      a->max_value + b->max_value : pos_inf();
  return IntervalSet(min_value, max_value);
141 142 143
}

template<>
144
inline IntervalSet Combine<tir::SubNode>(Analyzer* analyer,
145 146
                                        IntervalSet a,
                                        IntervalSet b) {
147 148
  if (a->IsSinglePoint() && b->IsSinglePoint()) {
    return IntervalSet::SinglePoint(a->min_value - b->min_value);
149
  }
150 151
  if (a->IsEmpty()) return a;
  if (b->IsEmpty()) return b;
152
  PrimExpr min_value =
153 154
      a->HasLowerBound() && b->HasUpperBound() ?
      a->min_value - b->max_value : neg_inf();
155
  PrimExpr max_value =
156 157 158
      a->HasUpperBound() && b->HasLowerBound() ?
      a->max_value - b->min_value : pos_inf();
  return IntervalSet(min_value, max_value);
159 160
}

161

162
template<>
163
inline IntervalSet Combine<tir::MulNode>(Analyzer* analyzer,
164 165
                                        IntervalSet a,
                                        IntervalSet b) {
166 167 168 169 170 171
  if (a->IsSinglePoint() && b->IsSinglePoint()) {
    return IntervalSet::SinglePoint(a->min_value * b->min_value);
  }
  if (a->IsEmpty()) return a;
  if (b->IsEmpty()) return b;
  if (a->IsSinglePoint()) {
172 173
    std::swap(a, b);
  }
174 175 176 177
  if (b->IsSinglePoint()) {
    if (is_zero(b->min_value)) return b;
    if (is_one(b->min_value)) return a;
    if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
178 179
      PrimExpr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf();
      PrimExpr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf();
180 181
      return IntervalSet(min_value, max_value);
    } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
182 183
      PrimExpr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf();
      PrimExpr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf();
184 185
      return IntervalSet(min_value, max_value);
    } else if (a->HasUpperBound() && a->HasLowerBound()) {
186
      using tir::SelectNode;
187 188 189
      PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
      PrimExpr e1 = a->min_value * b->min_value;
      PrimExpr e2 = a->max_value * b->min_value;
190
      return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
191 192
    }
  }
193 194
  DLOG(WARNING) << "Return Everything in CombineInterval Mul";
  return IntervalSet::Everything();
195 196 197
}

template<>
198
inline IntervalSet Combine<tir::DivNode>(Analyzer* analyzer,
199 200
                                        IntervalSet a,
                                        IntervalSet b) {
201 202 203 204 205 206 207
  if (a->IsSinglePoint() && b->IsSinglePoint()) {
    return IntervalSet::SinglePoint(a->min_value / b->min_value);
  }
  if (a->IsEmpty()) return a;
  if (b->IsEmpty()) return b;
  if (b->IsSinglePoint()) {
    if (is_zero(b->min_value)) {
208 209
      LOG(FATAL) << "Divide by zero in CombineInterval Div";
    }
210
    if (is_one(b->min_value)) return a;
211
    // no relaxation is needed in here due to set is inclusive
212
    if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
213 214
      PrimExpr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf();
      PrimExpr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf();
215 216
      return IntervalSet(min_value, max_value);
    } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
217 218
      PrimExpr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf();
      PrimExpr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf();
219 220
      return IntervalSet(min_value, max_value);
    } else if (a->HasUpperBound() && a->HasLowerBound()) {
221
      using tir::SelectNode;
222 223 224
      PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
      PrimExpr e1 = a->min_value / b->min_value;
      PrimExpr e2 = a->max_value / b->min_value;
225
      return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
226 227
    }
  }
228 229
  DLOG(WARNING) << "Return Everything in CombineInterval Div";
  return IntervalSet::Everything();
230 231 232
}

template<>
233
inline IntervalSet Combine<tir::ModNode>(Analyzer* analyzer,
234 235
                                        IntervalSet a,
                                        IntervalSet b) {
236
  if (a->IsSinglePoint() && b->IsSinglePoint()) {
237
    return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value));
238
  }
239 240 241 242
  if (a->IsEmpty()) return a;
  if (b->IsEmpty()) return b;

  if (b->IsSinglePoint()) {
243
    const PrimExpr& divisor = b->min_value;
244 245 246
    if (is_zero(divisor)) {
      LOG(FATAL) << "Modular by zero in CombineInterval Mod";
    }
247 248 249 250 251
    // We need to add more bound constraints throughout the code.
    // The logic below assumes a is non-negative, which usually
    // is the case of our application.
    // TODO(tqchen): add bound constraints for a.
    if (analyzer->CanProveGreaterEqual(divisor, 0)) {
252
      return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
253
    } else {
254
      PrimExpr bound = abs(divisor) - 1;
255 256
      return IntervalSet(-bound, bound);
    }
257
  }
258 259
  DLOG(WARNING) << "Return Everything in CombineInterval Mod";
  return IntervalSet::Everything();
260 261
}

262 263

template<>
264
inline IntervalSet Combine<tir::FloorDivNode>(Analyzer* analyzer,
265 266
                                             IntervalSet a,
                                             IntervalSet b) {
267 268 269 270 271 272 273 274 275 276 277 278
  if (a->IsSinglePoint() && b->IsSinglePoint()) {
    return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value));
  }
  if (a->IsEmpty()) return a;
  if (b->IsEmpty()) return b;
  if (b->IsSinglePoint()) {
    if (is_zero(b->min_value)) {
      LOG(FATAL) << "Divide by zero in CombineInterval Div";
    }
    if (is_one(b->min_value)) return a;
    // no relaxation is needed in here due to set is inclusive
    if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
279 280
      PrimExpr min_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : neg_inf();
      PrimExpr max_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : pos_inf();
281 282
      return IntervalSet(min_value, max_value);
    } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
283 284
      PrimExpr min_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : neg_inf();
      PrimExpr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf();
285 286
      return IntervalSet(min_value, max_value);
    } else if (a->HasUpperBound() && a->HasLowerBound()) {
287
      using tir::SelectNode;
288 289 290
      PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
      PrimExpr e1 = floordiv(a->min_value, b->min_value);
      PrimExpr e2 = floordiv(a->max_value, b->min_value);
291
      return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
292 293 294 295 296 297 298
    }
  }
  DLOG(WARNING) << "Return Everything in CombineInterval Div";
  return IntervalSet::Everything();
}

template<>
299
inline IntervalSet Combine<tir::FloorModNode>(Analyzer* analyzer,
300 301
                                             IntervalSet a,
                                             IntervalSet b) {
302 303 304 305 306 307 308
  if (a->IsSinglePoint() && b->IsSinglePoint()) {
    return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value));
  }
  if (a->IsEmpty()) return a;
  if (b->IsEmpty()) return b;

  if (b->IsSinglePoint()) {
309
    const PrimExpr& divisor = b->min_value;
310 311 312 313
    if (is_zero(divisor)) {
      LOG(FATAL) << "Modular by zero in CombineInterval Mod";
    }
    if (analyzer->CanProveGreaterEqual(divisor, 0)) {
314
      return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
315
    } else {
316
      PrimExpr bound = abs(divisor) - 1;
317 318 319 320 321 322 323
      return IntervalSet(-bound, bound);
    }
  }
  DLOG(WARNING) << "Return Everything in CombineInterval Mod";
  return IntervalSet::Everything();
}

324
template<>
325
inline IntervalSet Combine<tir::MaxNode>(Analyzer* analzyer,
326 327
                                        IntervalSet a,
                                        IntervalSet b) {
328 329
  if (a->IsSinglePoint() && b->IsSinglePoint()) {
    return IntervalSet::SinglePoint(max(a->min_value,  b->min_value));
330
  }
331 332 333 334
  if (a->IsEmpty()) return a;
  if (b->IsEmpty()) return b;
  return IntervalSet(max(a->min_value, b->min_value),
                     max(a->max_value, b->max_value));
335 336
}

337
template<>
338
inline IntervalSet Combine<tir::MinNode>(Analyzer* analzyer,
339 340
                                        IntervalSet a,
                                        IntervalSet b) {
341 342
  if (a->IsSinglePoint() && b->IsSinglePoint()) {
    return IntervalSet::SinglePoint(min(a->min_value, b->min_value));
343
  }
344 345 346 347
  if (a->IsEmpty()) return a;
  if (b->IsEmpty()) return b;
  return IntervalSet(min(a->min_value, b->min_value),
                     min(a->max_value, b->max_value));
348
}
349

350 351 352 353
// internal helper function to get an interval set
IntervalSet ToIntervalSet(IntSet set) {
  if (auto* node = set.as<IntervalSetNode>()) {
    return GetRef<IntervalSet>(node);
354
  }
355 356
  DLOG(INFO) << "cannot resolve int set " << set;
  return IntervalSet::Everything();
357 358
}

359
using namespace tir;
360

361 362 363
// Simplified version of int set evaluator that operates on IntervalSet
// We might use better set analysis in the future to replace the intervalset.
class IntervalSetEvaluator :
364
      public ExprFunctor<IntervalSet(const PrimExpr&)> {
365 366 367 368 369 370 371
 public:
  IntervalSetEvaluator(Analyzer* analyzer,
                       const Map<Var, IntSet>& dom_map,
                       bool eval_vec = false)
      : analyzer_(analyzer),
        dom_map_(dom_map),
        eval_vec_(eval_vec) {
372
  }
373

374
  IntervalSet Eval(const PrimExpr& val) {
375
    return this->VisitExpr(val);
376
  }
377 378 379 380 381 382 383 384 385 386
  // evaluate and relax the set
  IntervalSet Eval(IntervalSet val) {
    // avoid recursive indefinite recursive expansion.
    if (static_cast<size_t>(recur_depth_) >= dom_map_.size()) return val;
    ++recur_depth_;
    IntervalSet min_set = this->Eval(val->min_value);
    IntervalSet max_set = this->Eval(val->max_value);
    --recur_depth_;
    return IntervalSet(min_set->min_value, max_set->max_value);
  }
387

388
  IntervalSet VisitExpr_(const IntImmNode* op) final {
389
    return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
390
  }
391

392
  IntervalSet VisitExpr_(const VarNode* op) final {
393 394
    Var var = GetRef<Var>(op);
    auto it = dom_map_.find(var);
395
    if (it != dom_map_.end()) {
396 397 398 399 400 401 402 403
      IntervalSet res = ToIntervalSet((*it).second);
      if (res->min_value.same_as(var) &&
          res->max_value.same_as(var)) {
        return res;
      }
      // recursively evaluate mapped result
      // in case the domain contains variables to be relaxed.
      return Eval(res);
404
    } else {
405
      return IntervalSet::SinglePoint(var);
406
    }
407
  }
408

409

410
  IntervalSet VisitExpr_(const AddNode* op) final {
411
    return VisitBinaryExpr_(op);
412
  }
413

414
  IntervalSet VisitExpr_(const SubNode* op) final {
415
    return VisitBinaryExpr_(op);
416
  }
417

418
  IntervalSet VisitExpr_(const MulNode* op) final {
419
    return VisitBinaryExpr_(op);
420
  }
421

422
  IntervalSet VisitExpr_(const DivNode* op) final {
423
    return VisitBinaryExpr_(op);
424
  }
425

426
  IntervalSet VisitExpr_(const ModNode* op) final {
427
    return VisitBinaryExpr_(op);
428
  }
429

430
  IntervalSet VisitExpr_(const FloorDivNode* op) final {
431 432 433
    return VisitBinaryExpr_(op);
  }

434
  IntervalSet VisitExpr_(const FloorModNode* op) final {
435 436 437
    return VisitBinaryExpr_(op);
  }

438
  IntervalSet VisitExpr_(const MinNode* op) final {
439
    return VisitBinaryExpr_(op);
440
  }
441

442
  IntervalSet VisitExpr_(const MaxNode* op) final {
443
    return VisitBinaryExpr_(op);
444
  }
445

446
  IntervalSet VisitExpr_(const EQNode* op) final {
447
    return VisitBinaryExpr_(op);
448
  }
449

450
  IntervalSet VisitExpr_(const NENode* op) final {
451
    return VisitBinaryExpr_(op);
452
  }
453

454
  IntervalSet VisitExpr_(const LTNode* op) final {
455
    return VisitBinaryExpr_(op);
456
  }
457

458
  IntervalSet VisitExpr_(const LENode* op) final {
459
    return VisitBinaryExpr_(op);
460
  }
461

462
  IntervalSet VisitExpr_(const GTNode* op) final {
463
    return VisitBinaryExpr_(op);
464
  }
465

466
  IntervalSet VisitExpr_(const GENode* op) final {
467
    return VisitBinaryExpr_(op);
468
  }
469

470
  IntervalSet VisitExpr_(const AndNode* op) final {
471
    return VisitBinaryExpr_(op);
472
  }
473

474
  IntervalSet VisitExpr_(const OrNode* op) final {
475
    return VisitBinaryExpr_(op);
476
  }
477

478
  IntervalSet VisitExpr_(const RampNode* op) final {
479
    CHECK(eval_vec_);
480
    IntervalSet base = Eval(op->base);
481
    PVar<IntImm> stride;
482
    if (stride.Match(op->stride)) {
483
      DataType t = op->base.dtype();
484 485
      int64_t vstride = stride.Eval()->value;
      if (vstride> 0) {
486
        return Combine<AddNode>(
487
            analyzer_,
488
            base,
489
            IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1)));
490
      } else {
491
        return Combine<AddNode>(
492
            analyzer_,
493
            base,
494
            IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t)));
495 496
      }
    }
497
    DLOG(WARNING) << "cannot evaluate set on expression " << GetRef<PrimExpr>(op);
498
    return IntervalSet::Everything();
499
  }
500

501
  IntervalSet VisitExpr_(const BroadcastNode* op) final {
502
    CHECK(eval_vec_);
503
    return VisitExpr(op->value);
504
  }
505

506
  IntervalSet VisitExpr_(const SelectNode* op) final {
507 508 509
    IntervalSet true_set = this->Eval(op->true_value);
    IntervalSet false_set = this->Eval(op->false_value);
    return Union(analyzer_, false_set, true_set);
510
  }
511

512
  IntervalSet VisitExprDefault_(const Object* op) final {
513
    DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey();
514
    return IntervalSet::Everything();
515
  }
516

517
 private:
518 519
  // whether set is exactly single point that equals value.
  bool MatchPoint(const IntervalSet& set,
520
                  const PrimExpr& value) const {
521 522 523
    return set->min_value.same_as(value) && set->max_value.same_as(value);
  }

524
  template<typename T>
525 526 527
  inline IntervalSet VisitBinaryExpr_(const T* op) {
    IntervalSet a = this->Eval(op->a);
    IntervalSet b = this->Eval(op->b);
528
    if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
529
      return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
530
    }
531
    return Combine<T>(analyzer_, a, b);
532
  }
533

534 535 536
  // recursive depth
  int recur_depth_{0};
  // analyzer
537 538
  Analyzer* analyzer_;
  const Map<Var, IntSet>& dom_map_;
539 540
  bool eval_vec_{false};
};
541

542 543 544 545 546 547
class IntSetAnalyzer::Impl {
 public:
  explicit Impl(Analyzer* analyzer)
      : analyzer_(analyzer) {
  }

548
  IntSet Eval(const PrimExpr& expr, const Map<Var, IntSet>& dom_map) const {
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
    return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr);
  }

 private:
  Analyzer* analyzer_;
};

IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent)
    : impl_(new Impl(parent)) {
}

IntSetAnalyzer::~IntSetAnalyzer() {
  delete impl_;
}

564
IntSet IntSetAnalyzer::operator()(const PrimExpr& expr,
565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581
                                  const Map<Var, IntSet>& dom_map) {
  return impl_->Eval(expr, dom_map);
}

// Quickly adapt to IntSet interface
// TODO(tqchen): revisit IntSet interface as well.
Range IntSet::cover_range(Range max_range) const {
  IntSet temp;
  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
  CHECK(s_int != nullptr);
  if (s_int->HasUpperBound() && s_int->HasLowerBound()) {
    return Range::make_by_min_extent(
        s_int->min_value, Simplify(s_int->max_value + 1 - s_int->min_value));
  }
  return max_range;
}

582
PrimExpr IntSet::min() const {
583 584 585 586 587
  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
  CHECK(s_int);
  return s_int->min_value;
}

588
PrimExpr IntSet::max() const {
589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610
  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
  CHECK(s_int);
  return s_int->max_value;
}

bool IntSet::is_nothing() const {
  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
  return (s_int && s_int->IsEmpty());
}

bool IntSet::is_everything() const {
  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
  return (s_int && s_int->IsEverything());
}

bool IntSet::is_single_point() const {
  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
  return (s_int && s_int->IsSinglePoint());
}

bool IntSet::can_prove_positive() const {
  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
611
  return (s_int && is_positive_const(tir::Simplify(s_int->min_value)));
612 613 614 615
}

bool IntSet::can_prove_negative() const {
  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
616
  return (s_int && is_negative_const(tir::Simplify(s_int->max_value)));
617 618 619 620
}

bool IntSet::can_prove_non_positive() const {
  if (const auto* s_int = (*this).as<IntervalSetNode>()) {
621
    auto max = tir::Simplify(s_int->max_value);
622 623 624 625 626 627 628
    return is_zero(max) || is_negative_const(max);
  }
  return false;
}

bool IntSet::can_prove_non_negative() const {
  if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) {
629
    auto min = tir::Simplify(s_int->min_value);
630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645
    return is_zero(min) || is_positive_const(min);
  }
  return false;
}

SignType IntSet::sign_type() const {
  if (can_prove_positive()) {
    return kPositive;
  } else if (can_prove_negative()) {
    return kNegative;
  } else if (is_single_point() && is_zero(point_value())) {
    return kZero;
  } else {
    return kUnknown;
  }
}
646
PrimExpr IntSet::point_value() const {
647 648 649 650 651 652 653 654 655 656 657 658 659
  const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
  CHECK(s_int && s_int->IsSinglePoint());
  return s_int->min_value;
}

IntSet IntSet::nothing() {
  return IntervalSet::Empty();
}

IntSet IntSet::everything() {
  return IntervalSet::Everything();
}

660
IntSet IntSet::single_point(PrimExpr x) {
661 662 663
  return IntervalSet::SinglePoint(x);
}

664
IntSet IntSet::interval(PrimExpr min, PrimExpr max) {
665 666 667 668 669 670 671
  if (min.same_as(max)) {
    return IntSet::single_point(min);
  }
  return IntervalSet(min, max);
}

// Range related code
672
inline bool ProveEqual(PrimExpr lhs, PrimExpr rhs) {
673
  return is_zero(tir::Simplify(lhs - rhs));
674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699
}

IntSet IntSet::range(Range r) {
  // must make sure it can be matched back by MatchRange.
  if (is_one(r->extent)) {
    return IntSet::single_point(r->min);
  }
  return IntervalSet(r->min, r->extent + r->min - 1);
}

bool IntSet::match_range(const Range& b) const {
  const IntSet& a = *this;
  const IntervalSetNode* a_int = a.as<IntervalSetNode>();
  if (!a_int) return false;
  return ProveEqual(a_int->min_value, b->min) &&
      ProveEqual(a_int->max_value, b->extent + b->min - 1);
}

IntSet Union(const Array<IntSet>& sets) {
  if (sets.size() == 0) return IntSet::nothing();
  if (sets.size() == 1) return sets[0];
  Analyzer ana;
  IntervalSet x = ToIntervalSet(sets[0]);
  for (size_t i = 1; i < sets.size(); ++i) {
    x = Union(&ana, x, ToIntervalSet(sets[i]));
  }
700 701
  return IntervalSet(tir::Simplify(x->min_value),
                     tir::Simplify(x->max_value));
702 703 704 705 706 707 708 709 710 711
}

IntSet Intersect(const Array<IntSet>& sets) {
  if (sets.size() == 0) return IntSet::nothing();
  if (sets.size() == 1) return sets[0];
  Analyzer ana;
  IntervalSet x = ToIntervalSet(sets[0]);
  for (size_t i = 1; i < sets.size(); ++i) {
    x = Intersect(&ana, x, ToIntervalSet(sets[i]));
  }
712 713
  return IntervalSet(tir::Simplify(x->min_value),
                     tir::Simplify(x->max_value));
714 715 716 717 718 719 720 721 722 723 724
}

Map<Var, IntSet> ConvertDomMap(const Map<IterVar, IntSet>& dom_map) {
  Map<Var, IntSet> dmap;
  for (auto kv : dom_map) {
    dmap.Set(kv.first->var, kv.second);
  }
  return dmap;
}

Map<Var, IntSet> ConvertDomMap(
725
    const std::unordered_map<const VarNode*, IntSet>& dom_map) {
726 727 728 729 730 731 732
  Map<Var, IntSet> dmap;
  for (auto kv : dom_map) {
    dmap.Set(GetRef<Var>(kv.first), kv.second);
  }
  return dmap;
}

733
IntSet EvalSet(PrimExpr e,
734 735 736
               const Map<Var, IntSet>& dom_map) {
  Analyzer ana;
  return IntervalSetEvaluator(&ana, dom_map, false).Eval(e);
737 738
}

739
IntSet IntSet::vector(PrimExpr x) {
740 741 742
  Analyzer ana;
  Map<Var, IntSet> dmap;
  return IntervalSetEvaluator(&ana, dmap, true).Eval(x);
743 744
}

745
IntSet EvalSet(PrimExpr e,
746
               const Map<IterVar, IntSet>& dom_map) {
747
  return EvalSet(e, ConvertDomMap(dom_map));
748 749
}

750
IntSet EvalSet(PrimExpr e,
751
               const std::unordered_map<const VarNode*, IntSet>& dom_map) {
752 753 754 755 756 757 758
  return EvalSet(e, ConvertDomMap(dom_map));
}

IntSet EvalSet(Range r,
               const Map<Var, IntSet>& dom_map) {
  Analyzer ana;
  IntervalSetEvaluator m(&ana, dom_map);
759
  // Simplifying first can give tighter bounds if r->min and r->extent share variables
760
  PrimExpr sum = r->min + r->extent - 1;
761
  auto res  = m.Eval(IntervalSet(r->min,  Simplify(sum)));
762
  return std::move(res);
763 764
}

765
IntSet EvalSet(Range r,
766
               const std::unordered_map<const VarNode*, IntSet>& dom_map) {
767
  return EvalSet(r, ConvertDomMap(dom_map));
768 769
}

770
IntSet EvalSet(IntSet s,
771
               const std::unordered_map<const VarNode*, IntSet>& dom_map) {
772 773 774 775
  Analyzer ana;
  auto dmap = ConvertDomMap(dom_map);
  IntervalSetEvaluator m(&ana, dmap);
  const IntervalSetNode* s_int = s.as<IntervalSetNode>();
776
  PrimExpr vmax = s_int->HasUpperBound() ?
777
      m.Eval(s_int->max_value).max() : s_int->max_value;
778
  PrimExpr vmin = s_int->HasLowerBound() ?
779 780 781 782 783
      m.Eval(s_int->min_value).min() : s_int->min_value;
  return IntervalSet(vmin, vmax);
}

class SubExprIntervalSetEvaluator : public IntervalSetEvaluator {
784
 public:
785 786 787 788
  explicit SubExprIntervalSetEvaluator(
      Analyzer* analyzer,
      const Map<Var, IntSet>& dom_map)
      : IntervalSetEvaluator(analyzer, dom_map) {}
789

790
  IntervalSet VisitExpr(const PrimExpr& n) final {
791
    IntervalSet ret = IntervalSetEvaluator::VisitExpr(n);
792
    expr_map[n] = ret;
793 794 795 796 797 798
    return ret;
  }

  ExprIntSetMap expr_map;
};

799
ExprIntSetMap EvalSetForEachSubExpr(
800
    PrimExpr e,
801
    const std::unordered_map<const VarNode*, IntSet>& dom_map) {
802 803 804
  Analyzer ana;
  auto dmap = ConvertDomMap(dom_map);
  SubExprIntervalSetEvaluator m(&ana, dmap);
805 806 807 808
  m.Eval(e);
  return m.expr_map;
}

809 810
IntSet EvalSet(Range r,
               const Map<IterVar, IntSet>& dom_map) {
811
  return EvalSet(r, ConvertDomMap(dom_map));
812 813
}

814 815
TVM_REGISTER_NODE_TYPE(IntervalSetNode);

816 817
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, ReprPrinter* p) {
818
    auto* op = static_cast<const IntervalSetNode*>(node.get());
819 820 821
    p->stream << "IntervalSet"
              << "[" << op->min_value << ", "
              << op->max_value << ']';
822
  });
823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845


TVM_REGISTER_GLOBAL("arith.intset_single_point")
.set_body_typed(IntSet::single_point);

TVM_REGISTER_GLOBAL("arith.intset_vector")
.set_body_typed(IntSet::vector);

TVM_REGISTER_GLOBAL("arith.intset_interval")
.set_body_typed(IntSet::interval);

TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin")
.set_body_method(&IntSet::min);

TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax")
.set_body_method(&IntSet::max);

TVM_REGISTER_GLOBAL("arith.IntSetIsNothing")
.set_body_method(&IntSet::is_nothing);

TVM_REGISTER_GLOBAL("arith.IntSetIsEverything")
.set_body_method(&IntSet::is_everything);

846
}  // namespace arith
847
}  // namespace tvm