Unverified Commit c59a78e5 by Haichen Shen Committed by GitHub

[TVM][LANG] Add eager simplification for operations with FloatImm (#2615)

* Add eager simplication for FloatImm

* fix

* fix lint

* Fix gcc warning

* fix

* Add test case
parent 255c187b
...@@ -430,6 +430,34 @@ TVM_DLL Expr min(Expr source, Array<IterVar> axis); ...@@ -430,6 +430,34 @@ TVM_DLL Expr min(Expr source, Array<IterVar> axis);
*/ */
TVM_DLL Expr prod(Expr source, Array<IterVar> axis); TVM_DLL Expr prod(Expr source, Array<IterVar> axis);
/*!
* \brief Calculate floor(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr floor(Expr x);
/*!
* \brief Calculate ceil(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr ceil(Expr x);
/*!
* \brief Calculate round(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr round(Expr x);
/*!
* \brief Calculate trunc(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr trunc(Expr x);
// Intrinsic operators // Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \ #define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \ inline Expr OpName(Expr x) { \
...@@ -441,10 +469,6 @@ TVM_DECLARE_INTRIN_UNARY(tanh); ...@@ -441,10 +469,6 @@ TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid); TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt); TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(log); TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(floor);
TVM_DECLARE_INTRIN_UNARY(ceil);
TVM_DECLARE_INTRIN_UNARY(round);
TVM_DECLARE_INTRIN_UNARY(trunc);
TVM_DECLARE_INTRIN_UNARY(popcount); TVM_DECLARE_INTRIN_UNARY(popcount);
......
...@@ -94,4 +94,4 @@ def cast(src, dtype): ...@@ -94,4 +94,4 @@ def cast(src, dtype):
op : tvm.Expr op : tvm.Expr
The result Expr of divide operaton. The result Expr of divide operaton.
""" """
return _make.static_cast(dtype, src) return _make._cast(dtype, src)
...@@ -272,7 +272,7 @@ def floor(x): ...@@ -272,7 +272,7 @@ def floor(x):
y : Expr y : Expr
The result. The result.
""" """
return call_pure_intrin(x.dtype, "floor", x) return _make.floor(x)
def ceil(x): def ceil(x):
...@@ -288,7 +288,7 @@ def ceil(x): ...@@ -288,7 +288,7 @@ def ceil(x):
y : Expr y : Expr
The result. The result.
""" """
return call_pure_intrin(x.dtype, "ceil", x) return _make.ceil(x)
def trunc(x): def trunc(x):
...@@ -307,7 +307,7 @@ def trunc(x): ...@@ -307,7 +307,7 @@ def trunc(x):
y : Expr y : Expr
The result. The result.
""" """
return call_pure_intrin(x.dtype, "trunc", x) return _make.trunc(x)
def abs(x): def abs(x):
...@@ -339,7 +339,7 @@ def round(x): ...@@ -339,7 +339,7 @@ def round(x):
y : Expr y : Expr
The result. The result.
""" """
return call_pure_intrin(x.dtype, "round", x) return _make.round(x)
def power(x, y): def power(x, y):
......
...@@ -22,6 +22,31 @@ TVM_REGISTER_API("make.abs") ...@@ -22,6 +22,31 @@ TVM_REGISTER_API("make.abs")
*ret = tvm::abs(args[0]); *ret = tvm::abs(args[0]);
}); });
TVM_REGISTER_API("make.floor")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::floor(args[0]);
});
TVM_REGISTER_API("make.ceil")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::ceil(args[0]);
});
TVM_REGISTER_API("make.round")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::round(args[0]);
});
TVM_REGISTER_API("make.trunc")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::trunc(args[0]);
});
TVM_REGISTER_API("make._cast")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::cast(args[0], args[1]);
});
TVM_REGISTER_API("make._range_by_min_extent") TVM_REGISTER_API("make._range_by_min_extent")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Range::make_by_min_extent(args[0], args[1]); *ret = Range::make_by_min_extent(args[0], args[1]);
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_operator.h> #include <tvm/ir_operator.h>
#include <cmath>
namespace tvm { namespace tvm {
...@@ -49,17 +50,17 @@ void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*) ...@@ -49,17 +50,17 @@ void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*)
// and also help user to find potential type conversion problems. // and also help user to find potential type conversion problems.
if (!lhs.type().is_float() && rhs.type().is_float()) { if (!lhs.type().is_float() && rhs.type().is_float()) {
// int->float // int->float
lhs = ir::Cast::make(rhs.type(), lhs); lhs = cast(rhs.type(), lhs);
} else if (lhs.type().is_float() && !rhs.type().is_float()) { } else if (lhs.type().is_float() && !rhs.type().is_float()) {
// int->float // int->float
rhs = ir::Cast::make(lhs.type(), rhs); rhs = cast(lhs.type(), rhs);
} else if ((lhs.type().is_int() && rhs.type().is_int()) || } else if ((lhs.type().is_int() && rhs.type().is_int()) ||
(lhs.type().is_uint() && rhs.type().is_uint())) { (lhs.type().is_uint() && rhs.type().is_uint())) {
// promote int to higher bits // promote int to higher bits
if (lhs.type().bits() < rhs.type().bits()) { if (lhs.type().bits() < rhs.type().bits()) {
lhs = ir::Cast::make(rhs.type(), lhs); lhs = cast(rhs.type(), lhs);
} else { } else {
rhs = ir::Cast::make(lhs.type(), rhs); rhs = cast(lhs.type(), rhs);
} }
} else if ((lhs.type().is_int() && rhs.type().is_uint()) || } else if ((lhs.type().is_int() && rhs.type().is_uint()) ||
(lhs.type().is_uint() && rhs.type().is_int())) { (lhs.type().is_uint() && rhs.type().is_int())) {
...@@ -98,11 +99,14 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) { ...@@ -98,11 +99,14 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) {
Expr cast(const Type& t, Expr value) { Expr cast(const Type& t, Expr value) {
using ir::IntImm; using ir::IntImm;
using ir::FloatImm;
if (value.type() == t) return value; if (value.type() == t) return value;
// const fold IntImm as they are used in index computations // const fold IntImm as they are used in index computations
if (t.lanes() == 1) { if (t.lanes() == 1) {
if (const IntImm* op = value.as<IntImm>()) { if (const IntImm* op = value.as<IntImm>()) {
return make_const(t, op->value); return make_const(t, op->value);
} else if (const FloatImm* op = value.as<FloatImm>()) {
return make_const(t, op->value);
} }
return ir::Cast::make(t, value); return ir::Cast::make(t, value);
} else { } else {
...@@ -112,6 +116,8 @@ Expr cast(const Type& t, Expr value) { ...@@ -112,6 +116,8 @@ Expr cast(const Type& t, Expr value) {
if (value.type() != vtype) { if (value.type() != vtype) {
if (const IntImm* op = value.as<IntImm>()) { if (const IntImm* op = value.as<IntImm>()) {
value = make_const(vtype, op->value); value = make_const(vtype, op->value);
} else if (const FloatImm* op = value.as<FloatImm>()) {
value = make_const(vtype, op->value);
} else { } else {
value = ir::Cast::make(vtype, value); value = ir::Cast::make(vtype, value);
} }
...@@ -129,7 +135,7 @@ Expr reinterpret(const Type& t, Expr value) { ...@@ -129,7 +135,7 @@ Expr reinterpret(const Type& t, Expr value) {
return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic); return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic);
} }
#define TVM_CONST_PROPAGATION(BODY) \ #define TVM_INDEX_CONST_PROPAGATION(BODY) \
using ir::IntImm; \ using ir::IntImm; \
using ir::UIntImm; \ using ir::UIntImm; \
const IntImm* pa = a.as<IntImm>(); \ const IntImm* pa = a.as<IntImm>(); \
...@@ -141,37 +147,60 @@ Expr reinterpret(const Type& t, Expr value) { ...@@ -141,37 +147,60 @@ Expr reinterpret(const Type& t, Expr value) {
} \ } \
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
#define TVM_ARITH_CONST_PROPAGATION(BODY) \
using ir::IntImm; \
using ir::UIntImm; \
using ir::FloatImm; \
BinaryOpMatchTypes(a, b); \
const IntImm* pa = a.as<IntImm>(); \
const IntImm* pb = b.as<IntImm>(); \
const FloatImm* fa = a.as<FloatImm>(); \
const FloatImm* fb = b.as<FloatImm>(); \
BODY;
Expr operator+(Expr a, Expr b) { Expr operator+(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const Type& ta = a.type();
const Type& tb = b.type();
Type rtype = ta.bits() >= tb.bits() ? ta : tb; Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value + pb->value); if (pa && pb) return IntImm::make(rtype, pa->value + pb->value);
if (pa && pa->value == 0) return SimpleCast(rtype, b); if (pa && pa->value == 0) return SimpleCast(rtype, b);
if (pb && pb->value == 0) return SimpleCast(rtype, a); if (pb && pb->value == 0) return SimpleCast(rtype, a);
if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value);
if (fa && fa->value == 0) return SimpleCast(rtype, b);
if (fb && fb->value == 0) return SimpleCast(rtype, a);
}); });
return ir::Add::make(a, b); return ir::Add::make(a, b);
} }
Expr operator-(Expr a) { Expr operator-(Expr a) {
using ir::IntImm; using ir::IntImm;
using ir::FloatImm;
const IntImm* pa = a.as<IntImm>(); const IntImm* pa = a.as<IntImm>();
if (pa) { const FloatImm* fa = a.as<FloatImm>();
return ir::IntImm::make(a.type(), -pa->value); if (pa) return ir::IntImm::make(a.type(), -pa->value);
} if (fa) return ir::FloatImm::make(a.type(), -fa->value);
return make_zero(a.type()) - a; return make_zero(a.type()) - a;
} }
Expr operator-(Expr a, Expr b) { Expr operator-(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const Type& ta = a.type();
const Type& tb = b.type();
Type rtype = ta.bits() >= tb.bits() ? ta : tb; Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value - pb->value); if (pa && pb) return IntImm::make(rtype, pa->value - pb->value);
if (pb && pb->value == 0) return SimpleCast(rtype, a); if (pb && pb->value == 0) return SimpleCast(rtype, a);
if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value);
if (fb && fb->value == 0) return SimpleCast(rtype, a);
}); });
return ir::Sub::make(a, b); return ir::Sub::make(a, b);
} }
Expr operator*(Expr a, Expr b) { Expr operator*(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const Type& ta = a.type();
const Type& tb = b.type();
Type rtype = ta.bits() >= tb.bits() ? ta : tb; Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value * pb->value); if (pa && pb) return IntImm::make(rtype, pa->value * pb->value);
if (pa) { if (pa) {
...@@ -182,12 +211,23 @@ Expr operator*(Expr a, Expr b) { ...@@ -182,12 +211,23 @@ Expr operator*(Expr a, Expr b) {
if (pb->value == 1) return SimpleCast(rtype, a); if (pb->value == 1) return SimpleCast(rtype, a);
if (pb->value == 0) return SimpleCast(rtype, b); if (pb->value == 0) return SimpleCast(rtype, b);
} }
if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value);
if (fa) {
if (fa->value == 1) return SimpleCast(rtype, b);
if (fa->value == 0) return SimpleCast(rtype, a);
}
if (fb) {
if (fb->value == 1) return SimpleCast(rtype, a);
if (fb->value == 0) return SimpleCast(rtype, b);
}
}); });
return ir::Mul::make(a, b); return ir::Mul::make(a, b);
} }
Expr operator/(Expr a, Expr b) { Expr operator/(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const Type& ta = a.type();
const Type& tb = b.type();
Type rtype = ta.bits() >= tb.bits() ? ta : tb; Type rtype = ta.bits() >= tb.bits() ? ta : tb;
// due to division and mod can have different modes // due to division and mod can have different modes
// only constant fold positive number where rule is fixed. // only constant fold positive number where rule is fixed.
...@@ -201,12 +241,22 @@ Expr operator/(Expr a, Expr b) { ...@@ -201,12 +241,22 @@ Expr operator/(Expr a, Expr b) {
if (pb->value == 1) return SimpleCast(rtype, a); if (pb->value == 1) return SimpleCast(rtype, a);
CHECK_NE(pb->value, 0) << "Divide by zero"; CHECK_NE(pb->value, 0) << "Divide by zero";
} }
if (fa && fb && fb->value != 0) {
return FloatImm::make(rtype, fa->value / fb->value);
}
if (fa && fa->value == 0) {
return SimpleCast(rtype, a);
}
if (fb) {
if (fb->value == 1) return SimpleCast(rtype, a);
CHECK_NE(fb->value, 0) << "Divide by zero";
}
}); });
return ir::Div::make(a, b); return ir::Div::make(a, b);
} }
Expr operator%(Expr a, Expr b) { Expr operator%(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb; Type rtype = ta.bits() >= tb.bits() ? ta : tb;
// due to division and mod can have different modes // due to division and mod can have different modes
// only constant fold positive number where rule is fixed. // only constant fold positive number where rule is fixed.
...@@ -225,17 +275,23 @@ Expr operator%(Expr a, Expr b) { ...@@ -225,17 +275,23 @@ Expr operator%(Expr a, Expr b) {
} }
Expr min(Expr a, Expr b) { Expr min(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const Type& ta = a.type();
const Type& tb = b.type();
Type rtype = ta.bits() >= tb.bits() ? ta : tb; Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value)); if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value));
}); });
return ir::Min::make(a, b); return ir::Min::make(a, b);
} }
Expr max(Expr a, Expr b) { Expr max(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const Type& ta = a.type();
const Type& tb = b.type();
Type rtype = ta.bits() >= tb.bits() ? ta : tb; Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value)); if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value));
}); });
return ir::Max::make(a, b); return ir::Max::make(a, b);
} }
...@@ -272,43 +328,49 @@ Expr likely(Expr cond) { ...@@ -272,43 +328,49 @@ Expr likely(Expr cond) {
} }
Expr operator>(Expr a, Expr b) { Expr operator>(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value); if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value > fb->value);
}); });
return ir::GT::make(a, b); return ir::GT::make(a, b);
} }
Expr operator>=(Expr a, Expr b) { Expr operator>=(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value); if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value >= fb->value);
}); });
return ir::GE::make(a, b); return ir::GE::make(a, b);
} }
Expr operator<(Expr a, Expr b) { Expr operator<(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value); if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value < fb->value);
}); });
return ir::LT::make(a, b); return ir::LT::make(a, b);
} }
Expr operator<=(Expr a, Expr b) { Expr operator<=(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value); if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value <= fb->value);
}); });
return ir::LE::make(a, b); return ir::LE::make(a, b);
} }
Expr operator==(Expr a, Expr b) { Expr operator==(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value); if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value == fb->value);
}); });
return ir::EQ::make(a, b); return ir::EQ::make(a, b);
} }
Expr operator!=(Expr a, Expr b) { Expr operator!=(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value); if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value != fb->value);
}); });
return ir::NE::make(a, b); return ir::NE::make(a, b);
} }
...@@ -349,7 +411,7 @@ Expr operator!(Expr a) { ...@@ -349,7 +411,7 @@ Expr operator!(Expr a) {
} }
Expr operator>>(Expr a, Expr b) { Expr operator>>(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb; Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value)); if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value));
if (pb) { if (pb) {
...@@ -360,7 +422,7 @@ Expr operator>>(Expr a, Expr b) { ...@@ -360,7 +422,7 @@ Expr operator>>(Expr a, Expr b) {
} }
Expr operator<<(Expr a, Expr b) { Expr operator<<(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb; Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value)); if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value));
if (pb) { if (pb) {
...@@ -371,7 +433,7 @@ Expr operator<<(Expr a, Expr b) { ...@@ -371,7 +433,7 @@ Expr operator<<(Expr a, Expr b) {
} }
Expr operator&(Expr a, Expr b) { Expr operator&(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb; Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value)); if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value));
}); });
...@@ -379,7 +441,7 @@ Expr operator&(Expr a, Expr b) { ...@@ -379,7 +441,7 @@ Expr operator&(Expr a, Expr b) {
} }
Expr operator|(Expr a, Expr b) { Expr operator|(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb; Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value)); if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value));
}); });
...@@ -387,7 +449,7 @@ Expr operator|(Expr a, Expr b) { ...@@ -387,7 +449,7 @@ Expr operator|(Expr a, Expr b) {
} }
Expr operator^(Expr a, Expr b) { Expr operator^(Expr a, Expr b) {
TVM_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb; Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value)); if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value));
}); });
...@@ -414,6 +476,11 @@ Expr abs(Expr x) { ...@@ -414,6 +476,11 @@ Expr abs(Expr x) {
} }
return ir::Select::make(x >= make_zero(x.type()), x, -x); return ir::Select::make(x >= make_zero(x.type()), x, -x);
} else if (x.type().is_float()) { } else if (x.type().is_float()) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) {
return ir::FloatImm::make(x.type(), std::fabs(fx->value));
}
return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic); return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
} else if (x.type().is_uint()) { } else if (x.type().is_uint()) {
return x; return x;
...@@ -466,4 +533,35 @@ Expr fmod(Expr x, Expr y) { ...@@ -466,4 +533,35 @@ Expr fmod(Expr x, Expr y) {
return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic); return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic);
} }
Expr floor(Expr x) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) return FloatImm::make(x.type(), std::floor(fx->value));
return ir::Call::make(x.type(), "floor", {x}, ir::Call::PureIntrinsic);
}
Expr ceil(Expr x) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) return FloatImm::make(x.type(), std::ceil(fx->value));
return ir::Call::make(x.type(), "ceil", {x}, ir::Call::PureIntrinsic);
}
Expr round(Expr x) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) return FloatImm::make(x.type(), std::nearbyint(fx->value));
return ir::Call::make(x.type(), "round", {x}, ir::Call::PureIntrinsic);
}
Expr trunc(Expr x) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) {
return FloatImm::make(x.type(), (fx->value < 0 ? std::ceil(fx->value) :
std::floor(fx->value)));
}
return ir::Call::make(x.type(), "trunc", {x}, ir::Call::PureIntrinsic);
}
} // namespace tvm } // namespace tvm
...@@ -100,6 +100,22 @@ def test_modular(): ...@@ -100,6 +100,22 @@ def test_modular():
assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0 assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0
assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0 assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0
def test_const_propagation():
x1 = tvm.const(4, "int32")
x2 = x1 + 5
assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9
x3 = x2 / 3
assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3
x4 = x3 + 0.5
assert isinstance(x4, tvm.expr.FloatImm) and x4.value == 3.5
x5 = tvm.ceil(x4)
assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4
x6 = x5.astype('int')
assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4
y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int')
assert isinstance(y, tvm.expr.IntImm) and y.value == 6
if __name__ == "__main__": if __name__ == "__main__":
test_simplify_div() test_simplify_div()
test_simplify_mod() test_simplify_mod()
...@@ -107,3 +123,4 @@ if __name__ == "__main__": ...@@ -107,3 +123,4 @@ if __name__ == "__main__":
test_simplify() test_simplify()
test_mul() test_mul()
test_simplify_minmax() test_simplify_minmax()
test_const_propagation()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment