Unverified Commit 75892d2b by Tianqi Chen Committed by GitHub

[ARITH][IR] Introduce FloorDiv/Mod (#3479)

* [ARITH][IR] Introduce FloorDiv/Mod

* Address review comments

* address review comments, fix div sub rule
parent 45878ff2
......@@ -332,6 +332,26 @@ TVM_DLL Expr operator||(Expr a, Expr b);
*/
TVM_DLL Expr operator!(Expr a);
/*!
* \brief compute floor(a / b)
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr floordiv(Expr a, Expr b);
/*!
* \brief compute the remainder of floordiv
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr floormod(Expr a, Expr b);
/*!
* \brief take maximum of two values
*
* \param a left operand
......
......@@ -178,6 +178,18 @@ class Mod : public BinaryOpNode<Mod> {
static constexpr const char* _type_key = "Mod";
};
/*! \brief Floor division, floor(a/b) */
class FloorDiv : public BinaryOpNode<FloorDiv> {
public:
static constexpr const char* _type_key = "FloorDiv";
};
/*! \brief The remainder of the floordiv */
class FloorMod : public BinaryOpNode<FloorMod> {
public:
static constexpr const char* _type_key = "FloorMod";
};
/*! \brief min(a, b) */
class Min : public BinaryOpNode<Min> {
public:
......
......@@ -140,6 +140,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloorDiv* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloorMod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT;
......@@ -180,6 +182,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(Mul);
IR_EXPR_FUNCTOR_DISPATCH(Div);
IR_EXPR_FUNCTOR_DISPATCH(Mod);
IR_EXPR_FUNCTOR_DISPATCH(FloorDiv);
IR_EXPR_FUNCTOR_DISPATCH(FloorMod);
IR_EXPR_FUNCTOR_DISPATCH(Min);
IR_EXPR_FUNCTOR_DISPATCH(Max);
IR_EXPR_FUNCTOR_DISPATCH(EQ);
......
......@@ -98,6 +98,8 @@ class TVM_DLL IRMutator {
virtual Expr Mutate_(const Mul* op, const Expr& e);
virtual Expr Mutate_(const Div* op, const Expr& e);
virtual Expr Mutate_(const Mod* op, const Expr& e);
virtual Expr Mutate_(const FloorDiv* op, const Expr& e);
virtual Expr Mutate_(const FloorMod* op, const Expr& e);
virtual Expr Mutate_(const Min* op, const Expr& e);
virtual Expr Mutate_(const Max* op, const Expr& e);
virtual Expr Mutate_(const EQ* op, const Expr& e);
......
......@@ -114,6 +114,8 @@ class TVM_DLL IRVisitor {
virtual void Visit_(const Mul* op);
virtual void Visit_(const Div* op);
virtual void Visit_(const Mod* op);
virtual void Visit_(const FloorDiv* op);
virtual void Visit_(const FloorMod* op);
virtual void Visit_(const Min* op);
virtual void Visit_(const Max* op);
virtual void Visit_(const EQ* op);
......
......@@ -888,7 +888,46 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
return reducer
def floordiv(a, b):
"""Compute the floordiv of two expressions.
Parameters
----------
a : Expr
The left hand operand
b : Expr
The right hand operand
Returns
-------
res : Expr
The result expression.
"""
return _make._OpFloorDiv(a, b)
def floormod(a, b):
"""Compute the floormod of two expressions.
Parameters
----------
a : Expr
The left hand operand
b : Expr
The right hand operand
Returns
-------
res : Expr
The result expression.
"""
return _make._OpFloorMod(a, b)
_init_api("tvm.api")
#pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min')
......
......@@ -463,6 +463,40 @@ class Mod(BinaryOpExpr):
@register_node
class FloorDiv(BinaryOpExpr):
"""FloorDiv node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.FloorDiv, a, b)
@register_node
class FloorMod(BinaryOpExpr):
"""FloorMod node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.FloorMod, a, b)
@register_node
class Min(BinaryOpExpr):
"""Min node.
......
......@@ -126,6 +126,8 @@ REGISTER_MAKE(Sub);
REGISTER_MAKE(Mul);
REGISTER_MAKE(Div);
REGISTER_MAKE(Mod);
REGISTER_MAKE(FloorDiv);
REGISTER_MAKE(FloorMod);
REGISTER_MAKE(Min);
REGISTER_MAKE(Max);
REGISTER_MAKE(EQ);
......@@ -192,6 +194,8 @@ REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpDiv, operator/);
REGISTER_MAKE_BINARY_OP(_OpMod, operator%);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max);
REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
......
......@@ -36,6 +36,7 @@ using namespace ir;
class SumExpr;
class SplitExpr;
/*!
* \brief Base class of all temporary expression introduced
* for canonicalization.
......@@ -57,6 +58,31 @@ class CanonicalExprNode : public BaseExprNode {
TVM_DECLARE_BASE_NODE_INFO(CanonicalExprNode, BaseExprNode);
};
enum DivMode {
/*! \brief Truncated division. */
kTruncDiv,
/*! \brief Floor division. */
kFloorDiv
};
inline Expr ModImpl(Expr a, Expr b, DivMode mode) {
if (mode == kTruncDiv) {
return a % b;
} else {
CHECK_EQ(mode, kFloorDiv);
return floormod(a, b);
}
}
inline Expr DivImpl(Expr a, Expr b, DivMode mode) {
if (mode == kTruncDiv) {
return a / b;
} else {
CHECK_EQ(mode, kFloorDiv);
return floordiv(a, b);
}
}
/*!
* \brief Internal "Split normal form" of expression.
*
......@@ -78,6 +104,8 @@ class SplitExprNode : public CanonicalExprNode {
int64_t upper_factor{kPosInf};
/*! \brief scale to the expression. */
int64_t scale{1};
/*! \brief Division mode. */
DivMode div_mode{kTruncDiv};
/*! \brief verify that this is a valid entry. */
void Verify() const {
......@@ -91,10 +119,10 @@ class SplitExprNode : public CanonicalExprNode {
return make_const(dtype, 0);
}
if (this->upper_factor != SplitExprNode::kPosInf) {
res = res % make_const(dtype, this->upper_factor);
res = ModImpl(res, make_const(dtype, this->upper_factor), div_mode);
}
if (this->lower_factor != 1) {
res = res / make_const(dtype, this->lower_factor);
res = DivImpl(res, make_const(dtype, this->lower_factor), div_mode);
}
sscale *= this->scale;
if (sscale != 1) {
......@@ -113,6 +141,7 @@ class SplitExprNode : public CanonicalExprNode {
}
inline bool IndexEqual(const SplitExpr& other) const;
inline bool DivModeCompatibleTo(DivMode mode) const;
/*! \brief positive infty */
static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
......@@ -127,6 +156,12 @@ inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const {
return ir::Equal(index, other->index);
}
inline bool SplitExprNode::DivModeCompatibleTo(DivMode mode) const {
if (this->div_mode == mode) return true;
if (lower_factor == 1 && upper_factor == kPosInf) return true;
return false;
}
/*!
* \brief Normal form that represents sum of expressions.
*
......@@ -143,6 +178,10 @@ class SumExprNode : public CanonicalExprNode {
std::vector<SplitExpr> args;
/*! \brief Base value in the summation. */
int64_t base{0};
/*! \brief The expression equals zero. */
bool IsZero() const {
return base == 0 && args.size() == 0;
}
/*!
* \brief Return the normal Expr that is equivalent to self.
* \return The normal expression.
......@@ -218,7 +257,8 @@ class SumExprNode : public CanonicalExprNode {
return;
}
if (other->lower_factor == args[j]->lower_factor &&
other->upper_factor == args[j]->upper_factor) {
other->upper_factor == args[j]->upper_factor &&
other->DivModeCompatibleTo(args[j]->div_mode)) {
args[j].CopyOnWrite()->scale += other->scale * scale;
return;
}
......@@ -251,14 +291,16 @@ class SumExprNode : public CanonicalExprNode {
if (!lhs->IndexEqual(rhs)) break;
if (lhs->upper_factor < rhs->lower_factor) break;
if (lhs->upper_factor == rhs->upper_factor &&
lhs->lower_factor == rhs->lower_factor) {
lhs->lower_factor == rhs->lower_factor &&
lhs->DivModeCompatibleTo(rhs->div_mode)) {
// folding same co-efficient.
rhs.CopyOnWrite()->scale += lhs->scale;
lhs.CopyOnWrite()->scale = 0;
} else if (lhs->lower_factor == rhs->upper_factor &&
rhs->scale != 0 &&
lhs->scale % rhs->scale == 0 &&
lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor) {
lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor &&
lhs->DivModeCompatibleTo(rhs->div_mode)) {
// Rules used in the proof:
//
// Rule 1: (x % (c * s)) / c = (x / c) % s
......@@ -270,12 +312,12 @@ class SumExprNode : public CanonicalExprNode {
// Thus, lhs = rhs
//
// The above proof is for the floordiv.
// The same rule also holds for trucdiv(division rule in C).
// The same rule also holds for truncdiv(division rule in C).
// Because both sides only involve mul, div and mod,
// we can take abs of x, c and s, apply the floordiv proof,
// and finally add the sign back.
//
// Rule 2: (x / s) * s + x % s = x (true for both truc and floor div)
// Rule 2: (x / s) * s + x % s = x (true for both trunc and floor div)
//
// General merge condition and proof:
// - x = lhs->index % lhs->upper_factor
......@@ -324,6 +366,9 @@ class SumExprNode : public CanonicalExprNode {
// then order by upper factor
if (lhs->upper_factor > rhs->upper_factor) return true;
if (lhs->upper_factor < rhs->upper_factor) return false;
// then order by div mode
if (lhs->div_mode > rhs->div_mode) return true;
if (lhs->div_mode < rhs->div_mode) return false;
// tie.
// TODO(tvm-team) We might consider index as the last comparison point,
// after we make deep comparator more derministic.
......@@ -402,6 +447,8 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
Expr Mutate_(const Mul* op, const Expr& self) final;
Expr Mutate_(const Div* op, const Expr& self) final;
Expr Mutate_(const Mod* op, const Expr& self) final;
Expr Mutate_(const FloorDiv* op, const Expr& self) final;
Expr Mutate_(const FloorMod* op, const Expr& self) final;
Expr Mutate_(const Reduce* op, const Expr& self) final;
private:
......@@ -409,28 +456,29 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
* \brief compute lhs / cval
* \param lhs The left operand.
* \param cval The constant value.
* \param div_mode The division mode.
* \return The result expression;
*/
SplitExpr SplitDivConst(SplitExpr lhs, int64_t cval);
SplitExpr SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode);
/*!
* \brief compute lhs % cval
* \param lhs The left operand.
* \param cval The constant value.
* \param div_mode The division mode.
* \return The result expression;
*/
SplitExpr SplitModConst(SplitExpr lhs, int64_t cval);
SplitExpr SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode);
/*!
* \brief Detect if psum = q * coeff + r such that (q >= 0 && r >= 0)
* \brief Separate psum into divisible and non-divisible parts.
* \param psum The sum expression.
* \param coeff The co-efficient.
* \param out_divisible The result divisible component.
* \param out_non_divisible The non-divisible component.
* \return Whether detection is successful.
*/
bool TryLinearEquation(const SumExprNode* psum,
int64_t coeff,
SumExpr* out_divisible,
SumExpr* out_non_divisible);
void SeparateDivisibleParts(const SumExprNode* psum,
int64_t coeff,
SumExpr* out_divisible,
SumExpr* out_non_divisible);
/*!
* \brief Normalize expr to normal expr.
* \param expr The input expression.
......@@ -461,9 +509,32 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
NodePtr<SplitExprNode> n = make_node<SplitExprNode>();
n->type = expr.type();
n->index = std::move(expr);
n->div_mode = kTruncDiv;
return SplitExpr(n);
}
/*!
* \brief Convert expr to an equivalent SplitExpr
* that has the specified div_mode.
*
* This function will return the same expr if its
* div_mode already satisfies the need.
*
* \param expr The input expr.
* \param div_mode The new div_mode.
* \return The transformed SplitExpr.
*/
SplitExpr ConvertDivMode(SplitExpr expr, DivMode div_mode) {
if (expr->div_mode == div_mode) return expr;
if (expr->DivModeCompatibleTo(div_mode)) {
expr.CopyOnWrite()->div_mode = div_mode;
return expr;
}
expr = ToSplitExpr(Normalize(expr));
CHECK(expr->DivModeCompatibleTo(div_mode));
expr.CopyOnWrite()->div_mode = div_mode;
return expr;
}
/*!
* \brief Create a SumExpr from expr.
* \param expr The input expr.
* \return The transformed SumExpr.
......@@ -578,12 +649,11 @@ Mutate_(const Mul* op, const Expr& self) {
}
}
bool CanonicalSimplifier::Impl::
TryLinearEquation(const SumExprNode* psum,
int64_t coeff,
SumExpr* out_divisible,
SumExpr* out_non_divisible) {
void CanonicalSimplifier::Impl::
SeparateDivisibleParts(const SumExprNode* psum,
int64_t coeff,
SumExpr* out_divisible,
SumExpr* out_non_divisible) {
auto divisible = make_node<SumExprNode>();
auto non_divisible = make_node<SumExprNode>();
divisible->type = psum->type;
......@@ -603,20 +673,14 @@ TryLinearEquation(const SumExprNode* psum,
}
*out_divisible = SumExpr(divisible);
*out_non_divisible = SumExpr(non_divisible);
if (non_divisible->base == 0 && non_divisible->args.size() == 0) {
return true;
}
if (parent_->CanProveGreaterEqual(divisible->Normalize(), 0) &&
parent_->CanProveGreaterEqual(non_divisible->Normalize(), 0)) {
return true;
} else {
return false;
}
}
SplitExpr CanonicalSimplifier::Impl::
SplitDivConst(SplitExpr lhs, int64_t cval) {
SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
CHECK_GT(cval, 0);
lhs = ConvertDivMode(lhs, div_mode);
// the following rule works for both floordiv and truncdiv
if (lhs->scale % cval == 0) {
lhs.CopyOnWrite()->scale /= cval;
return lhs;
......@@ -637,7 +701,7 @@ SplitDivConst(SplitExpr lhs, int64_t cval) {
} else {
// move the upper_factor modular into index.
lhs.CopyOnWrite()->index =
lhs->index % make_const(lhs.type(), lhs->upper_factor);
ModImpl(lhs->index, make_const(lhs.type(), lhs->upper_factor), div_mode);
lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf;
lhs.CopyOnWrite()->scale = 1;
lhs.CopyOnWrite()->lower_factor *= scaled_cval;
......@@ -647,6 +711,7 @@ SplitDivConst(SplitExpr lhs, int64_t cval) {
}
// directly return the split with cval == 1
lhs = ToSplitExpr(Normalize(lhs));
CHECK(lhs->DivModeCompatibleTo(div_mode));
CHECK_EQ(lhs->scale, 1);
lhs.CopyOnWrite()->lower_factor *= cval;
return lhs;
......@@ -657,6 +722,7 @@ Mutate_(const Div* op, const Expr& self) {
if (!IsIndexType(op->type)) {
return Rewriter::Mutate_(op, self);
}
Expr a = this->CanonicalMutate(op->a);
Expr b = this->CanonicalMutate(op->b);
......@@ -671,16 +737,24 @@ Mutate_(const Div* op, const Expr& self) {
if (const auto* psum = a.as<SumExprNode>()) {
SumExpr lhs, extra;
if (TryLinearEquation(psum, cval, &lhs, &extra)) {
SeparateDivisibleParts(psum, cval, &lhs, &extra);
// can be divided by cval
if (extra->IsZero()) {
lhs.CopyOnWrite()->DivideBy(cval);
return std::move(lhs);
}
// both lhs and extra are non-negative
if (parent_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
parent_->CanProveGreaterEqual(extra->Normalize(), 0)) {
lhs.CopyOnWrite()->DivideBy(cval);
Expr temp = Normalize(extra);
if (const auto* pconst = temp.as<IntImm>()) {
lhs.CopyOnWrite()->AddToSelf(pconst->value / cval);
} else {
// if extra <= cval, it means the extra can be eliminated.
// if 0 <= extra < cval, it means the extra can be eliminated.
if (TryCompare(temp, cval) != kLT) {
lhs.CopyOnWrite()->AddToSelf(
SplitDivConst(ToSplitExpr(temp), cval), 1);
SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1);
}
}
return std::move(lhs);
......@@ -692,7 +766,7 @@ Mutate_(const Div* op, const Expr& self) {
return make_zero(a.type());
}
}
return SplitDivConst(ToSplitExpr(std::move(a)), cval);
return SplitDivConst(ToSplitExpr(std::move(a)), cval, kTruncDiv);
}
// normal path
a = Normalize(a);
......@@ -704,8 +778,67 @@ Mutate_(const Div* op, const Expr& self) {
}
}
Expr CanonicalSimplifier::Impl::
Mutate_(const FloorDiv* op, const Expr& self) {
if (!IsIndexType(op->type)) {
return Rewriter::Mutate_(op, self);
}
Expr a = this->CanonicalMutate(op->a);
Expr b = this->CanonicalMutate(op->b);
// const folding
Expr const_res = TryConstFold<FloorDiv>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
// x / c1
if (c1.Match(b) && c1.Eval()->value > 0) {
int64_t cval = c1.Eval()->value;
if (cval == 1) return a;
if (const auto* psum = a.as<SumExprNode>()) {
SumExpr lhs, extra;
SeparateDivisibleParts(psum, cval, &lhs, &extra);
if (extra->IsZero()) {
lhs.CopyOnWrite()->DivideBy(cval);
return std::move(lhs);
}
// continue simplification.
lhs.CopyOnWrite()->DivideBy(cval);
Expr temp = Normalize(extra);
if (const auto* pconst = temp.as<IntImm>()) {
lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval));
} else {
// if 0 <= extra < cval, it means the extra can be eliminated.
if (!(TryCompare(temp, cval) == kLT && parent_->CanProveGreaterEqual(temp, 0))) {
lhs.CopyOnWrite()->AddToSelf(
SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1);
}
}
return std::move(lhs);
} else {
// if a >= 0 && a < cval, then result == 0
auto cbound = parent_->const_int_bound(Normalize(a));
if (cbound->min_value >= 0 && cbound->max_value < cval) {
return make_zero(a.type());
}
}
return SplitDivConst(ToSplitExpr(std::move(a)), cval, kFloorDiv);
}
// normal path
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
} else {
return FloorDiv::make(a, b);
}
}
SplitExpr CanonicalSimplifier::Impl::
SplitModConst(SplitExpr lhs, int64_t cval) {
SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
CHECK_GT(cval, 0);
lhs = ConvertDivMode(lhs, div_mode);
if (lhs->scale % cval == 0) {
lhs.CopyOnWrite()->scale = 0;
return lhs;
......@@ -718,9 +851,24 @@ SplitModConst(SplitExpr lhs, int64_t cval) {
// try to see if we can reduce the existing upper modular.
if (lhs->upper_factor == SplitExprNode::kPosInf ||
lhs->upper_factor % new_upper_factor == 0) {
lhs.CopyOnWrite()->upper_factor = new_upper_factor;
lhs->Verify();
return lhs;
// we gained a new upper factor that is smaller
// than the original one
// Perhaps there are more chances in simplifying the index
// Do a recursive call to simplify the mod with the new factor.
if (new_upper_factor < lhs->upper_factor &&
lhs->upper_factor != SplitExprNode::kPosInf) {
auto updated = ToSplitExpr(Mutate(ModImpl(
lhs->index, make_const(lhs.type(), new_upper_factor), div_mode)));
// re-apply the lower_factor
if (lhs->lower_factor != 1) {
return SplitDivConst(updated, lhs->lower_factor, div_mode);
} else {
return updated;
}
} else {
lhs.CopyOnWrite()->upper_factor = new_upper_factor;
return lhs;
}
} else if (new_upper_factor % lhs->upper_factor == 0) {
// (x % 2) % 4 => x % 2
return lhs;
......@@ -728,8 +876,10 @@ SplitModConst(SplitExpr lhs, int64_t cval) {
}
// Normalize the value.
lhs = ToSplitExpr(Normalize(lhs));
CHECK(lhs->DivModeCompatibleTo(div_mode));
CHECK_EQ(lhs->scale, 1);
CHECK_EQ(lhs->lower_factor, 1);
lhs.CopyOnWrite()->div_mode = div_mode;
lhs.CopyOnWrite()->upper_factor = cval;
return lhs;
}
......@@ -753,7 +903,13 @@ Mutate_(const Mod* op, const Expr& self) {
int64_t cval = c1.Eval()->value;
if (const auto* psum = a.as<SumExprNode>()) {
SumExpr lhs, extra;
if (TryLinearEquation(psum, cval, &lhs, &extra)) {
SeparateDivisibleParts(psum, cval, &lhs, &extra);
if (extra->IsZero()) {
return make_zero(a.type());
}
// both lhs and extra are non-negative
if (parent_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
parent_->CanProveGreaterEqual(extra->Normalize(), 0)) {
Expr temp = Normalize(extra);
if (temp.as<IntImm>()) {
return temp % c1.Eval();
......@@ -777,7 +933,7 @@ Mutate_(const Mod* op, const Expr& self) {
cbound->min_value - psum->base + new_base >= 0) {
SumExpr sum_expr(std::move(a.node_));
sum_expr.CopyOnWrite()->base = new_base;
return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval);
return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kTruncDiv);
}
} else {
// if a >= 0 && a < cval, then result == 0
......@@ -786,7 +942,7 @@ Mutate_(const Mod* op, const Expr& self) {
return a;
}
}
return SplitModConst(ToSplitExpr(std::move(a)), cval);
return SplitModConst(ToSplitExpr(std::move(a)), cval, kTruncDiv);
}
// normal path
a = Normalize(a);
......@@ -798,6 +954,66 @@ Mutate_(const Mod* op, const Expr& self) {
}
}
Expr CanonicalSimplifier::Impl::
Mutate_(const FloorMod* op, const Expr& self) {
if (!IsIndexType(op->type)) {
return Rewriter::Mutate_(op, self);
}
// normalize
Expr a = this->CanonicalMutate(op->a);
Expr b = this->CanonicalMutate(op->b);
// const folding
Expr const_res = TryConstFold<FloorMod>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
// x % c1
if (c1.Match(b) && c1.Eval()->value > 0) {
int64_t cval = c1.Eval()->value;
if (const auto* psum = a.as<SumExprNode>()) {
SumExpr lhs, extra;
SeparateDivisibleParts(psum, cval, &lhs, &extra);
Expr temp = Normalize(extra);
if (temp.as<IntImm>()) {
return floormod(temp, c1.Eval());
} else {
// If temp < cval && temp >=0 then can remove the mod.
if (TryCompare(temp, cval) == kLT &&
parent_->CanProveGreaterEqual(temp, 0)) {
return temp;
} else {
// contonue to use logic below.
a = extra;
psum = a.as<SumExprNode>();
CHECK(psum != nullptr);
}
}
// Simplify the offset constant if necessary.
// floormod(x - 5, 3) => floormod(x + 1, 3)
int64_t new_base = floormod(psum->base, cval);
SumExpr sum_expr(std::move(a.node_));
sum_expr.CopyOnWrite()->base = new_base;
return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kFloorDiv);
} else {
// if a >= 0 && a < cval, then result == a
auto cbound = parent_->const_int_bound(Normalize(a));
if (cbound->min_value >= 0 && cbound->max_value < cval) {
return a;
}
}
return SplitModConst(ToSplitExpr(std::move(a)), cval, kFloorDiv);
}
// normal path
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
return self;
} else {
return FloorMod::make(a, b);
}
}
// Simplify reduce expression.
Expr CanonicalSimplifier::Impl::
SimplifyReduceCombiner(const Reduce* op) {
......
......@@ -29,6 +29,7 @@
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <algorithm>
#include "int_operator.h"
namespace tvm {
namespace arith {
......@@ -184,9 +185,7 @@ template<>
inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
TVM_INDEX_CONST_PROPAGATION({
const Type& rtype = a.type();
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
if (pa && pb) {
return IntImm::make(rtype, pa->value % pb->value);
}
if (pa) {
......@@ -201,6 +200,51 @@ inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
}
template<>
inline Expr TryConstFold<ir::FloorDiv>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) {
CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm::make(rtype, arith::floordiv(pa->value, pb->value));
}
if (pa) {
if (pa->value == 0) return a;
}
if (pb) {
if (pb->value == 1) return a;
CHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImm::make(rtype, std::floor(fa->value / fb->value));
}
if (fa && fa->value == 0) return a;
if (fb) {
if (fb->value == 1) return a;
CHECK_NE(fb->value, 0) << "Divide by zero";
}
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::FloorMod>(Expr a, Expr b) {
TVM_INDEX_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) {
return IntImm::make(rtype, arith::floormod(pa->value, pb->value));
}
if (pa) {
if (pa->value == 0) return a;
}
if (pb) {
if (pb->value == 1) return make_zero(rtype);
CHECK_NE(pb->value, 0) << "Divide by zero";
}
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
......
......@@ -24,7 +24,7 @@
#include <tvm/arithmetic.h>
#include <tvm/ir_functor_ext.h>
#include <algorithm>
#include "int_op_overflow.h"
#include "int_operator.h"
#include "pattern_match.h"
namespace tvm {
......@@ -215,6 +215,37 @@ class ConstIntBoundAnalyzer::Impl :
}
}
Entry VisitExpr_(const FloorDiv* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
CHECK(!b.is_const(0)) << "floordiv by zero";
// assume no division by 0
if (b.min_value == 0) b.min_value = 1;
if (b.max_value == 0) b.max_value = -1;
return BinaryOpBoundry(a, b, InfAwareFloorDiv);
}
Entry VisitExpr_(const FloorMod* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
if (b.min_value > 0) {
int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
if (a.min_value >= 0) {
// 0 <= [a_min, a_max] < b_min
if (a.max_value < b.min_value) return a;
// other case, we can get close to 0
return MakeBound(0, std::min(a.max_value, b_max_cap));
} else {
return MakeBound(0, b_max_cap);
}
} else {
CHECK(!b.is_const(0)) << "floormod by zero";
// mod by negative value is rare,
// and we just use the simpliest rule.
return Everything(op->type);
}
}
Entry VisitExpr_(const Min* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
......@@ -376,6 +407,20 @@ class ConstIntBoundAnalyzer::Impl :
return x / y;
}
/*!
* \brief Compute floodiv(x, y), aware of inf.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
static int64_t InfAwareFloorDiv(int64_t x, int64_t y) {
CHECK_NE(y, 0);
if (x == kPosInf || x == kNegInf) {
if (y > 0) return x;
return -x;
}
return floordiv(x, y);
}
/*!
* \brief Compute x / y, aware of inf.
* \param x The left operand.
* \param y The right operand.
......
......@@ -19,11 +19,11 @@
/*!
* Copyright (c) 2019 by Contributors
* \file int_op_overflow.h
* \brief Utility functions to detect if an integer op will overflow.
* \file int_operator.h
* \brief Additional useful operators for integer.
*/
#ifndef TVM_ARITHMETIC_INT_OP_OVERFLOW_H_
#define TVM_ARITHMETIC_INT_OP_OVERFLOW_H_
#ifndef TVM_ARITHMETIC_INT_OPERATOR_H_
#define TVM_ARITHMETIC_INT_OPERATOR_H_
#include <limits>
......@@ -48,30 +48,30 @@ inline bool WillOverflow(int64_t x,
}
template<>
bool WillOverflow<ir::Add>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
inline bool WillOverflow<ir::Add>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
if ((y > 0) && (x > max_value - y)) return true;
if ((y < 0) && (x < min_value - y)) return true;
return false;
}
template<>
bool WillOverflow<ir::Sub>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
inline bool WillOverflow<ir::Sub>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
if ((y > 0) && (x < min_value + y)) return true;
if ((y < 0) && (x > max_value + y)) return true;
return false;
}
template<>
bool WillOverflow<ir::Mul>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
inline bool WillOverflow<ir::Mul>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
if (y == 0) return false;
if (y > 0) {
if (x < min_value / y) return true;
......@@ -85,13 +85,42 @@ bool WillOverflow<ir::Mul>(int64_t x,
}
template<>
bool WillOverflow<ir::Mod>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
inline bool WillOverflow<ir::Mod>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
return y == 0;
}
/*!
* \brief Peform floor division of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t floordiv(int64_t x, int64_t y) {
bool round_down =
(x >= 0 && y >= 0) ||
(x <= 0 && y <= 0) ||
(x % y == 0);
return round_down ? (x / y) : (x / y - 1);
}
/*!
* \brief Compute the floordiv remainder of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t floormod(int64_t x, int64_t y) {
bool round_down =
(x >= 0 && y >= 0) ||
(x <= 0 && y <= 0) ||
(x % y == 0);
return round_down ? (x % y) : (x % y + y);
}
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_INT_OP_OVERFLOW_H_
#endif // TVM_ARITHMETIC_INT_OPERATOR_H_
......@@ -252,6 +252,68 @@ inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer,
return IntervalSet::Everything();
}
template<>
inline IntervalSet Combine<ir::FloorDiv>(Analyzer* analyzer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value));
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
if (b->IsSinglePoint()) {
if (is_zero(b->min_value)) {
LOG(FATAL) << "Divide by zero in CombineInterval Div";
}
if (is_one(b->min_value)) return a;
// no relaxation is needed in here due to set is inclusive
if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
Expr min_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : neg_inf();
Expr max_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : pos_inf();
return IntervalSet(min_value, max_value);
} else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
Expr min_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : neg_inf();
Expr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf();
return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::Select;
Expr sign = b->min_value >= make_zero(b->min_value.type().element_of());
Expr e1 = floordiv(a->min_value, b->min_value);
Expr e2 = floordiv(a->max_value, b->min_value);
return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
}
}
DLOG(WARNING) << "Return Everything in CombineInterval Div";
return IntervalSet::Everything();
}
template<>
inline IntervalSet Combine<ir::FloorMod>(Analyzer* analyzer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value));
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
if (b->IsSinglePoint()) {
const Expr& divisor = b->min_value;
if (is_zero(divisor)) {
LOG(FATAL) << "Modular by zero in CombineInterval Mod";
}
if (analyzer->CanProveGreaterEqual(divisor, 0)) {
return IntervalSet(make_zero(divisor.type()), divisor - 1);
} else {
Expr bound = abs(divisor) - 1;
return IntervalSet(-bound, bound);
}
}
DLOG(WARNING) << "Return Everything in CombineInterval Mod";
return IntervalSet::Everything();
}
template<>
inline IntervalSet Combine<ir::Max>(Analyzer* analzyer,
IntervalSet a,
......@@ -361,6 +423,14 @@ class IntervalSetEvaluator :
return VisitBinaryExpr_(op);
}
IntervalSet VisitExpr_(const FloorDiv* op) final {
return VisitBinaryExpr_(op);
}
IntervalSet VisitExpr_(const FloorMod* op) final {
return VisitBinaryExpr_(op);
}
IntervalSet VisitExpr_(const Min* op) final {
return VisitBinaryExpr_(op);
}
......
......@@ -179,6 +179,14 @@ class ModularSetAnalyzer::Impl :
return Everything();
}
Entry VisitExpr_(const FloorDiv* op) final {
Entry b = VisitExpr(op->b);
if (b.is_const()) {
return DivByConst(op->a, b.base, true);
}
return Everything();
}
Entry VisitExpr_(const Min* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -326,6 +326,8 @@ TVM_PATTERN_BINARY_OP(operator/, ir::Div);
TVM_PATTERN_BINARY_OP(operator%, ir::Mod);
TVM_PATTERN_BINARY_OP(min, ir::Min);
TVM_PATTERN_BINARY_OP(max, ir::Max);
TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod);
// logical expressions
TVM_PATTERN_BINARY_OP(operator>, ir::GT);
......
......@@ -193,21 +193,25 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_REWRITE(x * y + z * x, x * (y + z));
TVM_TRY_REWRITE(y * x + z * x, x * (y + z));
// modular-div simplification
// Always pre-condition on positive integer domain
TVM_TRY_REWRITE_IF(
(x / c1) * c1 + x % c1, x,
CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0);
// DivMod rules
// truc div
TVM_TRY_REWRITE((x / c1) * c1 + x % c1, x);
// floor div
TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x);
// canonicalization rule
// will try rewrite again after canonicalization.
TVM_TRY_RECURSIVE_REWRITE(x + (c1 - y), (x - y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x + c1 + y, (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x + (c1 + y), (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE((y % c1) + x * c1, x * c1 + (y % c1));
TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x);
TVM_TRY_RECURSIVE_REWRITE(x + min(y, z), min(y, z) + x);
// DivMod rules
// truc div
TVM_TRY_RECURSIVE_REWRITE((y % c1) + x * c1, x * c1 + (y % c1));
// floor div
TVM_TRY_RECURSIVE_REWRITE(floormod(y, c1) + x * c1, x * c1 + floormod(y, c1));
}
// condition rules.
......@@ -308,8 +312,9 @@ Mutate_(const Sub* op, const Expr& self) {
TVM_TRY_REWRITE_IF(max(b1, b2) - max(s1, s2), b1 - s2,
CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0));
// modular-div simplification
// Note that c*(x/c) + x % c == x is true for every x and c != 0 even for truncated division
// DivMod rules
// trucdiv
// NOTE: c*(x/c) + x % c == x is true all division mode.
TVM_TRY_REWRITE_IF(x - (x / c1) * c1, x % c1,
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF((x / c1) * c1 - x, 0 - (x % c1),
......@@ -355,6 +360,47 @@ Mutate_(const Sub* op, const Expr& self) {
CanProveGreaterEqual(x.Eval(), 0) &&
c1.Eval()->value >= 0 &&
c3.Eval()->value > 0);
// floordiv
TVM_TRY_REWRITE_IF(x - floordiv(x, c1) * c1, floormod(x, c1),
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 - x, 0 - floormod(x, c1),
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x - floordiv(x + y, c1) * c1, floormod(x + y, c1) - y,
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(floordiv(x + y, c1) * c1 - x, y - floormod(x + y, c1),
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x - floordiv(x - y, c1) * c1, floormod(x - y, c1) + y,
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c1 - x, 0 - floormod(x - y, c1) - y,
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(floordiv(x, c1) * c3 - x * c2, 0 - floormod(x, c1) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(x * c2 - floordiv(x + y, c1) * c3, (floormod(x + y, c1) - y) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(floordiv(x + y, c1) * c3 - x * c2, (y - floormod(x + y, c1)) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(x * c2 - floordiv(x - y, c1) * c3, (floormod(x - y, c1) + y) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2,
c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x + c2, c3),
floordiv(floormod(x + floormod(c2, c3), c3) + (c1 - c2), c3),
c3.Eval()->value > 0);
TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x, c3),
floordiv(floormod(x, c3) + c1, c3),
c3.Eval()->value > 0);
// canonicalization rule
// will try rewrite again after canonicalization.
TVM_TRY_REWRITE(x - c1, x + (0 - c1));
......@@ -618,7 +664,6 @@ Mutate_(const Div* op, const Expr& self) {
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Mod* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
......@@ -698,11 +743,223 @@ Mutate_(const Mod* op, const Expr& self) {
ModularSet mod = parent_->modular_set(x.Eval());
int64_t c1val = c1.Eval()->value;
if (mod->coeff % c1val == 0 &&
c1val > 0 &&
CanProveGreaterEqual(x.Eval(), 0)) {
return (mod->base % c1).Eval();
} else if (mod->coeff % c1val == 0 &&
mod->base % c1val == 0) {
return make_zero(ret.type());
}
}
}
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const FloorDiv* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<FloorDiv>();
Expr const_res = TryConstFold<FloorDiv>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, b1;
// Pattern var match IntImm
PVar<Integer> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(floordiv(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(floordiv(x, y), lanes));
// ramp // bcast
if (floordiv(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
if (c1val % c2val == 0) {
return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval();
}
// If all possible indices in ramp are the same.
ModularSet bmod = parent_->modular_set(b1.Eval());
int64_t ramp_min = floordiv(bmod->base, c2val);
int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val);
if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
return broadcast(floordiv(b1, c2), lanes).Eval();
}
}
}
if (IsIndexType(op->type)) {
// Be-aware of the division rules: this is floor division.
TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1), c2), floordiv(x, c1 * c2),
c1.Eval()->value > 0 && c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1) + c2, c3), floordiv(x + c1 * c2, c1 * c3),
c1.Eval()->value > 0 && c3.Eval()->value > 0);
if (floordiv(x * c1, c2).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
if (c1val > 0 && c2val > 0) {
if (c1val % c2val == 0) return (x * floordiv(c1, c2)).Eval();
if (c2val % c1val == 0) return (floordiv(x, floordiv(c2, c1))).Eval();
}
}
TVM_TRY_REWRITE(floordiv(x, x), OneWithTypeLike(x));
TVM_TRY_REWRITE(floordiv(x * c1, x), c1);
TVM_TRY_REWRITE(floordiv(c1 * x, x), c1);
// Rules involving 2-operands.
TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2),
x * floordiv(c1, c2) + floordiv(y, c2),
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2),
min(x * floordiv(c1, c2), floordiv(y, c2)),
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2),
max(x * floordiv(c1, c2), floordiv(y, c2)),
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2),
floordiv(y, c2) + x * floordiv(c1, c2),
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2),
min(floordiv(y, c2), x * floordiv(c1, c2)),
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2),
max(floordiv(y, c2), x * floordiv(c1, c2)),
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
// Rules involving 3-operands.
TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2),
x * floordiv(c1, c2) + floordiv(y + z, c2),
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(x * c1 - y + z, c2),
x * floordiv(c1, c2) + floordiv(z - y, c2),
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(x * c1 + y - z, c2),
x * floordiv(c1, c2) + floordiv(y - z, c2),
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(y + x * c1 + z, c2),
x * floordiv(c1, c2) + floordiv(y + z, c2),
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(x + c1, c2),
floordiv(x, c2) + floordiv(c1, c2),
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1,
CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1,
CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv((x + y) + z, x), floordiv(y + z, x) + 1,
CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv((y + x) + z, x), floordiv(y + z, x) + 1,
CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(y + (z + x), x), floordiv(y + z, x) + 1,
CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(y + (x + z), x), floordiv(y + z, x) + 1,
CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(x * y, y), x,
CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(y * x, y), x,
CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(x * z + y, z), x + floordiv(y, z),
CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(z * x + y, z), x + floordiv(y, z),
CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(y + x * z, z), floordiv(y, z) + x,
CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(y + z * x, z), floordiv(y, z) + x,
CanProveGreaterEqual(z.Eval(), 0));
}
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const FloorMod* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<FloorMod>();
Expr const_res = TryConstFold<FloorMod>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, b1;
// Pattern var match IntImm
PVar<Integer> c1, c2;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(floormod(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(floormod(x, y), lanes));
// floormod(ramp, bcast)
if (floormod(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
if (c1val % c2val == 0) {
return broadcast(floormod(b1, c2), lanes).Eval();
}
// If all possible indices in ramp are the same.
ModularSet bmod = parent_->modular_set(b1.Eval());
int64_t ramp_min = floordiv(bmod->base, c2val);
int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val);
if (bmod->coeff % c2val == 0) {
if (ramp_min == ramp_max) {
return ramp(floormod(bmod->base, c2), c1, lanes).Eval();
} else {
return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval();
}
}
}
}
if (IsIndexType(op->type)) {
// Be-aware of the division rules: we use floordiv/floormod here
TVM_TRY_REWRITE_IF(floormod(x * c1, c2), ZeroWithTypeLike(x),
c2.Eval()->value != 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(y, c2),
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2),
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2),
c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0);
// try modular analysis
if (floormod(x, c1).Match(ret)) {
ModularSet mod = parent_->modular_set(x.Eval());
int64_t c1val = c1.Eval()->value;
if (mod->coeff % c1val == 0 && c1val > 0) {
return floormod(mod->base, c1).Eval();
}
}
}
......@@ -766,7 +1023,9 @@ Mutate_(const Min* op, const Expr& self) {
}
}
// Divide up rounding
// DivMod rules
// Divide up rounding: truc div
// NOTE: trucdiv(x, y) >= floordiv(x, y)
TVM_TRY_REWRITE_IF(min(((x + c1) / c2) * c2, x), x,
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
......@@ -783,6 +1042,26 @@ Mutate_(const Min* op, const Expr& self) {
c1.Eval()->value + 1 == c2.Eval()->value &&
CanProveGreaterEqual(x.Eval(), 0));
// Divide up rounding: floor div
TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, x), x,
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, max(x, c2)), max(x, c2),
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(x, floordiv(x + c1, c2) * c2), x,
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(max(x, c2), floordiv(x + c1, c2) * c2), max(x, c2),
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(x, floordiv(x, c2) * c2), floordiv(x, c2) * c2,
c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(min(floordiv(x, c2) * c2, x), floordiv(x, c2) * c2,
c2.Eval()->value > 0);
TVM_TRY_REWRITE(min(max(x, y), min(x, y)), min(x, y));
TVM_TRY_REWRITE(min(max(x, y), min(y, x)), min(x, y));
TVM_TRY_REWRITE(min(min(x, y), max(x, y)), min(x, y));
......@@ -833,6 +1112,13 @@ Mutate_(const Min* op, const Expr& self) {
return (max(x, y) / c1).Eval();
}
}
if (min(floordiv(x, c1), floordiv(y, c1)).Match(ret)) {
if (c1.Eval()->value > 0) {
return floordiv(min(x, y), c1).Eval();
} else {
return floordiv(max(x, y), c1).Eval();
}
}
if (min(x * c1, y * c1).Match(ret)) {
if (c1.Eval()->value > 0) {
return (min(x, y) * c1).Eval();
......@@ -922,7 +1208,9 @@ Mutate_(const Max* op, const Expr& self) {
}
}
// Divide up rounding
// DivMod rules
// Divide up rounding: truc div
// NOTE: trucdiv(x, y) >= floordiv(x, y)
TVM_TRY_REWRITE_IF(max(((x + c1) / c2) * c2, x), ((x + c1) / c2) * c2,
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
......@@ -930,6 +1218,19 @@ Mutate_(const Max* op, const Expr& self) {
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
// Divide up rounding: floor div
TVM_TRY_REWRITE_IF(max(floordiv(x + c1, c2) * c2, x), floordiv(x + c1, c2) * c2,
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(x, floordiv(x + c1, c2) * c2), floordiv(x + c1, c2) * c2,
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(floordiv(x, c2) * c2, x), x,
c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(max(x, floordiv(x, c2) * c2), x,
c2.Eval()->value > 0);
TVM_TRY_REWRITE(max(min(x, y), max(x, y)), max(x, y));
TVM_TRY_REWRITE(max(min(x, y), max(y, x)), max(x, y));
TVM_TRY_REWRITE(max(max(x, y), min(x, y)), max(x, y));
......@@ -983,6 +1284,13 @@ Mutate_(const Max* op, const Expr& self) {
return (min(x, y) / c1).Eval();
}
}
if (max(floordiv(x, c1), floordiv(y, c1)).Match(ret)) {
if (c1.Eval()->value > 0) {
return floordiv(max(x, y), c1).Eval();
} else {
return floordiv(min(x, y), c1).Eval();
}
}
if (max(x * c1, y * c1).Match(ret)) {
if (c1.Eval()->value > 0) {
return (max(x, y) * c1).Eval();
......@@ -1116,6 +1424,8 @@ Mutate_(const LT* op, const Expr& self) {
TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x,
c1.Eval()->value < 0);
// constant cancelation: only need to make use of one mod
// truc div
TVM_TRY_REWRITE_IF(x * c2 < c1, x < (c1 - 1) / c2 + 1,
c1.Eval()->value > 0 &&
c2.Eval()->value > 0);
......@@ -1131,7 +1441,6 @@ Mutate_(const LT* op, const Expr& self) {
TVM_TRY_REWRITE_IF(x * c2 < c1, c1 / c2 < x,
c1.Eval()->value <= 0 &&
c2.Eval()->value < 0);
// NOTE: trunc div required
TVM_TRY_REWRITE_IF(c1 < x * c2, (c1 + 1) / c2 - 1 < x,
c1.Eval()->value < 0 &&
......@@ -1147,7 +1456,8 @@ Mutate_(const LT* op, const Expr& self) {
TVM_TRY_REWRITE_IF(c1 < x * c2, x < c1 / c2,
c1.Eval()->value >= 0 &&
c2.Eval()->value < 0);
// DivMod rules
// trucdiv
TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2,
c1.Eval()->value > 0 &&
c2.Eval()->value > 0);
......@@ -1164,7 +1474,6 @@ Mutate_(const LT* op, const Expr& self) {
c1.Eval()->value < 0 &&
c2.Eval()->value > 0);
// division related simplificationx
// invariance for any div mod: x - (x / c1) * c1 == x % c1
TVM_TRY_REWRITE_IF((x / c1) * c1 < x, 0 < x % c1,
c1.Eval()->value > 0);
......@@ -1173,16 +1482,38 @@ Mutate_(const LT* op, const Expr& self) {
TVM_TRY_REWRITE_IF((x / c1) * c1 < x - y, y < x % c1,
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(((x + c2)/ c1) * c1 < x,
TVM_TRY_REWRITE_IF(((x + c2) / c1) * c1 < x,
c2 < (x + c2) % c1,
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(((x + c2)/ c1) * c1 < x + y,
TVM_TRY_REWRITE_IF(((x + c2) / c1) * c1 < x + y,
c2 < (x + c2) % c1 + y,
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(((x + c2)/ c1) * c1 < x - y,
TVM_TRY_REWRITE_IF(((x + c2) / c1) * c1 < x - y,
y < (x + c2) % c1 + (0 - c2),
c1.Eval()->value > 0);
// floordiv
TVM_TRY_REWRITE_IF(floordiv(x, c1) < c2, x < c1 * c2,
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(c1 < floordiv(x, c2), (c1 + 1) * c2 - 1 < x,
c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1),
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x + y, 0 < floormod(x, c1) + y,
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x - y, y < floormod(x, c1),
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x,
c2 < floormod(x + c2, c1),
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x + y,
c2 < floormod(x + c2, c1) + y,
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x - y,
y < floormod(x + c2, c1) + (0 - c2),
c1.Eval()->value > 0);
// canonicalization rule
TVM_TRY_RECURSIVE_REWRITE(min(x, y) < z, x < z || y < z);
TVM_TRY_RECURSIVE_REWRITE(max(x, y) < z, x < z && y < z);
......
......@@ -53,6 +53,8 @@ class RewriteSimplifier::Impl : public IRMutator {
Expr Mutate_(const Mul* op, const Expr& self) override;
Expr Mutate_(const Div* op, const Expr& self) override;
Expr Mutate_(const Mod* op, const Expr& self) override;
Expr Mutate_(const FloorDiv* op, const Expr& self) override;
Expr Mutate_(const FloorMod* op, const Expr& self) override;
Expr Mutate_(const Min* op, const Expr& self) override;
Expr Mutate_(const Max* op, const Expr& self) override;
Expr Mutate_(const EQ* op, const Expr& self) override;
......
......@@ -188,6 +188,19 @@ Expr operator%(Expr a, Expr b) {
return ir::Mod::make(a, b);
}
Expr floordiv(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
if (ret.defined()) return ret;
return ir::FloorDiv::make(a, b);
}
Expr floormod(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorMod>(a, b);
if (ret.defined()) return ret;
return ir::FloorMod::make(a, b);
}
Expr min(Expr a, Expr b) {
// inf-aware simplificaiton
......
......@@ -59,14 +59,12 @@ Expr StringImm::make(std::string value) {
Expr Cast::make(DataType t, Expr value) {
CHECK(value.defined());
CHECK_EQ(t.lanes(), value.type().lanes());
NodePtr<Cast> node = make_node<Cast>();
node->type = t;
node->value = std::move(value);
return Expr(node);
}
Expr And::make(Expr a, Expr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
......@@ -700,6 +698,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FloorDiv>([](const FloorDiv* op, IRPrinter *p) {
p->stream << "floordiv(" << op->a << ", " << op->b << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FloorMod>([](const FloorMod* op, IRPrinter *p) {
p->stream << "floormod(" << op->a << ", " << op->b << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<And>([](const And* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
......@@ -1098,7 +1106,6 @@ TVM_REGISTER_NODE_TYPE(CommReducerNode);
TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(Any);
TVM_REGISTER_NODE_TYPE(AttrStmt);
TVM_REGISTER_NODE_TYPE(FloatImm);
TVM_REGISTER_NODE_TYPE(IntImm);
TVM_REGISTER_NODE_TYPE(UIntImm);
......@@ -1110,6 +1117,8 @@ TVM_REGISTER_NODE_TYPE(Sub);
TVM_REGISTER_NODE_TYPE(Mul);
TVM_REGISTER_NODE_TYPE(Div);
TVM_REGISTER_NODE_TYPE(Mod);
TVM_REGISTER_NODE_TYPE(FloorDiv);
TVM_REGISTER_NODE_TYPE(FloorMod);
TVM_REGISTER_NODE_TYPE(Min);
TVM_REGISTER_NODE_TYPE(Max);
TVM_REGISTER_NODE_TYPE(EQ);
......
......@@ -302,6 +302,8 @@ class IRDeepCompare :
DEFINE_BIOP_EXPR_CMP_(Mul)
DEFINE_BIOP_EXPR_CMP_(Div)
DEFINE_BIOP_EXPR_CMP_(Mod)
DEFINE_BIOP_EXPR_CMP_(FloorDiv)
DEFINE_BIOP_EXPR_CMP_(FloorMod)
DEFINE_BIOP_EXPR_CMP_(Min)
DEFINE_BIOP_EXPR_CMP_(Max)
DEFINE_BIOP_EXPR_CMP_(EQ)
......
......@@ -401,6 +401,8 @@ DEFINE_BIOP_EXPR_MUTATE_(Sub)
DEFINE_BIOP_EXPR_MUTATE_(Mul)
DEFINE_BIOP_EXPR_MUTATE_(Div)
DEFINE_BIOP_EXPR_MUTATE_(Mod)
DEFINE_BIOP_EXPR_MUTATE_(FloorDiv)
DEFINE_BIOP_EXPR_MUTATE_(FloorMod)
DEFINE_BIOP_EXPR_MUTATE_(Min)
DEFINE_BIOP_EXPR_MUTATE_(Max)
DEFINE_BIOP_EXPR_MUTATE_(EQ)
......@@ -506,6 +508,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Mul)
.DISPATCH_TO_MUTATE_EXPR(Div)
.DISPATCH_TO_MUTATE_EXPR(Mod)
.DISPATCH_TO_MUTATE_EXPR(FloorDiv)
.DISPATCH_TO_MUTATE_EXPR(FloorMod)
.DISPATCH_TO_MUTATE_EXPR(Min)
.DISPATCH_TO_MUTATE_EXPR(Max)
.DISPATCH_TO_MUTATE_EXPR(EQ)
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -139,6 +139,8 @@ DEFINE_BINOP_VISIT_(Sub)
DEFINE_BINOP_VISIT_(Mul)
DEFINE_BINOP_VISIT_(Div)
DEFINE_BINOP_VISIT_(Mod)
DEFINE_BINOP_VISIT_(FloorDiv)
DEFINE_BINOP_VISIT_(FloorMod)
DEFINE_BINOP_VISIT_(Min)
DEFINE_BINOP_VISIT_(Max)
DEFINE_BINOP_VISIT_(EQ)
......@@ -250,6 +252,8 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.DISPATCH_TO_VISIT(Mul)
.DISPATCH_TO_VISIT(Div)
.DISPATCH_TO_VISIT(Mod)
.DISPATCH_TO_VISIT(FloorDiv)
.DISPATCH_TO_VISIT(FloorMod)
.DISPATCH_TO_VISIT(Min)
.DISPATCH_TO_VISIT(Max)
.DISPATCH_TO_VISIT(EQ)
......
......@@ -151,6 +151,12 @@ class Vectorizer : public IRMutator {
Expr Mutate_(const Mod* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const FloorDiv* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const FloorMod* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Min* op, const Expr &e) final {
return BinaryVec(op, e);
}
......
......@@ -28,21 +28,32 @@ class CanonicalChecker:
def test_mul_sum_simplify():
ck = CanonicalChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify(2 + (3 * x + z + y + 1) * 4 + x,
x * 13 + z * 4 + y * 4 +6)
ck.verify((x + y + x + y * 3) / 2, y * 2 + x)
ck.verify((x + y + x + y * 3) % 2, 0)
ck.verify(x * 3 - 4 * x + 1, 1 - x)
ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2)
# trucdiv
ck.verify((x + y + x + y * 3) / 2, y * 2 + x)
ck.verify((x + y + x + y * 3) % 2, 0)
# floordiv
fld = tvm.floordiv
flm = tvm.floormod
ck.verify(flm(x + x + y * 3, 2), flm(y * 3, 2))
ck.verify(fld(x + y + x + y * 3, 2), y * 2 + x)
ck.verify(flm(x + y + x + y * 3, 2), 0)
ck.verify(fld(x + x + y * 3, 2), fld(y * 3, 2) + x)
def test_split_index_simplify():
ck = CanonicalChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify((x/3) *3 + x % 3, x)
ck.verify((x/6) * 6 + ((x/3) % 2) * 3 + x % 3, x)
# trucdiv
# split div const
ck.verify((x/3) *3 + x % 3, x)
ck.verify((x/6) * 6 + ((x/3) % 2) * 3 + x % 3, x)
ck.verify(((x % 16) / 2) * 2 / 4, (x % 16) / 4)
ck.verify((x % 2) / 8, 0)
ck.verify((x % 2) / 7, 0)
......@@ -59,11 +70,29 @@ def test_split_index_simplify():
# complex fold
ck.verify((z * 9 + y) / 2 * 2 + (z * 9 + y) % 2, z * 9 + y)
ck.analyzer.update(x, tvm.arith.ConstIntBound(-100, 1000), True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(-100, 1000), True)
ck.verify((x * 4 + y) / 2 * 2 + (x * 4 + y) % 2, x * 4 + y)
# floordiv
fld = tvm.floordiv
flm = tvm.floormod
ck.verify(fld(x, 3) * 3 + flm(x, 3), x)
ck.verify(fld(x, 6) * 6 + flm(fld(x, 3), 2) * 3 + flm(x, 3), x)
ck.verify(fld(fld(flm(x, 16), 2) * 2, 4), fld(flm(x, 16), 4))
ck.verify(fld(flm(x, 2), 8), 0)
ck.verify(fld(flm(x, 2), 7), 0)
ck.verify(fld(fld(flm(x, 16), 2) * 2, 6), fld(flm(x, 16), 6))
# cannot simplify mixed case, unless we canonicalize into one mode.
ck.verify((x/6) * 2 + fld(x,3) % 2, (x/6) * 2 + fld(x,3) % 2)
def test_div_simplify():
ck = CanonicalChecker()
x = tvm.var("x")
# truc div
ck.verify((16+48*x)/16, x*3 + 1)
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 != 1+3*x
......@@ -74,6 +103,22 @@ def test_div_simplify():
# Trying expressions that are not simplifiable for any values of the variables
ck.verify((17+47*x)/16, (x * 47 + 17) / 16)
# floordiv
fld = tvm.floordiv
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 10000), True)
ck.verify(fld(16+48*x, 16), x*3 + 1)
ck.verify(fld(17+48*x, 16), x * 3 + 1)
ck.verify(fld(17+47*x, 16), fld(x * 47 + 17, 16))
def test_floormod_simplify():
ck = CanonicalChecker()
flm = tvm.floormod
x, y = tvm.var("x"), tvm.var("y")
ck.verify(flm(flm((x*4) + y - 466036, 24528) - 24512, 16),
flm((x*4) + y + 12, 16))
def test_canonical_mixed():
ck = CanonicalChecker()
......@@ -86,6 +131,10 @@ def test_canonical_mixed():
ck.verify(tvm.min(x, 1) - tvm.min(x, 1), 0)
ck.verify(x * x - x * x, 0)
fld = tvm.floordiv
ck.verify(fld(x, (z*z)) - fld(x, (z*z)), 0)
ck.verify(fld(x, (z+z)) - fld(x, (z+z)), 0)
def test_reduce_combiner_simplify():
ck = CanonicalChecker()
......@@ -218,11 +267,13 @@ def test_complex_cases():
if __name__ == "__main__":
test_floormod_simplify()
test_mul_sum_simplify()
test_simplify_if_then_else()
test_div_simplify()
test_reduce_simplify()
test_reduce_combiner_simplify()
test_mul_sum_simplify()
test_split_index_simplify()
test_canonical_mixed()
test_complex_cases()
......@@ -143,6 +143,52 @@ def test_mod_bound():
assert bd.max_value == 9
def test_floordiv_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
fld = tvm.floordiv
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(fld(x, y))
assert bd.min_value == -9 // 4
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True)
bd = analyzer.const_int_bound(fld(x, y))
assert bd.min_value == -4
assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True)
bd = analyzer.const_int_bound(fld(x, y))
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF
def test_floormod_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
flm = tvm.floormod
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(flm(x, y))
assert bd.min_value == 0
assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(flm(x, y))
assert bd.min_value == 0
assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(flm(x, y))
assert bd.min_value == 0
assert bd.max_value == 9
def test_min_max_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
......@@ -229,6 +275,8 @@ if __name__ == "__main__":
test_mul_bound()
test_div_bound()
test_mod_bound()
test_floordiv_bound()
test_floormod_bound()
test_min_max_bound()
test_select_bound()
test_shift_and_bound()
......
......@@ -64,9 +64,14 @@ def test_mul_div():
ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20))
ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2))
ck.verify(x / y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 / y))
ck.verify(x / 2, {x : tvm.arith.IntervalSet(1, 10)}, (0, 5))
fld = tvm.floordiv
ck.verify(fld(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y)))
ck.verify(fld(x, 2), {x : tvm.arith.IntervalSet(-1, 10)}, (-1, 5))
def test_mod():
ck = IntSetChecker()
......@@ -75,6 +80,10 @@ def test_mod():
ck.verify(x % y, {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
ck.verify(x % 10, {x : tvm.arith.IntervalSet(1, 10)}, (0, 9))
flm = tvm.floormod
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9))
def test_max_min():
ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y")
......@@ -99,4 +108,3 @@ if __name__ == "__main__":
test_max_min()
test_select()
test_mod()
......@@ -61,6 +61,10 @@ def test_div_shift():
m = analyzer.modular_set((x * 4 + 2) >> 1)
assert m.coeff == 2
assert m.base == 1
fld = tvm.floordiv
m = analyzer.modular_set(fld(x * 4 + 2, 2))
assert m.coeff == 2
assert m.base == 1
# x is non-negative
analyzer.update(x, tvm.arith.ConstIntBound(0, 100))
m = analyzer.modular_set((x * 4 + 2) / 2)
......
......@@ -55,7 +55,8 @@ def test_vector_simplify():
ck.verify(2 * tvm.expr.Ramp(x, 4, 4),
tvm.expr.Ramp(x * 2, 8, 4))
## Div rules
## DivMod rules
# truc div
ck.verify(y.astype("int32x2") / x.astype("int32x2"),
(y / x).astype("int32x2"))
ck.verify(tvm.expr.Ramp(x, 4, 4) / 2,
......@@ -65,18 +66,36 @@ def test_vector_simplify():
(x).astype("int32x4"))
ck.verify(tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8,
tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8)
## Mod rules
ck.verify(y.astype("int32x2") % x.astype("int32x2"),
(y % x).astype("int32x2"))
ck.verify(tvm.expr.Ramp(x, 4, 4) % 2,
tvm.expr.Broadcast(x % 2, 4))
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) % 8,
tvm.expr.Ramp(1, 1, 4))
ck.verify(tvm.expr.Ramp(x * 8 + 1, 15, 4) % 8,
tvm.expr.Ramp(1, 15, 4) % 8)
# floor div
fld = tvm.floordiv
flm = tvm.floormod
ck.analyzer.update(x, tvm.arith.ConstIntBound(-10, 1000), override=True)
ck.verify(fld(y.astype("int32x2"), x.astype("int32x2")),
fld(y, x).astype("int32x2"))
ck.verify(fld(tvm.expr.Ramp(x, 4, 4), 2),
tvm.expr.Ramp(fld(x, 2), 2, 4))
ck.verify(fld(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8),
(x).astype("int32x4"))
ck.verify(fld(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8),
fld(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8))
ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")),
flm(y, x).astype("int32x2"))
ck.verify(flm(tvm.expr.Ramp(x, 4, 4), 2),
tvm.expr.Broadcast(flm(x, 2), 4))
ck.verify(flm(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8),
tvm.expr.Ramp(1, 1, 4))
ck.verify(flm(tvm.expr.Ramp(x * 8 + 1, 15, 4), 8),
flm(tvm.expr.Ramp(1, 15, 4), 8))
# Min/Max rules
vx = tvm.var("vx", dtype="int32x2")
vc = tvm.var("vc", dtype="uint1")
......@@ -162,21 +181,23 @@ def test_add_index_simplify():
ck.verify(y * x + 10 * x, x * (y + 10))
ck.verify(x * y + 10 * x, x * (y + 10))
ck.verify(y * (x % 8) + 10 * (x % 8), (x % 8) * (y + 10))
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify((x / 8) * 8 + x % 8, x)
# canonicalization
ck.verify(x + 2 + 3 + 4 + x, x * 2 + 9);
ck.verify(x + 2 + 3 + 4 + x * 3, x * 4 + 9);
# conservative bound
try:
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1, 1000), override=True)
ck.verify((x / 8) * 8 + x % 8, x)
raise RuntimeError("bad")
except AssertionError:
pass
# DivMod rules
# truc div
ck.verify(y * (x % 8) + 10 * (x % 8), (x % 8) * (y + 10))
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1, 1000), override=True)
ck.verify((x / 8) * 8 + x % 8, x)
# floor div
fld = tvm.floordiv
flm = tvm.floormod
ck.verify(y * flm(x, 8) + 10 * flm(x, 8), flm(x, 8) * (y + 10))
ck.verify(fld(x, 8) * 8 + flm(x, 8), x)
def test_sub_index_simplify():
......@@ -233,7 +254,8 @@ def test_sub_index_simplify():
ck.verify(tvm.min(x, y) - tvm.min(x + 10, y + 10), -10)
ck.verify(tvm.min(x + 10, y + 1) - tvm.min(x, y - 9), 10)
# div pattern
# DivMod patterns
# truc div
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(x - (x / 3) * 3, x % 3)
......@@ -260,6 +282,33 @@ def test_sub_index_simplify():
ck.verify(6 * ((y + z) / 3) - y * 2, (z - (y + z) % 3) * 2)
ck.verify(((y - z) / 3) * 6 - 2 * y, (0 - (y - z) % 3 - z) * 2)
# floor div
fld = tvm.floordiv
flm = tvm.floormod
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), override=True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(-1000, 1000), override=True)
ck.verify(x - fld(x, 3) * 3, flm(x, 3))
ck.verify(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 5, 3))
ck.verify(fld(x + 5, 3) - fld(x + 2, 3), fld(flm(x + 2, 3), 3) + 1)
ck.verify(fld(y, 3) * 3 - y, 0 - flm(y, 3))
ck.verify(y - fld(y - 6, 5) * 5, flm(y + (-6), 5) + 6)
ck.verify(fld(y - 6, 5) * 5 - y, (-6) - flm(y + (-6), 5))
ck.verify(y - fld(y + z, 5) * 5, flm(y + z, 5) - z)
ck.verify(fld(y + z, 5) * 5 - y, z - flm(y + z, 5))
ck.verify(y - fld(y - z, 5) * 5, flm(y - z, 5) + z)
ck.verify(fld(y - z, 5) * 5 - y, 0 - flm(y - z, 5) - z)
ck.verify(y * 3 - fld(y, 2) * 6, flm(y, 2) * 3)
ck.verify(fld(y, 3) * 6 - y * 2, flm(y, 3) * (-2))
ck.verify(y * 5 - fld(y + z, 2) * 10, (flm(y + z, 2) - z) * 5)
ck.verify(y * 5 - fld(y - z, 2) * 10, (flm(y - z, 2) + z) * 5)
ck.verify(fld(y + z, 3) * 6 - y * 2, (z - flm(y + z, 3)) * 2)
ck.verify(fld(y - z, 3) * 6 - y * 2, (0 - flm(y - z, 3) - z) * 2)
ck.verify(5 * y - fld(y + z, 2) * 10, (flm(y + z, 2) - z) * 5)
ck.verify(5 * y - 10 * fld(y - z, 2), (flm(y - z, 2) + z) * 5)
ck.verify(6 * fld(y + z, 3) - y * 2, (z - flm(y + z, 3)) * 2)
ck.verify(fld(y - z, 3) * 6 - 2 * y, (0 - flm(y - z, 3) - z) * 2)
def test_mul_index_simplify():
ck = RewriteChecker()
......@@ -316,6 +365,50 @@ def test_div_index_simplify():
ck.verify((y + z * x) / z, y / z + x)
def test_floordiv_index_simplify():
# short name for floordiv
fld = tvm.floordiv
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify(fld(fld(x, 2), 3), fld(x, 6))
ck.verify(fld(fld(x, 2) + 1, 3), fld(x + 2, 6))
ck.verify(fld(x * 2, 4), fld(x, 2))
ck.verify(fld(x * 4, 2), x * 2)
ck.verify(fld(x * 4 + y, 2), x * 2 + fld(y, 2))
ck.verify(fld(tvm.min(x * 6, y), 2), tvm.min(x * 3, fld(y, 2)))
ck.verify(fld(tvm.max(x * 6, y), 2), tvm.max(x * 3, fld(y, 2)))
ck.verify(fld(y + x * 4, 2), fld(y, 2) + x * 2)
ck.verify(fld(tvm.min(y, x * 6), 2), tvm.min(fld(y, 2), x * 3))
ck.verify(fld(tvm.max(y, x * 6), 2), tvm.max(fld(y, 2), x * 3))
# 3-operands
ck.verify(fld(x * 6 + y + z, 2), x * 3 + fld(y + z, 2))
ck.verify(fld(x * 6 - y + (y + 3), 2), x * 3 + 1)
ck.verify(fld(x * 6 + (y + 3) - y, 2), x * 3 + 1)
ck.verify(fld(y + x * 6 + z, 2), x * 3 + fld(y + z, 2))
ck.verify(fld(x + 4, 2), fld(x, 2) + 2)
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(fld(x + y, x), fld(y, x) + 1)
ck.verify(fld(y + x, x), fld(y, x) + 1)
ck.verify(fld((x + y) + z, x), fld(y + z, x) + 1)
ck.verify(fld((y + x) + z, x), fld(y + z, x) + 1)
ck.verify(fld(y + (x + z), x), fld(y + z, x) + 1)
ck.verify(fld(y + (z + x), x), fld(y + z, x) + 1)
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(fld(x * y, y), x)
ck.verify(fld(y * x, y), x)
ck.verify(fld(x * z + y, z), x + fld(y, z))
ck.verify(fld(z * x + y, z), x + fld(y, z))
ck.verify(fld(y + x * z, z), fld(y, z) + x)
ck.verify(fld(y + z * x, z), fld(y, z) + x)
def test_mod_index_simplify():
ck = RewriteChecker()
......@@ -350,11 +443,34 @@ def test_mod_index_simplify():
ck.verify((nx * (-10) + y) % -2, y % 2)
ck.verify((x + ny * (-10)) % -2, x % 2)
def test_floormod_index_simplify():
# short name for floordiv
flm = tvm.floormod
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck = RewriteChecker()
x, y, nx, ny, z = tvm.var("x"), tvm.var("y"), tvm.var("nx"), tvm.var("ny"), tvm.var("z")
ck.verify(flm(x * 10, 2), 0)
ck.verify(flm(x * 10 + y, 2), flm(y, 2))
ck.verify(flm(x + 10, 2), flm(x, 2))
ck.verify(flm(x + y * 10, 2), flm(x, 2))
ck.verify(flm(x* 10 + 1 + y * 2 + 2, 2), 1)
ck.verify(flm(x * (-10), 2), 0)
ck.verify(flm(x * (-10) + y, 2), flm(y, 2))
ck.verify(flm(x + (-10), 2), flm(x, 2))
ck.verify(flm(x + y * (-10), 2), flm(x, 2))
def test_min_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
fld = tvm.floordiv
flm = tvm.floormod
# const int bound
ck.verify(tvm.min(x % 2, y % 2 + 10), x % 2)
ck.verify(tvm.min(flm(x, 2), flm(y, 2) + 10), flm(x, 2))
ck.verify(tvm.min(x + 1, x + 10), x + 1)
ck.verify(tvm.min(x + 111, x + 10), x + 10)
......@@ -363,13 +479,6 @@ def test_min_index_simplify():
ck.verify(tvm.min(1 - x, 2 - x), 1 - x)
ck.verify(tvm.min(3 - x, 2 - x), 2 - x)
ck.verify(tvm.min((x + 3) / 4 * 4, x), x)
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000))
ck.verify(tvm.min((x + 3) / 4 * 4, tvm.max(x, 4)), tvm.max(x, 4))
ck.verify(tvm.min(x, (x + 3) / 4 * 4), x)
ck.verify(tvm.min(tvm.max(x, 4), (x + 3) / 4 * 4), tvm.max(x, 4))
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True)
ck.verify(tvm.min(tvm.max(x, y), tvm.min(x, y)), tvm.min(x, y))
ck.verify(tvm.min(tvm.max(x, y), tvm.min(y, x)), tvm.min(x, y))
......@@ -406,17 +515,39 @@ def test_min_index_simplify():
ck.verify(tvm.min(tvm.min(x, 1), 10), tvm.min(x, 1))
ck.verify(tvm.min(tvm.min(x, 11), 10), tvm.min(x, 10))
ck.verify(tvm.min(x / 10, y / 10), tvm.min(x, y) / 10)
ck.verify(tvm.min(x / (-10), y / (-10)), tvm.max(x, y) / (-10))
ck.verify(tvm.min(x * 3, 9), tvm.min(x, 3) * 3)
ck.verify(tvm.min(3 - x, 2), 3 - tvm.max(x, 1))
# DivMod rules
# truc div
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000))
ck.verify(tvm.min((x + 3) / 4 * 4, x), x)
ck.verify(tvm.min((x + 3) / 4 * 4, tvm.max(x, 4)), tvm.max(x, 4))
ck.verify(tvm.min(x, (x + 3) / 4 * 4), x)
ck.verify(tvm.min(tvm.max(x, 4), (x + 3) / 4 * 4), tvm.max(x, 4))
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True)
ck.verify(tvm.min(x / 10, y / 10), tvm.min(x, y) / 10)
ck.verify(tvm.min(x / (-10), y / (-10)), tvm.max(x, y) / (-10))
# floor div
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True)
ck.verify(tvm.min(fld(x + 3, 4) * 4, x), x)
ck.verify(tvm.min(fld(x + 3, 4) * 4, tvm.max(x, 4)), tvm.max(x, 4))
ck.verify(tvm.min(x, fld(x + 3, 4) * 4), x)
ck.verify(tvm.min(x, fld(x, 4) * 4), fld(x, 4) * 4)
ck.verify(tvm.min(tvm.max(x, 4), fld(x + 3, 4) * 4), tvm.max(x, 4))
ck.verify(tvm.min(fld(x, 10), fld(y, 10)), fld(tvm.min(x, y), 10))
ck.verify(tvm.min(fld(x, (-10)), fld(y, (-10))), fld(tvm.max(x, y), (-10)))
def test_max_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
flm = tvm.floormod
fld = tvm.floordiv
# const int bound
ck.verify(tvm.max(x % 2, y % 2 + 10), y % 2 + 10)
ck.verify(tvm.max(flm(x, 2), flm(y, 2) + 10), flm(y, 2) + 10)
ck.verify(tvm.max(x + 1, x + 10), x + 10)
ck.verify(tvm.max(x + 111, x + 10), x + 111)
......@@ -425,8 +556,6 @@ def test_max_index_simplify():
ck.verify(tvm.max(1 - x, 2 - x), 2 - x)
ck.verify(tvm.max(3 - x, 2 - x), 3 - x)
ck.verify(tvm.max((x + 3) / 4 * 4, x), (x + 3) / 4 * 4)
ck.verify(tvm.max(tvm.min(x, y), tvm.max(x, y)), tvm.max(x, y))
ck.verify(tvm.max(tvm.min(x, y), tvm.max(y, x)), tvm.max(x, y))
......@@ -463,20 +592,36 @@ def test_max_index_simplify():
ck.verify(tvm.max(tvm.max(x, 1), 10), tvm.max(x, 10))
ck.verify(tvm.max(tvm.max(x, 11), 10), tvm.max(x, 11))
ck.verify(tvm.max(x / 10, y / 10), tvm.max(x, y) / 10)
ck.verify(tvm.max(x / (-10), y / (-10)), tvm.min(x, y) / (-10))
ck.verify(tvm.max(x * 3, 9), tvm.max(x, 3) * 3)
ck.verify(tvm.max(3 - x, 1), 3 - tvm.min(x, 2))
# DivMod rules
# truc div
ck.verify(tvm.max(x / 10, y / 10), tvm.max(x, y) / 10)
ck.verify(tvm.max(x / (-10), y / (-10)), tvm.min(x, y) / (-10))
ck.verify(tvm.max((x + 3) / 4 * 4, x), (x + 3) / 4 * 4)
# floordiv
ck.verify(tvm.max(fld(x, 10), fld(y, 10)), fld(tvm.max(x, y), 10))
ck.verify(tvm.max(fld(x, (-10)), fld(y, (-10))), fld(tvm.min(x, y), (-10)))
ck.verify(tvm.max(fld(x + 3, 4) * 4, x), fld(x + 3, 4) * 4)
ck.verify(tvm.max(fld(x, 4) * 4, x), x)
ck.verify(tvm.max(x, fld(x, 4) * 4), x)
def test_cmp_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
flm = tvm.floormod
fld = tvm.floordiv
# const int bound
ck.verify((x % 2 + 10).equal(0), tvm.const(0, "bool"))
ck.verify(tvm.expr.NE(x % 2 + 10, 0), tvm.const(1, "bool"))
ck.verify(x % 2 + 10 > 1, tvm.const(1, "bool"))
ck.verify(x % 2 + 10 <= 1, tvm.const(0, "bool"))
ck.verify(flm(x, 2) + 2 > 1, tvm.const(1, "bool"))
ck.verify(flm(x, 2) + 10 <= 1, tvm.const(0, "bool"))
ck.verify(x * 3 + 10 == 0, tvm.const(0, "bool"))
ck.verify(x * 3 + 10 != 0, tvm.const(1, "bool"))
......@@ -504,10 +649,7 @@ def test_cmp_simplify():
ck.verify(x * 4 >= 2, tvm.expr.LE(1, x))
ck.verify(x * 2 >= 50, tvm.expr.LE(25, x))
ck.verify(x / 2 < 3, x < 6)
ck.verify(x * 4 <= 2, x <= 0)
ck.verify(3 < x / 2, tvm.expr.LT(7, x))
ck.verify(x / 3 >= 0, tvm.expr.LE(-2, x))
ck.verify((0 - x * 3) <= 0, tvm.expr.LE(0, x))
ck.verify((0 - x * 3) >= 0, tvm.expr.LE(x, 0))
ck.verify(2 * x <= 0, x <= 0)
......@@ -544,6 +686,11 @@ def test_cmp_simplify():
ck.verify(x * (-2) <= -2, tvm.expr.LE(1, x))
ck.verify(x * (-2) <= -3, tvm.expr.LE(2, x))
# DivMod rules
# truc div
ck.verify(x / 2 < 3, x < 6)
ck.verify(3 < x / 2, tvm.expr.LT(7, x))
ck.verify(x / 3 >= 0, tvm.expr.LE(-2, x))
ck.verify(x / 2 >= 1, tvm.expr.LE(2, x))
ck.verify(x / 2 >= 0, tvm.expr.LE(-1, x))
ck.verify(x / 2 >= -1, tvm.expr.LE(-3, x))
......@@ -562,6 +709,29 @@ def test_cmp_simplify():
ck.verify((x + 2) / 4 * 4 >= x + y, tvm.expr.LE((x + 2) % 4 + y, 2))
ck.verify((x + 2) / 4 * 4 >= x - y, tvm.expr.LE((x + 2) % 4 + (-2), y))
# floor div
ck.verify(fld(x, 2) < 3, x < 6)
ck.verify(3 < fld(x, 2), tvm.expr.LT(7, x))
ck.verify(-3 < fld(x, 2), tvm.expr.LT(-5, x))
ck.verify(fld(x, 3) >= 0, tvm.expr.LE(0, x))
ck.verify(fld(x, 2) >= 1, tvm.expr.LE(2, x))
ck.verify(fld(x, 2) >= 0, tvm.expr.LE(0, x))
ck.verify(fld(x, 2) >= -1, tvm.expr.LE(-2, x))
ck.verify(fld(x, 2) <= 1, tvm.expr.LE(x, 3))
ck.verify(fld(x, 2) <= 0, tvm.expr.LE(x, 1))
ck.verify(fld(x, 2) <= -1, tvm.expr.LE(x, -1))
ck.verify(fld(x, 4) * 4 < x, tvm.expr.LT(0, flm(x, 4)))
ck.verify(fld(x, 4) * 4 >= x, tvm.expr.LE(flm(x, 4), 0))
ck.verify(fld(x, 4) * 4 < x + y, tvm.expr.LT(0, flm(x, 4) + y))
ck.verify(fld(x, 4) * 4 < x - y, tvm.expr.LT(y, flm(x, 4)))
ck.verify(fld(x + 2, 4) * 4 >= x, tvm.expr.LE(flm(x + 2, 4), 2))
ck.verify(fld(x + 2, 4) * 4 >= x + y, tvm.expr.LE(flm(x + 2, 4) + y, 2))
ck.verify(fld(x + 2, 4) * 4 >= x - y, tvm.expr.LE(flm(x + 2, 4) + (-2), y))
# End DivMod Rules
ck.verify(tvm.min(x, 11) < 10, x < 10)
ck.verify(tvm.min(x, 8) < 10, tvm.const(1, "bool"))
......@@ -630,6 +800,8 @@ def test_logical_simplify():
if __name__ == "__main__":
test_floordiv_index_simplify()
test_floormod_index_simplify()
test_cmp_simplify()
test_vector_simplify()
test_add_index_simplify()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment