Unverified Commit c59a78e5 by Haichen Shen Committed by GitHub

[TVM][LANG] Add eager simplification for operations with FloatImm (#2615)

* Add eager simplication for FloatImm

* fix

* fix lint

* Fix gcc warning

* fix

* Add test case
parent 255c187b
......@@ -430,6 +430,34 @@ TVM_DLL Expr min(Expr source, Array<IterVar> axis);
*/
TVM_DLL Expr prod(Expr source, Array<IterVar> axis);
/*!
* \brief Calculate floor(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr floor(Expr x);
/*!
* \brief Calculate ceil(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr ceil(Expr x);
/*!
* \brief Calculate round(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr round(Expr x);
/*!
* \brief Calculate trunc(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr trunc(Expr x);
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
......@@ -441,10 +469,6 @@ TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(floor);
TVM_DECLARE_INTRIN_UNARY(ceil);
TVM_DECLARE_INTRIN_UNARY(round);
TVM_DECLARE_INTRIN_UNARY(trunc);
TVM_DECLARE_INTRIN_UNARY(popcount);
......
......@@ -94,4 +94,4 @@ def cast(src, dtype):
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make.static_cast(dtype, src)
return _make._cast(dtype, src)
......@@ -272,7 +272,7 @@ def floor(x):
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "floor", x)
return _make.floor(x)
def ceil(x):
......@@ -288,7 +288,7 @@ def ceil(x):
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "ceil", x)
return _make.ceil(x)
def trunc(x):
......@@ -307,7 +307,7 @@ def trunc(x):
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "trunc", x)
return _make.trunc(x)
def abs(x):
......@@ -339,7 +339,7 @@ def round(x):
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "round", x)
return _make.round(x)
def power(x, y):
......
......@@ -22,6 +22,31 @@ TVM_REGISTER_API("make.abs")
*ret = tvm::abs(args[0]);
});
TVM_REGISTER_API("make.floor")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::floor(args[0]);
});
TVM_REGISTER_API("make.ceil")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::ceil(args[0]);
});
TVM_REGISTER_API("make.round")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::round(args[0]);
});
TVM_REGISTER_API("make.trunc")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::trunc(args[0]);
});
TVM_REGISTER_API("make._cast")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::cast(args[0], args[1]);
});
TVM_REGISTER_API("make._range_by_min_extent")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Range::make_by_min_extent(args[0], args[1]);
......
......@@ -5,6 +5,7 @@
#include <tvm/base.h>
#include <tvm/ir.h>
#include <tvm/ir_operator.h>
#include <cmath>
namespace tvm {
......@@ -49,17 +50,17 @@ void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*)
// and also help user to find potential type conversion problems.
if (!lhs.type().is_float() && rhs.type().is_float()) {
// int->float
lhs = ir::Cast::make(rhs.type(), lhs);
lhs = cast(rhs.type(), lhs);
} else if (lhs.type().is_float() && !rhs.type().is_float()) {
// int->float
rhs = ir::Cast::make(lhs.type(), rhs);
rhs = cast(lhs.type(), rhs);
} else if ((lhs.type().is_int() && rhs.type().is_int()) ||
(lhs.type().is_uint() && rhs.type().is_uint())) {
// promote int to higher bits
if (lhs.type().bits() < rhs.type().bits()) {
lhs = ir::Cast::make(rhs.type(), lhs);
lhs = cast(rhs.type(), lhs);
} else {
rhs = ir::Cast::make(lhs.type(), rhs);
rhs = cast(lhs.type(), rhs);
}
} else if ((lhs.type().is_int() && rhs.type().is_uint()) ||
(lhs.type().is_uint() && rhs.type().is_int())) {
......@@ -98,11 +99,14 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) {
Expr cast(const Type& t, Expr value) {
using ir::IntImm;
using ir::FloatImm;
if (value.type() == t) return value;
// const fold IntImm as they are used in index computations
if (t.lanes() == 1) {
if (const IntImm* op = value.as<IntImm>()) {
return make_const(t, op->value);
} else if (const FloatImm* op = value.as<FloatImm>()) {
return make_const(t, op->value);
}
return ir::Cast::make(t, value);
} else {
......@@ -112,6 +116,8 @@ Expr cast(const Type& t, Expr value) {
if (value.type() != vtype) {
if (const IntImm* op = value.as<IntImm>()) {
value = make_const(vtype, op->value);
} else if (const FloatImm* op = value.as<FloatImm>()) {
value = make_const(vtype, op->value);
} else {
value = ir::Cast::make(vtype, value);
}
......@@ -129,7 +135,7 @@ Expr reinterpret(const Type& t, Expr value) {
return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic);
}
#define TVM_CONST_PROPAGATION(BODY) \
#define TVM_INDEX_CONST_PROPAGATION(BODY) \
using ir::IntImm; \
using ir::UIntImm; \
const IntImm* pa = a.as<IntImm>(); \
......@@ -141,37 +147,60 @@ Expr reinterpret(const Type& t, Expr value) {
} \
BinaryOpMatchTypes(a, b);
#define TVM_ARITH_CONST_PROPAGATION(BODY) \
using ir::IntImm; \
using ir::UIntImm; \
using ir::FloatImm; \
BinaryOpMatchTypes(a, b); \
const IntImm* pa = a.as<IntImm>(); \
const IntImm* pb = b.as<IntImm>(); \
const FloatImm* fa = a.as<FloatImm>(); \
const FloatImm* fb = b.as<FloatImm>(); \
BODY;
Expr operator+(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_ARITH_CONST_PROPAGATION({
const Type& ta = a.type();
const Type& tb = b.type();
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value + pb->value);
if (pa && pa->value == 0) return SimpleCast(rtype, b);
if (pb && pb->value == 0) return SimpleCast(rtype, a);
if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value);
if (fa && fa->value == 0) return SimpleCast(rtype, b);
if (fb && fb->value == 0) return SimpleCast(rtype, a);
});
return ir::Add::make(a, b);
}
Expr operator-(Expr a) {
using ir::IntImm;
using ir::FloatImm;
const IntImm* pa = a.as<IntImm>();
if (pa) {
return ir::IntImm::make(a.type(), -pa->value);
}
const FloatImm* fa = a.as<FloatImm>();
if (pa) return ir::IntImm::make(a.type(), -pa->value);
if (fa) return ir::FloatImm::make(a.type(), -fa->value);
return make_zero(a.type()) - a;
}
Expr operator-(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_ARITH_CONST_PROPAGATION({
const Type& ta = a.type();
const Type& tb = b.type();
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value - pb->value);
if (pb && pb->value == 0) return SimpleCast(rtype, a);
if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value);
if (fb && fb->value == 0) return SimpleCast(rtype, a);
});
return ir::Sub::make(a, b);
}
Expr operator*(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_ARITH_CONST_PROPAGATION({
const Type& ta = a.type();
const Type& tb = b.type();
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value * pb->value);
if (pa) {
......@@ -182,12 +211,23 @@ Expr operator*(Expr a, Expr b) {
if (pb->value == 1) return SimpleCast(rtype, a);
if (pb->value == 0) return SimpleCast(rtype, b);
}
if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value);
if (fa) {
if (fa->value == 1) return SimpleCast(rtype, b);
if (fa->value == 0) return SimpleCast(rtype, a);
}
if (fb) {
if (fb->value == 1) return SimpleCast(rtype, a);
if (fb->value == 0) return SimpleCast(rtype, b);
}
});
return ir::Mul::make(a, b);
}
Expr operator/(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_ARITH_CONST_PROPAGATION({
const Type& ta = a.type();
const Type& tb = b.type();
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
......@@ -201,12 +241,22 @@ Expr operator/(Expr a, Expr b) {
if (pb->value == 1) return SimpleCast(rtype, a);
CHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImm::make(rtype, fa->value / fb->value);
}
if (fa && fa->value == 0) {
return SimpleCast(rtype, a);
}
if (fb) {
if (fb->value == 1) return SimpleCast(rtype, a);
CHECK_NE(fb->value, 0) << "Divide by zero";
}
});
return ir::Div::make(a, b);
}
Expr operator%(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
......@@ -225,17 +275,23 @@ Expr operator%(Expr a, Expr b) {
}
Expr min(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_ARITH_CONST_PROPAGATION({
const Type& ta = a.type();
const Type& tb = b.type();
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value));
});
return ir::Min::make(a, b);
}
Expr max(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_ARITH_CONST_PROPAGATION({
const Type& ta = a.type();
const Type& tb = b.type();
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value));
});
return ir::Max::make(a, b);
}
......@@ -272,43 +328,49 @@ Expr likely(Expr cond) {
}
Expr operator>(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value > fb->value);
});
return ir::GT::make(a, b);
}
Expr operator>=(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value >= fb->value);
});
return ir::GE::make(a, b);
}
Expr operator<(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value < fb->value);
});
return ir::LT::make(a, b);
}
Expr operator<=(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value <= fb->value);
});
return ir::LE::make(a, b);
}
Expr operator==(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value == fb->value);
});
return ir::EQ::make(a, b);
}
Expr operator!=(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value != fb->value);
});
return ir::NE::make(a, b);
}
......@@ -349,7 +411,7 @@ Expr operator!(Expr a) {
}
Expr operator>>(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value));
if (pb) {
......@@ -360,7 +422,7 @@ Expr operator>>(Expr a, Expr b) {
}
Expr operator<<(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value));
if (pb) {
......@@ -371,7 +433,7 @@ Expr operator<<(Expr a, Expr b) {
}
Expr operator&(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value));
});
......@@ -379,7 +441,7 @@ Expr operator&(Expr a, Expr b) {
}
Expr operator|(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value));
});
......@@ -387,7 +449,7 @@ Expr operator|(Expr a, Expr b) {
}
Expr operator^(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value));
});
......@@ -414,6 +476,11 @@ Expr abs(Expr x) {
}
return ir::Select::make(x >= make_zero(x.type()), x, -x);
} else if (x.type().is_float()) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) {
return ir::FloatImm::make(x.type(), std::fabs(fx->value));
}
return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
} else if (x.type().is_uint()) {
return x;
......@@ -466,4 +533,35 @@ Expr fmod(Expr x, Expr y) {
return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic);
}
Expr floor(Expr x) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) return FloatImm::make(x.type(), std::floor(fx->value));
return ir::Call::make(x.type(), "floor", {x}, ir::Call::PureIntrinsic);
}
Expr ceil(Expr x) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) return FloatImm::make(x.type(), std::ceil(fx->value));
return ir::Call::make(x.type(), "ceil", {x}, ir::Call::PureIntrinsic);
}
Expr round(Expr x) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) return FloatImm::make(x.type(), std::nearbyint(fx->value));
return ir::Call::make(x.type(), "round", {x}, ir::Call::PureIntrinsic);
}
Expr trunc(Expr x) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) {
return FloatImm::make(x.type(), (fx->value < 0 ? std::ceil(fx->value) :
std::floor(fx->value)));
}
return ir::Call::make(x.type(), "trunc", {x}, ir::Call::PureIntrinsic);
}
} // namespace tvm
......@@ -100,6 +100,22 @@ def test_modular():
assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0
assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0
def test_const_propagation():
x1 = tvm.const(4, "int32")
x2 = x1 + 5
assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9
x3 = x2 / 3
assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3
x4 = x3 + 0.5
assert isinstance(x4, tvm.expr.FloatImm) and x4.value == 3.5
x5 = tvm.ceil(x4)
assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4
x6 = x5.astype('int')
assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4
y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int')
assert isinstance(y, tvm.expr.IntImm) and y.value == 6
if __name__ == "__main__":
test_simplify_div()
test_simplify_mod()
......@@ -107,3 +123,4 @@ if __name__ == "__main__":
test_simplify()
test_mul()
test_simplify_minmax()
test_const_propagation()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment