/*! * Copyright (c) 2017 by Contributors * \file int_set.cc * \brief The integer set functions */ #include <tvm/ir.h> #include <tvm/ir_pass.h> #include <tvm/arithmetic.h> #include <tvm/ir_functor_ext.h> #include <arithmetic/Interval.h> #include <unordered_map> #include "./compute_expr.h" #include "./int_set_internal.h" namespace tvm { namespace arith { using HalideIR::Internal::Interval; using namespace ir; inline IntSet IntSet::cover_interval() const { if ((*this).as<IntervalSet>()) return *this; const StrideSet* s = (*this).as<StrideSet>(); if (s) { CHECK_NE(s->extents.size(), 0U); Expr max = s->base.max; for (size_t i = 0; i < s->extents.size(); ++i) { max = max + s->extents[i] * s->strides[i] - s->strides[i]; } return IntervalSet::make(s->base.min, Simplify(max)); } LOG(FATAL) << "cannot convert set " << (*this)->type_key() << " to interval"; return IntSet::everything(); } Range IntSet::cover_range(Range max_range) const { IntSet temp; const IntervalSet* s_int = (*this).as<IntervalSet>(); if (s_int == nullptr) { temp = this->cover_interval(); s_int = temp.as<IntervalSet>(); } if (s_int->i.is_bounded()) { return Range::make_by_min_extent( s_int->i.min, Simplify(s_int->i.max + 1 - s_int->i.min)); } return max_range; } Expr IntSet::min() const { const IntervalSet* s_int = (*this).as<IntervalSet>(); CHECK(s_int); return s_int->i.min; } Expr IntSet::max() const { const IntervalSet* s_int = (*this).as<IntervalSet>(); CHECK(s_int); return s_int->i.max; } bool IntSet::is_nothing() const { const IntervalSet* s_int = (*this).as<IntervalSet>(); return (s_int && s_int->i.is_empty()); } bool IntSet::is_everything() const { const IntervalSet* s_int = (*this).as<IntervalSet>(); return (s_int && s_int->i.is_everything()); } bool IntSet::is_single_point() const { const IntervalSet* s_int = (*this).as<IntervalSet>(); return (s_int && s_int->i.is_single_point()); } bool IntSet::can_prove_positive() const { const IntervalSet* s_int = (*this).as<IntervalSet>(); return (s_int && is_positive_const(ir::Simplify(s_int->i.min))); } bool IntSet::can_prove_negative() const { const IntervalSet* s_int = (*this).as<IntervalSet>(); return (s_int && is_negative_const(ir::Simplify(s_int->i.max))); } SignType IntSet::sign_type() const { if (can_prove_positive()) { return kPositive; } else if (can_prove_negative()) { return kNegative; } else if (is_single_point() && is_zero(point_value())) { return kZero; } else { return kUnknown; } } Expr IntSet::point_value() const { const IntervalSet* s_int = (*this).as<IntervalSet>(); CHECK(s_int && s_int->i.is_single_point()); return s_int->i.min; } IntSet IntSet::nothing() { return IntervalSet::make(Interval::nothing()); } IntSet IntSet::everything() { return IntervalSet::make(Interval::everything()); } IntSet IntSet::single_point(Expr x) { return IntervalSet::make(Interval::single_point(x)); } IntSet IntSet::range(Range r) { // must make sure it can be matched back by MatchRange. if (is_one(r->extent)) { return IntSet::single_point(r->min); } if (is_positive_const(r->extent) && is_const(r->min)) { return IntervalSet::make( r->min, ComputeExpr<Sub>(ComputeExpr<Add>(r->extent, r->min), 1)); } return IntervalSet::make(r->min, (r->extent + r->min) - 1); } IntSet IntSet::interval(Expr min, Expr max) { if (min.same_as(max)) { return IntSet::single_point(min); } return IntervalSet::make(min, max); } inline bool prove_equal(Expr lhs, Expr rhs) { return is_zero(ir::Simplify(lhs - rhs)); } // Check if a is created from b. bool IntSet::match_range(const Range& b) const { const IntSet& a = *this; const IntervalSet* a_int = a.as<IntervalSet>(); if (!a_int) return false; const Interval& i = a_int->i; return prove_equal(i.min, b->min) && prove_equal(i.max, ComputeExpr<Sub>(ComputeExpr<Add>(b->extent, b->min), 1)); } inline bool MatchPoint(const IntSet& a, const Expr& b) { const IntervalSet* a_int = a.as<IntervalSet>(); if (!a_int) return false; const Interval& i = a_int->i; return i.is_single_point() && i.min.same_as(b); } IntSet Union(const Array<IntSet>& sets) { if (sets.size() == 0) return IntSet::nothing(); if (sets.size() == 1) return sets[0]; Interval x = sets[0].cover_interval().as<IntervalSet>()->i; for (size_t i = 1; i < sets.size(); ++i) { IntSet s = sets[i].cover_interval(); const Interval& y = s.as<IntervalSet>()->i; x.include(y); } x.max = ir::Simplify(x.max); x.min = ir::Simplify(x.min); return IntervalSet::make(x); } IntSet Intersect(const Array<IntSet>& sets) { Interval x = sets[0].cover_interval().as<IntervalSet>()->i; for (size_t i = 1; i < sets.size(); ++i) { Interval y = sets[i].cover_interval().as<IntervalSet>()->i; x = Interval::make_intersection(x, y); } return IntervalSet::make(x); } // type traits template<typename OP> struct is_logical_op { static const bool value = false; }; #define TVM_DECLARE_LOGICAL_OP(OP) \ template<> \ struct is_logical_op<ir::OP> { \ static const bool value = true; \ }; // interval related. template<typename OP> inline IntSet CombineInterval(Interval a, Interval b) { if (a.is_single_point() && b.is_single_point()) { return IntSet::single_point(ComputeExpr<OP>(a.min, b.min)); } LOG(WARNING) << "Return Everything in CombineInterval " << OP::_type_key; return IntSet::everything(); } template<> inline IntSet CombineInterval<Add>(Interval a, Interval b) { if (a.is_single_point() && b.is_single_point()) { return IntSet::single_point(ComputeExpr<Add>(a.min, b.min)); } Interval r = Interval::everything(); if (a.has_lower_bound() && b.has_lower_bound()) { r.min = ComputeExpr<Add>(a.min, b.min); } if (a.has_upper_bound() && b.has_upper_bound()) { r.max = ComputeExpr<Add>(a.max, b.max); } return IntervalSet::make(r); } template<> inline IntSet CombineInterval<Sub>(Interval a, Interval b) { if (a.is_single_point() && b.is_single_point()) { return IntSet::single_point(ComputeExpr<Sub>(a.min, b.min)); } Interval r = Interval::everything(); if (a.has_lower_bound() && b.has_upper_bound()) { r.min = ComputeExpr<Sub>(a.min, b.max); } if (a.has_upper_bound() && b.has_lower_bound()) { r.max = ComputeExpr<Sub>(a.max, b.min); } return IntervalSet::make(r); } template<> inline IntSet CombineInterval<Mul>(Interval a, Interval b) { if (a.is_single_point() && b.is_single_point()) { return IntSet::single_point(ComputeExpr<Mul>(a.min, b.min)); } if (a.is_single_point() && !b.is_single_point()) { std::swap(a, b); } if (b.is_single_point()) { if (is_zero(b.min)) return IntSet::single_point(0); if (is_one(b.min)) return IntervalSet::make(a); Expr e1 = a.has_lower_bound() ? ComputeExpr<Mul>(a.min, b.min) : a.min; Expr e2 = a.has_upper_bound() ? ComputeExpr<Mul>(a.max, b.min) : a.max; // no relaxation is needed in here due to set is inclusive // TODO(tqchen): consider convert to StrideSet. if (is_positive_const(b.min)) { return IntervalSet::make(e1, e2); } else if (is_negative_const(b.min)) { return IntervalSet::make(e2, e1); } else if (a.is_bounded()) { Expr cmp = b.min >= make_zero(b.min.type().element_of()); return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1)); } } LOG(WARNING) << "Return Everything in CombineInterval Mul"; return IntSet::everything(); } template<> inline IntSet CombineInterval<Div>(Interval a, Interval b) { if (a.is_single_point() && b.is_single_point()) { return IntSet::single_point(ComputeExpr<Div>(a.min, b.min)); } if (b.is_single_point()) { if (is_zero(b.min)) { LOG(FATAL) << "Divide by zero in CombineInterval Div"; } if (is_one(b.min)) return IntervalSet::make(a); Expr e1 = a.has_lower_bound() ? ComputeExpr<Div>(a.min, b.min) : a.min; Expr e2 = a.has_upper_bound() ? ComputeExpr<Div>(a.max, b.min) : a.max; // no relaxation is needed in here due to set is inclusive if (is_positive_const(b.min)) { return IntervalSet::make(e1, e2); } else if (is_negative_const(b.min)) { return IntervalSet::make(e2, e1); } else if (a.is_bounded()) { Expr cmp = b.min >= make_zero(b.min.type().element_of()); return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1)); } } LOG(WARNING) << "Return Everything in CombineInterval Div"; return IntSet::everything(); } template<> inline IntSet CombineInterval<Mod>(Interval a, Interval b) { if (a.is_single_point() && b.is_single_point()) { return IntSet::single_point(ComputeExpr<Mod>(a.min, b.min)); } if (b.is_single_point()) { Expr divisor = b.min; if (is_zero(divisor)) { LOG(FATAL) << "Modular by zero in CombineInterval Mod"; } return IntervalSet::make(make_zero(divisor.type()), divisor - 1); } LOG(WARNING) << "Return Everything in CombineInterval Mod"; return IntSet::everything(); } template<> inline IntSet CombineInterval<Max>(Interval a, Interval b) { if (a.is_single_point() && b.is_single_point()) { return IntSet::single_point(ComputeExpr<Max>(a.min, b.min)); } return IntervalSet::make(Interval::make_max(a.min, b.min), Interval::make_max(a.max, b.max)); } template<> inline IntSet CombineInterval<Min>(Interval a, Interval b) { if (a.is_single_point() && b.is_single_point()) { return IntSet::single_point(ComputeExpr<Min>(a.min, b.min)); } return IntervalSet::make(Interval::make_min(a.min, b.min), Interval::make_min(a.max, b.max)); } template<typename OP> inline IntSet CombineInterval_(IntSet a, IntSet b) { return CombineInterval<OP>( a.as<IntervalSet>()->i, b.as<IntervalSet>()->i); } // stride related inline IntSet AsStrideSet(IntSet a) { if (a.as<StrideSet>()) return a; const IntervalSet* s = a.as<IntervalSet>(); CHECK(s->i.is_bounded()); std::shared_ptr<StrideSet> n = std::make_shared<StrideSet>(); n->base = s->i; return IntSet(n); } template<typename OP> inline IntSet CombineSets(IntSet a, IntSet b) { return CombineInterval_<OP>(a.cover_interval(), b.cover_interval()); } template<> inline IntSet CombineSets<Add>(IntSet a, IntSet b) { const IntervalSet* a_int = a.as<IntervalSet>(); const IntervalSet* b_int = b.as<IntervalSet>(); if (a_int && is_zero(a_int->i.min)) return b; if (b_int && is_zero(b_int->i.min)) return a; a = AsStrideSet(a); b = AsStrideSet(b); const StrideSet* a_stride = a.as<StrideSet>(); const StrideSet* b_stride = b.as<StrideSet>(); auto n = std::make_shared<StrideSet>(*a_stride); for (size_t i = 0; i < b_stride->extents.size(); ++i) { n->extents.push_back(b_stride->extents[i]); n->strides.push_back(b_stride->strides[i]); } n->base = CombineInterval<Add>( a_stride->base, b_stride->base).as<IntervalSet>()->i; return IntSet(n); } inline IntSet NegateSet(IntSet a) { const IntervalSet* a_int = a.as<IntervalSet>(); if (a_int) { if (a_int->i.is_single_point()) { return IntSet::single_point(-a_int->i.min); } else { Interval r = Interval::everything(); if (a_int->i.has_upper_bound()) { r.min = -(a_int->i.max); } if (a_int->i.has_lower_bound()) { r.max = -(a_int->i.min); } return IntervalSet::make(r); } } else { return NegateSet(a.cover_interval()); } } template<> inline IntSet CombineSets<Sub>(IntSet a, IntSet b) { return CombineSets<Add>(a, NegateSet(b)); } TVM_DECLARE_LOGICAL_OP(And); TVM_DECLARE_LOGICAL_OP(Or); TVM_DECLARE_LOGICAL_OP(EQ); TVM_DECLARE_LOGICAL_OP(NE); TVM_DECLARE_LOGICAL_OP(GE); TVM_DECLARE_LOGICAL_OP(GT); TVM_DECLARE_LOGICAL_OP(LE); TVM_DECLARE_LOGICAL_OP(LT); TVM_DECLARE_LOGICAL_OP(Not); // generic combine operations of two sets template<typename OP> inline IntSet Combine(const IntSet& a, const IntSet &b) { if (is_logical_op<OP>::value) { return IntervalSet::make(0, 1); } const IntervalSet* a_int = a.as<IntervalSet>(); const IntervalSet* b_int = b.as<IntervalSet>(); if (a_int && a_int->i.is_everything()) return a; if (b_int && b_int->i.is_everything()) return b; if (a_int && b_int) { return CombineInterval<OP>(a_int->i, b_int->i); } if (a_int && !(a_int->i.is_bounded())) { return CombineInterval_<OP>(a, b.cover_interval()); } if (b_int && !(b_int->i.is_bounded())) { return CombineInterval_<OP>(a.cover_interval(), b); } return CombineSets<OP>(a, b); } class IntSetEvaluator : public ExprFunctor<IntSet(const Expr&, const Expr&)> { public: explicit IntSetEvaluator( const std::unordered_map<const Variable*, IntSet>& dom_map, bool eval_vec = false) : dom_map_(dom_map), eval_vec_(eval_vec) {} // Evaluate. IntSet Eval(const Expr& e) { return this->VisitExpr(e, e); } IntSet VisitExpr_(const IntImm* op, const Expr& e) final { return IntSet::single_point(e); } IntSet VisitExpr_(const UIntImm* op, const Expr& e) final { return IntSet::single_point(e); } IntSet VisitExpr_(const Variable* op, const Expr& e) final { auto it = dom_map_.find(op); if (it != dom_map_.end()) { return it->second; } else { return IntSet::single_point(e); } } IntSet VisitExpr_(const Add* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const Sub* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const Mul* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const Div* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const Mod* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const Min* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const Max* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const EQ* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const NE* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const LT* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const LE* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const GT* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const GE* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const And* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const Or* op, const Expr& e) final { return Binary(op, e); } IntSet VisitExpr_(const Ramp* op, const Expr& e) final { CHECK(eval_vec_); IntSet base = Eval(op->base); int vstride; if (GetConstInt(op->stride, &vstride)) { Type t = op->base.type(); if (vstride > 0) { return Combine<Add>( base, IntSet::interval(make_zero(t), make_const(t, vstride * op->lanes -1))); } else { return Combine<Add>( base, IntSet::interval(make_const(t, vstride * op->lanes + 1), make_zero(t))); } } LOG(WARNING) << "cannot evaluate set on expression " << e; return IntSet::everything(); } IntSet VisitExpr_(const Broadcast* op, const Expr& e) final { CHECK(eval_vec_); return Eval(op->value); } IntSet VisitExprDefault_(const Node* op, const Expr& e) final { LOG(WARNING) << "cannot evaluate set type " << e->type_key(); return IntSet::everything(); } private: template<typename T> inline IntSet Binary(const T* op, const Expr& e) { IntSet a = this->Eval(op->a); IntSet b = this->Eval(op->b); if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { return IntSet::single_point(e); } return Combine<T>(a, b); } const std::unordered_map<const Variable*, IntSet>& dom_map_; bool eval_vec_{false}; }; IntSet EvalSet(Expr e, const std::unordered_map<const Variable*, IntSet>& dom_map) { return IntSetEvaluator(dom_map, false).Eval(e); } IntSet IntSet::vector(Expr x) { std::unordered_map<const Variable*, IntSet> dmap; return IntSetEvaluator(dmap, true).Eval(x); } IntSet EvalSet(Expr e, const Map<IterVar, IntSet>& dom_map) { std::unordered_map<const Variable*, IntSet> dmap; for (auto kv : dom_map) { dmap[kv.first->var.as<Variable>()] = kv.second; } return EvalSet(e, dmap); } IntSet EvalSet(Range r, const std::unordered_map<const Variable*, IntSet>& dom_map) { IntSetEvaluator m(dom_map); IntSet min_set = m.Eval(r->min); IntSet ext_set = m.Eval(r->extent).cover_interval(); const Interval& ei = ext_set.as<IntervalSet>()->i; if (!ei.has_upper_bound()) return IntSet::everything(); ext_set = IntervalSet::make(make_zero(ei.max.type()), ComputeExpr<Sub>(ei.max, 1)); return Combine<Add>(min_set, ext_set); } IntSet EvalSet(IntSet s, const std::unordered_map<const Variable*, IntSet>& dom_map) { IntSetEvaluator m(dom_map); s = s.cover_interval(); const IntervalSet* s_int = s.as<IntervalSet>(); Expr vmax = s_int->i.has_upper_bound() ? m.Eval(s_int->i.max).cover_interval().max() : s_int->i.max; Expr vmin = s_int->i.has_lower_bound() ? m.Eval(s_int->i.min).cover_interval().min() : s_int->i.min; return IntervalSet::make(vmin, vmax); } class SubExprIntSetEvaluator : public IntSetEvaluator { public: explicit SubExprIntSetEvaluator( const std::unordered_map<const Variable*, IntSet>& dom_map) : IntSetEvaluator(dom_map) {} IntSet VisitExpr(const Expr& n, const Expr& e) final { IntSet ret = IntSetEvaluator::VisitExpr(n, e); expr_map[n] = ret; return ret; } ExprIntSetMap expr_map; }; ExprIntSetMap EvalSetForEachSubExpr(Expr e, const std::unordered_map<const Variable*, IntSet>& dom_map) { SubExprIntSetEvaluator m(dom_map); m.Eval(e); return m.expr_map; } IntSet EvalSet(Range r, const Map<IterVar, IntSet>& dom_map) { std::unordered_map<const Variable*, IntSet> dmap; for (auto kv : dom_map) { dmap[kv.first->var.as<Variable>()] = kv.second; } return EvalSet(r, dmap); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch<IntervalSet>([](const IntervalSet *op, IRPrinter *p) { p->stream << "interval-set" << "[" << op->i.min << ", " << op->i.max << ']'; }); } // namespace arith } // namespace tvm