Commit 65038950 by Tianqi Chen Committed by GitHub

[ARITH] Improve detect linear equation (#529)

* [ARITH] Improve detect linear equation

* fix doc
parent 46082223
Subproject commit dbf043a8d8bf379b05c56d8aa9025db55f589d6d Subproject commit a40a3e2fedee88d2f7b97ba4caf8a9d0eb25886f
...@@ -118,7 +118,7 @@ class IntSet : public NodeRef { ...@@ -118,7 +118,7 @@ class IntSet : public NodeRef {
* \brief Range of a linear integer function. * \brief Range of a linear integer function.
* Use to do specify the possible index values. * Use to do specify the possible index values.
* *
* set = { base + coeff * x | x in Z } * set = { coeff * x + base | x in Z }
* *
* When coeff != 0, it can also be written as * When coeff != 0, it can also be written as
* set = { n | n % coeff == base } * set = { n | n % coeff == base }
...@@ -127,16 +127,17 @@ class IntSet : public NodeRef { ...@@ -127,16 +127,17 @@ class IntSet : public NodeRef {
* For example, if index = 0 + 4 x, then we know it can be divided by 4. * For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/ */
struct ModularEntry { struct ModularEntry {
/*! \brief The base */
int base{0};
/*! \brief linear co-efficient */ /*! \brief linear co-efficient */
int coeff{1}; int coeff{1};
/*! \brief The base */
int base{0};
/*! \return entry represent everything */ /*! \return entry represent everything */
static ModularEntry everything() { static ModularEntry everything() {
// always safe to set 0 + x, so it can be everything. // always safe to set 0 + x, so it can be everything.
ModularEntry e; ModularEntry e;
e.base = 0; e.coeff = 1; e.coeff = 1;
e.base = 0;
return e; return e;
} }
/*! /*!
...@@ -157,14 +158,25 @@ struct IntSetNode : public Node { ...@@ -157,14 +158,25 @@ struct IntSetNode : public Node {
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
}; };
/*! /*!
* \brief Detect if e can be rewritten as e = base + var * coeff * \brief Detect if e can be rewritten as e = sum_{i=0}^n var[i] * coeff[i] + coeff[n]
* Where coeff and base are invariant of var. * Where coeff and base are invariant of var.
* *
* \return [base, coeff] if it is possible, empty array if it is not. * \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars);
/*!
* \brief Detect if expression corresponds to clip bound of the vars
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
*/ */
Array<Expr> DetectLinearEquation(Expr e, Var var); Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars);
/*! /*!
* \brief Find an symbolic integer set that contains all possible values of * \brief Find an symbolic integer set that contains all possible values of
......
...@@ -36,6 +36,11 @@ TVM_REGISTER_API("arith.DetectLinearEquation") ...@@ -36,6 +36,11 @@ TVM_REGISTER_API("arith.DetectLinearEquation")
*ret = DetectLinearEquation(args[0], args[1]); *ret = DetectLinearEquation(args[0], args[1]);
}); });
TVM_REGISTER_API("arith.DetectClipBound")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DetectClipBound(args[0], args[1]);
});
TVM_REGISTER_API("arith.DeduceBound") TVM_REGISTER_API("arith.DeduceBound")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], *ret = DeduceBound(args[0], args[1],
......
...@@ -21,22 +21,27 @@ struct LinearEqEntry { ...@@ -21,22 +21,27 @@ struct LinearEqEntry {
Expr coeff; Expr coeff;
}; };
struct IntervalEntry {
Expr min_value;
Expr max_value;
};
class LinearEqDetector class LinearEqDetector
: public ExprFunctor<LinearEqEntry(const Expr&, const Expr &)> { : public ExprFunctor<LinearEqEntry(const Expr&, const Expr &)> {
public: public:
explicit LinearEqDetector(Var var) explicit LinearEqDetector(Var var)
: var_(var) {} : var_(var) {}
Array<Expr> Detect(const Expr& e) { bool Detect(const Expr& e, LinearEqEntry* ret) {
LinearEqEntry ret = VisitExpr(e, e); *ret = VisitExpr(e, e);
if (fail_) return Array<Expr>(); if (fail_) return false;
if (!ret.base.defined()) { if (!ret->base.defined()) {
ret.base = make_zero(var_.type()); ret->base = make_zero(var_.type());
} }
if (!ret.coeff.defined()) { if (!ret->coeff.defined()) {
ret.coeff = make_zero(var_.type()); ret->coeff = make_zero(var_.type());
} }
return Array<Expr>{ret.base, ret.coeff}; return true;
} }
LinearEqEntry VisitExpr_(const Add* op, const Expr& e) final { LinearEqEntry VisitExpr_(const Add* op, const Expr& e) final {
...@@ -48,6 +53,17 @@ class LinearEqDetector ...@@ -48,6 +53,17 @@ class LinearEqDetector
ret.coeff = AddCombine(a.coeff, b.coeff); ret.coeff = AddCombine(a.coeff, b.coeff);
return ret; return ret;
} }
LinearEqEntry VisitExpr_(const Sub* op, const Expr& e) final {
if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b);
LinearEqEntry ret;
ret.base = SubCombine(a.base, b.base);
ret.coeff = SubCombine(a.coeff, b.coeff);
return ret;
}
LinearEqEntry VisitExpr_(const Mul* op, const Expr& e) final { LinearEqEntry VisitExpr_(const Mul* op, const Expr& e) final {
if (fail_) return LinearEqEntry(); if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a); LinearEqEntry a = VisitExpr(op->a, op->a);
...@@ -94,6 +110,11 @@ class LinearEqDetector ...@@ -94,6 +110,11 @@ class LinearEqDetector
if (!b.defined()) return a; if (!b.defined()) return a;
return ComputeExpr<Add>(a, b); return ComputeExpr<Add>(a, b);
} }
Expr SubCombine(Expr a, Expr b) {
if (!a.defined()) return -b;
if (!b.defined()) return a;
return ComputeExpr<Sub>(a, b);
}
Expr MulCombine(Expr a, Expr b) { Expr MulCombine(Expr a, Expr b) {
if (!a.defined()) return a; if (!a.defined()) return a;
if (!b.defined()) return b; if (!b.defined()) return b;
...@@ -101,9 +122,134 @@ class LinearEqDetector ...@@ -101,9 +122,134 @@ class LinearEqDetector
} }
}; };
Array<Expr> DetectLinearEquation(Expr e, Var var) { Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
return LinearEqDetector(var).Detect(e); CHECK_GE(vars.size(), 1U);
Expr base = e;
Array<Expr> coeff;
for (Var v : vars) {
LinearEqEntry ret;
if (!LinearEqDetector(v).Detect(base, &ret)) {
return Array<Expr>();
}
coeff.push_back(ret.coeff);
base = std::move(ret.base);
}
std::unordered_set<const Variable*> vset;
for (size_t i = vars.size(); i != 1; --i) {
vset.insert(vars[i - 1].get());
// The previous coeff contains the variable
if (ExprUseVar(coeff[i - 2], vset)) {
return Array<Expr>();
}
}
coeff.push_back(base);
return coeff;
} }
// Detect clip condition as min max value
bool DetectClipBound(
const Expr& cond,
std::unordered_map<const Variable*, IntervalEntry>* bmap) {
int flag = 0;
Var var;
auto fvisit = [&bmap, &flag, &var](const NodeRef& n) {
if (const Variable* v = n.as<Variable>()) {
if (bmap->count(v)) {
if (flag == 0) {
var = Var(n.node_);
flag = 1;
} else if (flag == 1) {
if (!var.same_as(n)) {
flag = -1;
}
}
}
}
};
PostOrderVisit(cond, fvisit);
if (flag != 1) return false;
// canonical form: exp >= 0
Expr canonical;
if (const LT* op = cond.as<LT>()) {
if (!op->a.type().is_int()) return false;
canonical = op->b - op->a - make_const(op->a.type(), 1);
} else if (const LE* op = cond.as<LE>()) {
if (!op->a.type().is_int()) return false;
canonical = op->b - op->a;
} else if (const GT* op = cond.as<GT>()) {
if (!op->a.type().is_int()) return false;
canonical = op->a - op->b - make_const(op->a.type(), 1);
} else if (const GE* op = cond.as<GE>()) {
if (!op->a.type().is_int()) return false;
canonical = op->a - op->b;
} else {
return false;
}
LinearEqEntry ret;
if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
ret.coeff = Simplify(ret.coeff);
IntervalEntry& p = (*bmap)[var.get()];
if (is_one(ret.coeff)) {
// var + shift >=0 -> var >= -shift
if (p.min_value.defined()) {
p.min_value = ir::Max::make(p.min_value, -ret.base);
} else {
p.min_value = -ret.base;
}
return true;
}
if (is_const(ret.coeff, -1)) {
// -var + shift >=0 -> var <= shift
if (p.max_value.defined()) {
p.max_value = ir::Min::make(p.max_value, ret.base);
} else {
p.max_value = ret.base;
}
return true;
}
return false;
}
template<typename OP>
void SplitCommExpr(const Expr& e, std::vector<Expr>* ret) {
if (const OP* op = e.as<OP>()) {
SplitCommExpr<OP>(op->a, ret);
SplitCommExpr<OP>(op->b, ret);
} else {
ret->push_back(e);
}
}
// Detect the lower and upper bound from the expression.
// e must be connected by and.
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars) {
std::vector<Expr> splits;
SplitCommExpr<ir::And>(e, &splits);
std::unordered_map<const Variable*, IntervalEntry> rmap;
for (Var v : vars) {
rmap[v.get()] = IntervalEntry();
}
for (Expr cond : splits) {
if (!DetectClipBound(cond, &rmap)) return Array<Expr>();
}
Array<Expr> ret;
for (Var v : vars) {
IntervalEntry e = rmap[v.get()];
if (e.min_value.defined()) {
e.min_value = Simplify(e.min_value);
}
if (e.max_value.defined()) {
e.max_value = Simplify(e.max_value);
}
ret.push_back(e.min_value);
ret.push_back(e.max_value);
}
return ret;
}
} // namespace arith } // namespace arith
} // namespace tvm } // namespace tvm
...@@ -175,10 +175,10 @@ class ChannelAccessRewriter : public IRMutator { ...@@ -175,10 +175,10 @@ class ChannelAccessRewriter : public IRMutator {
r = Range::make_by_min_extent( r = Range::make_by_min_extent(
ir::Simplify(r->min), ir::Simplify(r->extent)); ir::Simplify(r->min), ir::Simplify(r->extent));
if (ExprUseVar(r->extent, var)) return body; if (ExprUseVar(r->extent, var)) return body;
Array<Expr> linear_eq = DetectLinearEquation(r->min, var); Array<Expr> linear_eq = DetectLinearEquation(r->min, {var});
if (linear_eq.size() == 0) return body; if (linear_eq.size() == 0) return body;
Expr base = linear_eq[0]; Expr coeff = linear_eq[0];
Expr coeff = linear_eq[1]; Expr base = linear_eq[1];
if (!is_zero(base)) return body; if (!is_zero(base)) return body;
Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent); Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent);
if (!can_prove(left >= 0)) return body; if (!can_prove(left >= 0)) return body;
......
import tvm
def test_basic():
a = tvm.var("a")
b = tvm.var("b")
c = tvm.var("c")
m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
a - 1 > 0), [a])
assert tvm.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0
assert m[0].value == 2
m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
a - 1 > 0), [a, b])
assert len(m) == 0
m = tvm.arith.DetectClipBound(tvm.all(a + 10 * c <= 20,
b - 1 > 0), [a, b])
assert tvm.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0
assert tvm.ir_pass.Simplify(m[2] - 2).value == 0
if __name__ == "__main__":
test_basic()
...@@ -3,22 +3,41 @@ import tvm ...@@ -3,22 +3,41 @@ import tvm
def test_basic(): def test_basic():
a = tvm.var("a") a = tvm.var("a")
b = tvm.var("b") b = tvm.var("b")
m = tvm.arith.DetectLinearEquation(a * 4 + b * 6 + 7, a) m = tvm.arith.DetectLinearEquation(a * 4 + b * 6 + 7, [a])
assert m[1].value == 4 assert m[0].value == 4
assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7)).value == 0 assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7)).value == 0
m = tvm.arith.DetectLinearEquation(a * 4 * (a+1) + b * 6 + 7, a) m = tvm.arith.DetectLinearEquation(a * 4 * (a+1) + b * 6 + 7, [a])
assert len(m) == 0 assert len(m) == 0
m = tvm.arith.DetectLinearEquation(a * 4 + (a+1) + b * 6 + 7, a) m = tvm.arith.DetectLinearEquation(a * 4 + (a+1) + b * 6 + 7, [a])
assert m[1].value == 5 assert m[0].value == 5
assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7 + 1)).value == 0 assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7 + 1)).value == 0
m = tvm.arith.DetectLinearEquation(a * b + 7, a) m = tvm.arith.DetectLinearEquation(a * b + 7, [a])
assert m[1] == b assert m[0] == b
m = tvm.arith.DetectLinearEquation(b * 7, a) m = tvm.arith.DetectLinearEquation(b * 7, [a])
assert m[1].value == 0 assert m[0].value == 0
def test_multivariate():
v = [tvm.var("v%d" % i) for i in range(4)]
b = tvm.var("b")
m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8, v)
assert(tvm.ir_pass.Equal(tvm.ir_pass.Simplify(m[0]), b + 5))
assert(m[1].value == 8)
m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v)
assert(len(m) == 0)
m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[1] + v[3], v)
assert(len(m) == 0)
m = tvm.arith.DetectLinearEquation(((v[0] * b + v[1]) * 8 + v[2] + 1) * 2, v)
assert(m[1].value == 16)
assert(m[2].value == 2)
assert(m[len(m)-1].value == 2)
if __name__ == "__main__": if __name__ == "__main__":
test_basic() test_basic()
test_multivariate()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment