Unverified Commit 75892d2b by Tianqi Chen Committed by GitHub

[ARITH][IR] Introduce FloorDiv/Mod (#3479)

* [ARITH][IR] Introduce FloorDiv/Mod

* Address review comments

* address review comments, fix div sub rule
parent 45878ff2
...@@ -332,6 +332,26 @@ TVM_DLL Expr operator||(Expr a, Expr b); ...@@ -332,6 +332,26 @@ TVM_DLL Expr operator||(Expr a, Expr b);
*/ */
TVM_DLL Expr operator!(Expr a); TVM_DLL Expr operator!(Expr a);
/*! /*!
* \brief compute floor(a / b)
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr floordiv(Expr a, Expr b);
/*!
* \brief compute the remainder of floordiv
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr floormod(Expr a, Expr b);
/*!
* \brief take maximum of two values * \brief take maximum of two values
* *
* \param a left operand * \param a left operand
......
...@@ -178,6 +178,18 @@ class Mod : public BinaryOpNode<Mod> { ...@@ -178,6 +178,18 @@ class Mod : public BinaryOpNode<Mod> {
static constexpr const char* _type_key = "Mod"; static constexpr const char* _type_key = "Mod";
}; };
/*! \brief Floor division, floor(a/b) */
class FloorDiv : public BinaryOpNode<FloorDiv> {
public:
static constexpr const char* _type_key = "FloorDiv";
};
/*! \brief The remainder of the floordiv */
class FloorMod : public BinaryOpNode<FloorMod> {
public:
static constexpr const char* _type_key = "FloorMod";
};
/*! \brief min(a, b) */ /*! \brief min(a, b) */
class Min : public BinaryOpNode<Min> { class Min : public BinaryOpNode<Min> {
public: public:
......
...@@ -140,6 +140,8 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -140,6 +140,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloorDiv* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloorMod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT;
...@@ -180,6 +182,8 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -180,6 +182,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(Mul); IR_EXPR_FUNCTOR_DISPATCH(Mul);
IR_EXPR_FUNCTOR_DISPATCH(Div); IR_EXPR_FUNCTOR_DISPATCH(Div);
IR_EXPR_FUNCTOR_DISPATCH(Mod); IR_EXPR_FUNCTOR_DISPATCH(Mod);
IR_EXPR_FUNCTOR_DISPATCH(FloorDiv);
IR_EXPR_FUNCTOR_DISPATCH(FloorMod);
IR_EXPR_FUNCTOR_DISPATCH(Min); IR_EXPR_FUNCTOR_DISPATCH(Min);
IR_EXPR_FUNCTOR_DISPATCH(Max); IR_EXPR_FUNCTOR_DISPATCH(Max);
IR_EXPR_FUNCTOR_DISPATCH(EQ); IR_EXPR_FUNCTOR_DISPATCH(EQ);
......
...@@ -98,6 +98,8 @@ class TVM_DLL IRMutator { ...@@ -98,6 +98,8 @@ class TVM_DLL IRMutator {
virtual Expr Mutate_(const Mul* op, const Expr& e); virtual Expr Mutate_(const Mul* op, const Expr& e);
virtual Expr Mutate_(const Div* op, const Expr& e); virtual Expr Mutate_(const Div* op, const Expr& e);
virtual Expr Mutate_(const Mod* op, const Expr& e); virtual Expr Mutate_(const Mod* op, const Expr& e);
virtual Expr Mutate_(const FloorDiv* op, const Expr& e);
virtual Expr Mutate_(const FloorMod* op, const Expr& e);
virtual Expr Mutate_(const Min* op, const Expr& e); virtual Expr Mutate_(const Min* op, const Expr& e);
virtual Expr Mutate_(const Max* op, const Expr& e); virtual Expr Mutate_(const Max* op, const Expr& e);
virtual Expr Mutate_(const EQ* op, const Expr& e); virtual Expr Mutate_(const EQ* op, const Expr& e);
......
...@@ -114,6 +114,8 @@ class TVM_DLL IRVisitor { ...@@ -114,6 +114,8 @@ class TVM_DLL IRVisitor {
virtual void Visit_(const Mul* op); virtual void Visit_(const Mul* op);
virtual void Visit_(const Div* op); virtual void Visit_(const Div* op);
virtual void Visit_(const Mod* op); virtual void Visit_(const Mod* op);
virtual void Visit_(const FloorDiv* op);
virtual void Visit_(const FloorMod* op);
virtual void Visit_(const Min* op); virtual void Visit_(const Min* op);
virtual void Visit_(const Max* op); virtual void Visit_(const Max* op);
virtual void Visit_(const EQ* op); virtual void Visit_(const EQ* op);
......
...@@ -888,7 +888,46 @@ def comm_reducer(fcombine, fidentity, name="reduce"): ...@@ -888,7 +888,46 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
return reducer return reducer
def floordiv(a, b):
"""Compute the floordiv of two expressions.
Parameters
----------
a : Expr
The left hand operand
b : Expr
The right hand operand
Returns
-------
res : Expr
The result expression.
"""
return _make._OpFloorDiv(a, b)
def floormod(a, b):
"""Compute the floormod of two expressions.
Parameters
----------
a : Expr
The left hand operand
b : Expr
The right hand operand
Returns
-------
res : Expr
The result expression.
"""
return _make._OpFloorMod(a, b)
_init_api("tvm.api") _init_api("tvm.api")
#pylint: disable=unnecessary-lambda #pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum") sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min') min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min')
......
...@@ -463,6 +463,40 @@ class Mod(BinaryOpExpr): ...@@ -463,6 +463,40 @@ class Mod(BinaryOpExpr):
@register_node @register_node
class FloorDiv(BinaryOpExpr):
"""FloorDiv node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.FloorDiv, a, b)
@register_node
class FloorMod(BinaryOpExpr):
"""FloorMod node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.FloorMod, a, b)
@register_node
class Min(BinaryOpExpr): class Min(BinaryOpExpr):
"""Min node. """Min node.
......
...@@ -126,6 +126,8 @@ REGISTER_MAKE(Sub); ...@@ -126,6 +126,8 @@ REGISTER_MAKE(Sub);
REGISTER_MAKE(Mul); REGISTER_MAKE(Mul);
REGISTER_MAKE(Div); REGISTER_MAKE(Div);
REGISTER_MAKE(Mod); REGISTER_MAKE(Mod);
REGISTER_MAKE(FloorDiv);
REGISTER_MAKE(FloorMod);
REGISTER_MAKE(Min); REGISTER_MAKE(Min);
REGISTER_MAKE(Max); REGISTER_MAKE(Max);
REGISTER_MAKE(EQ); REGISTER_MAKE(EQ);
...@@ -192,6 +194,8 @@ REGISTER_MAKE_BINARY_OP(_OpSub, operator-); ...@@ -192,6 +194,8 @@ REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*); REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpDiv, operator/); REGISTER_MAKE_BINARY_OP(_OpDiv, operator/);
REGISTER_MAKE_BINARY_OP(_OpMod, operator%); REGISTER_MAKE_BINARY_OP(_OpMod, operator%);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpMin, min); REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max); REGISTER_MAKE_BINARY_OP(_OpMax, max);
REGISTER_MAKE_BINARY_OP(_OpEQ, operator==); REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <algorithm> #include <algorithm>
#include "int_operator.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -184,9 +185,7 @@ template<> ...@@ -184,9 +185,7 @@ template<>
inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) { inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
TVM_INDEX_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
const Type& rtype = a.type(); const Type& rtype = a.type();
// due to division and mod can have different modes if (pa && pb) {
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
return IntImm::make(rtype, pa->value % pb->value); return IntImm::make(rtype, pa->value % pb->value);
} }
if (pa) { if (pa) {
...@@ -201,6 +200,51 @@ inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) { ...@@ -201,6 +200,51 @@ inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
} }
template<> template<>
inline Expr TryConstFold<ir::FloorDiv>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) {
CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm::make(rtype, arith::floordiv(pa->value, pb->value));
}
if (pa) {
if (pa->value == 0) return a;
}
if (pb) {
if (pb->value == 1) return a;
CHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImm::make(rtype, std::floor(fa->value / fb->value));
}
if (fa && fa->value == 0) return a;
if (fb) {
if (fb->value == 1) return a;
CHECK_NE(fb->value, 0) << "Divide by zero";
}
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::FloorMod>(Expr a, Expr b) {
TVM_INDEX_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) {
return IntImm::make(rtype, arith::floormod(pa->value, pb->value));
}
if (pa) {
if (pa->value == 0) return a;
}
if (pb) {
if (pb->value == 1) return make_zero(rtype);
CHECK_NE(pb->value, 0) << "Divide by zero";
}
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::Min>(Expr a, Expr b) { inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type(); const Type& rtype = a.type();
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <algorithm> #include <algorithm>
#include "int_op_overflow.h" #include "int_operator.h"
#include "pattern_match.h" #include "pattern_match.h"
namespace tvm { namespace tvm {
...@@ -215,6 +215,37 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -215,6 +215,37 @@ class ConstIntBoundAnalyzer::Impl :
} }
} }
Entry VisitExpr_(const FloorDiv* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
CHECK(!b.is_const(0)) << "floordiv by zero";
// assume no division by 0
if (b.min_value == 0) b.min_value = 1;
if (b.max_value == 0) b.max_value = -1;
return BinaryOpBoundry(a, b, InfAwareFloorDiv);
}
Entry VisitExpr_(const FloorMod* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
if (b.min_value > 0) {
int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
if (a.min_value >= 0) {
// 0 <= [a_min, a_max] < b_min
if (a.max_value < b.min_value) return a;
// other case, we can get close to 0
return MakeBound(0, std::min(a.max_value, b_max_cap));
} else {
return MakeBound(0, b_max_cap);
}
} else {
CHECK(!b.is_const(0)) << "floormod by zero";
// mod by negative value is rare,
// and we just use the simpliest rule.
return Everything(op->type);
}
}
Entry VisitExpr_(const Min* op) final { Entry VisitExpr_(const Min* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
...@@ -376,6 +407,20 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -376,6 +407,20 @@ class ConstIntBoundAnalyzer::Impl :
return x / y; return x / y;
} }
/*! /*!
* \brief Compute floodiv(x, y), aware of inf.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
static int64_t InfAwareFloorDiv(int64_t x, int64_t y) {
CHECK_NE(y, 0);
if (x == kPosInf || x == kNegInf) {
if (y > 0) return x;
return -x;
}
return floordiv(x, y);
}
/*!
* \brief Compute x / y, aware of inf. * \brief Compute x / y, aware of inf.
* \param x The left operand. * \param x The left operand.
* \param y The right operand. * \param y The right operand.
......
...@@ -19,11 +19,11 @@ ...@@ -19,11 +19,11 @@
/*! /*!
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* \file int_op_overflow.h * \file int_operator.h
* \brief Utility functions to detect if an integer op will overflow. * \brief Additional useful operators for integer.
*/ */
#ifndef TVM_ARITHMETIC_INT_OP_OVERFLOW_H_ #ifndef TVM_ARITHMETIC_INT_OPERATOR_H_
#define TVM_ARITHMETIC_INT_OP_OVERFLOW_H_ #define TVM_ARITHMETIC_INT_OPERATOR_H_
#include <limits> #include <limits>
...@@ -48,30 +48,30 @@ inline bool WillOverflow(int64_t x, ...@@ -48,30 +48,30 @@ inline bool WillOverflow(int64_t x,
} }
template<> template<>
bool WillOverflow<ir::Add>(int64_t x, inline bool WillOverflow<ir::Add>(int64_t x,
int64_t y, int64_t y,
int64_t min_value, int64_t min_value,
int64_t max_value) { int64_t max_value) {
if ((y > 0) && (x > max_value - y)) return true; if ((y > 0) && (x > max_value - y)) return true;
if ((y < 0) && (x < min_value - y)) return true; if ((y < 0) && (x < min_value - y)) return true;
return false; return false;
} }
template<> template<>
bool WillOverflow<ir::Sub>(int64_t x, inline bool WillOverflow<ir::Sub>(int64_t x,
int64_t y, int64_t y,
int64_t min_value, int64_t min_value,
int64_t max_value) { int64_t max_value) {
if ((y > 0) && (x < min_value + y)) return true; if ((y > 0) && (x < min_value + y)) return true;
if ((y < 0) && (x > max_value + y)) return true; if ((y < 0) && (x > max_value + y)) return true;
return false; return false;
} }
template<> template<>
bool WillOverflow<ir::Mul>(int64_t x, inline bool WillOverflow<ir::Mul>(int64_t x,
int64_t y, int64_t y,
int64_t min_value, int64_t min_value,
int64_t max_value) { int64_t max_value) {
if (y == 0) return false; if (y == 0) return false;
if (y > 0) { if (y > 0) {
if (x < min_value / y) return true; if (x < min_value / y) return true;
...@@ -85,13 +85,42 @@ bool WillOverflow<ir::Mul>(int64_t x, ...@@ -85,13 +85,42 @@ bool WillOverflow<ir::Mul>(int64_t x,
} }
template<> template<>
bool WillOverflow<ir::Mod>(int64_t x, inline bool WillOverflow<ir::Mod>(int64_t x,
int64_t y, int64_t y,
int64_t min_value, int64_t min_value,
int64_t max_value) { int64_t max_value) {
return y == 0; return y == 0;
} }
/*!
* \brief Peform floor division of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t floordiv(int64_t x, int64_t y) {
bool round_down =
(x >= 0 && y >= 0) ||
(x <= 0 && y <= 0) ||
(x % y == 0);
return round_down ? (x / y) : (x / y - 1);
}
/*!
* \brief Compute the floordiv remainder of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t floormod(int64_t x, int64_t y) {
bool round_down =
(x >= 0 && y >= 0) ||
(x <= 0 && y <= 0) ||
(x % y == 0);
return round_down ? (x % y) : (x % y + y);
}
} // namespace arith } // namespace arith
} // namespace tvm } // namespace tvm
#endif // TVM_ARITHMETIC_INT_OP_OVERFLOW_H_ #endif // TVM_ARITHMETIC_INT_OPERATOR_H_
...@@ -252,6 +252,68 @@ inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer, ...@@ -252,6 +252,68 @@ inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer,
return IntervalSet::Everything(); return IntervalSet::Everything();
} }
template<>
inline IntervalSet Combine<ir::FloorDiv>(Analyzer* analyzer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value));
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
if (b->IsSinglePoint()) {
if (is_zero(b->min_value)) {
LOG(FATAL) << "Divide by zero in CombineInterval Div";
}
if (is_one(b->min_value)) return a;
// no relaxation is needed in here due to set is inclusive
if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
Expr min_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : neg_inf();
Expr max_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : pos_inf();
return IntervalSet(min_value, max_value);
} else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
Expr min_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : neg_inf();
Expr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf();
return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::Select;
Expr sign = b->min_value >= make_zero(b->min_value.type().element_of());
Expr e1 = floordiv(a->min_value, b->min_value);
Expr e2 = floordiv(a->max_value, b->min_value);
return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
}
}
DLOG(WARNING) << "Return Everything in CombineInterval Div";
return IntervalSet::Everything();
}
template<>
inline IntervalSet Combine<ir::FloorMod>(Analyzer* analyzer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value));
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
if (b->IsSinglePoint()) {
const Expr& divisor = b->min_value;
if (is_zero(divisor)) {
LOG(FATAL) << "Modular by zero in CombineInterval Mod";
}
if (analyzer->CanProveGreaterEqual(divisor, 0)) {
return IntervalSet(make_zero(divisor.type()), divisor - 1);
} else {
Expr bound = abs(divisor) - 1;
return IntervalSet(-bound, bound);
}
}
DLOG(WARNING) << "Return Everything in CombineInterval Mod";
return IntervalSet::Everything();
}
template<> template<>
inline IntervalSet Combine<ir::Max>(Analyzer* analzyer, inline IntervalSet Combine<ir::Max>(Analyzer* analzyer,
IntervalSet a, IntervalSet a,
...@@ -361,6 +423,14 @@ class IntervalSetEvaluator : ...@@ -361,6 +423,14 @@ class IntervalSetEvaluator :
return VisitBinaryExpr_(op); return VisitBinaryExpr_(op);
} }
IntervalSet VisitExpr_(const FloorDiv* op) final {
return VisitBinaryExpr_(op);
}
IntervalSet VisitExpr_(const FloorMod* op) final {
return VisitBinaryExpr_(op);
}
IntervalSet VisitExpr_(const Min* op) final { IntervalSet VisitExpr_(const Min* op) final {
return VisitBinaryExpr_(op); return VisitBinaryExpr_(op);
} }
......
...@@ -179,6 +179,14 @@ class ModularSetAnalyzer::Impl : ...@@ -179,6 +179,14 @@ class ModularSetAnalyzer::Impl :
return Everything(); return Everything();
} }
Entry VisitExpr_(const FloorDiv* op) final {
Entry b = VisitExpr(op->b);
if (b.is_const()) {
return DivByConst(op->a, b.base, true);
}
return Everything();
}
Entry VisitExpr_(const Min* op) final { Entry VisitExpr_(const Min* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -326,6 +326,8 @@ TVM_PATTERN_BINARY_OP(operator/, ir::Div); ...@@ -326,6 +326,8 @@ TVM_PATTERN_BINARY_OP(operator/, ir::Div);
TVM_PATTERN_BINARY_OP(operator%, ir::Mod); TVM_PATTERN_BINARY_OP(operator%, ir::Mod);
TVM_PATTERN_BINARY_OP(min, ir::Min); TVM_PATTERN_BINARY_OP(min, ir::Min);
TVM_PATTERN_BINARY_OP(max, ir::Max); TVM_PATTERN_BINARY_OP(max, ir::Max);
TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod);
// logical expressions // logical expressions
TVM_PATTERN_BINARY_OP(operator>, ir::GT); TVM_PATTERN_BINARY_OP(operator>, ir::GT);
......
...@@ -53,6 +53,8 @@ class RewriteSimplifier::Impl : public IRMutator { ...@@ -53,6 +53,8 @@ class RewriteSimplifier::Impl : public IRMutator {
Expr Mutate_(const Mul* op, const Expr& self) override; Expr Mutate_(const Mul* op, const Expr& self) override;
Expr Mutate_(const Div* op, const Expr& self) override; Expr Mutate_(const Div* op, const Expr& self) override;
Expr Mutate_(const Mod* op, const Expr& self) override; Expr Mutate_(const Mod* op, const Expr& self) override;
Expr Mutate_(const FloorDiv* op, const Expr& self) override;
Expr Mutate_(const FloorMod* op, const Expr& self) override;
Expr Mutate_(const Min* op, const Expr& self) override; Expr Mutate_(const Min* op, const Expr& self) override;
Expr Mutate_(const Max* op, const Expr& self) override; Expr Mutate_(const Max* op, const Expr& self) override;
Expr Mutate_(const EQ* op, const Expr& self) override; Expr Mutate_(const EQ* op, const Expr& self) override;
......
...@@ -188,6 +188,19 @@ Expr operator%(Expr a, Expr b) { ...@@ -188,6 +188,19 @@ Expr operator%(Expr a, Expr b) {
return ir::Mod::make(a, b); return ir::Mod::make(a, b);
} }
Expr floordiv(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
if (ret.defined()) return ret;
return ir::FloorDiv::make(a, b);
}
Expr floormod(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorMod>(a, b);
if (ret.defined()) return ret;
return ir::FloorMod::make(a, b);
}
Expr min(Expr a, Expr b) { Expr min(Expr a, Expr b) {
// inf-aware simplificaiton // inf-aware simplificaiton
......
...@@ -59,14 +59,12 @@ Expr StringImm::make(std::string value) { ...@@ -59,14 +59,12 @@ Expr StringImm::make(std::string value) {
Expr Cast::make(DataType t, Expr value) { Expr Cast::make(DataType t, Expr value) {
CHECK(value.defined()); CHECK(value.defined());
CHECK_EQ(t.lanes(), value.type().lanes()); CHECK_EQ(t.lanes(), value.type().lanes());
NodePtr<Cast> node = make_node<Cast>(); NodePtr<Cast> node = make_node<Cast>();
node->type = t; node->type = t;
node->value = std::move(value); node->value = std::move(value);
return Expr(node); return Expr(node);
} }
Expr And::make(Expr a, Expr b) { Expr And::make(Expr a, Expr b) {
CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined"; CHECK(b.defined()) << "ValueError: b is undefined";
...@@ -700,6 +698,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -700,6 +698,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FloorDiv>([](const FloorDiv* op, IRPrinter *p) {
p->stream << "floordiv(" << op->a << ", " << op->b << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FloorMod>([](const FloorMod* op, IRPrinter *p) {
p->stream << "floormod(" << op->a << ", " << op->b << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<And>([](const And* op, IRPrinter* p) { .set_dispatch<And>([](const And* op, IRPrinter* p) {
p->stream << '('; p->stream << '(';
p->Print(op->a); p->Print(op->a);
...@@ -1098,7 +1106,6 @@ TVM_REGISTER_NODE_TYPE(CommReducerNode); ...@@ -1098,7 +1106,6 @@ TVM_REGISTER_NODE_TYPE(CommReducerNode);
TVM_REGISTER_NODE_TYPE(Reduce); TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(Any); TVM_REGISTER_NODE_TYPE(Any);
TVM_REGISTER_NODE_TYPE(AttrStmt); TVM_REGISTER_NODE_TYPE(AttrStmt);
TVM_REGISTER_NODE_TYPE(FloatImm); TVM_REGISTER_NODE_TYPE(FloatImm);
TVM_REGISTER_NODE_TYPE(IntImm); TVM_REGISTER_NODE_TYPE(IntImm);
TVM_REGISTER_NODE_TYPE(UIntImm); TVM_REGISTER_NODE_TYPE(UIntImm);
...@@ -1110,6 +1117,8 @@ TVM_REGISTER_NODE_TYPE(Sub); ...@@ -1110,6 +1117,8 @@ TVM_REGISTER_NODE_TYPE(Sub);
TVM_REGISTER_NODE_TYPE(Mul); TVM_REGISTER_NODE_TYPE(Mul);
TVM_REGISTER_NODE_TYPE(Div); TVM_REGISTER_NODE_TYPE(Div);
TVM_REGISTER_NODE_TYPE(Mod); TVM_REGISTER_NODE_TYPE(Mod);
TVM_REGISTER_NODE_TYPE(FloorDiv);
TVM_REGISTER_NODE_TYPE(FloorMod);
TVM_REGISTER_NODE_TYPE(Min); TVM_REGISTER_NODE_TYPE(Min);
TVM_REGISTER_NODE_TYPE(Max); TVM_REGISTER_NODE_TYPE(Max);
TVM_REGISTER_NODE_TYPE(EQ); TVM_REGISTER_NODE_TYPE(EQ);
......
...@@ -302,6 +302,8 @@ class IRDeepCompare : ...@@ -302,6 +302,8 @@ class IRDeepCompare :
DEFINE_BIOP_EXPR_CMP_(Mul) DEFINE_BIOP_EXPR_CMP_(Mul)
DEFINE_BIOP_EXPR_CMP_(Div) DEFINE_BIOP_EXPR_CMP_(Div)
DEFINE_BIOP_EXPR_CMP_(Mod) DEFINE_BIOP_EXPR_CMP_(Mod)
DEFINE_BIOP_EXPR_CMP_(FloorDiv)
DEFINE_BIOP_EXPR_CMP_(FloorMod)
DEFINE_BIOP_EXPR_CMP_(Min) DEFINE_BIOP_EXPR_CMP_(Min)
DEFINE_BIOP_EXPR_CMP_(Max) DEFINE_BIOP_EXPR_CMP_(Max)
DEFINE_BIOP_EXPR_CMP_(EQ) DEFINE_BIOP_EXPR_CMP_(EQ)
......
...@@ -401,6 +401,8 @@ DEFINE_BIOP_EXPR_MUTATE_(Sub) ...@@ -401,6 +401,8 @@ DEFINE_BIOP_EXPR_MUTATE_(Sub)
DEFINE_BIOP_EXPR_MUTATE_(Mul) DEFINE_BIOP_EXPR_MUTATE_(Mul)
DEFINE_BIOP_EXPR_MUTATE_(Div) DEFINE_BIOP_EXPR_MUTATE_(Div)
DEFINE_BIOP_EXPR_MUTATE_(Mod) DEFINE_BIOP_EXPR_MUTATE_(Mod)
DEFINE_BIOP_EXPR_MUTATE_(FloorDiv)
DEFINE_BIOP_EXPR_MUTATE_(FloorMod)
DEFINE_BIOP_EXPR_MUTATE_(Min) DEFINE_BIOP_EXPR_MUTATE_(Min)
DEFINE_BIOP_EXPR_MUTATE_(Max) DEFINE_BIOP_EXPR_MUTATE_(Max)
DEFINE_BIOP_EXPR_MUTATE_(EQ) DEFINE_BIOP_EXPR_MUTATE_(EQ)
...@@ -506,6 +508,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) ...@@ -506,6 +508,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Mul) .DISPATCH_TO_MUTATE_EXPR(Mul)
.DISPATCH_TO_MUTATE_EXPR(Div) .DISPATCH_TO_MUTATE_EXPR(Div)
.DISPATCH_TO_MUTATE_EXPR(Mod) .DISPATCH_TO_MUTATE_EXPR(Mod)
.DISPATCH_TO_MUTATE_EXPR(FloorDiv)
.DISPATCH_TO_MUTATE_EXPR(FloorMod)
.DISPATCH_TO_MUTATE_EXPR(Min) .DISPATCH_TO_MUTATE_EXPR(Min)
.DISPATCH_TO_MUTATE_EXPR(Max) .DISPATCH_TO_MUTATE_EXPR(Max)
.DISPATCH_TO_MUTATE_EXPR(EQ) .DISPATCH_TO_MUTATE_EXPR(EQ)
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -139,6 +139,8 @@ DEFINE_BINOP_VISIT_(Sub) ...@@ -139,6 +139,8 @@ DEFINE_BINOP_VISIT_(Sub)
DEFINE_BINOP_VISIT_(Mul) DEFINE_BINOP_VISIT_(Mul)
DEFINE_BINOP_VISIT_(Div) DEFINE_BINOP_VISIT_(Div)
DEFINE_BINOP_VISIT_(Mod) DEFINE_BINOP_VISIT_(Mod)
DEFINE_BINOP_VISIT_(FloorDiv)
DEFINE_BINOP_VISIT_(FloorMod)
DEFINE_BINOP_VISIT_(Min) DEFINE_BINOP_VISIT_(Min)
DEFINE_BINOP_VISIT_(Max) DEFINE_BINOP_VISIT_(Max)
DEFINE_BINOP_VISIT_(EQ) DEFINE_BINOP_VISIT_(EQ)
...@@ -250,6 +252,8 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) ...@@ -250,6 +252,8 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.DISPATCH_TO_VISIT(Mul) .DISPATCH_TO_VISIT(Mul)
.DISPATCH_TO_VISIT(Div) .DISPATCH_TO_VISIT(Div)
.DISPATCH_TO_VISIT(Mod) .DISPATCH_TO_VISIT(Mod)
.DISPATCH_TO_VISIT(FloorDiv)
.DISPATCH_TO_VISIT(FloorMod)
.DISPATCH_TO_VISIT(Min) .DISPATCH_TO_VISIT(Min)
.DISPATCH_TO_VISIT(Max) .DISPATCH_TO_VISIT(Max)
.DISPATCH_TO_VISIT(EQ) .DISPATCH_TO_VISIT(EQ)
......
...@@ -151,6 +151,12 @@ class Vectorizer : public IRMutator { ...@@ -151,6 +151,12 @@ class Vectorizer : public IRMutator {
Expr Mutate_(const Mod* op, const Expr &e) final { Expr Mutate_(const Mod* op, const Expr &e) final {
return BinaryVec(op, e); return BinaryVec(op, e);
} }
Expr Mutate_(const FloorDiv* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const FloorMod* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Min* op, const Expr &e) final { Expr Mutate_(const Min* op, const Expr &e) final {
return BinaryVec(op, e); return BinaryVec(op, e);
} }
......
...@@ -28,21 +28,32 @@ class CanonicalChecker: ...@@ -28,21 +28,32 @@ class CanonicalChecker:
def test_mul_sum_simplify(): def test_mul_sum_simplify():
ck = CanonicalChecker() ck = CanonicalChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify(2 + (3 * x + z + y + 1) * 4 + x, ck.verify(2 + (3 * x + z + y + 1) * 4 + x,
x * 13 + z * 4 + y * 4 +6) x * 13 + z * 4 + y * 4 +6)
ck.verify((x + y + x + y * 3) / 2, y * 2 + x)
ck.verify((x + y + x + y * 3) % 2, 0)
ck.verify(x * 3 - 4 * x + 1, 1 - x) ck.verify(x * 3 - 4 * x + 1, 1 - x)
ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2) ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2)
# trucdiv
ck.verify((x + y + x + y * 3) / 2, y * 2 + x)
ck.verify((x + y + x + y * 3) % 2, 0)
# floordiv
fld = tvm.floordiv
flm = tvm.floormod
ck.verify(flm(x + x + y * 3, 2), flm(y * 3, 2))
ck.verify(fld(x + y + x + y * 3, 2), y * 2 + x)
ck.verify(flm(x + y + x + y * 3, 2), 0)
ck.verify(fld(x + x + y * 3, 2), fld(y * 3, 2) + x)
def test_split_index_simplify(): def test_split_index_simplify():
ck = CanonicalChecker() ck = CanonicalChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify((x/3) *3 + x % 3, x)
ck.verify((x/6) * 6 + ((x/3) % 2) * 3 + x % 3, x)
# trucdiv
# split div const # split div const
ck.verify((x/3) *3 + x % 3, x)
ck.verify((x/6) * 6 + ((x/3) % 2) * 3 + x % 3, x)
ck.verify(((x % 16) / 2) * 2 / 4, (x % 16) / 4) ck.verify(((x % 16) / 2) * 2 / 4, (x % 16) / 4)
ck.verify((x % 2) / 8, 0) ck.verify((x % 2) / 8, 0)
ck.verify((x % 2) / 7, 0) ck.verify((x % 2) / 7, 0)
...@@ -59,11 +70,29 @@ def test_split_index_simplify(): ...@@ -59,11 +70,29 @@ def test_split_index_simplify():
# complex fold # complex fold
ck.verify((z * 9 + y) / 2 * 2 + (z * 9 + y) % 2, z * 9 + y) ck.verify((z * 9 + y) / 2 * 2 + (z * 9 + y) % 2, z * 9 + y)
ck.analyzer.update(x, tvm.arith.ConstIntBound(-100, 1000), True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(-100, 1000), True)
ck.verify((x * 4 + y) / 2 * 2 + (x * 4 + y) % 2, x * 4 + y)
# floordiv
fld = tvm.floordiv
flm = tvm.floormod
ck.verify(fld(x, 3) * 3 + flm(x, 3), x)
ck.verify(fld(x, 6) * 6 + flm(fld(x, 3), 2) * 3 + flm(x, 3), x)
ck.verify(fld(fld(flm(x, 16), 2) * 2, 4), fld(flm(x, 16), 4))
ck.verify(fld(flm(x, 2), 8), 0)
ck.verify(fld(flm(x, 2), 7), 0)
ck.verify(fld(fld(flm(x, 16), 2) * 2, 6), fld(flm(x, 16), 6))
# cannot simplify mixed case, unless we canonicalize into one mode.
ck.verify((x/6) * 2 + fld(x,3) % 2, (x/6) * 2 + fld(x,3) % 2)
def test_div_simplify(): def test_div_simplify():
ck = CanonicalChecker() ck = CanonicalChecker()
x = tvm.var("x") x = tvm.var("x")
# truc div
ck.verify((16+48*x)/16, x*3 + 1) ck.verify((16+48*x)/16, x*3 + 1)
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0 # (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 != 1+3*x # (17+48*x)/16 != 1+3*x
...@@ -74,6 +103,22 @@ def test_div_simplify(): ...@@ -74,6 +103,22 @@ def test_div_simplify():
# Trying expressions that are not simplifiable for any values of the variables # Trying expressions that are not simplifiable for any values of the variables
ck.verify((17+47*x)/16, (x * 47 + 17) / 16) ck.verify((17+47*x)/16, (x * 47 + 17) / 16)
# floordiv
fld = tvm.floordiv
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 10000), True)
ck.verify(fld(16+48*x, 16), x*3 + 1)
ck.verify(fld(17+48*x, 16), x * 3 + 1)
ck.verify(fld(17+47*x, 16), fld(x * 47 + 17, 16))
def test_floormod_simplify():
ck = CanonicalChecker()
flm = tvm.floormod
x, y = tvm.var("x"), tvm.var("y")
ck.verify(flm(flm((x*4) + y - 466036, 24528) - 24512, 16),
flm((x*4) + y + 12, 16))
def test_canonical_mixed(): def test_canonical_mixed():
ck = CanonicalChecker() ck = CanonicalChecker()
...@@ -86,6 +131,10 @@ def test_canonical_mixed(): ...@@ -86,6 +131,10 @@ def test_canonical_mixed():
ck.verify(tvm.min(x, 1) - tvm.min(x, 1), 0) ck.verify(tvm.min(x, 1) - tvm.min(x, 1), 0)
ck.verify(x * x - x * x, 0) ck.verify(x * x - x * x, 0)
fld = tvm.floordiv
ck.verify(fld(x, (z*z)) - fld(x, (z*z)), 0)
ck.verify(fld(x, (z+z)) - fld(x, (z+z)), 0)
def test_reduce_combiner_simplify(): def test_reduce_combiner_simplify():
ck = CanonicalChecker() ck = CanonicalChecker()
...@@ -218,11 +267,13 @@ def test_complex_cases(): ...@@ -218,11 +267,13 @@ def test_complex_cases():
if __name__ == "__main__": if __name__ == "__main__":
test_floormod_simplify()
test_mul_sum_simplify()
test_simplify_if_then_else() test_simplify_if_then_else()
test_div_simplify() test_div_simplify()
test_reduce_simplify() test_reduce_simplify()
test_reduce_combiner_simplify() test_reduce_combiner_simplify()
test_mul_sum_simplify()
test_split_index_simplify() test_split_index_simplify()
test_canonical_mixed() test_canonical_mixed()
test_complex_cases() test_complex_cases()
...@@ -143,6 +143,52 @@ def test_mod_bound(): ...@@ -143,6 +143,52 @@ def test_mod_bound():
assert bd.max_value == 9 assert bd.max_value == 9
def test_floordiv_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
fld = tvm.floordiv
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(fld(x, y))
assert bd.min_value == -9 // 4
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True)
bd = analyzer.const_int_bound(fld(x, y))
assert bd.min_value == -4
assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True)
bd = analyzer.const_int_bound(fld(x, y))
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF
def test_floormod_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
flm = tvm.floormod
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(flm(x, y))
assert bd.min_value == 0
assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(flm(x, y))
assert bd.min_value == 0
assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(flm(x, y))
assert bd.min_value == 0
assert bd.max_value == 9
def test_min_max_bound(): def test_min_max_bound():
analyzer = tvm.arith.Analyzer() analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y") x, y = tvm.var("x"), tvm.var("y")
...@@ -229,6 +275,8 @@ if __name__ == "__main__": ...@@ -229,6 +275,8 @@ if __name__ == "__main__":
test_mul_bound() test_mul_bound()
test_div_bound() test_div_bound()
test_mod_bound() test_mod_bound()
test_floordiv_bound()
test_floormod_bound()
test_min_max_bound() test_min_max_bound()
test_select_bound() test_select_bound()
test_shift_and_bound() test_shift_and_bound()
......
...@@ -64,9 +64,14 @@ def test_mul_div(): ...@@ -64,9 +64,14 @@ def test_mul_div():
ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y)) ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20)) ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20))
ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2)) ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2))
ck.verify(x / y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 / y)) ck.verify(x / y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 / y))
ck.verify(x / 2, {x : tvm.arith.IntervalSet(1, 10)}, (0, 5)) ck.verify(x / 2, {x : tvm.arith.IntervalSet(1, 10)}, (0, 5))
fld = tvm.floordiv
ck.verify(fld(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y)))
ck.verify(fld(x, 2), {x : tvm.arith.IntervalSet(-1, 10)}, (-1, 5))
def test_mod(): def test_mod():
ck = IntSetChecker() ck = IntSetChecker()
...@@ -75,6 +80,10 @@ def test_mod(): ...@@ -75,6 +80,10 @@ def test_mod():
ck.verify(x % y, {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1)) ck.verify(x % y, {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
ck.verify(x % 10, {x : tvm.arith.IntervalSet(1, 10)}, (0, 9)) ck.verify(x % 10, {x : tvm.arith.IntervalSet(1, 10)}, (0, 9))
flm = tvm.floormod
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9))
def test_max_min(): def test_max_min():
ck = IntSetChecker() ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y") x, y = tvm.var("x"), tvm.var("y")
...@@ -99,4 +108,3 @@ if __name__ == "__main__": ...@@ -99,4 +108,3 @@ if __name__ == "__main__":
test_max_min() test_max_min()
test_select() test_select()
test_mod() test_mod()
...@@ -61,6 +61,10 @@ def test_div_shift(): ...@@ -61,6 +61,10 @@ def test_div_shift():
m = analyzer.modular_set((x * 4 + 2) >> 1) m = analyzer.modular_set((x * 4 + 2) >> 1)
assert m.coeff == 2 assert m.coeff == 2
assert m.base == 1 assert m.base == 1
fld = tvm.floordiv
m = analyzer.modular_set(fld(x * 4 + 2, 2))
assert m.coeff == 2
assert m.base == 1
# x is non-negative # x is non-negative
analyzer.update(x, tvm.arith.ConstIntBound(0, 100)) analyzer.update(x, tvm.arith.ConstIntBound(0, 100))
m = analyzer.modular_set((x * 4 + 2) / 2) m = analyzer.modular_set((x * 4 + 2) / 2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment