Unverified Commit 75892d2b by Tianqi Chen Committed by GitHub

[ARITH][IR] Introduce FloorDiv/Mod (#3479)

* [ARITH][IR] Introduce FloorDiv/Mod

* Address review comments

* address review comments, fix div sub rule
parent 45878ff2
......@@ -332,6 +332,26 @@ TVM_DLL Expr operator||(Expr a, Expr b);
*/
TVM_DLL Expr operator!(Expr a);
/*!
* \brief compute floor(a / b)
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr floordiv(Expr a, Expr b);
/*!
* \brief compute the remainder of floordiv
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr floormod(Expr a, Expr b);
/*!
* \brief take maximum of two values
*
* \param a left operand
......
......@@ -178,6 +178,18 @@ class Mod : public BinaryOpNode<Mod> {
static constexpr const char* _type_key = "Mod";
};
/*! \brief Floor division, floor(a/b) */
class FloorDiv : public BinaryOpNode<FloorDiv> {
public:
static constexpr const char* _type_key = "FloorDiv";
};
/*! \brief The remainder of the floordiv */
class FloorMod : public BinaryOpNode<FloorMod> {
public:
static constexpr const char* _type_key = "FloorMod";
};
/*! \brief min(a, b) */
class Min : public BinaryOpNode<Min> {
public:
......
......@@ -140,6 +140,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloorDiv* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloorMod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT;
......@@ -180,6 +182,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(Mul);
IR_EXPR_FUNCTOR_DISPATCH(Div);
IR_EXPR_FUNCTOR_DISPATCH(Mod);
IR_EXPR_FUNCTOR_DISPATCH(FloorDiv);
IR_EXPR_FUNCTOR_DISPATCH(FloorMod);
IR_EXPR_FUNCTOR_DISPATCH(Min);
IR_EXPR_FUNCTOR_DISPATCH(Max);
IR_EXPR_FUNCTOR_DISPATCH(EQ);
......
......@@ -98,6 +98,8 @@ class TVM_DLL IRMutator {
virtual Expr Mutate_(const Mul* op, const Expr& e);
virtual Expr Mutate_(const Div* op, const Expr& e);
virtual Expr Mutate_(const Mod* op, const Expr& e);
virtual Expr Mutate_(const FloorDiv* op, const Expr& e);
virtual Expr Mutate_(const FloorMod* op, const Expr& e);
virtual Expr Mutate_(const Min* op, const Expr& e);
virtual Expr Mutate_(const Max* op, const Expr& e);
virtual Expr Mutate_(const EQ* op, const Expr& e);
......
......@@ -114,6 +114,8 @@ class TVM_DLL IRVisitor {
virtual void Visit_(const Mul* op);
virtual void Visit_(const Div* op);
virtual void Visit_(const Mod* op);
virtual void Visit_(const FloorDiv* op);
virtual void Visit_(const FloorMod* op);
virtual void Visit_(const Min* op);
virtual void Visit_(const Max* op);
virtual void Visit_(const EQ* op);
......
......@@ -888,7 +888,46 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
return reducer
def floordiv(a, b):
"""Compute the floordiv of two expressions.
Parameters
----------
a : Expr
The left hand operand
b : Expr
The right hand operand
Returns
-------
res : Expr
The result expression.
"""
return _make._OpFloorDiv(a, b)
def floormod(a, b):
"""Compute the floormod of two expressions.
Parameters
----------
a : Expr
The left hand operand
b : Expr
The right hand operand
Returns
-------
res : Expr
The result expression.
"""
return _make._OpFloorMod(a, b)
_init_api("tvm.api")
#pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min')
......
......@@ -463,6 +463,40 @@ class Mod(BinaryOpExpr):
@register_node
class FloorDiv(BinaryOpExpr):
"""FloorDiv node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.FloorDiv, a, b)
@register_node
class FloorMod(BinaryOpExpr):
"""FloorMod node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.FloorMod, a, b)
@register_node
class Min(BinaryOpExpr):
"""Min node.
......
......@@ -126,6 +126,8 @@ REGISTER_MAKE(Sub);
REGISTER_MAKE(Mul);
REGISTER_MAKE(Div);
REGISTER_MAKE(Mod);
REGISTER_MAKE(FloorDiv);
REGISTER_MAKE(FloorMod);
REGISTER_MAKE(Min);
REGISTER_MAKE(Max);
REGISTER_MAKE(EQ);
......@@ -192,6 +194,8 @@ REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpDiv, operator/);
REGISTER_MAKE_BINARY_OP(_OpMod, operator%);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max);
REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
......
......@@ -29,6 +29,7 @@
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <algorithm>
#include "int_operator.h"
namespace tvm {
namespace arith {
......@@ -184,9 +185,7 @@ template<>
inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
TVM_INDEX_CONST_PROPAGATION({
const Type& rtype = a.type();
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
if (pa && pb) {
return IntImm::make(rtype, pa->value % pb->value);
}
if (pa) {
......@@ -201,6 +200,51 @@ inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
}
template<>
inline Expr TryConstFold<ir::FloorDiv>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) {
CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm::make(rtype, arith::floordiv(pa->value, pb->value));
}
if (pa) {
if (pa->value == 0) return a;
}
if (pb) {
if (pb->value == 1) return a;
CHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImm::make(rtype, std::floor(fa->value / fb->value));
}
if (fa && fa->value == 0) return a;
if (fb) {
if (fb->value == 1) return a;
CHECK_NE(fb->value, 0) << "Divide by zero";
}
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::FloorMod>(Expr a, Expr b) {
TVM_INDEX_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) {
return IntImm::make(rtype, arith::floormod(pa->value, pb->value));
}
if (pa) {
if (pa->value == 0) return a;
}
if (pb) {
if (pb->value == 1) return make_zero(rtype);
CHECK_NE(pb->value, 0) << "Divide by zero";
}
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
......
......@@ -24,7 +24,7 @@
#include <tvm/arithmetic.h>
#include <tvm/ir_functor_ext.h>
#include <algorithm>
#include "int_op_overflow.h"
#include "int_operator.h"
#include "pattern_match.h"
namespace tvm {
......@@ -215,6 +215,37 @@ class ConstIntBoundAnalyzer::Impl :
}
}
Entry VisitExpr_(const FloorDiv* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
CHECK(!b.is_const(0)) << "floordiv by zero";
// assume no division by 0
if (b.min_value == 0) b.min_value = 1;
if (b.max_value == 0) b.max_value = -1;
return BinaryOpBoundry(a, b, InfAwareFloorDiv);
}
Entry VisitExpr_(const FloorMod* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
if (b.min_value > 0) {
int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
if (a.min_value >= 0) {
// 0 <= [a_min, a_max] < b_min
if (a.max_value < b.min_value) return a;
// other case, we can get close to 0
return MakeBound(0, std::min(a.max_value, b_max_cap));
} else {
return MakeBound(0, b_max_cap);
}
} else {
CHECK(!b.is_const(0)) << "floormod by zero";
// mod by negative value is rare,
// and we just use the simpliest rule.
return Everything(op->type);
}
}
Entry VisitExpr_(const Min* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
......@@ -376,6 +407,20 @@ class ConstIntBoundAnalyzer::Impl :
return x / y;
}
/*!
* \brief Compute floodiv(x, y), aware of inf.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
static int64_t InfAwareFloorDiv(int64_t x, int64_t y) {
CHECK_NE(y, 0);
if (x == kPosInf || x == kNegInf) {
if (y > 0) return x;
return -x;
}
return floordiv(x, y);
}
/*!
* \brief Compute x / y, aware of inf.
* \param x The left operand.
* \param y The right operand.
......
......@@ -19,11 +19,11 @@
/*!
* Copyright (c) 2019 by Contributors
* \file int_op_overflow.h
* \brief Utility functions to detect if an integer op will overflow.
* \file int_operator.h
* \brief Additional useful operators for integer.
*/
#ifndef TVM_ARITHMETIC_INT_OP_OVERFLOW_H_
#define TVM_ARITHMETIC_INT_OP_OVERFLOW_H_
#ifndef TVM_ARITHMETIC_INT_OPERATOR_H_
#define TVM_ARITHMETIC_INT_OPERATOR_H_
#include <limits>
......@@ -48,30 +48,30 @@ inline bool WillOverflow(int64_t x,
}
template<>
bool WillOverflow<ir::Add>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
inline bool WillOverflow<ir::Add>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
if ((y > 0) && (x > max_value - y)) return true;
if ((y < 0) && (x < min_value - y)) return true;
return false;
}
template<>
bool WillOverflow<ir::Sub>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
inline bool WillOverflow<ir::Sub>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
if ((y > 0) && (x < min_value + y)) return true;
if ((y < 0) && (x > max_value + y)) return true;
return false;
}
template<>
bool WillOverflow<ir::Mul>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
inline bool WillOverflow<ir::Mul>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
if (y == 0) return false;
if (y > 0) {
if (x < min_value / y) return true;
......@@ -85,13 +85,42 @@ bool WillOverflow<ir::Mul>(int64_t x,
}
template<>
bool WillOverflow<ir::Mod>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
inline bool WillOverflow<ir::Mod>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
return y == 0;
}
/*!
* \brief Peform floor division of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t floordiv(int64_t x, int64_t y) {
bool round_down =
(x >= 0 && y >= 0) ||
(x <= 0 && y <= 0) ||
(x % y == 0);
return round_down ? (x / y) : (x / y - 1);
}
/*!
* \brief Compute the floordiv remainder of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t floormod(int64_t x, int64_t y) {
bool round_down =
(x >= 0 && y >= 0) ||
(x <= 0 && y <= 0) ||
(x % y == 0);
return round_down ? (x % y) : (x % y + y);
}
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_INT_OP_OVERFLOW_H_
#endif // TVM_ARITHMETIC_INT_OPERATOR_H_
......@@ -252,6 +252,68 @@ inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer,
return IntervalSet::Everything();
}
template<>
inline IntervalSet Combine<ir::FloorDiv>(Analyzer* analyzer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value));
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
if (b->IsSinglePoint()) {
if (is_zero(b->min_value)) {
LOG(FATAL) << "Divide by zero in CombineInterval Div";
}
if (is_one(b->min_value)) return a;
// no relaxation is needed in here due to set is inclusive
if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
Expr min_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : neg_inf();
Expr max_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : pos_inf();
return IntervalSet(min_value, max_value);
} else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
Expr min_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : neg_inf();
Expr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf();
return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::Select;
Expr sign = b->min_value >= make_zero(b->min_value.type().element_of());
Expr e1 = floordiv(a->min_value, b->min_value);
Expr e2 = floordiv(a->max_value, b->min_value);
return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
}
}
DLOG(WARNING) << "Return Everything in CombineInterval Div";
return IntervalSet::Everything();
}
template<>
inline IntervalSet Combine<ir::FloorMod>(Analyzer* analyzer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value));
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
if (b->IsSinglePoint()) {
const Expr& divisor = b->min_value;
if (is_zero(divisor)) {
LOG(FATAL) << "Modular by zero in CombineInterval Mod";
}
if (analyzer->CanProveGreaterEqual(divisor, 0)) {
return IntervalSet(make_zero(divisor.type()), divisor - 1);
} else {
Expr bound = abs(divisor) - 1;
return IntervalSet(-bound, bound);
}
}
DLOG(WARNING) << "Return Everything in CombineInterval Mod";
return IntervalSet::Everything();
}
template<>
inline IntervalSet Combine<ir::Max>(Analyzer* analzyer,
IntervalSet a,
......@@ -361,6 +423,14 @@ class IntervalSetEvaluator :
return VisitBinaryExpr_(op);
}
IntervalSet VisitExpr_(const FloorDiv* op) final {
return VisitBinaryExpr_(op);
}
IntervalSet VisitExpr_(const FloorMod* op) final {
return VisitBinaryExpr_(op);
}
IntervalSet VisitExpr_(const Min* op) final {
return VisitBinaryExpr_(op);
}
......
......@@ -179,6 +179,14 @@ class ModularSetAnalyzer::Impl :
return Everything();
}
Entry VisitExpr_(const FloorDiv* op) final {
Entry b = VisitExpr(op->b);
if (b.is_const()) {
return DivByConst(op->a, b.base, true);
}
return Everything();
}
Entry VisitExpr_(const Min* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -326,6 +326,8 @@ TVM_PATTERN_BINARY_OP(operator/, ir::Div);
TVM_PATTERN_BINARY_OP(operator%, ir::Mod);
TVM_PATTERN_BINARY_OP(min, ir::Min);
TVM_PATTERN_BINARY_OP(max, ir::Max);
TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod);
// logical expressions
TVM_PATTERN_BINARY_OP(operator>, ir::GT);
......
......@@ -53,6 +53,8 @@ class RewriteSimplifier::Impl : public IRMutator {
Expr Mutate_(const Mul* op, const Expr& self) override;
Expr Mutate_(const Div* op, const Expr& self) override;
Expr Mutate_(const Mod* op, const Expr& self) override;
Expr Mutate_(const FloorDiv* op, const Expr& self) override;
Expr Mutate_(const FloorMod* op, const Expr& self) override;
Expr Mutate_(const Min* op, const Expr& self) override;
Expr Mutate_(const Max* op, const Expr& self) override;
Expr Mutate_(const EQ* op, const Expr& self) override;
......
......@@ -188,6 +188,19 @@ Expr operator%(Expr a, Expr b) {
return ir::Mod::make(a, b);
}
Expr floordiv(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
if (ret.defined()) return ret;
return ir::FloorDiv::make(a, b);
}
Expr floormod(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorMod>(a, b);
if (ret.defined()) return ret;
return ir::FloorMod::make(a, b);
}
Expr min(Expr a, Expr b) {
// inf-aware simplificaiton
......
......@@ -59,14 +59,12 @@ Expr StringImm::make(std::string value) {
Expr Cast::make(DataType t, Expr value) {
CHECK(value.defined());
CHECK_EQ(t.lanes(), value.type().lanes());
NodePtr<Cast> node = make_node<Cast>();
node->type = t;
node->value = std::move(value);
return Expr(node);
}
Expr And::make(Expr a, Expr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
......@@ -700,6 +698,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FloorDiv>([](const FloorDiv* op, IRPrinter *p) {
p->stream << "floordiv(" << op->a << ", " << op->b << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FloorMod>([](const FloorMod* op, IRPrinter *p) {
p->stream << "floormod(" << op->a << ", " << op->b << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<And>([](const And* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
......@@ -1098,7 +1106,6 @@ TVM_REGISTER_NODE_TYPE(CommReducerNode);
TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(Any);
TVM_REGISTER_NODE_TYPE(AttrStmt);
TVM_REGISTER_NODE_TYPE(FloatImm);
TVM_REGISTER_NODE_TYPE(IntImm);
TVM_REGISTER_NODE_TYPE(UIntImm);
......@@ -1110,6 +1117,8 @@ TVM_REGISTER_NODE_TYPE(Sub);
TVM_REGISTER_NODE_TYPE(Mul);
TVM_REGISTER_NODE_TYPE(Div);
TVM_REGISTER_NODE_TYPE(Mod);
TVM_REGISTER_NODE_TYPE(FloorDiv);
TVM_REGISTER_NODE_TYPE(FloorMod);
TVM_REGISTER_NODE_TYPE(Min);
TVM_REGISTER_NODE_TYPE(Max);
TVM_REGISTER_NODE_TYPE(EQ);
......
......@@ -302,6 +302,8 @@ class IRDeepCompare :
DEFINE_BIOP_EXPR_CMP_(Mul)
DEFINE_BIOP_EXPR_CMP_(Div)
DEFINE_BIOP_EXPR_CMP_(Mod)
DEFINE_BIOP_EXPR_CMP_(FloorDiv)
DEFINE_BIOP_EXPR_CMP_(FloorMod)
DEFINE_BIOP_EXPR_CMP_(Min)
DEFINE_BIOP_EXPR_CMP_(Max)
DEFINE_BIOP_EXPR_CMP_(EQ)
......
......@@ -401,6 +401,8 @@ DEFINE_BIOP_EXPR_MUTATE_(Sub)
DEFINE_BIOP_EXPR_MUTATE_(Mul)
DEFINE_BIOP_EXPR_MUTATE_(Div)
DEFINE_BIOP_EXPR_MUTATE_(Mod)
DEFINE_BIOP_EXPR_MUTATE_(FloorDiv)
DEFINE_BIOP_EXPR_MUTATE_(FloorMod)
DEFINE_BIOP_EXPR_MUTATE_(Min)
DEFINE_BIOP_EXPR_MUTATE_(Max)
DEFINE_BIOP_EXPR_MUTATE_(EQ)
......@@ -506,6 +508,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Mul)
.DISPATCH_TO_MUTATE_EXPR(Div)
.DISPATCH_TO_MUTATE_EXPR(Mod)
.DISPATCH_TO_MUTATE_EXPR(FloorDiv)
.DISPATCH_TO_MUTATE_EXPR(FloorMod)
.DISPATCH_TO_MUTATE_EXPR(Min)
.DISPATCH_TO_MUTATE_EXPR(Max)
.DISPATCH_TO_MUTATE_EXPR(EQ)
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -139,6 +139,8 @@ DEFINE_BINOP_VISIT_(Sub)
DEFINE_BINOP_VISIT_(Mul)
DEFINE_BINOP_VISIT_(Div)
DEFINE_BINOP_VISIT_(Mod)
DEFINE_BINOP_VISIT_(FloorDiv)
DEFINE_BINOP_VISIT_(FloorMod)
DEFINE_BINOP_VISIT_(Min)
DEFINE_BINOP_VISIT_(Max)
DEFINE_BINOP_VISIT_(EQ)
......@@ -250,6 +252,8 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.DISPATCH_TO_VISIT(Mul)
.DISPATCH_TO_VISIT(Div)
.DISPATCH_TO_VISIT(Mod)
.DISPATCH_TO_VISIT(FloorDiv)
.DISPATCH_TO_VISIT(FloorMod)
.DISPATCH_TO_VISIT(Min)
.DISPATCH_TO_VISIT(Max)
.DISPATCH_TO_VISIT(EQ)
......
......@@ -151,6 +151,12 @@ class Vectorizer : public IRMutator {
Expr Mutate_(const Mod* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const FloorDiv* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const FloorMod* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Min* op, const Expr &e) final {
return BinaryVec(op, e);
}
......
......@@ -28,21 +28,32 @@ class CanonicalChecker:
def test_mul_sum_simplify():
ck = CanonicalChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify(2 + (3 * x + z + y + 1) * 4 + x,
x * 13 + z * 4 + y * 4 +6)
ck.verify((x + y + x + y * 3) / 2, y * 2 + x)
ck.verify((x + y + x + y * 3) % 2, 0)
ck.verify(x * 3 - 4 * x + 1, 1 - x)
ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2)
# trucdiv
ck.verify((x + y + x + y * 3) / 2, y * 2 + x)
ck.verify((x + y + x + y * 3) % 2, 0)
# floordiv
fld = tvm.floordiv
flm = tvm.floormod
ck.verify(flm(x + x + y * 3, 2), flm(y * 3, 2))
ck.verify(fld(x + y + x + y * 3, 2), y * 2 + x)
ck.verify(flm(x + y + x + y * 3, 2), 0)
ck.verify(fld(x + x + y * 3, 2), fld(y * 3, 2) + x)
def test_split_index_simplify():
ck = CanonicalChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify((x/3) *3 + x % 3, x)
ck.verify((x/6) * 6 + ((x/3) % 2) * 3 + x % 3, x)
# trucdiv
# split div const
ck.verify((x/3) *3 + x % 3, x)
ck.verify((x/6) * 6 + ((x/3) % 2) * 3 + x % 3, x)
ck.verify(((x % 16) / 2) * 2 / 4, (x % 16) / 4)
ck.verify((x % 2) / 8, 0)
ck.verify((x % 2) / 7, 0)
......@@ -59,11 +70,29 @@ def test_split_index_simplify():
# complex fold
ck.verify((z * 9 + y) / 2 * 2 + (z * 9 + y) % 2, z * 9 + y)
ck.analyzer.update(x, tvm.arith.ConstIntBound(-100, 1000), True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(-100, 1000), True)
ck.verify((x * 4 + y) / 2 * 2 + (x * 4 + y) % 2, x * 4 + y)
# floordiv
fld = tvm.floordiv
flm = tvm.floormod
ck.verify(fld(x, 3) * 3 + flm(x, 3), x)
ck.verify(fld(x, 6) * 6 + flm(fld(x, 3), 2) * 3 + flm(x, 3), x)
ck.verify(fld(fld(flm(x, 16), 2) * 2, 4), fld(flm(x, 16), 4))
ck.verify(fld(flm(x, 2), 8), 0)
ck.verify(fld(flm(x, 2), 7), 0)
ck.verify(fld(fld(flm(x, 16), 2) * 2, 6), fld(flm(x, 16), 6))
# cannot simplify mixed case, unless we canonicalize into one mode.
ck.verify((x/6) * 2 + fld(x,3) % 2, (x/6) * 2 + fld(x,3) % 2)
def test_div_simplify():
ck = CanonicalChecker()
x = tvm.var("x")
# truc div
ck.verify((16+48*x)/16, x*3 + 1)
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 != 1+3*x
......@@ -74,6 +103,22 @@ def test_div_simplify():
# Trying expressions that are not simplifiable for any values of the variables
ck.verify((17+47*x)/16, (x * 47 + 17) / 16)
# floordiv
fld = tvm.floordiv
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 10000), True)
ck.verify(fld(16+48*x, 16), x*3 + 1)
ck.verify(fld(17+48*x, 16), x * 3 + 1)
ck.verify(fld(17+47*x, 16), fld(x * 47 + 17, 16))
def test_floormod_simplify():
ck = CanonicalChecker()
flm = tvm.floormod
x, y = tvm.var("x"), tvm.var("y")
ck.verify(flm(flm((x*4) + y - 466036, 24528) - 24512, 16),
flm((x*4) + y + 12, 16))
def test_canonical_mixed():
ck = CanonicalChecker()
......@@ -86,6 +131,10 @@ def test_canonical_mixed():
ck.verify(tvm.min(x, 1) - tvm.min(x, 1), 0)
ck.verify(x * x - x * x, 0)
fld = tvm.floordiv
ck.verify(fld(x, (z*z)) - fld(x, (z*z)), 0)
ck.verify(fld(x, (z+z)) - fld(x, (z+z)), 0)
def test_reduce_combiner_simplify():
ck = CanonicalChecker()
......@@ -218,11 +267,13 @@ def test_complex_cases():
if __name__ == "__main__":
test_floormod_simplify()
test_mul_sum_simplify()
test_simplify_if_then_else()
test_div_simplify()
test_reduce_simplify()
test_reduce_combiner_simplify()
test_mul_sum_simplify()
test_split_index_simplify()
test_canonical_mixed()
test_complex_cases()
......@@ -143,6 +143,52 @@ def test_mod_bound():
assert bd.max_value == 9
def test_floordiv_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
fld = tvm.floordiv
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(fld(x, y))
assert bd.min_value == -9 // 4
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True)
bd = analyzer.const_int_bound(fld(x, y))
assert bd.min_value == -4
assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True)
bd = analyzer.const_int_bound(fld(x, y))
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF
def test_floormod_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
flm = tvm.floormod
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(flm(x, y))
assert bd.min_value == 0
assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(flm(x, y))
assert bd.min_value == 0
assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(flm(x, y))
assert bd.min_value == 0
assert bd.max_value == 9
def test_min_max_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
......@@ -229,6 +275,8 @@ if __name__ == "__main__":
test_mul_bound()
test_div_bound()
test_mod_bound()
test_floordiv_bound()
test_floormod_bound()
test_min_max_bound()
test_select_bound()
test_shift_and_bound()
......
......@@ -64,9 +64,14 @@ def test_mul_div():
ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20))
ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2))
ck.verify(x / y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 / y))
ck.verify(x / 2, {x : tvm.arith.IntervalSet(1, 10)}, (0, 5))
fld = tvm.floordiv
ck.verify(fld(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y)))
ck.verify(fld(x, 2), {x : tvm.arith.IntervalSet(-1, 10)}, (-1, 5))
def test_mod():
ck = IntSetChecker()
......@@ -75,6 +80,10 @@ def test_mod():
ck.verify(x % y, {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
ck.verify(x % 10, {x : tvm.arith.IntervalSet(1, 10)}, (0, 9))
flm = tvm.floormod
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9))
def test_max_min():
ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y")
......@@ -99,4 +108,3 @@ if __name__ == "__main__":
test_max_min()
test_select()
test_mod()
......@@ -61,6 +61,10 @@ def test_div_shift():
m = analyzer.modular_set((x * 4 + 2) >> 1)
assert m.coeff == 2
assert m.base == 1
fld = tvm.floordiv
m = analyzer.modular_set(fld(x * 4 + 2, 2))
assert m.coeff == 2
assert m.base == 1
# x is non-negative
analyzer.update(x, tvm.arith.ConstIntBound(0, 100))
m = analyzer.modular_set((x * 4 + 2) / 2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment