Commit 3555769e by Ziheng Jiang Committed by Tianqi Chen

[ARITH] Add CombineInterval<Div> in IntSet (#48)

* [FIX] add CombineInterval<Div>

* fix error message and add comment about rounding

* fix comment
parent c8ec4111
...@@ -244,7 +244,7 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) { ...@@ -244,7 +244,7 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
if (is_one(b.min)) return IntervalSet::make(a); if (is_one(b.min)) return IntervalSet::make(a);
Expr e1 = a.has_lower_bound() ? ComputeExpr<Mul>(a.min, b.min) : a.min; Expr e1 = a.has_lower_bound() ? ComputeExpr<Mul>(a.min, b.min) : a.min;
Expr e2 = a.has_upper_bound() ? ComputeExpr<Mul>(a.max, b.min) : a.max; Expr e2 = a.has_upper_bound() ? ComputeExpr<Mul>(a.max, b.min) : a.max;
// This is relaxiation // no relaxation is needed in here due to set is inclusive
// TODO(tqchen): consider convert to StrideSet. // TODO(tqchen): consider convert to StrideSet.
if (is_positive_const(b.min)) { if (is_positive_const(b.min)) {
return IntervalSet::make(e1, e2); return IntervalSet::make(e1, e2);
...@@ -260,6 +260,32 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) { ...@@ -260,6 +260,32 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
} }
template<> template<>
inline IntSet CombineInterval<Div>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Div>(a.min, b.min));
}
if (b.is_single_point()) {
if (is_zero(b.min)) {
LOG(FATAL) << "Divide by zero in CombineInterval Div";
}
if (is_one(b.min)) return IntervalSet::make(a);
Expr e1 = a.has_lower_bound() ? ComputeExpr<Div>(a.min, b.min) : a.min;
Expr e2 = a.has_upper_bound() ? ComputeExpr<Div>(a.max, b.min) : a.max;
// no relaxation is needed in here due to set is inclusive
if (is_positive_const(b.min)) {
return IntervalSet::make(e1, e2);
} else if (is_negative_const(b.min)) {
return IntervalSet::make(e2, e1);
} else if (a.is_bounded()) {
Expr cmp = b.min >= make_zero(b.min.type().element_of());
return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1));
}
}
LOG(WARNING) << "Return Everything in CombineInterval Div";
return IntSet::everything();
}
template<>
inline IntSet CombineInterval<Max>(Interval a, Interval b) { inline IntSet CombineInterval<Max>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) { if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Max>(a.min, b.min)); return IntSet::single_point(ComputeExpr<Max>(a.min, b.min));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment