Unverified Commit d1830964 by Tianqi Chen Committed by GitHub

[ARITH] Explicitly state truncdiv/mod in pattern matching. (#3986)

* [ARITH] Explicitly state truncdiv/mod in pattern matching.

* Fix the dependent cpp test
parent c48e1cc1
...@@ -333,6 +333,20 @@ TVM_DLL Expr operator||(Expr a, Expr b); ...@@ -333,6 +333,20 @@ TVM_DLL Expr operator||(Expr a, Expr b);
*/ */
TVM_DLL Expr operator!(Expr a); TVM_DLL Expr operator!(Expr a);
/*! /*!
* \brief compute division in C semantics.
*
* a / b as in C/C++.
*
* When operands are integers, it directly corresponds to truncdiv.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr div(Expr a, Expr b);
/*!
* \brief compute trunc(a / b) * \brief compute trunc(a / b)
* *
* This is the default integer division behavior in C. * This is the default integer division behavior in C.
...@@ -640,6 +654,21 @@ inline Expr make_zero(Type t) { ...@@ -640,6 +654,21 @@ inline Expr make_zero(Type t) {
return make_const(t, 0); return make_const(t, 0);
} }
/*!
* \brief Helper function to raise a compiler error about division ambiguity.
* \note The call to this function will always results in a compiler error.
* \tparam TA Any class type.
*/
template<typename TA>
inline void DivAmbiguityError(const TA& a) {
constexpr bool div_ambiguity = !std::is_class<TA>::value;
static_assert(div_ambiguity,
"TVM supports multiple types of integer divisions, "
"please call div, floordiv/floormod or truncdiv/truncmod directly "
"to avoid ambiguity in the code. "
"Checkout these functions in expr_operator.h.");
}
// additional const expression overloading // additional const expression overloading
#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
inline Expr Name(Expr& a, Expr b) { \ inline Expr Name(Expr& a, Expr b) { \
...@@ -688,12 +717,17 @@ TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*); ...@@ -688,12 +717,17 @@ TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(div);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*) TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*) TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=);
// integer related ops // integer related ops
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&);
......
...@@ -67,7 +67,7 @@ enum DivMode { ...@@ -67,7 +67,7 @@ enum DivMode {
inline Expr ModImpl(Expr a, Expr b, DivMode mode) { inline Expr ModImpl(Expr a, Expr b, DivMode mode) {
if (mode == kTruncDiv) { if (mode == kTruncDiv) {
return a % b; return truncmod(a, b);
} else { } else {
CHECK_EQ(mode, kFloorDiv); CHECK_EQ(mode, kFloorDiv);
return floormod(a, b); return floormod(a, b);
...@@ -76,7 +76,7 @@ inline Expr ModImpl(Expr a, Expr b, DivMode mode) { ...@@ -76,7 +76,7 @@ inline Expr ModImpl(Expr a, Expr b, DivMode mode) {
inline Expr DivImpl(Expr a, Expr b, DivMode mode) { inline Expr DivImpl(Expr a, Expr b, DivMode mode) {
if (mode == kTruncDiv) { if (mode == kTruncDiv) {
return a / b; return truncdiv(a, b);
} else { } else {
CHECK_EQ(mode, kFloorDiv); CHECK_EQ(mode, kFloorDiv);
return floordiv(a, b); return floordiv(a, b);
......
...@@ -93,6 +93,26 @@ inline bool WillOverflow<ir::Mod>(int64_t x, ...@@ -93,6 +93,26 @@ inline bool WillOverflow<ir::Mod>(int64_t x,
} }
/*! /*!
* \brief Peform trunc division of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t truncdiv(int64_t x, int64_t y) {
return x / y;
}
/*!
* \brief Compute the truncdiv remainder of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t truncmod(int64_t x, int64_t y) {
return x % y;
}
/*!
* \brief Peform floor division of two integers. * \brief Peform floor division of two integers.
* \param x The left operand. * \param x The left operand.
* \param y The right operand. * \param y The right operand.
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file modular_set.cc * \file modular_set.cc
* \brief Modular set analysis * \brief Modular set analysis
*/ */
...@@ -111,7 +110,8 @@ class ModularSetAnalyzer::Impl : ...@@ -111,7 +110,8 @@ class ModularSetAnalyzer::Impl :
PVar<Var> var; PVar<Var> var;
PVar<Integer> coeff, base; PVar<Integer> coeff, base;
// pattern match interesting constraints // pattern match interesting constraints
if (((var % coeff) == base).Match(constraint)) { if ((truncmod(var, coeff) == base).Match(constraint) ||
(floormod(var, coeff) == base).Match(constraint)) {
Entry entry(coeff.Eval()->value, base.Eval()->value); Entry entry(coeff.Eval()->value, base.Eval()->value);
return UpdateByIntersect(var.Eval(), entry); return UpdateByIntersect(var.Eval(), entry);
} }
......
...@@ -300,31 +300,41 @@ class PConstWithTypeLike : ...@@ -300,31 +300,41 @@ class PConstWithTypeLike :
}; };
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \ #define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \
template<typename TA, typename TB> \ template<typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> \ inline PBinaryExpr<NodeName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \ FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
CheckStep; \
return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \ return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
} \ } \
template<typename TA> \ template<typename TA> \
inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> > \ inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> > \
FuncName(const Pattern<TA>& a, int64_t b) { \ FuncName(const Pattern<TA>& a, int64_t b) { \
CheckStep; \
return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \ return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
} \ } \
template<typename TA> \ template<typename TA> \
inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> \ inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> \
FuncName(int64_t b, const Pattern<TA>& a) { \ FuncName(int64_t b, const Pattern<TA>& a) { \
CheckStep; \
return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \ return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
} }
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, )
// raise ambiguity error for operator overload of / and %
TVM_PATTERN_BINARY_OP_EX(operator/, ir::Div, DivAmbiguityError(a));
TVM_PATTERN_BINARY_OP_EX(operator%, ir::Mod, DivAmbiguityError(a));
// arithmetic expressions // arithmetic expressions
TVM_PATTERN_BINARY_OP(operator+, ir::Add); TVM_PATTERN_BINARY_OP(operator+, ir::Add);
TVM_PATTERN_BINARY_OP(operator-, ir::Sub); TVM_PATTERN_BINARY_OP(operator-, ir::Sub);
TVM_PATTERN_BINARY_OP(operator*, ir::Mul); TVM_PATTERN_BINARY_OP(operator*, ir::Mul);
TVM_PATTERN_BINARY_OP(operator/, ir::Div);
TVM_PATTERN_BINARY_OP(operator%, ir::Mod);
TVM_PATTERN_BINARY_OP(min, ir::Min); TVM_PATTERN_BINARY_OP(min, ir::Min);
TVM_PATTERN_BINARY_OP(max, ir::Max); TVM_PATTERN_BINARY_OP(max, ir::Max);
TVM_PATTERN_BINARY_OP(div, ir::Div);
TVM_PATTERN_BINARY_OP(truncdiv, ir::Div); TVM_PATTERN_BINARY_OP(truncdiv, ir::Div);
TVM_PATTERN_BINARY_OP(truncmod, ir::Mod); TVM_PATTERN_BINARY_OP(truncmod, ir::Mod);
TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv); TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
......
...@@ -194,7 +194,7 @@ Mutate_(const Add* op, const Expr& self) { ...@@ -194,7 +194,7 @@ Mutate_(const Add* op, const Expr& self) {
// DivMod rules // DivMod rules
// truc div // truc div
TVM_TRY_REWRITE((x / c1) * c1 + x % c1, x); TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x);
// floor div // floor div
TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x); TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x);
...@@ -208,7 +208,7 @@ Mutate_(const Add* op, const Expr& self) { ...@@ -208,7 +208,7 @@ Mutate_(const Add* op, const Expr& self) {
// DivMod rules // DivMod rules
// truc div // truc div
TVM_TRY_RECURSIVE_REWRITE((y % c1) + x * c1, x * c1 + (y % c1)); TVM_TRY_RECURSIVE_REWRITE(truncmod(y, c1) + x * c1, x * c1 + truncmod(y, c1));
// floor div // floor div
TVM_TRY_RECURSIVE_REWRITE(floormod(y, c1) + x * c1, x * c1 + floormod(y, c1)); TVM_TRY_RECURSIVE_REWRITE(floormod(y, c1) + x * c1, x * c1 + floormod(y, c1));
} }
...@@ -314,48 +314,49 @@ Mutate_(const Sub* op, const Expr& self) { ...@@ -314,48 +314,49 @@ Mutate_(const Sub* op, const Expr& self) {
// DivMod rules // DivMod rules
// trucdiv // trucdiv
// NOTE: c*(x/c) + x % c == x is true all division mode. // NOTE: c*(x/c) + x % c == x is true all division mode.
TVM_TRY_REWRITE_IF(x - (x / c1) * c1, x % c1, TVM_TRY_REWRITE_IF(x - truncdiv(x, c1) * c1, truncmod(x, c1),
c1.Eval()->value != 0); c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF((x / c1) * c1 - x, 0 - (x % c1), TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 - x, 0 - truncmod(x, c1),
c1.Eval()->value != 0); c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x - ((x + y) / c1) * c1, (x + y) % c1 - y, TVM_TRY_REWRITE_IF(x - (truncdiv(x + y, c1)) * c1, truncmod(x + y, c1) - y,
c1.Eval()->value != 0); c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(((x + y) / c1) * c1 - x, y - ((x + y) % c1), TVM_TRY_REWRITE_IF((truncdiv(x + y, c1)) * c1 - x, y - truncmod(x + y, c1),
c1.Eval()->value != 0); c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x - ((x - y) / c1) * c1, (x - y) % c1 + y, TVM_TRY_REWRITE_IF(x - truncdiv(x - y, c1) * c1, truncmod(x - y, c1) + y,
c1.Eval()->value != 0); c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(((x - y) / c1) * c1 - x, 0 - (x - y) % c1 - y, TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c1 - x, 0 - truncmod(x - y, c1) - y,
c1.Eval()->value != 0); c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x * c2 - (x / c1) * c3, (x % c1) * c2, TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x, c1) * c3, truncmod(x, c1) * c2,
c1.Eval()->value != 0 && c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF((x / c1) * c3 - x * c2, 0 - (x % c1) * c2, TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c3 - x * c2, 0 - truncmod(x, c1) * c2,
c1.Eval()->value != 0 && c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(x * c2 - ((x + y) / c1) * c3, ((x + y) % c1 - y) * c2, TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x + y, c1) * c3, (truncmod(x + y, c1) - y) * c2,
c1.Eval()->value != 0 && c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(((x + y) / c1) * c3 - x * c2, (y - ((x + y) % c1)) * c2, TVM_TRY_REWRITE_IF(truncdiv(x + y, c1) * c3 - x * c2, (y - truncmod(x + y, c1)) * c2,
c1.Eval()->value != 0 && c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(x * c2 - ((x - y) / c1) * c3, ((x - y) % c1 + y) * c2, TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x - y, c1) * c3, (truncmod(x - y, c1) + y) * c2,
c1.Eval()->value != 0 && c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(((x - y) / c1) * c3 - x * c2, (0 - (x - y) % c1 - y) * c2, TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c3 - x * c2, (0 - truncmod(x - y, c1) - y) * c2,
c1.Eval()->value != 0 && c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
// Proof in the case of floordiv, need positive condition. // Proof in the case of floordiv, need positive condition.
// let x = a * c3 + r // let x = a * c3 + r
// (x + c1) / c3 - x / c3 => (r + c1) / c3 // (x + c1) / c3 - x / c3 => (r + c1) / c3
TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3, // NOTE: the use of floormod(c2, c3) was intentional to simplify the const.
((x + ((c2 % c3) + c3) % c3) % c3 + (c1 - c2)) / c3, TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x + c2, c3),
truncdiv(truncmod(x + floormod(c2, c3), c3) + (c1 - c2), c3),
CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) && CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) &&
c1.Eval()->value >= c2.Eval()->value && c1.Eval()->value >= c2.Eval()->value &&
c3.Eval()->value > 0); c3.Eval()->value > 0);
TVM_TRY_REWRITE_IF((x + c1) / c3 - x / c3, TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x, c3),
(x % c3 + c1) / c3, truncdiv(truncmod(x, c3) + c1, c3),
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c3.Eval()->value > 0); c3.Eval()->value > 0);
...@@ -478,14 +479,15 @@ Mutate_(const Div* op, const Expr& self) { ...@@ -478,14 +479,15 @@ Mutate_(const Div* op, const Expr& self) {
// Vector rules // Vector rules
if (op->type.lanes() != 1) { if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) / broadcast(y, lanes), // NOTE: use div as the pattern also works for float.
broadcast(x / y, lanes)); TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(div(x, y), lanes));
// ramp / bcast // ramp / bcast
if ((ramp(b1, c1, lanes) / broadcast(c2, lanes)).Match(ret)) { if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) {
int64_t c1val = c1.Eval()->value; int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value; int64_t c2val = c2.Eval()->value;
if (c1val % c2val == 0) { if (c1val % c2val == 0) {
return ramp(b1 / c2, c1 / c2, lanes).Eval(); return ramp(div(b1, c2), div(c1, c2), lanes).Eval();
} }
// If all possible indices in ramp are the same. // If all possible indices in ramp are the same.
if (CanProveGreaterEqual(b1.Eval(), 0)) { if (CanProveGreaterEqual(b1.Eval(), 0)) {
...@@ -493,7 +495,7 @@ Mutate_(const Div* op, const Expr& self) { ...@@ -493,7 +495,7 @@ Mutate_(const Div* op, const Expr& self) {
int64_t ramp_min = bmod->base / c2val; int64_t ramp_min = bmod->base / c2val;
int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val; int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val;
if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) { if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
return broadcast(b1 / c2, lanes).Eval(); return broadcast(div(b1, c2), lanes).Eval();
} }
} }
} }
...@@ -508,73 +510,79 @@ Mutate_(const Div* op, const Expr& self) { ...@@ -508,73 +510,79 @@ Mutate_(const Div* op, const Expr& self) {
// parts of tvm which still assume euclidean div. In this simplifier we assume that the division // parts of tvm which still assume euclidean div. In this simplifier we assume that the division
// is truncated, so perform const folding again. // is truncated, so perform const folding again.
// NOTE: trunc div required // NOTE: trunc div required
if ((c1 / c2).Match(ret)) { if (truncdiv(c1, c2).Match(ret)) {
int64_t c1val = c1.Eval()->value; int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value; int64_t c2val = c2.Eval()->value;
return make_const(op->type, c1val / c2val); return make_const(op->type, truncdiv(c1val, c2val));
} }
// while it is always true for trunc div // while it is always true for trunc div
// restrict to common case(positive div) // restrict to common case(positive div)
TVM_TRY_REWRITE_IF((x / c1) / c2, x / (c1 * c2), TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1), c2), truncdiv(x, c1 * c2),
c1.Eval()->value > 0 && c2.Eval()->value > 0); c1.Eval()->value > 0 && c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF((x / c1 + c2) / c3, (x + c1 * c2) / (c1 * c3), TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1) + c2, c3), truncdiv(x + c1 * c2, c1 * c3),
c1.Eval()->value > 0 && c1.Eval()->value > 0 &&
c2.Eval()->value >= 0 && c2.Eval()->value >= 0 &&
c3.Eval()->value > 0 && c3.Eval()->value > 0 &&
CanProveGreaterEqual(x.Eval(), 0)); CanProveGreaterEqual(x.Eval(), 0));
if (((x * c1) / c2).Match(ret)) { if (truncdiv(x * c1, c2).Match(ret)) {
int64_t c1val = c1.Eval()->value; int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value; int64_t c2val = c2.Eval()->value;
if (c1val > 0 && c2val > 0) { if (c1val > 0 && c2val > 0) {
if (c1val % c2val == 0) return (x * (c1 / c2)).Eval(); if (c1val % c2val == 0) return (x * truncdiv(c1, c2)).Eval();
if (c2val % c1val == 0) return (x / (c2 / c1)).Eval(); if (c2val % c1val == 0) return truncdiv(x, truncdiv(c2, c1)).Eval();
} }
} }
TVM_TRY_REWRITE(x / x, OneWithTypeLike(x)); TVM_TRY_REWRITE(truncdiv(x, x), OneWithTypeLike(x));
TVM_TRY_REWRITE(x * c1 / x, c1); TVM_TRY_REWRITE(truncdiv(x * c1, x), c1);
TVM_TRY_REWRITE(c1 * x / x, c1); TVM_TRY_REWRITE(truncdiv(c1 * x, x), c1);
// Rules involving 2-operands. // Rules involving 2-operands.
TVM_TRY_REWRITE_IF((x * c1 + y) / c2, x * (c1 / c2) + y / c2, TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y, c2),
x * truncdiv(c1, c2) + truncdiv(y, c2),
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0)); CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(min(x * c1, y) / c2, min(x * (c1 / c2), y / c2), TVM_TRY_REWRITE_IF(truncdiv(min(x * c1, y), c2),
min(x * truncdiv(c1, c2), truncdiv(y, c2)),
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0)); CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(max(x * c1, y) / c2, max(x * (c1 / c2), y / c2), TVM_TRY_REWRITE_IF(truncdiv(max(x * c1, y), c2),
max(x * truncdiv(c1, c2), truncdiv(y, c2)),
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0)); CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF((y + x * c1) / c2, y / c2 + x * (c1 / c2), TVM_TRY_REWRITE_IF(truncdiv(y + x * c1, c2),
truncdiv(y, c2) + x * truncdiv(c1, c2),
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0)); CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(min(y, x * c1) / c2, min(y / c2, x * (c1 / c2)), TVM_TRY_REWRITE_IF(truncdiv(min(y, x * c1), c2),
min(truncdiv(y, c2), x * truncdiv(c1, c2)),
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0)); CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(max(y, x * c1) / c2, max(y / c2, x * (c1 / c2)), TVM_TRY_REWRITE_IF(truncdiv(max(y, x * c1), c2),
max(truncdiv(y, c2), x * truncdiv(c1, c2)),
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
...@@ -582,80 +590,89 @@ Mutate_(const Div* op, const Expr& self) { ...@@ -582,80 +590,89 @@ Mutate_(const Div* op, const Expr& self) {
CanProveGreaterEqual(y.Eval(), 0)); CanProveGreaterEqual(y.Eval(), 0));
// Rules involving 3-operands. // Rules involving 3-operands.
TVM_TRY_REWRITE_IF((x * c1 + y + z) / c2, x * (c1 / c2) + (y + z)/ c2, TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y + z, c2),
x * truncdiv(c1, c2) + truncdiv(y + z, c2),
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((y + z).Eval(), 0)); CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF((x * c1 - y + z) / c2, x * (c1 / c2) + (z - y)/ c2, TVM_TRY_REWRITE_IF(truncdiv(x * c1 - y + z, c2),
x * truncdiv(c1, c2) + truncdiv(z - y, c2),
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((z - y).Eval(), 0)); CanProveGreaterEqual((z - y).Eval(), 0));
TVM_TRY_REWRITE_IF((x * c1 + y - z) / c2, x * (c1 / c2) + (y - z)/ c2, TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y - z, c2),
x * truncdiv(c1, c2) + truncdiv(y - z, c2),
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((y - z).Eval(), 0)); CanProveGreaterEqual((y - z).Eval(), 0));
TVM_TRY_REWRITE_IF((y + x * c1 + z) / c2, x * (c1 / c2) + (y + z) / c2, TVM_TRY_REWRITE_IF(truncdiv(y + x * c1 + z, c2),
x * truncdiv(c1, c2) + truncdiv(y + z, c2),
c1.Eval()->value > 0 && c1.Eval()->value > 0 &&
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((y + z).Eval(), 0)); CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF((x + c1) / c2, x / c2 + c1 / c2, TVM_TRY_REWRITE_IF(truncdiv(x + c1, c2),
truncdiv(x, c2) + truncdiv(c1, c2),
c1.Eval()->value > 0 && c1.Eval()->value > 0 &&
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0)); CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF((x + y) / x, y / x + 1, TVM_TRY_REWRITE_IF(truncdiv(x + y, x), truncdiv(y, x) + 1,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0)); CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF((y + x) / x, y / x + 1, TVM_TRY_REWRITE_IF(truncdiv(y + x, x), truncdiv(y, x) + 1,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0)); CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(((x + y) + z) / x, (y + z) / x + 1, TVM_TRY_REWRITE_IF(truncdiv((x + y) + z, x),
truncdiv(y + z, x) + 1,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((y + z).Eval(), 0)); CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF(((y + x) + z) / x, (y + z) / x + 1, TVM_TRY_REWRITE_IF(truncdiv((y + x) + z, x),
truncdiv(y + z, x) + 1,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((y + z).Eval(), 0)); CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF((y + (z + x)) / x, (y + z) / x + 1, TVM_TRY_REWRITE_IF(truncdiv(y + (z + x), x),
truncdiv(y + z, x) + 1,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((y + z).Eval(), 0)); CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF((y + (x + z)) / x, (y + z) / x + 1, TVM_TRY_REWRITE_IF(truncdiv(y + (x + z), x),
truncdiv(y + z, x) + 1,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual((y + z).Eval(), 0)); CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF((x * y) / y, x, TVM_TRY_REWRITE_IF(truncdiv(x * y, y), x,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0)); CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF((y * x) / y, x, TVM_TRY_REWRITE_IF(truncdiv(y * x, y), x,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0)); CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF((x * z + y) / z, x + y / z, TVM_TRY_REWRITE_IF(truncdiv(x * z + y, z), x + truncdiv(y, z),
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) &&
CanProveGreaterEqual(z.Eval(), 0)); CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF((z * x + y) / z, x + y / z, TVM_TRY_REWRITE_IF(truncdiv(z * x + y, z), x + truncdiv(y, z),
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) &&
CanProveGreaterEqual(z.Eval(), 0)); CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF((y + x * z) / z, y / z + x, TVM_TRY_REWRITE_IF(truncdiv(y + x * z, z), truncdiv(y, z) + x,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) &&
CanProveGreaterEqual(z.Eval(), 0)); CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF((y + z * x) / z, y / z + x, TVM_TRY_REWRITE_IF(truncdiv(y + z * x, z), truncdiv(y, z) + x,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) &&
CanProveGreaterEqual(z.Eval(), 0)); CanProveGreaterEqual(z.Eval(), 0));
...@@ -679,15 +696,15 @@ Mutate_(const Mod* op, const Expr& self) { ...@@ -679,15 +696,15 @@ Mutate_(const Mod* op, const Expr& self) {
// Vector rules // Vector rules
if (op->type.lanes() != 1) { if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) % broadcast(y, lanes), TVM_TRY_REWRITE(truncmod(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(x % y, lanes)); broadcast(truncmod(x, y), lanes));
// ramp % bcast // ramp % bcast
if ((ramp(b1, c1, lanes) % broadcast(c2, lanes)).Match(ret)) { if (truncmod(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) {
int64_t c1val = c1.Eval()->value; int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value; int64_t c2val = c2.Eval()->value;
if (c1val % c2val == 0) { if (c1val % c2val == 0) {
return broadcast(b1 % c2, lanes).Eval(); return broadcast(truncmod(b1, c2), lanes).Eval();
} }
// If all possible indices in ramp are the same. // If all possible indices in ramp are the same.
if (CanProveGreaterEqual(b1.Eval(), 0)) { if (CanProveGreaterEqual(b1.Eval(), 0)) {
...@@ -696,9 +713,10 @@ Mutate_(const Mod* op, const Expr& self) { ...@@ -696,9 +713,10 @@ Mutate_(const Mod* op, const Expr& self) {
int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val; int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val;
if (bmod->coeff % c2val == 0) { if (bmod->coeff % c2val == 0) {
if (ramp_min == ramp_max) { if (ramp_min == ramp_max) {
return ramp(bmod->base % c2, c1, lanes).Eval(); return ramp(truncmod(bmod->base, c2), c1, lanes).Eval();
} else { } else {
return (ramp(bmod->base % c2, c1, lanes) % broadcast(c2, lanes)).Eval(); return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes),
broadcast(c2, lanes)).Eval();
} }
} }
} }
...@@ -709,23 +727,23 @@ Mutate_(const Mod* op, const Expr& self) { ...@@ -709,23 +727,23 @@ Mutate_(const Mod* op, const Expr& self) {
// Be-aware of the division rules: // Be-aware of the division rules:
// We adopt the default C division uses truncation instead of floordiv. // We adopt the default C division uses truncation instead of floordiv.
// This means most rules need to check non-negativeness of the operands. // This means most rules need to check non-negativeness of the operands.
TVM_TRY_REWRITE_IF((x * c1) % c2, ZeroWithTypeLike(x), TVM_TRY_REWRITE_IF(truncmod(x * c1, c2), ZeroWithTypeLike(x),
c2.Eval()->value != 0 && c2.Eval()->value != 0 &&
c1.Eval()->value % c2.Eval()->value == 0); c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF((x * c1 + y) % c2, y % c2, TVM_TRY_REWRITE_IF(truncmod(x * c1 + y, c2), truncmod(y, c2),
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual((x * c1).Eval(), 0) && CanProveGreaterEqual((x * c1).Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0)); CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF((x + c1) % c2, x % c2, TVM_TRY_REWRITE_IF(truncmod(x + c1, c2), truncmod(x, c2),
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c1.Eval()->value % c2.Eval()->value == 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0)); CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF((x + y * c1) % c2, x % c2, TVM_TRY_REWRITE_IF(truncmod(x + y * c1, c2), truncmod(x, c2),
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
...@@ -733,18 +751,18 @@ Mutate_(const Mod* op, const Expr& self) { ...@@ -733,18 +751,18 @@ Mutate_(const Mod* op, const Expr& self) {
// canonicalization: x % c == x % (-c) for truncated division // canonicalization: x % c == x % (-c) for truncated division
// NOTE: trunc div required // NOTE: trunc div required
TVM_TRY_RECURSIVE_REWRITE_IF(x % c1, TVM_TRY_RECURSIVE_REWRITE_IF(truncmod(x, c1),
x % PConst<Expr>(make_const(op->type, -c1.Eval()->value)), truncmod(x, PConst<Expr>(make_const(op->type, -c1.Eval()->value))),
c1.Eval()->value < 0); c1.Eval()->value < 0);
// try modular analysis // try modular analysis
if ((x % c1).Match(ret)) { if (truncmod(x, c1).Match(ret)) {
ModularSet mod = analyzer_->modular_set(x.Eval()); ModularSet mod = analyzer_->modular_set(x.Eval());
int64_t c1val = c1.Eval()->value; int64_t c1val = c1.Eval()->value;
if (mod->coeff % c1val == 0 && if (mod->coeff % c1val == 0 &&
c1val > 0 && c1val > 0 &&
CanProveGreaterEqual(x.Eval(), 0)) { CanProveGreaterEqual(x.Eval(), 0)) {
return (mod->base % c1).Eval(); return truncmod(mod->base, c1).Eval();
} }
} }
} }
...@@ -798,7 +816,7 @@ Mutate_(const FloorDiv* op, const Expr& self) { ...@@ -798,7 +816,7 @@ Mutate_(const FloorDiv* op, const Expr& self) {
int64_t c2val = c2.Eval()->value; int64_t c2val = c2.Eval()->value;
if (c1val > 0 && c2val > 0) { if (c1val > 0 && c2val > 0) {
if (c1val % c2val == 0) return (x * floordiv(c1, c2)).Eval(); if (c1val % c2val == 0) return (x * floordiv(c1, c2)).Eval();
if (c2val % c1val == 0) return (floordiv(x, floordiv(c2, c1))).Eval(); if (c2val % c1val == 0) return floordiv(x, floordiv(c2, c1)).Eval();
} }
} }
...@@ -1025,18 +1043,18 @@ Mutate_(const Min* op, const Expr& self) { ...@@ -1025,18 +1043,18 @@ Mutate_(const Min* op, const Expr& self) {
// DivMod rules // DivMod rules
// Divide up rounding: truc div // Divide up rounding: truc div
// NOTE: trucdiv(x, y) >= floordiv(x, y) // NOTE: trucdiv(x, y) >= floordiv(x, y)
TVM_TRY_REWRITE_IF(min(((x + c1) / c2) * c2, x), x, TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, x), x,
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value); c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(((x + c1) / c2) * c2, max(x, c2)), max(x, c2), TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, max(x, c2)), max(x, c2),
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value && c1.Eval()->value + 1 == c2.Eval()->value &&
CanProveGreaterEqual(x.Eval(), 0)); CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(min(x, ((x + c1) / c2) * c2), x, TVM_TRY_REWRITE_IF(min(x, truncdiv(x + c1, c2) * c2), x,
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value); c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(max(x, c2), ((x + c1) / c2) * c2), max(x, c2), TVM_TRY_REWRITE_IF(min(max(x, c2), truncdiv(x + c1, c2) * c2), max(x, c2),
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value && c1.Eval()->value + 1 == c2.Eval()->value &&
CanProveGreaterEqual(x.Eval(), 0)); CanProveGreaterEqual(x.Eval(), 0));
...@@ -1104,11 +1122,11 @@ Mutate_(const Min* op, const Expr& self) { ...@@ -1104,11 +1122,11 @@ Mutate_(const Min* op, const Expr& self) {
TVM_TRY_REWRITE(min(min(x, c1), c2), min(x, min(c1, c2))); TVM_TRY_REWRITE(min(min(x, c1), c2), min(x, min(c1, c2)));
// scaling rule // scaling rule
if (min(x / c1, y / c1).Match(ret)) { if (min(truncdiv(x, c1), truncdiv(y, c1)).Match(ret)) {
if (c1.Eval()->value > 0) { if (c1.Eval()->value > 0) {
return (min(x, y) / c1).Eval(); return truncdiv(min(x, y), c1).Eval();
} else { } else {
return (max(x, y) / c1).Eval(); return truncdiv(max(x, y), c1).Eval();
} }
} }
if (min(floordiv(x, c1), floordiv(y, c1)).Match(ret)) { if (min(floordiv(x, c1), floordiv(y, c1)).Match(ret)) {
...@@ -1210,10 +1228,12 @@ Mutate_(const Max* op, const Expr& self) { ...@@ -1210,10 +1228,12 @@ Mutate_(const Max* op, const Expr& self) {
// DivMod rules // DivMod rules
// Divide up rounding: truc div // Divide up rounding: truc div
// NOTE: trucdiv(x, y) >= floordiv(x, y) // NOTE: trucdiv(x, y) >= floordiv(x, y)
TVM_TRY_REWRITE_IF(max(((x + c1) / c2) * c2, x), ((x + c1) / c2) * c2, TVM_TRY_REWRITE_IF(max(truncdiv(x + c1, c2) * c2, x),
truncdiv(x + c1, c2) * c2,
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value); c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(x, ((x + c1) / c2) * c2), ((x + c1) / c2) * c2, TVM_TRY_REWRITE_IF(max(x, truncdiv(x + c1, c2) * c2),
truncdiv(x + c1, c2) * c2,
c2.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value); c1.Eval()->value + 1 == c2.Eval()->value);
...@@ -1276,11 +1296,11 @@ Mutate_(const Max* op, const Expr& self) { ...@@ -1276,11 +1296,11 @@ Mutate_(const Max* op, const Expr& self) {
TVM_TRY_REWRITE(max(max(x, c1), c2), max(x, max(c1, c2))); TVM_TRY_REWRITE(max(max(x, c1), c2), max(x, max(c1, c2)));
// scaling rule // scaling rule
if (max(x / c1, y / c1).Match(ret)) { if (max(truncdiv(x, c1), truncdiv(y, c1)).Match(ret)) {
if (c1.Eval()->value > 0) { if (c1.Eval()->value > 0) {
return (max(x, y) / c1).Eval(); return truncdiv(max(x, y), c1).Eval();
} else { } else {
return (min(x, y) / c1).Eval(); return truncdiv(min(x, y), c1).Eval();
} }
} }
if (max(floordiv(x, c1), floordiv(y, c1)).Match(ret)) { if (max(floordiv(x, c1), floordiv(y, c1)).Match(ret)) {
...@@ -1425,70 +1445,70 @@ Mutate_(const LT* op, const Expr& self) { ...@@ -1425,70 +1445,70 @@ Mutate_(const LT* op, const Expr& self) {
// constant cancelation: only need to make use of one mod // constant cancelation: only need to make use of one mod
// truc div // truc div
TVM_TRY_REWRITE_IF(x * c2 < c1, x < (c1 - 1) / c2 + 1, TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1 - 1, c2) + 1,
c1.Eval()->value > 0 && c1.Eval()->value > 0 &&
c2.Eval()->value > 0); c2.Eval()->value > 0);
// NOTE: trunc div required // NOTE: trunc div required
TVM_TRY_REWRITE_IF(x * c2 < c1, x < c1 / c2, TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1, c2),
c1.Eval()->value <= 0 && c1.Eval()->value <= 0 &&
c2.Eval()->value > 0); c2.Eval()->value > 0);
// NOTE: trunc div required (euclidean is ok too, floored is not) // NOTE: trunc div required (euclidean is ok too, floored is not)
TVM_TRY_REWRITE_IF(x * c2 < c1, (c1 - 1) / c2 - 1 < x, TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1 - 1, c2) - 1 < x,
c1.Eval()->value > 0 && c1.Eval()->value > 0 &&
c2.Eval()->value < 0); c2.Eval()->value < 0);
// NOTE: trunc div required (floored is ok too, euclidean is not) // NOTE: trunc div required (floored is ok too, euclidean is not)
TVM_TRY_REWRITE_IF(x * c2 < c1, c1 / c2 < x, TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1, c2) < x,
c1.Eval()->value <= 0 && c1.Eval()->value <= 0 &&
c2.Eval()->value < 0); c2.Eval()->value < 0);
// NOTE: trunc div required // NOTE: trunc div required
TVM_TRY_REWRITE_IF(c1 < x * c2, (c1 + 1) / c2 - 1 < x, TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1 + 1, c2) - 1 < x,
c1.Eval()->value < 0 && c1.Eval()->value < 0 &&
c2.Eval()->value > 0); c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(c1 < x * c2, c1 / c2 < x, TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1, c2) < x,
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c2.Eval()->value > 0); c2.Eval()->value > 0);
// NOTE: trunc div required (floored is ok too, euclidean is not) // NOTE: trunc div required (floored is ok too, euclidean is not)
TVM_TRY_REWRITE_IF(c1 < x * c2, x < (c1 + 1) / c2 + 1, TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1 + 1, c2) + 1,
c1.Eval()->value < 0 && c1.Eval()->value < 0 &&
c2.Eval()->value < 0); c2.Eval()->value < 0);
// NOTE: trunc div required (euclidean is ok too, floored is not) // NOTE: trunc div required (euclidean is ok too, floored is not)
TVM_TRY_REWRITE_IF(c1 < x * c2, x < c1 / c2, TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1, c2),
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c2.Eval()->value < 0); c2.Eval()->value < 0);
// DivMod rules // DivMod rules
// trucdiv // trucdiv
TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2, TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, x < c1 * c2,
c1.Eval()->value > 0 && c1.Eval()->value > 0 &&
c2.Eval()->value > 0); c2.Eval()->value > 0);
// NOTE: trunc div required // NOTE: trunc div required
TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * (c2 - 1) + 1, TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, x < c1 * (c2 - 1) + 1,
c1.Eval()->value > 0 && c1.Eval()->value > 0 &&
c2.Eval()->value <= 0); c2.Eval()->value <= 0);
TVM_TRY_REWRITE_IF(c1 < x / c2, (c1 + 1) * c2 - 1 < x, TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), (c1 + 1) * c2 - 1 < x,
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c2.Eval()->value > 0); c2.Eval()->value > 0);
// NOTE: trunc div required // NOTE: trunc div required
TVM_TRY_REWRITE_IF(c1 < x / c2, c1 * c2 < x, TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), c1 * c2 < x,
c1.Eval()->value < 0 && c1.Eval()->value < 0 &&
c2.Eval()->value > 0); c2.Eval()->value > 0);
// invariance for any div mod: x - (x / c1) * c1 == x % c1 // invariance for any div mod: x - (x / c1) * c1 == x % c1
TVM_TRY_REWRITE_IF((x / c1) * c1 < x, 0 < x % c1, TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x, 0 < truncmod(x, c1),
c1.Eval()->value > 0); c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF((x / c1) * c1 < x + y, 0 < x % c1 + y, TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x + y, 0 < truncmod(x, c1) + y,
c1.Eval()->value > 0); c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF((x / c1) * c1 < x - y, y < x % c1, TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x - y, y < truncmod(x, c1),
c1.Eval()->value > 0); c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(((x + c2) / c1) * c1 < x, TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x,
c2 < (x + c2) % c1, c2 < truncmod(x + c2, c1),
c1.Eval()->value > 0); c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(((x + c2) / c1) * c1 < x + y, TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x + y,
c2 < (x + c2) % c1 + y, c2 < truncmod(x + c2, c1) + y,
c1.Eval()->value > 0); c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(((x + c2) / c1) * c1 < x - y, TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x - y,
y < (x + c2) % c1 + (0 - c2), y < truncmod(x + c2, c1) + (0 - c2),
c1.Eval()->value > 0); c1.Eval()->value > 0);
// floordiv // floordiv
......
...@@ -178,13 +178,19 @@ Expr operator*(Expr a, Expr b) { ...@@ -178,13 +178,19 @@ Expr operator*(Expr a, Expr b) {
return ir::Mul::make(a, b); return ir::Mul::make(a, b);
} }
Expr truncdiv(Expr a, Expr b) { Expr div(Expr a, Expr b) {
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Div>(a, b); Expr ret = arith::TryConstFold<ir::Div>(a, b);
if (ret.defined()) return ret; if (ret.defined()) return ret;
return ir::Div::make(a, b); return ir::Div::make(a, b);
} }
Expr truncdiv(Expr a, Expr b) {
CHECK(a.type().is_int() || a.type().is_uint());
CHECK(b.type().is_int() || b.type().is_uint());
return div(a, b);
}
Expr truncmod(Expr a, Expr b) { Expr truncmod(Expr a, Expr b) {
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Mod>(a, b); Expr ret = arith::TryConstFold<ir::Mod>(a, b);
...@@ -193,7 +199,7 @@ Expr truncmod(Expr a, Expr b) { ...@@ -193,7 +199,7 @@ Expr truncmod(Expr a, Expr b) {
} }
Expr operator/(Expr a, Expr b) { Expr operator/(Expr a, Expr b) {
return truncdiv(a, b); return div(a, b);
} }
Expr operator%(Expr a, Expr b) { Expr operator%(Expr a, Expr b) {
......
...@@ -47,9 +47,9 @@ TEST(Pattern, Basic) { ...@@ -47,9 +47,9 @@ TEST(Pattern, Basic) {
} }
CHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1)))); CHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1))));
CHECK((px + min(py, px)).Match(z + min(y, z))); CHECK((px + min(py, px)).Match(z + min(y, z)));
CHECK((px + py / (px * py)).Match(x + 2 / (x * 2))); CHECK((px + truncdiv(py, px * py)).Match(x + truncdiv(2, x * 2)));
CHECK((px - py % (px * pz)).Match(x - 2 % (x * 2))); CHECK((px - truncmod(py, px * pz)).Match(x - truncmod(2, x * 2)));
CHECK((px - py % (px * PConst<Expr>(2))).Match(x - 2 % (x * 2))); CHECK((px - floormod(py, px * PConst<Expr>(2))).Match(x - floormod(2, x * 2)));
// logicals // logicals
CHECK((px == pz).Match(x == 1)); CHECK((px == pz).Match(x == 1));
......
...@@ -56,24 +56,26 @@ def test_vector_simplify(): ...@@ -56,24 +56,26 @@ def test_vector_simplify():
tvm.expr.Ramp(x * 2, 8, 4)) tvm.expr.Ramp(x * 2, 8, 4))
## DivMod rules ## DivMod rules
tdiv = tvm.truncdiv
tmod = tvm.truncmod
# truc div # truc div
ck.verify(y.astype("int32x2") / x.astype("int32x2"), ck.verify(tdiv(y.astype("int32x2"), x.astype("int32x2")),
(y / x).astype("int32x2")) tdiv(y, x).astype("int32x2"))
ck.verify(tvm.expr.Ramp(x, 4, 4) / 2, ck.verify(tdiv(tvm.expr.Ramp(x, 4, 4), 2),
tvm.expr.Ramp(x/ 2, 2, 4)) tvm.expr.Ramp(tdiv(x, 2), 2, 4))
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) / 8, ck.verify(tdiv(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8),
(x).astype("int32x4")) (x).astype("int32x4"))
ck.verify(tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8, ck.verify(tdiv(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8),
tvm.expr.Ramp(x * 8 + 15, 1, 4) / 8) tdiv(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8))
ck.verify(y.astype("int32x2") % x.astype("int32x2"), ck.verify(tmod(y.astype("int32x2"), x.astype("int32x2")),
(y % x).astype("int32x2")) tmod(y, x).astype("int32x2"))
ck.verify(tvm.expr.Ramp(x, 4, 4) % 2, ck.verify(tmod(tvm.expr.Ramp(x, 4, 4), 2),
tvm.expr.Broadcast(x % 2, 4)) tvm.expr.Broadcast(tmod(x, 2), 4))
ck.verify(tvm.expr.Ramp(x * 8 + 1, 1, 4) % 8, ck.verify(tmod(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8),
tvm.expr.Ramp(1, 1, 4)) tvm.expr.Ramp(1, 1, 4))
ck.verify(tvm.expr.Ramp(x * 8 + 1, 15, 4) % 8, ck.verify(tmod(tvm.expr.Ramp(x * 8 + 1, 15, 4), 8),
tvm.expr.Ramp(1, 15, 4) % 8) tmod(tvm.expr.Ramp(1, 15, 4), 8))
# floor div # floor div
fld = tvm.floordiv fld = tvm.floordiv
...@@ -187,10 +189,12 @@ def test_add_index_simplify(): ...@@ -187,10 +189,12 @@ def test_add_index_simplify():
ck.verify(x + 2 + 3 + 4 + x * 3, x * 4 + 9); ck.verify(x + 2 + 3 + 4 + x * 3, x * 4 + 9);
# DivMod rules # DivMod rules
tdiv = tvm.truncdiv
tmod = tvm.truncmod
# truc div # truc div
ck.verify(y * (x % 8) + 10 * (x % 8), (x % 8) * (y + 10)) ck.verify(y * tmod(x, 8) + 10 * tmod(x, 8), tmod(x, 8) * (y + 10))
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1, 1000), override=True) ck.analyzer.update(x, tvm.arith.ConstIntBound(-1, 1000), override=True)
ck.verify((x / 8) * 8 + x % 8, x) ck.verify(tdiv(x, 8) * 8 + tmod(x, 8), x)
# floor div # floor div
fld = tvm.floordiv fld = tvm.floordiv
...@@ -256,31 +260,33 @@ def test_sub_index_simplify(): ...@@ -256,31 +260,33 @@ def test_sub_index_simplify():
# DivMod patterns # DivMod patterns
# truc div # truc div
tdiv = tvm.truncdiv
tmod = tvm.truncmod
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(x - (x / 3) * 3, x % 3) ck.verify(x - tdiv(x, 3) * 3, tmod(x, 3))
ck.verify((x + 5) / 3 - x / 3, ((x % 3) + 5)/ 3) ck.verify(tdiv(x + 5, 3) - tdiv(x, 3), tdiv(tmod(x, 3) + 5, 3))
ck.verify((x + 5) / 3 - (x + 1) / 3, (((x + 1) % 3) + 4)/ 3) ck.verify(tdiv(x + 5, 3) - tdiv(x + 1, 3), tdiv(tmod(x + 1, 3) + 4, 3))
ck.verify(y - (y / (-5)) * (-5), y % 5) ck.verify(y - tdiv(y, (-5)) * (-5), tmod(y, 5))
ck.verify((y / 3) * 3 - y, 0 - y % 3) ck.verify(tdiv(y, 3) * 3 - y, 0 - tmod(y, 3))
ck.verify(y - ((y - 6) / 5) * 5, (y + (-6)) % 5 + 6) ck.verify(y - tdiv(y - 6, 5) * 5, tmod(y + (-6), 5) + 6)
ck.verify(((y - 6) / 5) * 5 - y, (-6) - (y + (-6)) % 5) ck.verify(tdiv(y - 6, 5) * 5 - y, (-6) - tmod(y + (-6), 5))
ck.verify(y - ((y + z) / 5) * 5, (y + z) % 5 - z) ck.verify(y - tdiv(y + z, 5) * 5, tmod(y + z, 5) - z)
ck.verify(((y + z) / 5) * 5 - y, z - (y + z) % 5) ck.verify(tdiv(y + z, 5) * 5 - y, z - tmod(y + z, 5))
ck.verify(y - ((y - z) / 5) * 5, (y - z) % 5 + z) ck.verify(y - tdiv(y - z, 5) * 5, tmod(y - z, 5) + z)
ck.verify(((y - z) / 5) * 5 - y, 0 - (y - z) % 5 - z) ck.verify(tdiv(y - z, 5) * 5 - y, 0 - tmod(y - z, 5) - z)
ck.verify(y * 3 - (y / 2) * 6, (y % 2) * 3) ck.verify(y * 3 - tdiv(y, 2) * 6, tmod(y, 2) * 3)
ck.verify((y / 3) * 6 - y * 2, (y % 3) * (-2)) ck.verify(tdiv(y, 3) * 6 - y * 2, tmod(y, 3) * (-2))
ck.verify(y * 5 - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5) ck.verify(y * 5 - tdiv(y + z, 2) * 10, (tmod(y + z, 2) - z) * 5)
ck.verify(y * 5 - ((y - z) / 2) * 10, ((y - z) % 2 + z) * 5) ck.verify(y * 5 - tdiv(y - z, 2) * 10, (tmod(y - z, 2) + z) * 5)
ck.verify(((y + z) / 3) * 6 - y * 2, (z - (y + z) % 3) * 2) ck.verify(tdiv(y + z, 3) * 6 - y * 2, (z - tmod(y + z, 3)) * 2)
ck.verify(((y - z) / 3) * 6 - y * 2, (0 - (y - z) % 3 - z) * 2) ck.verify(tdiv(y - z, 3) * 6 - y * 2, (0 - tmod(y - z, 3) - z) * 2)
ck.verify(5 * y - ((y + z) / 2) * 10, ((y + z) % 2 - z) * 5) ck.verify(5 * y - tdiv(y + z, 2) * 10, (tmod(y + z, 2) - z) * 5)
ck.verify(5 * y - 10 * ((y - z) / 2), ((y - z) % 2 + z) * 5) ck.verify(5 * y - 10 * tdiv(y - z, 2), (tmod(y - z, 2) + z) * 5)
ck.verify(6 * ((y + z) / 3) - y * 2, (z - (y + z) % 3) * 2) ck.verify(6 * tdiv(y + z, 3) - y * 2, (z - tmod(y + z, 3)) * 2)
ck.verify(((y - z) / 3) * 6 - 2 * y, (0 - (y - z) % 3 - z) * 2) ck.verify(tdiv(y - z, 3) * 6 - 2 * y, (0 - tmod(y - z, 3) - z) * 2)
# floor div # floor div
fld = tvm.floordiv fld = tvm.floordiv
...@@ -323,46 +329,48 @@ def test_mul_index_simplify(): ...@@ -323,46 +329,48 @@ def test_mul_index_simplify():
def test_div_index_simplify(): def test_div_index_simplify():
ck = RewriteChecker() ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
tdiv = tvm.truncdiv
tmod = tvm.truncmod
ck.verify(x / x, 1) ck.verify(tdiv(x, x), 1)
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(x / 2 / 3, x / 6) ck.verify(tdiv(tdiv(x, 2), 3), tdiv(x, 6))
ck.verify((x / 2 + 1) / 3, (x + 2) / 6) ck.verify(tdiv(tdiv(x, 2) + 1, 3), tdiv(x + 2, 6))
ck.verify(x * 2 / 4, x / 2) ck.verify(tdiv(x * 2, 4), tdiv(x, 2))
ck.verify(x * 4 / 2, x * 2) ck.verify(tdiv(x * 4, 2), x * 2)
ck.verify((x * 4 + y) / 2, x * 2 + y / 2) ck.verify(tdiv(x * 4 + y, 2), x * 2 + tdiv(y, 2))
ck.verify(tvm.min(x * 6, y) / 2, tvm.min(x * 3, y / 2)) ck.verify(tdiv(tvm.min(x * 6, y), 2), tvm.min(x * 3, tdiv(y, 2)))
ck.verify(tvm.max(x * 6, y) / 2, tvm.max(x * 3, y / 2)) ck.verify(tdiv(tvm.max(x * 6, y), 2), tvm.max(x * 3, tdiv(y, 2)))
ck.verify((y + x * 4) / 2, y / 2 + x * 2) ck.verify(tdiv(y + x * 4, 2), tdiv(y, 2) + x * 2)
ck.verify(tvm.min(y, x * 6) / 2, tvm.min(y / 2, x * 3)) ck.verify(tdiv(tvm.min(y, x * 6), 2), tvm.min(tdiv(y, 2), x * 3))
ck.verify(tvm.max(y, x * 6) / 2, tvm.max(y / 2, x * 3)) ck.verify(tdiv(tvm.max(y, x * 6), 2), tvm.max(tdiv(y, 2), x * 3))
# 3-operands # 3-operands
ck.verify((x * 6 + y + z) / 2, x * 3 + (y + z) / 2) ck.verify(tdiv(x * 6 + y + z, 2), x * 3 + tdiv(y + z, 2))
ck.verify((x * 6 - y + (y + 3)) / 2, x * 3 + 1) ck.verify(tdiv(x * 6 - y + (y + 3), 2), x * 3 + 1)
ck.verify((x * 6 + (y + 3) - y) / 2, x * 3 + 1) ck.verify(tdiv(x * 6 + (y + 3) - y, 2), x * 3 + 1)
ck.verify((y + x * 6 + z) / 2, x * 3 + (y + z) / 2) ck.verify(tdiv(y + x * 6 + z, 2), x * 3 + tdiv(y + z, 2))
ck.verify((x + 4) / 2, x / 2 + 2) ck.verify(tdiv(x + 4, 2), tdiv(x, 2) + 2)
ck.verify((x + y) / x, y / x + 1) ck.verify(tdiv(x + y, x), tdiv(y, x) + 1)
ck.verify((y + x) / x, y / x + 1) ck.verify(tdiv(y + x, x), tdiv(y, x) + 1)
ck.verify(((x + y) + z) / x, (y + z) / x + 1) ck.verify(tdiv((x + y) + z, x), tdiv(y + z, x) + 1)
ck.verify(((y + x) + z) / x, (y + z) / x + 1) ck.verify(tdiv((y + x) + z, x), tdiv(y + z, x) + 1)
ck.verify((y + (x + z)) / x, (y + z) / x + 1) ck.verify(tdiv(y + (x + z), x), tdiv(y + z, x) + 1)
ck.verify((y + (z + x)) / x, (y + z) / x + 1) ck.verify(tdiv(y + (z + x), x), tdiv(y + z, x) + 1)
ck.verify((x * y) / y, x) ck.verify(tdiv(x * y, y), x)
ck.verify((y * x) / y, x) ck.verify(tdiv(y * x, y), x)
ck.verify((x * z + y) / z, x + y / z) ck.verify(tdiv(x * z + y, z), x + tdiv(y, z))
ck.verify((z * x + y) / z, x + y / z) ck.verify(tdiv(z * x + y, z), x + tdiv(y, z))
ck.verify((y + x * z) / z, y / z + x) ck.verify(tdiv(y + x * z, z), tdiv(y, z) + x)
ck.verify((y + z * x) / z, y / z + x) ck.verify(tdiv(y + z * x, z), tdiv(y, z) + x)
def test_floordiv_index_simplify(): def test_floordiv_index_simplify():
...@@ -417,31 +425,33 @@ def test_mod_index_simplify(): ...@@ -417,31 +425,33 @@ def test_mod_index_simplify():
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.analyzer.update(nx, tvm.arith.ConstIntBound(-1000, 0), override=True) ck.analyzer.update(nx, tvm.arith.ConstIntBound(-1000, 0), override=True)
ck.analyzer.update(ny, tvm.arith.ConstIntBound(-1000, 0), override=True) ck.analyzer.update(ny, tvm.arith.ConstIntBound(-1000, 0), override=True)
tdiv = tvm.truncdiv
ck.verify(x * 10 % 2, 0) tmod = tvm.truncmod
ck.verify((x * 10 + y) % 2, y % 2)
ck.verify((x + 10) % 2, x % 2) ck.verify(tmod(x * 10, 2), 0)
ck.verify((x + y * 10) % 2, x % 2) ck.verify(tmod(x * 10 + y, 2), tmod(y, 2))
ck.verify((x* 10 + 1 + y * 2 + 2) % 2, 1) ck.verify(tmod(x + 10, 2), tmod(x, 2))
ck.verify(x * 10 % -2, 0) ck.verify(tmod(x + y * 10, 2), tmod(x, 2))
ck.verify((x * 10 + y) % -2, y % 2) ck.verify(tmod(x* 10 + 1 + y * 2 + 2, 2), 1)
ck.verify((x + 10) % -2, x % 2) ck.verify(tmod(x * 10, -2), 0)
ck.verify((x + y * 10) % -2, x % 2) ck.verify(tmod(x * 10 + y, -2), tmod(y, 2))
ck.verify((x* 10 + 1 + y * 2 + 2) % -2, 1) ck.verify(tmod(x + 10, -2), tmod(x, 2))
ck.verify(tmod(x + y * 10, -2), tmod(x, 2))
ck.verify(x * (-10) % 2, 0) ck.verify(tmod(x* 10 + 1 + y * 2 + 2, -2), 1)
ck.verify((x * (-10) + y) % 2, (x * (-10) + y) % 2)
ck.verify((x + (-10)) % 2, (x + (-10)) % 2) ck.verify(tmod(x * (-10), 2), 0)
ck.verify((x + y * (-10)) % 2, (x + y * (-10)) % 2) ck.verify(tmod(x * (-10) + y, 2), tmod(x * (-10) + y, 2))
ck.verify(x * (-10) % -2, 0) ck.verify(tmod(x + (-10), 2), tmod(x + (-10), 2))
ck.verify(tmod(x + y * (-10), 2), tmod(x + y * (-10), 2))
ck.verify(nx * 10 % 2, 0) ck.verify(tmod(x * (-10), -2), 0)
ck.verify((nx * (-10) + y) % 2, y % 2)
ck.verify((x + ny * (-10)) % 2, x % 2) ck.verify(tmod(nx * 10, 2), 0)
ck.verify((nx * (-10) + 1 + ny * (-2) + 2) % 2, 1) ck.verify(tmod(nx * (-10) + y, 2), tmod(y, 2))
ck.verify(nx * 10 % -2, 0) ck.verify(tmod(x + ny * (-10), 2), tmod(x, 2))
ck.verify((nx * (-10) + y) % -2, y % 2) ck.verify(tmod(nx * (-10) + 1 + ny * (-2) + 2, 2), 1)
ck.verify((x + ny * (-10)) % -2, x % 2) ck.verify(tmod(nx * 10, -2), 0)
ck.verify(tmod(nx * (-10) + y, -2), tmod(y, 2))
ck.verify(tmod(x + ny * (-10), -2), tmod(x, 2))
def test_floormod_index_simplify(): def test_floormod_index_simplify():
...@@ -468,8 +478,10 @@ def test_min_index_simplify(): ...@@ -468,8 +478,10 @@ def test_min_index_simplify():
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
fld = tvm.floordiv fld = tvm.floordiv
flm = tvm.floormod flm = tvm.floormod
tdiv = tvm.truncdiv
tmod = tvm.truncmod
# const int bound # const int bound
ck.verify(tvm.min(x % 2, y % 2 + 10), x % 2) ck.verify(tvm.min(tmod(x, 2), tmod(y, 2) + 10), tmod(x, 2))
ck.verify(tvm.min(flm(x, 2), flm(y, 2) + 10), flm(x, 2)) ck.verify(tvm.min(flm(x, 2), flm(y, 2) + 10), flm(x, 2))
ck.verify(tvm.min(x + 1, x + 10), x + 1) ck.verify(tvm.min(x + 1, x + 10), x + 1)
...@@ -521,13 +533,14 @@ def test_min_index_simplify(): ...@@ -521,13 +533,14 @@ def test_min_index_simplify():
# DivMod rules # DivMod rules
# truc div # truc div
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000))
ck.verify(tvm.min((x + 3) / 4 * 4, x), x) ck.verify(tvm.min(tdiv(x + 3, 4) * 4, x), x)
ck.verify(tvm.min((x + 3) / 4 * 4, tvm.max(x, 4)), tvm.max(x, 4)) ck.verify(tvm.min(tdiv(x + 3, 4) * 4, tvm.max(x, 4)), tvm.max(x, 4))
ck.verify(tvm.min(x, (x + 3) / 4 * 4), x) ck.verify(tvm.min(x, tdiv(x + 3, 4) * 4), x)
ck.verify(tvm.min(tvm.max(x, 4), (x + 3) / 4 * 4), tvm.max(x, 4)) ck.verify(tvm.min(tvm.max(x, 4), tdiv(x + 3, 4) * 4), tvm.max(x, 4))
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True) ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True)
ck.verify(tvm.min(x / 10, y / 10), tvm.min(x, y) / 10) ck.verify(tvm.min(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.min(x, y), 10))
ck.verify(tvm.min(x / (-10), y / (-10)), tvm.max(x, y) / (-10)) ck.verify(tvm.min(tdiv(x, (-10)), tdiv(y, (-10))),
tdiv(tvm.max(x, y), (-10)))
# floor div # floor div
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True) ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True)
...@@ -545,8 +558,10 @@ def test_max_index_simplify(): ...@@ -545,8 +558,10 @@ def test_max_index_simplify():
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
flm = tvm.floormod flm = tvm.floormod
fld = tvm.floordiv fld = tvm.floordiv
tdiv = tvm.truncdiv
tmod = tvm.truncmod
# const int bound # const int bound
ck.verify(tvm.max(x % 2, y % 2 + 10), y % 2 + 10) ck.verify(tvm.max(tmod(x, 2), tmod(y, 2) + 10), tmod(y, 2) + 10)
ck.verify(tvm.max(flm(x, 2), flm(y, 2) + 10), flm(y, 2) + 10) ck.verify(tvm.max(flm(x, 2), flm(y, 2) + 10), flm(y, 2) + 10)
ck.verify(tvm.max(x + 1, x + 10), x + 10) ck.verify(tvm.max(x + 1, x + 10), x + 10)
...@@ -597,9 +612,9 @@ def test_max_index_simplify(): ...@@ -597,9 +612,9 @@ def test_max_index_simplify():
# DivMod rules # DivMod rules
# truc div # truc div
ck.verify(tvm.max(x / 10, y / 10), tvm.max(x, y) / 10) ck.verify(tvm.max(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.max(x, y), 10))
ck.verify(tvm.max(x / (-10), y / (-10)), tvm.min(x, y) / (-10)) ck.verify(tvm.max(tdiv(x, (-10)), tdiv(y, (-10))), tdiv(tvm.min(x, y), (-10)))
ck.verify(tvm.max((x + 3) / 4 * 4, x), (x + 3) / 4 * 4) ck.verify(tvm.max(tdiv(x + 3, 4) * 4, x), tdiv(x + 3, 4) * 4)
# floordiv # floordiv
ck.verify(tvm.max(fld(x, 10), fld(y, 10)), fld(tvm.max(x, y), 10)) ck.verify(tvm.max(fld(x, 10), fld(y, 10)), fld(tvm.max(x, y), 10))
...@@ -614,11 +629,13 @@ def test_cmp_simplify(): ...@@ -614,11 +629,13 @@ def test_cmp_simplify():
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
flm = tvm.floormod flm = tvm.floormod
fld = tvm.floordiv fld = tvm.floordiv
tdiv = tvm.truncdiv
tmod = tvm.truncmod
# const int bound # const int bound
ck.verify((x % 2 + 10).equal(0), tvm.const(0, "bool")) ck.verify((tmod(x, 2) + 10).equal(0), tvm.const(0, "bool"))
ck.verify(tvm.expr.NE(x % 2 + 10, 0), tvm.const(1, "bool")) ck.verify(tvm.expr.NE(tmod(x, 2) + 10, 0), tvm.const(1, "bool"))
ck.verify(x % 2 + 10 > 1, tvm.const(1, "bool")) ck.verify(tmod(x, 2) + 10 > 1, tvm.const(1, "bool"))
ck.verify(x % 2 + 10 <= 1, tvm.const(0, "bool")) ck.verify(tmod(x, 2) + 10 <= 1, tvm.const(0, "bool"))
ck.verify(flm(x, 2) + 2 > 1, tvm.const(1, "bool")) ck.verify(flm(x, 2) + 2 > 1, tvm.const(1, "bool"))
ck.verify(flm(x, 2) + 10 <= 1, tvm.const(0, "bool")) ck.verify(flm(x, 2) + 10 <= 1, tvm.const(0, "bool"))
...@@ -688,26 +705,26 @@ def test_cmp_simplify(): ...@@ -688,26 +705,26 @@ def test_cmp_simplify():
# DivMod rules # DivMod rules
# truc div # truc div
ck.verify(x / 2 < 3, x < 6) ck.verify(tdiv(x, 2) < 3, x < 6)
ck.verify(3 < x / 2, tvm.expr.LT(7, x)) ck.verify(3 < tdiv(x, 2), tvm.expr.LT(7, x))
ck.verify(x / 3 >= 0, tvm.expr.LE(-2, x)) ck.verify(tdiv(x, 3) >= 0, tvm.expr.LE(-2, x))
ck.verify(x / 2 >= 1, tvm.expr.LE(2, x)) ck.verify(tdiv(x, 2) >= 1, tvm.expr.LE(2, x))
ck.verify(x / 2 >= 0, tvm.expr.LE(-1, x)) ck.verify(tdiv(x, 2) >= 0, tvm.expr.LE(-1, x))
ck.verify(x / 2 >= -1, tvm.expr.LE(-3, x)) ck.verify(tdiv(x, 2) >= -1, tvm.expr.LE(-3, x))
ck.verify(x / 2 <= 1, tvm.expr.LE(x, 3)) ck.verify(tdiv(x, 2) <= 1, tvm.expr.LE(x, 3))
ck.verify(x / 2 <= 0, tvm.expr.LE(x, 1)) ck.verify(tdiv(x, 2) <= 0, tvm.expr.LE(x, 1))
ck.verify(x / 2 <= -1, tvm.expr.LE(x, -2)) ck.verify(tdiv(x, 2) <= -1, tvm.expr.LE(x, -2))
ck.verify(x / 4 * 4 < x, tvm.expr.LT(0, x % 4)) ck.verify(tdiv(x, 4) * 4 < x, tvm.expr.LT(0, tmod(x, 4)))
ck.verify(x / 4 * 4 >= x, tvm.expr.LE(x % 4, 0)) ck.verify(tdiv(x, 4) * 4 >= x, tvm.expr.LE(tmod(x, 4), 0))
ck.verify(x / 4 * 4 < x + y, tvm.expr.LT(0, x % 4 + y)) ck.verify(tdiv(x, 4) * 4 < x + y, tvm.expr.LT(0, tmod(x, 4) + y))
ck.verify(x / 4 * 4 < x - y, tvm.expr.LT(y, x % 4)) ck.verify(tdiv(x, 4) * 4 < x - y, tvm.expr.LT(y, tmod(x, 4)))
ck.verify((x + 2) / 4 * 4 >= x, tvm.expr.LE((x + 2) % 4, 2)) ck.verify(tdiv(x + 2, 4) * 4 >= x, tvm.expr.LE(tmod(x + 2, 4), 2))
ck.verify((x + 2) / 4 * 4 >= x + y, tvm.expr.LE((x + 2) % 4 + y, 2)) ck.verify(tdiv(x + 2, 4) * 4 >= x + y, tvm.expr.LE(tmod(x + 2, 4) + y, 2))
ck.verify((x + 2) / 4 * 4 >= x - y, tvm.expr.LE((x + 2) % 4 + (-2), y)) ck.verify(tdiv(x + 2, 4) * 4 >= x - y, tvm.expr.LE(tmod(x + 2, 4) + (-2), y))
# floor div # floor div
ck.verify(fld(x, 2) < 3, x < 6) ck.verify(fld(x, 2) < 3, x < 6)
...@@ -753,7 +770,7 @@ def test_cmp_simplify(): ...@@ -753,7 +770,7 @@ def test_cmp_simplify():
ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool")) ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool"))
ck.verify(y*y >= 0, tvm.const(1, "bool")) ck.verify(y*y >= 0, tvm.const(1, "bool"))
ck.verify(x*6 <= -3, tvm.const(0, "bool")) ck.verify(x*6 <= -3, tvm.const(0, "bool"))
ck.verify((y - 1) % 3 == 0, (y + (-1)) % 3 == 0) ck.verify(tmod(y - 1, 3) == 0, tmod(y + (-1), 3) == 0)
def test_logical_simplify(): def test_logical_simplify():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment