Commit 39c116f0 by Siyuan Feng Committed by Lianmin Zheng

Fix intersect of modular set (#2904)

Fix comment bugs and code style
parent fb7fa8e4
...@@ -36,6 +36,18 @@ struct ModularSetAnalyzer::Entry { ...@@ -36,6 +36,18 @@ struct ModularSetAnalyzer::Entry {
int64_t coeff{1}; int64_t coeff{1};
int64_t base{0}; int64_t base{0};
Entry() = default;
Entry(int64_t coeff, int64_t base) {
CHECK_GE(coeff, 0);
this->coeff = coeff;
if (coeff != 0) {
base = base % coeff;
if (base < 0) base += coeff;
}
this->base = base;
}
bool is_const() const { bool is_const() const {
return coeff == 0; return coeff == 0;
} }
...@@ -53,10 +65,7 @@ class ModularSetAnalyzer::Impl : ...@@ -53,10 +65,7 @@ class ModularSetAnalyzer::Impl :
if (!override) { if (!override) {
CHECK(!var_map_.count(var)); CHECK(!var_map_.count(var));
} }
Entry e; var_map_[var] = Entry(info->coeff, info->base);
e.coeff = info->coeff;
e.base = info->base;
var_map_[var] = e;
} }
// Detect useful constraints and use them in the analysis scope. // Detect useful constraints and use them in the analysis scope.
...@@ -65,9 +74,7 @@ class ModularSetAnalyzer::Impl : ...@@ -65,9 +74,7 @@ class ModularSetAnalyzer::Impl :
PVar<Integer> coeff, base; PVar<Integer> coeff, base;
// pattern match interesting constraints // pattern match interesting constraints
if (((var % coeff) == base).Match(constraint)) { if (((var % coeff) == base).Match(constraint)) {
Entry entry; Entry entry(coeff.Eval()->value, base.Eval()->value);
entry.coeff = coeff.Eval()->value;
entry.base = base.Eval()->value;
return UpdateByIntersect(var.Eval(), entry); return UpdateByIntersect(var.Eval(), entry);
} }
return nullptr; return nullptr;
...@@ -83,18 +90,12 @@ class ModularSetAnalyzer::Impl : ...@@ -83,18 +90,12 @@ class ModularSetAnalyzer::Impl :
} }
Entry VisitExpr_(const IntImm* op) final { Entry VisitExpr_(const IntImm* op) final {
Entry ret; return Entry(0, op->value);
ret.base = op->value;
ret.coeff = 0;
return ret;
} }
Entry VisitExpr_(const UIntImm* op) final { Entry VisitExpr_(const UIntImm* op) final {
if (op->value < std::numeric_limits<int64_t>::max()) { if (op->value < std::numeric_limits<int64_t>::max()) {
Entry ret; return Entry(0, static_cast<int>(op->value));
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
} else { } else {
return Everything(); return Everything();
} }
...@@ -103,19 +104,15 @@ class ModularSetAnalyzer::Impl : ...@@ -103,19 +104,15 @@ class ModularSetAnalyzer::Impl :
Entry VisitExpr_(const Add* op) final { Entry VisitExpr_(const Add* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
Entry ret; int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff); return Entry(coeff, a.base + b.base);
ret.base = BaseSimplify(a.base + b.base, ret.coeff);
return ret;
} }
Entry VisitExpr_(const Sub* op) final { Entry VisitExpr_(const Sub* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
Entry ret; int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff); return Entry(coeff, a.base - b.base);
ret.base = BaseSimplify(a.base - b.base, ret.coeff);
return ret;
} }
Entry VisitExpr_(const Mul* op) final { Entry VisitExpr_(const Mul* op) final {
...@@ -128,10 +125,8 @@ class ModularSetAnalyzer::Impl : ...@@ -128,10 +125,8 @@ class ModularSetAnalyzer::Impl :
int64_t pq = a.coeff * b.coeff; int64_t pq = a.coeff * b.coeff;
int64_t pm = a.coeff * b.base; int64_t pm = a.coeff * b.base;
int64_t qn = a.base * b.coeff; int64_t qn = a.base * b.coeff;
Entry ret; int64_t coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn)); return Entry(coeff, a.base * b.base);
ret.base = BaseSimplify(a.base * b.base, ret.coeff);
return ret;
} }
Entry DivByConst(const Expr& lhs, Entry DivByConst(const Expr& lhs,
...@@ -140,20 +135,15 @@ class ModularSetAnalyzer::Impl : ...@@ -140,20 +135,15 @@ class ModularSetAnalyzer::Impl :
Entry a = VisitExpr(lhs); Entry a = VisitExpr(lhs);
CHECK_NE(val, 0); CHECK_NE(val, 0);
if (a.coeff % val == 0) { if (a.coeff % val == 0) {
Entry ret;
if (a.base == 0) { if (a.base == 0) {
// a c x / c -> a x // a c x / c -> a x
ret.coeff = std::abs(a.coeff / val); return Entry(std::abs(a.coeff / val), 0);
ret.base = 0;
return ret;
} }
// positive division have a clear rounding mode. // positive division have a clear rounding mode.
// Only handle case where we clearly know we need to round down. // Only handle case where we clearly know we need to round down.
if (a.base > 0 && val > 0 && if (a.base > 0 && val > 0 &&
(round_down || parent_->CanProveGreaterEqual(lhs, 0))) { (round_down || parent_->CanProveGreaterEqual(lhs, 0))) {
ret.coeff = a.coeff / val; return Entry(a.coeff / val, a.base / val);
ret.base = a.base / val;
return ret;
} }
} }
return Everything(); return Everything();
...@@ -251,41 +241,80 @@ class ModularSetAnalyzer::Impl : ...@@ -251,41 +241,80 @@ class ModularSetAnalyzer::Impl :
} }
int64_t base0 = a.base % coeff; int64_t base0 = a.base % coeff;
int64_t base1 = b.base % coeff; int64_t base1 = b.base % coeff;
Entry ret;
if (base0 == base1) { if (base0 == base1) {
ret.coeff = coeff; return Entry(coeff, base0);
ret.base = base0;
return ret;
} else { } else {
ret.coeff = ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff); return Entry(ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff), base0);
ret.base = 0;
return ret;
} }
} }
/*! /*!
* \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
* \param a The first coefficient.
* \param b The second coefficient.
* \param x The solution of x.
* \param y The solution of y.
* \return The GCD of a and b.
*/
static int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) {
// Extended Euclidean algorithm
// if a < 0, the problem can be convert into
// |a|* (-x) + b * y = gcd(|a|, b)
//
// initial condition:
// a * 0 + b * 1 = b
// a * 1 + b * 0 = a
int64_t s = 0, old_s = 1;
int64_t r = b, old_r = a >= 0 ? a : -a;
// Iteration (r2 < r1):
// a * x1 + b * y1 = r1
// a * x2 + b * y2 = r2
// The above two eqs can derive the following eq (q = r1 / r2)
// a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3
// Because r3 < r2, the iteration can eventually terminate
while (r != 0) {
int64_t q = old_r / r;
int64_t tmp = old_r;
old_r = r;
r = tmp - q * r;
tmp = old_s;
old_s = s;
s = tmp - q * s;
}
*x = a >= 0 ? old_s : -old_s;
if (b != 0) {
*y = (old_r - (*x) * a) / b;
} else {
*y = 1;
}
return old_r;
}
/*!
* \brief Create interect of two sets. * \brief Create interect of two sets.
* \param a The left operand. * \param a The left operand.
* \param b the right operand. * \param b the right operand.
*/ */
static Entry Intersect(Entry a, Entry b) { static Entry Intersect(Entry a, Entry b) {
// simple rule for now: pick higher constraints. int64_t x, y;
// TODO(team-team): Use extended euclidean algorithm. int64_t c1 = a.coeff, b1 = a.base, c2 = b.coeff, b2 = b.base;
if (a.coeff == 0) return a; // z = c1 * p + b1
if (b.coeff == 0) return b; // z = c2 * q + b2
if (a.coeff >= b.coeff) return a; // c1 * x + c2 * y = gcd(c1, c2)
return b; // -> c1 * p - c2 * q = b2 - b1
} // -> p = (b2 - b1) / gcd * x
/*! // -> q = (b2 - b1) / gcd * (-y)
* \brief Simplify base so that it is in [0, coeff) when coeff != 0. // -> z = LCM(x, y) * k + (c1 * p + b1)
* \param base The base value. int64_t gcd = ExtendedEuclidean(c1, c2, &x, &y);
* \param coeff The coeff value. int64_t v = b2 - b1;
* \return The simplified base. if (v % gcd == 0) {
*/ x = v / gcd * x;
static int64_t BaseSimplify(int64_t base, int64_t coeff) { y = v / gcd * (-y);
if (coeff == 0) return base; int64_t coeff = c1 / gcd * c2;
base = base % coeff; return Entry(coeff, x * c1 + b1);
if (base < 0) base += coeff; } else {
return base; return Nothing();
}
} }
/*! /*!
* \brief Take GCD of a and b. * \brief Take GCD of a and b.
...@@ -311,9 +340,14 @@ class ModularSetAnalyzer::Impl : ...@@ -311,9 +340,14 @@ class ModularSetAnalyzer::Impl :
* \return Bound that represent everything dtype can represent. * \return Bound that represent everything dtype can represent.
*/ */
static Entry Everything() { static Entry Everything() {
Entry ret; return Entry(1, 0);
ret.coeff = 1; ret.base = 0; }
return ret; /*!
* \brief return an empty set
* \return Bound that represent everything dtype can represent.
*/
static Entry Nothing() {
return Entry(0, 1);
} }
}; };
......
...@@ -117,6 +117,22 @@ def test_constraint_scope(): ...@@ -117,6 +117,22 @@ def test_constraint_scope():
assert m.coeff == 1 assert m.coeff == 1
assert m.base == 0 assert m.base == 0
def test_intersect():
a = tvm.var("a")
analyzer = tvm.arith.Analyzer()
with analyzer.constraint_scope(a % 4 == 1):
with analyzer.constraint_scope(a % 3 == 1):
m = analyzer.modular_set(a)
assert m.coeff == 12
assert m.base == 1
with analyzer.constraint_scope(a % 3 == 2):
with analyzer.constraint_scope(a % 5 == 3):
with analyzer.constraint_scope(a % 7 == 2):
m = analyzer.modular_set(a)
assert m.coeff == 105
assert m.base == 23
if __name__ == "__main__": if __name__ == "__main__":
test_cast() test_cast()
...@@ -126,3 +142,4 @@ if __name__ == "__main__": ...@@ -126,3 +142,4 @@ if __name__ == "__main__":
test_min_max_select() test_min_max_select()
test_mix_index() test_mix_index()
test_constraint_scope() test_constraint_scope()
test_intersect()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment