Commit 52ad69fc by ziheng Committed by Tianqi Chen

[FIX] Add CombineInternal<Mod> & Fix LoopPartition (#138)

* Add CombineInternal<Mod> & Fix LoopPartition

* Add check for path
parent 979623e5
...@@ -204,10 +204,14 @@ void BoundDeducer::Transform() { ...@@ -204,10 +204,14 @@ void BoundDeducer::Transform() {
void BoundDeducer::Deduce() { void BoundDeducer::Deduce() {
Init(); Init();
if (!success) return; if (!success) return;
Relax(); Relax();
if (!success) return;
// get the path // get the path
path_ = GetPath(target_, expr_); path_ = GetPath(target_, expr_);
if (!path_.size()) {
success = false;
return;
}
// get the sign of every subexpr // get the sign of every subexpr
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
...@@ -215,13 +219,14 @@ void BoundDeducer::Deduce() { ...@@ -215,13 +219,14 @@ void BoundDeducer::Deduce() {
} }
void BoundDeducer::Relax() { void BoundDeducer::Relax() {
if (is_greater) { IntSet a = EvalSet(expr_, relax_map_);
expr_ = EvalSet(expr_ , relax_map_).min(); IntSet b = EvalSet(result, relax_map_);
result = EvalSet(result, relax_map_).max(); if (a.is_everything() || b.is_everything()) {
} else { success = false;
expr_ = EvalSet(expr_ , relax_map_).max(); return;
result = EvalSet(result, relax_map_).min();
} }
expr_ = is_greater ? a.min() : a.max();
result = is_greater ? b.max() : b.min();
} }
IntSet DeduceBound(Expr v, Expr e, IntSet DeduceBound(Expr v, Expr e,
......
...@@ -113,6 +113,12 @@ inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) { ...@@ -113,6 +113,12 @@ inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
} }
template<> template<>
inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
if (is_zero(a)) return make_zero(a.type());
return ir::Mod::make(a, b);
}
template<>
inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) { inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
return Halide::Internal::Interval::make_max(a, b); return Halide::Internal::Interval::make_max(a, b);
} }
......
...@@ -292,6 +292,23 @@ inline IntSet CombineInterval<Div>(Interval a, Interval b) { ...@@ -292,6 +292,23 @@ inline IntSet CombineInterval<Div>(Interval a, Interval b) {
} }
template<> template<>
inline IntSet CombineInterval<Mod>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Mod>(a.min, b.min));
}
if (b.is_single_point()) {
Expr divisor = b.min;
if (is_zero(divisor)) {
LOG(FATAL) << "Modular by zero in CombineInterval Mod";
}
return IntervalSet::make(make_zero(divisor.type()), divisor - 1);
}
LOG(WARNING) << "Return Everything in CombineInterval Mod";
return IntSet::everything();
}
template<>
inline IntSet CombineInterval<Max>(Interval a, Interval b) { inline IntSet CombineInterval<Max>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) { if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Max>(a.min, b.min)); return IntSet::single_point(ComputeExpr<Max>(a.min, b.min));
......
...@@ -153,8 +153,10 @@ class PartitionFinder : public IRVisitor { ...@@ -153,8 +153,10 @@ class PartitionFinder : public IRVisitor {
std::unordered_set<const Variable*>({current_var_.get()}))) { std::unordered_set<const Variable*>({current_var_.get()}))) {
IntSet interval = IntSet interval =
DeduceBound(current_var_, cond, hint_map_, relax_map_); DeduceBound(current_var_, cond, hint_map_, relax_map_);
if (!interval.is_nothing()) {
partitions[cond.get()] = Partition{cond, interval}; partitions[cond.get()] = Partition{cond, interval};
} }
}
} else { } else {
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
} }
......
...@@ -148,6 +148,20 @@ def test_thread_axis2(): ...@@ -148,6 +148,20 @@ def test_thread_axis2():
for_body = stmt.body.body.body.body.body.first for_body = stmt.body.body.body.body.body.first
assert('threadIdx' not in str(for_body.extent)) assert('threadIdx' not in str(for_body.extent))
def test_everything_during_deduction():
m = tvm.var('m')
n = tvm.var('n')
ib = tvm.ir_builder.create()
with ib.for_range(0, n, 'i') as i:
with ib.for_range(0, 32, 'j') as j:
with ib.if_scope(ib.likely(i/j < m)):
# this guard will produce everything during deduction
ib.emit(tvm.make.Evaluate(m))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert(isinstance(stmt.body.body, tvm.stmt.IfThenElse))
if __name__ == "__main__": if __name__ == "__main__":
test_basic() test_basic()
test_multi_loop() test_multi_loop()
...@@ -156,3 +170,4 @@ if __name__ == "__main__": ...@@ -156,3 +170,4 @@ if __name__ == "__main__":
test_vectorize() test_vectorize()
test_select() test_select()
test_thread_axis2() test_thread_axis2()
test_everything_during_deduction()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment