Commit 4b431c67 by Umang Yadav Committed by Tianqi Chen

1) Add EQ op to the deduce_bound and add unittests for the same (#3775)

2) Add EQ support in the loop partition and add test for the same
3) Change typo truc to trunc
parent 2536465c
......@@ -69,6 +69,8 @@ std::vector<const Node*> GetPath(Expr target, Expr expr) {
return v.path_;
}
enum CompareOp {kGreater, kLess, kEqual};
// a visitor to deduce the bound of a variable from a expression
class BoundDeducer: public IRVisitor {
public:
......@@ -120,7 +122,7 @@ class BoundDeducer: public IRVisitor {
} else {
result_ -= op->a;
result_ = - result_;
is_greater_ = !is_greater_;
comp_op = ReverseOp(comp_op);
}
Visit(left ? op->a : op->b);
}
......@@ -138,7 +140,7 @@ class BoundDeducer: public IRVisitor {
}
if (sign_operand == SignType::kNegative) {
is_greater_ = !is_greater_;
comp_op = ReverseOp(comp_op);
} else if (sign_operand == SignType::kUnknown) {
// unable to get the sign of operand
success_ = false;
......@@ -151,11 +153,15 @@ class BoundDeducer: public IRVisitor {
if (!divided) {
// Handle non-divisible case
// NOTE: this accounts for truc div behavior.
// NOTE: this accounts for trunc div behavior.
bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative();
if (is_greater_) {
if (comp_op == kGreater) {
result_ += 1;
} else if (comp_op == kEqual) {
// condition unsatisfiable as with trunc div, it will change the expression
success_ = false;
return;
} else {
// NOTE: this is a bit sutble hack.
//
......@@ -185,14 +191,14 @@ class BoundDeducer: public IRVisitor {
}
Expr result_;
bool is_greater_{true};
CompareOp comp_op{kGreater};
bool success_{true};
private:
void Init();
void Transform();
void Relax();
CompareOp ReverseOp(CompareOp comp_op);
Expr target_;
Expr expr_;
const std::unordered_map<const Variable*, IntSet>& hint_map_;
......@@ -228,51 +234,72 @@ void BoundDeducer::Init() {
Transform();
}
CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) {
switch (comp_op) {
case kEqual: return kEqual; // IntSet can not represent range for `NE
case kGreater: return kLess;
case kLess: return kGreater;
default:
LOG(FATAL) << "Not a valid compare op";
return kGreater; // return some default value
}
}
void BoundDeducer::Transform() {
// We will ensure to set expr_ such that it contains target_
if (const LT* op = expr_.as<LT>()) {
if (GetPath(target_, op->a).empty()) {
// a < b -> b >= a + 1
is_greater_ = true;
comp_op = kGreater;
expr_ = op->b;
result_ = op->a + 1;
} else {
// a < b -> a <= b - 1
is_greater_ = false;
comp_op = kLess;
expr_ = op->a;
result_ = op->b - 1;
}
} else if (const LE* op = expr_.as<LE>()) {
if (GetPath(target_, op->a).empty()) {
// a <= b -> b >= a
is_greater_ = true;
comp_op = kGreater;
expr_ = op->b;
result_ = op->a;
} else {
is_greater_ = false;
comp_op = kLess;
expr_ = op->a;
result_ = op->b;
}
} else if (const GT* op = expr_.as<GT>()) {
if (GetPath(target_, op->a).empty()) {
// a > b -> b <= a - 1
is_greater_ = false;
comp_op = kLess;
expr_ = op->b;
result_ = op->a - 1;
} else {
// a > b -> a >= b + 1
is_greater_ = true;
comp_op = kGreater;
expr_ = op->a;
result_ = op->b + 1;
}
} else if (const GE* op = expr_.as<GE>()) {
if (GetPath(target_, op->a).empty()) {
// a >= b -> b <= a
is_greater_ = false;
comp_op = kLess;
expr_ = op->b;
result_ = op->a;
} else {
comp_op = kGreater;
expr_ = op->a;
result_ = op->b;
}
} else if (const EQ* op = expr_.as<EQ>()) {
comp_op = kEqual;
if (GetPath(target_, op->a).empty()) {
// if the b == a -> a == b
expr_ = op->b;
result_ = op->a;
} else {
is_greater_ = true;
expr_ = op->a;
result_ = op->b;
}
......@@ -304,8 +331,16 @@ void BoundDeducer::Relax() {
success_ = false;
return;
}
expr_ = is_greater_ ? a.min() : a.max();
result_ = is_greater_ ? b.max() : b.min();
// Both LHS and RHS of the EQ should behave as constants e.g. i == j,
// can not be resolved when either `i` or `j` or both are variables with
// some Range OR `i` and `j` both should be a single point in IntSet
if (comp_op == kEqual && (!analyzer_.CanProve(b.min() == b.max())
|| !analyzer_.CanProve(a.min() == a.max()))) {
success_ = false;
return;
}
expr_ = (comp_op == kGreater) ? a.min() : a.max();
result_ = (comp_op == kGreater) ? b.max() : b.min();
}
IntSet DeduceBound(Expr v, Expr e,
......@@ -315,7 +350,10 @@ IntSet DeduceBound(Expr v, Expr e,
d.Deduce();
if (!d.success_) return IntSet::nothing();
Expr min = neg_inf(), max = pos_inf();
if (d.is_greater_) {
if (d.comp_op == kEqual) {
min = d.result_;
max = d.result_;
} else if (d.comp_op == kGreater) {
min = d.result_;
} else {
max = d.result_;
......
......@@ -226,8 +226,6 @@ class PartitionFinder : public IRVisitor {
private:
Expr InverseCond(const Expr& cond) {
// We expect most condition not to be of EQ or NE form.
// Currently we do not handle inversing EQ or NE.
Expr inverse_cond;
if (const LT* op = cond.as<LT>()) {
// a < b -> a >= b
......@@ -241,6 +239,12 @@ class PartitionFinder : public IRVisitor {
} else if (const GE* op = cond.as<GE>()) {
// a >= b -> a < b
inverse_cond = LT::make(op->a, op->b);
} else if (const EQ* op = cond.as<EQ>()) {
// a == b -> a != b
inverse_cond = NE::make(op->a, op->b);
// a != b -> a == b
} else if (const NE* op = cond.as<NE>()) {
inverse_cond = EQ::make(op->a, op->b);
}
return inverse_cond;
}
......
......@@ -85,6 +85,44 @@ def test_deduce():
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
# tests for `EQ` op
res4 = tvm.arith.DeduceBound(a, a == b, {}, {})
assert_expr_equal(res4.max_value, b)
assert_expr_equal(res4.min_value, b)
# Unsatisfiable `EQ`, variable as one of the Operand
res5 = tvm.arith.DeduceBound(a, (a == b), {b: b_s}, {b: b_s})
assert str(res5.max_value) == "neg_inf"
assert str(res5.min_value) == "pos_inf"
# variable `a` on the RHS side
res6 = tvm.arith.DeduceBound(a, 10 == a, {}, {})
assert_expr_equal(res6.max_value, 10)
assert_expr_equal(res6.min_value, 10)
# Add, Sub in `EQ`
e4 = ((a - c) == (b + d))
ans4 = (b + d + c)
res7 = tvm.arith.DeduceBound(a, e4, {b: b_s, c: c_s, d: d_s}, {})
assert_expr_equal(res7.max_value, ans4)
assert_expr_equal(res7.min_value, ans4)
# Satisfiable Mul in `EQ` with negative sign
res8 = tvm.arith.DeduceBound(a, (5 * a == -10), {}, {})
assert_expr_equal(res8.max_value, -2)
assert_expr_equal(res8.min_value, -2)
# Unsatisfiable Mul in `EQ`
e5 = (4 * a == b)
res9 = tvm.arith.DeduceBound(a, e5, {b: b_s}, {})
assert str(res9.max_value) == "neg_inf"
assert str(res9.min_value) == "pos_inf"
# Unsatisfiable Mul in `EQ`
res10 = tvm.arith.DeduceBound(a, (b * a == b), {b: b_s}, {}) # simplifier is not able to prove that (b % b == 0)
assert str(res10.max_value) == "neg_inf"
assert str(res10.min_value) == "pos_inf"
def test_check():
a = tvm.var('a')
......@@ -175,5 +213,6 @@ def test_deduce_complex():
if __name__ == "__main__":
test_check()
test_deduce()
test_deduce_basic()
test_deduce_complex()
......@@ -171,6 +171,18 @@ def test_condition():
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
def test_condition_EQ():
ib = tvm.ir_builder.create()
m = tvm.var('m')
n = tvm.var('n')
with ib.for_range(0, 10, 'i') as i:
ib.emit(tvm.make.Evaluate(
tvm.make.Select(ib.likely(tvm.expr.EQ(i, 5)), m, n)))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
def test_thread_axis2():
n = tvm.convert(4096)
m = tvm.var('m')
......@@ -420,6 +432,7 @@ if __name__ == "__main__":
test_thread_axis()
test_vectorize()
test_condition()
test_condition_EQ()
test_thread_axis2()
test_everything_during_deduction()
test_single_likely()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment