Commit 52ad69fc by ziheng Committed by Tianqi Chen

[FIX] Add CombineInternal<Mod> & Fix LoopPartition (#138)

* Add CombineInternal<Mod> & Fix LoopPartition

* Add check for path
parent 979623e5
......@@ -204,10 +204,14 @@ void BoundDeducer::Transform() {
void BoundDeducer::Deduce() {
Init();
if (!success) return;
Relax();
if (!success) return;
// get the path
path_ = GetPath(target_, expr_);
if (!path_.size()) {
success = false;
return;
}
// get the sign of every subexpr
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
......@@ -215,13 +219,14 @@ void BoundDeducer::Deduce() {
}
void BoundDeducer::Relax() {
if (is_greater) {
expr_ = EvalSet(expr_ , relax_map_).min();
result = EvalSet(result, relax_map_).max();
} else {
expr_ = EvalSet(expr_ , relax_map_).max();
result = EvalSet(result, relax_map_).min();
IntSet a = EvalSet(expr_, relax_map_);
IntSet b = EvalSet(result, relax_map_);
if (a.is_everything() || b.is_everything()) {
success = false;
return;
}
expr_ = is_greater ? a.min() : a.max();
result = is_greater ? b.max() : b.min();
}
IntSet DeduceBound(Expr v, Expr e,
......
......@@ -113,6 +113,12 @@ inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
}
template<>
inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
if (is_zero(a)) return make_zero(a.type());
return ir::Mod::make(a, b);
}
template<>
inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
return Halide::Internal::Interval::make_max(a, b);
}
......
......@@ -292,6 +292,23 @@ inline IntSet CombineInterval<Div>(Interval a, Interval b) {
}
template<>
inline IntSet CombineInterval<Mod>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Mod>(a.min, b.min));
}
if (b.is_single_point()) {
Expr divisor = b.min;
if (is_zero(divisor)) {
LOG(FATAL) << "Modular by zero in CombineInterval Mod";
}
return IntervalSet::make(make_zero(divisor.type()), divisor - 1);
}
LOG(WARNING) << "Return Everything in CombineInterval Mod";
return IntSet::everything();
}
template<>
inline IntSet CombineInterval<Max>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Max>(a.min, b.min));
......
......@@ -153,7 +153,9 @@ class PartitionFinder : public IRVisitor {
std::unordered_set<const Variable*>({current_var_.get()}))) {
IntSet interval =
DeduceBound(current_var_, cond, hint_map_, relax_map_);
partitions[cond.get()] = Partition{cond, interval};
if (!interval.is_nothing()) {
partitions[cond.get()] = Partition{cond, interval};
}
}
} else {
IRVisitor::Visit_(op);
......
......@@ -148,6 +148,20 @@ def test_thread_axis2():
for_body = stmt.body.body.body.body.body.first
assert('threadIdx' not in str(for_body.extent))
def test_everything_during_deduction():
m = tvm.var('m')
n = tvm.var('n')
ib = tvm.ir_builder.create()
with ib.for_range(0, n, 'i') as i:
with ib.for_range(0, 32, 'j') as j:
with ib.if_scope(ib.likely(i/j < m)):
# this guard will produce everything during deduction
ib.emit(tvm.make.Evaluate(m))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
assert(isinstance(stmt.body.body, tvm.stmt.IfThenElse))
if __name__ == "__main__":
test_basic()
test_multi_loop()
......@@ -156,3 +170,4 @@ if __name__ == "__main__":
test_vectorize()
test_select()
test_thread_axis2()
test_everything_during_deduction()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment