Unverified Commit ec95675c by Tianqi Chen Committed by GitHub

[ARITH] Analyzer RewriteSimplifier: add/sub/mul/div/mod (#2722)

parent 0bf64ee0
......@@ -193,6 +193,39 @@ class ModularSetAnalyzer {
};
/*!
* \brief Rewrite-rule based simplifier.
*/
class RewriteSimplifier {
public:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr operator()(const Expr& expr);
/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const Expr& new_expr,
bool override = false);
private:
friend class Analyzer;
friend class ConstraintContext;
explicit RewriteSimplifier(Analyzer* parent);
~RewriteSimplifier();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};
/*!
* \brief A RAII constraint context.
*
* \code
......@@ -242,6 +275,8 @@ class Analyzer {
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplfy */
RewriteSimplifier rewrite_simplify;
/*! \brief constructor */
Analyzer();
/*!
......
......@@ -96,6 +96,7 @@ class Analyzer:
self._const_int_bound_update = _mod("const_int_bound_update")
self._bind = _mod("bind")
self._modular_set = _mod("modular_set")
self._rewrite_simplify = _mod("rewrite_simplify")
self._enter_constraint_context = _mod("enter_constraint_context")
def const_int_bound(self, expr):
......@@ -128,6 +129,21 @@ class Analyzer:
"""
return self._modular_set(expr)
def rewrite_simplify(self, expr):
"""Simplify expression via rewriting rules.
Parameters
----------
expr : tvm.Expr
The expression.
Returns
-------
result : Expr
The result.
"""
return self._rewrite_simplify(expr)
def bind(self, var, expr):
"""Bind a variable to the expression.
......
......@@ -98,6 +98,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
self->const_int_bound.Update(args[0], args[1], args[2]);
});
} else if (name == "rewrite_simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->rewrite_simplify(args[0]);
});
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto& sptr = args[1].node_sptr();
......
......@@ -2,6 +2,7 @@
* Copyright (c) 2019 by Contributors
* \file tvm/arithmetic/analyzer.cc
*/
#include <tvm/ir.h>
#include <tvm/arithmetic.h>
namespace tvm {
......@@ -9,19 +10,22 @@ namespace arith {
Analyzer::Analyzer()
: const_int_bound(this),
modular_set(this) {
modular_set(this),
rewrite_simplify(this) {
}
void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
Var var(v.node_);
this->const_int_bound.Update(var, this->const_int_bound(expr));
this->modular_set.Update(var, this->modular_set(expr));
this->rewrite_simplify.Update(var, this->rewrite_simplify(expr));
}
void Analyzer::Bind(const VarExpr& v, const Range& range) {
Var var(v.node_);
this->const_int_bound.Bind(var, range);
// skip modular_set
// skip rewrite simplify
}
ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) {
......@@ -36,7 +40,10 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint)
}
bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
auto bd = this->const_int_bound(expr);
if (const auto* ptr = expr.as<ir::IntImm>()) {
return ptr->value > lower_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->min_value >= lower_bound) return true;
return false;
}
......
......@@ -23,7 +23,9 @@ namespace arith {
* \return nullptr if constant fold fails, otherwise return folded result.
*/
template<typename Op>
inline Expr TryConstFold(Expr a, Expr b);
inline Expr TryConstFold(Expr a, Expr b) {
return Expr();
}
/*!
* \brief Try to run unary compute with constant folding.
......
......@@ -49,6 +49,7 @@
#include <tvm/ir_pass.h>
#include <tuple>
#include "const_fold.h"
namespace tvm {
namespace arith {
......@@ -242,7 +243,11 @@ class PBinaryExpr :
}
Expr Eval() const {
return NodeType::make(a_.Eval(), b_.Eval());
Expr lhs = a_.Eval();
Expr rhs = b_.Eval();
Expr ret = TryConstFold<NodeType>(lhs, rhs);
if (ret.defined()) return ret;
return NodeType::make(lhs, rhs);
}
private:
......@@ -250,12 +255,48 @@ class PBinaryExpr :
typename TB::Nested b_;
};
template<typename TA>
class PConstWithTypeLike :
public Pattern<PConstWithTypeLike<TA> > {
public:
PConstWithTypeLike(const TA& ref, int64_t value)
: ref_(ref), value_(value) {}
void InitMatch_() const {}
bool Match_(const NodeRef& node) const {
if (const ir::IntImm* ptr = node.as<ir::IntImm>()) {
return ptr->value == value_;
} else {
return false;
}
}
Expr Eval() const {
return make_const(ref_.Eval().type(), value_);
}
private:
typename TA::Nested ref_;
int64_t value_;
};
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
template<typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
template<typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> > \
FuncName(const Pattern<TA>& a, int64_t b) { \
return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> \
FuncName(int64_t b, const Pattern<TA>& a) { \
return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
}
// arithmetic expressions
......
......@@ -117,6 +117,7 @@ TEST(Pattern, Integer) {
// special case container of Expr
CHECK((v * c).Match(tx * 3));
CHECK_EQ(c.Eval()->value, 3);
CHECK((v * 3).Match(tx * 3));
}
// cannot match c to ty
CHECK(!(v * c).Match(tx * ty));
......
import tvm
class RewriteChecker:
def __init__(self):
self.analyzer = tvm.arith.Analyzer()
def verify(self, data, expected):
res = self.analyzer.rewrite_simplify(data)
assert tvm.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format(
data, res, expected)
def test_vector_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
# Add rules
ck.verify(tvm.expr.Ramp(x, 1, 4) + tvm.expr.Ramp(y, 2, 4),
tvm.expr.Ramp(x + y, 3, 4))
ck.verify(tvm.expr.Ramp(x, 1, 2) + y,
tvm.expr.Ramp(x + y, 1, 2))
ck.verify(y + tvm.expr.Ramp(x, 1, 2) ,
tvm.expr.Ramp(y + x, 1, 2))
ck.verify(y.astype("int32x2") + x.astype("int32x2"),
(y + x).astype("int32x2"))
# Sub rules
ck.verify(tvm.expr.Ramp(x, 4, 4) - tvm.expr.Ramp(y, 2, 4),
tvm.expr.Ramp(x - y, 2, 4))
ck.verify(tvm.expr.Ramp(x, 1, 2) - y,
tvm.expr.Ramp(x - y, 1, 2))
ck.verify(y - tvm.expr.Ramp(x, 1, 2) ,
tvm.expr.Ramp(y - x, -1, 2))
ck.verify(y.astype("int32x2") - x.astype("int32x2"),
(y - x).astype("int32x2"))
# Mul rules
ck.verify(y.astype("int32x2") * x.astype("int32x2"),
(y * x).astype("int32x2"))
ck.verify(tvm.expr.Ramp(x, 4, 4) * 2,
tvm.expr.Ramp(x * 2, 8, 4))
ck.verify(2 * tvm.expr.Ramp(x, 4, 4),
tvm.expr.Ramp(x * 2, 8, 4))
## Div rules
ck.verify(y.astype("int32x2") / x.astype("int32x2"),
(y / x).astype("int32x2"))
ck.verify(tvm.expr.Ramp(x, 4, 4) / 2,
tvm.expr.Ramp(x/ 2, 2, 4))
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) / 8,
(x).astype("int32x4"))
ck.verify(tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8,
tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8)
## Mod rules
ck.verify(y.astype("int32x2") % x.astype("int32x2"),
(y % x).astype("int32x2"))
ck.verify(tvm.expr.Ramp(x, 4, 4) % 2,
tvm.expr.Broadcast(x % 2, 4))
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) % 8,
tvm.expr.Ramp(1, 1, 4))
ck.verify(tvm.expr.Ramp(x * 8 + 1, 15, 4) % 8,
tvm.expr.Ramp(1, 15, 4) % 8)
def test_select_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
# Add rules
ck.verify(tvm.expr.Select(x > 0, y, 0) + tvm.expr.Select(x > 0, 1, z),
tvm.expr.Select(x > 0, y + 1, z))
ck.verify(tvm.expr.Select(x > 0, y, 1) - tvm.expr.Select(x > 0, 1, z),
tvm.expr.Select(x > 0, y + (-1), 1 - z))
ck.verify(tvm.expr.Select(x > 0, y, z) - y,
tvm.expr.Select(x > 0, 0, z - y))
ck.verify(tvm.expr.Select(x > 0, y, z) - z,
tvm.expr.Select(x > 0, y - z, 0))
def test_add_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify(x + (y - x), y)
ck.verify(x - (y + 1) + (y + 1), x)
ck.verify((x - 10) + (10 - z), x - z)
ck.verify((x - y) + (z - x), z - y)
ck.verify(tvm.min(x, y - z) + z, tvm.min(x + z, y))
ck.verify(tvm.min(x - z, y) + z, tvm.min(x, y + z))
ck.verify(tvm.max(x, y - 10) + 10, tvm.max(x + 10, y))
ck.verify(tvm.max(x - 11, y) + 11, tvm.max(x, y + 11))
ck.verify(tvm.max(x, y * 2) + tvm.min(x, y * 2), x + y * 2);
ck.verify(tvm.min(x, y * 2) + tvm.max(x, y * 2), x + y * 2);
ck.verify(tvm.max(x, y + 2) + (-2), tvm.max(x + (-2), y));
ck.verify(tvm.min(x, y + 2) + (-2), tvm.min(x + (-2), y));
ck.verify(tvm.min(x + 2, y + 3) + (-2), tvm.min(x, y + 1));
ck.verify(x * y + x * 10, x * (y + 10))
ck.verify(y * x + x * 10, x * (y + 10))
ck.verify(y * x + 10 * x, x * (y + 10))
ck.verify(x * y + 10 * x, x * (y + 10))
ck.verify(y * (x % 8) + 10 * (x % 8), (x % 8) * (y + 10))
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify((x / 8) * 8 + x % 8, x)
# canonicalization
ck.verify(x + 2 + 3 + 4 + x, x * 2 + 9);
ck.verify(x + 2 + 3 + 4 + x * 3, x * 4 + 9);
# conservative bound
try:
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1, 1000), override=True)
ck.verify((x / 8) * 8 + x % 8, x)
raise RuntimeError("bad")
except AssertionError:
pass
def test_sub_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify(x + y - y, x)
ck.verify(x + y - x, y)
ck.verify(x - (y + x), 0 - y)
ck.verify(x - (x + y), 0 - y)
ck.verify(tvm.min(x, y) - x, tvm.min(0, y - x))
ck.verify(tvm.min(x, y) - y, tvm.min(x - y, 0))
ck.verify(tvm.max(x, y) - x, tvm.max(0, y - x))
ck.verify(tvm.max(x, y) - y, tvm.max(x - y, 0))
ck.verify(x - tvm.min(x, y), tvm.max(0, x - y))
ck.verify(y - tvm.min(x, y), tvm.max(y - x, 0))
ck.verify(x - tvm.max(x, y), tvm.min(0, x - y))
ck.verify(y - tvm.max(x, y), tvm.min(y - x, 0))
# mul co-efficient foldng
ck.verify(x - x, 0)
ck.verify(x * y - x, x * (y + (-1)))
ck.verify(x * y - 10 * x, x * (y + (-10)))
ck.verify(y * x - x * z, x * (y - z))
ck.verify(y * x - z * x, x * (y - z))
ck.verify(x + 10 - 20, x + (-10))
# 4-operands pattern
ck.verify((x + y) - (x + z), y - z)
ck.verify((y + x) - (x + z), y - z)
ck.verify((x + y) - (z + x), y - z)
ck.verify((y + x) - (z + x), y - z)
ck.verify(tvm.min(x + y, z) - x, tvm.min(y, z - x))
ck.verify(tvm.min(y + x, z) - x, tvm.min(y, z - x))
ck.verify(tvm.min(z, x + y) - x, tvm.min(z - x, y))
ck.verify(tvm.min(z, y + x) - x, tvm.min(z - x, y))
ck.verify(x - tvm.min(x + y, z), tvm.max(0 - y, x - z))
ck.verify(x - tvm.min(y + x, z), tvm.max(0 - y, x - z))
ck.verify(x - tvm.min(z, x + y), tvm.max(x - z, 0 - y))
ck.verify(x - tvm.min(z, y + x), tvm.max(x - z, 0 - y))
ck.verify(tvm.min(x, y) - tvm.min(y, x), 0)
ck.verify(tvm.max(x, y) - tvm.max(y, x), 0)
ck.verify(tvm.min(x, y) - tvm.min(x + 10, y + 10), -10)
ck.verify(tvm.min(x + 10, y + 1) - tvm.min(x, y - 9), 10)
# div pattern
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(x - (x / 3) * 3, x % 3)
ck.verify((x + 5) / 3 - x / 3, (((x + 2) % 3) + 5)/ 3)
def test_mul_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify((x + 2) * 3, x * 3 + 6)
ck.verify((x * 2) * 3, x * 6)
ck.verify(tvm.min(x, y) * tvm.max(x, y), x * y)
ck.verify(tvm.max(x, y) * tvm.min(x, y), x * y)
ck.verify((x - y) * (-2), (y - x) * 2)
def test_div_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(x / 2 / 3, x / 6)
ck.verify((x / 2 + 1) / 3, (x + 2) / 6)
ck.verify(x * 2 / 4, x / 2)
ck.verify(x * 4 / 2, x * 2)
ck.verify((x * 4 + y) / 2, x * 2 + y / 2)
ck.verify(tvm.min(x * 6, y) / 2, tvm.min(x * 3, y / 2))
ck.verify(tvm.max(x * 6, y) / 2, tvm.max(x * 3, y / 2))
ck.verify((y + x * 4) / 2, y / 2 + x * 2)
ck.verify(tvm.min(y, x * 6) / 2, tvm.min(y / 2, x * 3))
ck.verify(tvm.max(y, x * 6) / 2, tvm.max(y / 2, x * 3))
# 3-operands
ck.verify((x * 6 + y + z) / 2, x * 3 + (y + z) / 2)
ck.verify((x * 6 - y + (y + 3)) / 2, x * 3 + 1)
ck.verify((x * 6 + (y + 3) - y) / 2, x * 3 + 1)
ck.verify((y + x * 6 + z) / 2, x * 3 + (y + z) / 2)
ck.verify((x + 4) / 2, x / 2 + 2)
ck.verify((x + y) / x, y / x + 1)
ck.verify((y + x) / x, y / x + 1)
ck.verify(((x + y) + z) / x, (y + z) / x + 1)
ck.verify(((y + x) + z) / x, (y + z) / x + 1)
ck.verify((y + (x + z)) / x, (y + z) / x + 1)
ck.verify((y + (z + x)) / x, (y + z) / x + 1)
ck.verify((x * y) / y, x)
ck.verify((y * x) / y, x)
ck.verify((x * z + y) / z, x + y / z)
ck.verify((z * x + y) / z, x + y / z)
ck.verify((y + x * z) / z, y / z + x)
ck.verify((y + z * x) / z, y / z + x)
def test_mod_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(x * 10 % 2, 0)
ck.verify((x * 10 + y) % 2, y % 2)
ck.verify((x + 10) % 2, x % 2)
ck.verify((x + y * 10) % 2, x % 2)
ck.verify((x* 10 + 1 + y * 2 + 2) % 2, 1)
if __name__ == "__main__":
test_mod_index_simplify()
test_vector_simplify()
test_add_index_simplify()
test_sub_index_simplify()
test_mul_index_simplify()
test_div_index_simplify()
test_select_simplify()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment