Unverified Commit d1830964 by Tianqi Chen Committed by GitHub

[ARITH] Explicitly state truncdiv/mod in pattern matching. (#3986)

* [ARITH] Explicitly state truncdiv/mod in pattern matching.

* Fix the dependent cpp test
parent c48e1cc1
......@@ -333,6 +333,20 @@ TVM_DLL Expr operator||(Expr a, Expr b);
*/
TVM_DLL Expr operator!(Expr a);
/*!
* \brief compute division in C semantics.
*
* a / b as in C/C++.
*
* When operands are integers, it directly corresponds to truncdiv.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr div(Expr a, Expr b);
/*!
* \brief compute trunc(a / b)
*
* This is the default integer division behavior in C.
......@@ -640,6 +654,21 @@ inline Expr make_zero(Type t) {
return make_const(t, 0);
}
/*!
* \brief Helper function to raise a compiler error about division ambiguity.
* \note The call to this function will always results in a compiler error.
* \tparam TA Any class type.
*/
template<typename TA>
inline void DivAmbiguityError(const TA& a) {
constexpr bool div_ambiguity = !std::is_class<TA>::value;
static_assert(div_ambiguity,
"TVM supports multiple types of integer divisions, "
"please call div, floordiv/floormod or truncdiv/truncmod directly "
"to avoid ambiguity in the code. "
"Checkout these functions in expr_operator.h.");
}
// additional const expression overloading
#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
inline Expr Name(Expr& a, Expr b) { \
......@@ -688,12 +717,17 @@ TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(div);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=);
// integer related ops
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&);
......
......@@ -67,7 +67,7 @@ enum DivMode {
inline Expr ModImpl(Expr a, Expr b, DivMode mode) {
if (mode == kTruncDiv) {
return a % b;
return truncmod(a, b);
} else {
CHECK_EQ(mode, kFloorDiv);
return floormod(a, b);
......@@ -76,7 +76,7 @@ inline Expr ModImpl(Expr a, Expr b, DivMode mode) {
inline Expr DivImpl(Expr a, Expr b, DivMode mode) {
if (mode == kTruncDiv) {
return a / b;
return truncdiv(a, b);
} else {
CHECK_EQ(mode, kFloorDiv);
return floordiv(a, b);
......
......@@ -93,6 +93,26 @@ inline bool WillOverflow<ir::Mod>(int64_t x,
}
/*!
* \brief Peform trunc division of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t truncdiv(int64_t x, int64_t y) {
return x / y;
}
/*!
* \brief Compute the truncdiv remainder of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t truncmod(int64_t x, int64_t y) {
return x % y;
}
/*!
* \brief Peform floor division of two integers.
* \param x The left operand.
* \param y The right operand.
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file modular_set.cc
* \brief Modular set analysis
*/
......@@ -111,7 +110,8 @@ class ModularSetAnalyzer::Impl :
PVar<Var> var;
PVar<Integer> coeff, base;
// pattern match interesting constraints
if (((var % coeff) == base).Match(constraint)) {
if ((truncmod(var, coeff) == base).Match(constraint) ||
(floormod(var, coeff) == base).Match(constraint)) {
Entry entry(coeff.Eval()->value, base.Eval()->value);
return UpdateByIntersect(var.Eval(), entry);
}
......
......@@ -300,31 +300,41 @@ class PConstWithTypeLike :
};
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \
template<typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
CheckStep; \
return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> > \
FuncName(const Pattern<TA>& a, int64_t b) { \
CheckStep; \
return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> \
FuncName(int64_t b, const Pattern<TA>& a) { \
CheckStep; \
return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
}
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, )
// raise ambiguity error for operator overload of / and %
TVM_PATTERN_BINARY_OP_EX(operator/, ir::Div, DivAmbiguityError(a));
TVM_PATTERN_BINARY_OP_EX(operator%, ir::Mod, DivAmbiguityError(a));
// arithmetic expressions
TVM_PATTERN_BINARY_OP(operator+, ir::Add);
TVM_PATTERN_BINARY_OP(operator-, ir::Sub);
TVM_PATTERN_BINARY_OP(operator*, ir::Mul);
TVM_PATTERN_BINARY_OP(operator/, ir::Div);
TVM_PATTERN_BINARY_OP(operator%, ir::Mod);
TVM_PATTERN_BINARY_OP(min, ir::Min);
TVM_PATTERN_BINARY_OP(max, ir::Max);
TVM_PATTERN_BINARY_OP(div, ir::Div);
TVM_PATTERN_BINARY_OP(truncdiv, ir::Div);
TVM_PATTERN_BINARY_OP(truncmod, ir::Mod);
TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
......
......@@ -178,13 +178,19 @@ Expr operator*(Expr a, Expr b) {
return ir::Mul::make(a, b);
}
Expr truncdiv(Expr a, Expr b) {
Expr div(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Div>(a, b);
if (ret.defined()) return ret;
return ir::Div::make(a, b);
}
Expr truncdiv(Expr a, Expr b) {
CHECK(a.type().is_int() || a.type().is_uint());
CHECK(b.type().is_int() || b.type().is_uint());
return div(a, b);
}
Expr truncmod(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Mod>(a, b);
......@@ -193,7 +199,7 @@ Expr truncmod(Expr a, Expr b) {
}
Expr operator/(Expr a, Expr b) {
return truncdiv(a, b);
return div(a, b);
}
Expr operator%(Expr a, Expr b) {
......
......@@ -47,9 +47,9 @@ TEST(Pattern, Basic) {
}
CHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1))));
CHECK((px + min(py, px)).Match(z + min(y, z)));
CHECK((px + py / (px * py)).Match(x + 2 / (x * 2)));
CHECK((px - py % (px * pz)).Match(x - 2 % (x * 2)));
CHECK((px - py % (px * PConst<Expr>(2))).Match(x - 2 % (x * 2)));
CHECK((px + truncdiv(py, px * py)).Match(x + truncdiv(2, x * 2)));
CHECK((px - truncmod(py, px * pz)).Match(x - truncmod(2, x * 2)));
CHECK((px - floormod(py, px * PConst<Expr>(2))).Match(x - floormod(2, x * 2)));
// logicals
CHECK((px == pz).Match(x == 1));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment