Commit 4b431c67 by Umang Yadav Committed by Tianqi Chen

1) Add EQ op to the deduce_bound and add unittests for the same (#3775)

2) Add EQ support in the loop partition and add test for the same
3) Change typo truc to trunc
parent 2536465c
...@@ -69,6 +69,8 @@ std::vector<const Node*> GetPath(Expr target, Expr expr) { ...@@ -69,6 +69,8 @@ std::vector<const Node*> GetPath(Expr target, Expr expr) {
return v.path_; return v.path_;
} }
enum CompareOp {kGreater, kLess, kEqual};
// a visitor to deduce the bound of a variable from a expression // a visitor to deduce the bound of a variable from a expression
class BoundDeducer: public IRVisitor { class BoundDeducer: public IRVisitor {
public: public:
...@@ -120,7 +122,7 @@ class BoundDeducer: public IRVisitor { ...@@ -120,7 +122,7 @@ class BoundDeducer: public IRVisitor {
} else { } else {
result_ -= op->a; result_ -= op->a;
result_ = - result_; result_ = - result_;
is_greater_ = !is_greater_; comp_op = ReverseOp(comp_op);
} }
Visit(left ? op->a : op->b); Visit(left ? op->a : op->b);
} }
...@@ -138,7 +140,7 @@ class BoundDeducer: public IRVisitor { ...@@ -138,7 +140,7 @@ class BoundDeducer: public IRVisitor {
} }
if (sign_operand == SignType::kNegative) { if (sign_operand == SignType::kNegative) {
is_greater_ = !is_greater_; comp_op = ReverseOp(comp_op);
} else if (sign_operand == SignType::kUnknown) { } else if (sign_operand == SignType::kUnknown) {
// unable to get the sign of operand // unable to get the sign of operand
success_ = false; success_ = false;
...@@ -151,11 +153,15 @@ class BoundDeducer: public IRVisitor { ...@@ -151,11 +153,15 @@ class BoundDeducer: public IRVisitor {
if (!divided) { if (!divided) {
// Handle non-divisible case // Handle non-divisible case
// NOTE: this accounts for truc div behavior. // NOTE: this accounts for trunc div behavior.
bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative(); bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative();
if (is_greater_) { if (comp_op == kGreater) {
result_ += 1; result_ += 1;
} else if (comp_op == kEqual) {
// condition unsatisfiable as with trunc div, it will change the expression
success_ = false;
return;
} else { } else {
// NOTE: this is a bit sutble hack. // NOTE: this is a bit sutble hack.
// //
...@@ -185,14 +191,14 @@ class BoundDeducer: public IRVisitor { ...@@ -185,14 +191,14 @@ class BoundDeducer: public IRVisitor {
} }
Expr result_; Expr result_;
bool is_greater_{true}; CompareOp comp_op{kGreater};
bool success_{true}; bool success_{true};
private: private:
void Init(); void Init();
void Transform(); void Transform();
void Relax(); void Relax();
CompareOp ReverseOp(CompareOp comp_op);
Expr target_; Expr target_;
Expr expr_; Expr expr_;
const std::unordered_map<const Variable*, IntSet>& hint_map_; const std::unordered_map<const Variable*, IntSet>& hint_map_;
...@@ -228,51 +234,72 @@ void BoundDeducer::Init() { ...@@ -228,51 +234,72 @@ void BoundDeducer::Init() {
Transform(); Transform();
} }
CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) {
switch (comp_op) {
case kEqual: return kEqual; // IntSet can not represent range for `NE
case kGreater: return kLess;
case kLess: return kGreater;
default:
LOG(FATAL) << "Not a valid compare op";
return kGreater; // return some default value
}
}
void BoundDeducer::Transform() { void BoundDeducer::Transform() {
// We will ensure to set expr_ such that it contains target_ // We will ensure to set expr_ such that it contains target_
if (const LT* op = expr_.as<LT>()) { if (const LT* op = expr_.as<LT>()) {
if (GetPath(target_, op->a).empty()) { if (GetPath(target_, op->a).empty()) {
// a < b -> b >= a + 1 // a < b -> b >= a + 1
is_greater_ = true; comp_op = kGreater;
expr_ = op->b; expr_ = op->b;
result_ = op->a + 1; result_ = op->a + 1;
} else { } else {
// a < b -> a <= b - 1 // a < b -> a <= b - 1
is_greater_ = false; comp_op = kLess;
expr_ = op->a; expr_ = op->a;
result_ = op->b - 1; result_ = op->b - 1;
} }
} else if (const LE* op = expr_.as<LE>()) { } else if (const LE* op = expr_.as<LE>()) {
if (GetPath(target_, op->a).empty()) { if (GetPath(target_, op->a).empty()) {
// a <= b -> b >= a // a <= b -> b >= a
is_greater_ = true; comp_op = kGreater;
expr_ = op->b; expr_ = op->b;
result_ = op->a; result_ = op->a;
} else { } else {
is_greater_ = false; comp_op = kLess;
expr_ = op->a; expr_ = op->a;
result_ = op->b; result_ = op->b;
} }
} else if (const GT* op = expr_.as<GT>()) { } else if (const GT* op = expr_.as<GT>()) {
if (GetPath(target_, op->a).empty()) { if (GetPath(target_, op->a).empty()) {
// a > b -> b <= a - 1 // a > b -> b <= a - 1
is_greater_ = false; comp_op = kLess;
expr_ = op->b; expr_ = op->b;
result_ = op->a - 1; result_ = op->a - 1;
} else { } else {
// a > b -> a >= b + 1 // a > b -> a >= b + 1
is_greater_ = true; comp_op = kGreater;
expr_ = op->a; expr_ = op->a;
result_ = op->b + 1; result_ = op->b + 1;
} }
} else if (const GE* op = expr_.as<GE>()) { } else if (const GE* op = expr_.as<GE>()) {
if (GetPath(target_, op->a).empty()) { if (GetPath(target_, op->a).empty()) {
// a >= b -> b <= a // a >= b -> b <= a
is_greater_ = false; comp_op = kLess;
expr_ = op->b;
result_ = op->a;
} else {
comp_op = kGreater;
expr_ = op->a;
result_ = op->b;
}
} else if (const EQ* op = expr_.as<EQ>()) {
comp_op = kEqual;
if (GetPath(target_, op->a).empty()) {
// if the b == a -> a == b
expr_ = op->b; expr_ = op->b;
result_ = op->a; result_ = op->a;
} else { } else {
is_greater_ = true;
expr_ = op->a; expr_ = op->a;
result_ = op->b; result_ = op->b;
} }
...@@ -304,8 +331,16 @@ void BoundDeducer::Relax() { ...@@ -304,8 +331,16 @@ void BoundDeducer::Relax() {
success_ = false; success_ = false;
return; return;
} }
expr_ = is_greater_ ? a.min() : a.max(); // Both LHS and RHS of the EQ should behave as constants e.g. i == j,
result_ = is_greater_ ? b.max() : b.min(); // can not be resolved when either `i` or `j` or both are variables with
// some Range OR `i` and `j` both should be a single point in IntSet
if (comp_op == kEqual && (!analyzer_.CanProve(b.min() == b.max())
|| !analyzer_.CanProve(a.min() == a.max()))) {
success_ = false;
return;
}
expr_ = (comp_op == kGreater) ? a.min() : a.max();
result_ = (comp_op == kGreater) ? b.max() : b.min();
} }
IntSet DeduceBound(Expr v, Expr e, IntSet DeduceBound(Expr v, Expr e,
...@@ -315,7 +350,10 @@ IntSet DeduceBound(Expr v, Expr e, ...@@ -315,7 +350,10 @@ IntSet DeduceBound(Expr v, Expr e,
d.Deduce(); d.Deduce();
if (!d.success_) return IntSet::nothing(); if (!d.success_) return IntSet::nothing();
Expr min = neg_inf(), max = pos_inf(); Expr min = neg_inf(), max = pos_inf();
if (d.is_greater_) { if (d.comp_op == kEqual) {
min = d.result_;
max = d.result_;
} else if (d.comp_op == kGreater) {
min = d.result_; min = d.result_;
} else { } else {
max = d.result_; max = d.result_;
......
...@@ -226,8 +226,6 @@ class PartitionFinder : public IRVisitor { ...@@ -226,8 +226,6 @@ class PartitionFinder : public IRVisitor {
private: private:
Expr InverseCond(const Expr& cond) { Expr InverseCond(const Expr& cond) {
// We expect most condition not to be of EQ or NE form.
// Currently we do not handle inversing EQ or NE.
Expr inverse_cond; Expr inverse_cond;
if (const LT* op = cond.as<LT>()) { if (const LT* op = cond.as<LT>()) {
// a < b -> a >= b // a < b -> a >= b
...@@ -241,6 +239,12 @@ class PartitionFinder : public IRVisitor { ...@@ -241,6 +239,12 @@ class PartitionFinder : public IRVisitor {
} else if (const GE* op = cond.as<GE>()) { } else if (const GE* op = cond.as<GE>()) {
// a >= b -> a < b // a >= b -> a < b
inverse_cond = LT::make(op->a, op->b); inverse_cond = LT::make(op->a, op->b);
} else if (const EQ* op = cond.as<EQ>()) {
// a == b -> a != b
inverse_cond = NE::make(op->a, op->b);
// a != b -> a == b
} else if (const NE* op = cond.as<NE>()) {
inverse_cond = EQ::make(op->a, op->b);
} }
return inverse_cond; return inverse_cond;
} }
......
...@@ -85,6 +85,44 @@ def test_deduce(): ...@@ -85,6 +85,44 @@ def test_deduce():
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
# tests for `EQ` op
res4 = tvm.arith.DeduceBound(a, a == b, {}, {})
assert_expr_equal(res4.max_value, b)
assert_expr_equal(res4.min_value, b)
# Unsatisfiable `EQ`, variable as one of the Operand
res5 = tvm.arith.DeduceBound(a, (a == b), {b: b_s}, {b: b_s})
assert str(res5.max_value) == "neg_inf"
assert str(res5.min_value) == "pos_inf"
# variable `a` on the RHS side
res6 = tvm.arith.DeduceBound(a, 10 == a, {}, {})
assert_expr_equal(res6.max_value, 10)
assert_expr_equal(res6.min_value, 10)
# Add, Sub in `EQ`
e4 = ((a - c) == (b + d))
ans4 = (b + d + c)
res7 = tvm.arith.DeduceBound(a, e4, {b: b_s, c: c_s, d: d_s}, {})
assert_expr_equal(res7.max_value, ans4)
assert_expr_equal(res7.min_value, ans4)
# Satisfiable Mul in `EQ` with negative sign
res8 = tvm.arith.DeduceBound(a, (5 * a == -10), {}, {})
assert_expr_equal(res8.max_value, -2)
assert_expr_equal(res8.min_value, -2)
# Unsatisfiable Mul in `EQ`
e5 = (4 * a == b)
res9 = tvm.arith.DeduceBound(a, e5, {b: b_s}, {})
assert str(res9.max_value) == "neg_inf"
assert str(res9.min_value) == "pos_inf"
# Unsatisfiable Mul in `EQ`
res10 = tvm.arith.DeduceBound(a, (b * a == b), {b: b_s}, {}) # simplifier is not able to prove that (b % b == 0)
assert str(res10.max_value) == "neg_inf"
assert str(res10.min_value) == "pos_inf"
def test_check(): def test_check():
a = tvm.var('a') a = tvm.var('a')
...@@ -175,5 +213,6 @@ def test_deduce_complex(): ...@@ -175,5 +213,6 @@ def test_deduce_complex():
if __name__ == "__main__": if __name__ == "__main__":
test_check() test_check()
test_deduce()
test_deduce_basic() test_deduce_basic()
test_deduce_complex() test_deduce_complex()
...@@ -171,6 +171,18 @@ def test_condition(): ...@@ -171,6 +171,18 @@ def test_condition():
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select)))) assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
def test_condition_EQ():
ib = tvm.ir_builder.create()
m = tvm.var('m')
n = tvm.var('n')
with ib.for_range(0, 10, 'i') as i:
ib.emit(tvm.make.Evaluate(
tvm.make.Select(ib.likely(tvm.expr.EQ(i, 5)), m, n)))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
def test_thread_axis2(): def test_thread_axis2():
n = tvm.convert(4096) n = tvm.convert(4096)
m = tvm.var('m') m = tvm.var('m')
...@@ -420,6 +432,7 @@ if __name__ == "__main__": ...@@ -420,6 +432,7 @@ if __name__ == "__main__":
test_thread_axis() test_thread_axis()
test_vectorize() test_vectorize()
test_condition() test_condition()
test_condition_EQ()
test_thread_axis2() test_thread_axis2()
test_everything_during_deduction() test_everything_during_deduction()
test_single_likely() test_single_likely()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment