Unverified Commit c59a78e5 by Haichen Shen Committed by GitHub

[TVM][LANG] Add eager simplification for operations with FloatImm (#2615)

* Add eager simplication for FloatImm

* fix

* fix lint

* Fix gcc warning

* fix

* Add test case
parent 255c187b
......@@ -430,6 +430,34 @@ TVM_DLL Expr min(Expr source, Array<IterVar> axis);
*/
TVM_DLL Expr prod(Expr source, Array<IterVar> axis);
/*!
* \brief Calculate floor(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr floor(Expr x);
/*!
* \brief Calculate ceil(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr ceil(Expr x);
/*!
* \brief Calculate round(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr round(Expr x);
/*!
* \brief Calculate trunc(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr trunc(Expr x);
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
......@@ -441,10 +469,6 @@ TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(floor);
TVM_DECLARE_INTRIN_UNARY(ceil);
TVM_DECLARE_INTRIN_UNARY(round);
TVM_DECLARE_INTRIN_UNARY(trunc);
TVM_DECLARE_INTRIN_UNARY(popcount);
......
......@@ -94,4 +94,4 @@ def cast(src, dtype):
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make.static_cast(dtype, src)
return _make._cast(dtype, src)
......@@ -272,7 +272,7 @@ def floor(x):
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "floor", x)
return _make.floor(x)
def ceil(x):
......@@ -288,7 +288,7 @@ def ceil(x):
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "ceil", x)
return _make.ceil(x)
def trunc(x):
......@@ -307,7 +307,7 @@ def trunc(x):
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "trunc", x)
return _make.trunc(x)
def abs(x):
......@@ -339,7 +339,7 @@ def round(x):
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "round", x)
return _make.round(x)
def power(x, y):
......
......@@ -22,6 +22,31 @@ TVM_REGISTER_API("make.abs")
*ret = tvm::abs(args[0]);
});
TVM_REGISTER_API("make.floor")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::floor(args[0]);
});
TVM_REGISTER_API("make.ceil")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::ceil(args[0]);
});
TVM_REGISTER_API("make.round")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::round(args[0]);
});
TVM_REGISTER_API("make.trunc")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::trunc(args[0]);
});
TVM_REGISTER_API("make._cast")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::cast(args[0], args[1]);
});
TVM_REGISTER_API("make._range_by_min_extent")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Range::make_by_min_extent(args[0], args[1]);
......
......@@ -100,6 +100,22 @@ def test_modular():
assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0
assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0
def test_const_propagation():
x1 = tvm.const(4, "int32")
x2 = x1 + 5
assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9
x3 = x2 / 3
assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3
x4 = x3 + 0.5
assert isinstance(x4, tvm.expr.FloatImm) and x4.value == 3.5
x5 = tvm.ceil(x4)
assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4
x6 = x5.astype('int')
assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4
y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int')
assert isinstance(y, tvm.expr.IntImm) and y.value == 6
if __name__ == "__main__":
test_simplify_div()
test_simplify_mod()
......@@ -107,3 +123,4 @@ if __name__ == "__main__":
test_simplify()
test_mul()
test_simplify_minmax()
test_const_propagation()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment