Unverified Commit 173b4fc4 by pankratz Committed by GitHub

Fixed div by zero core dump. Fixed rounding intrinsics on int crash (#5026)

parent ec86d7f1
......@@ -181,6 +181,7 @@ inline PrimExpr TryConstFold<tir::ModNode>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, pa->value % pb->value);
}
if (pa) {
......@@ -226,6 +227,7 @@ inline PrimExpr TryConstFold<tir::FloorModNode>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, floormod(pa->value, pb->value));
}
if (pa) {
......
......@@ -606,6 +606,9 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) {
}
PrimExpr floor(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::floor(fx->value));
......@@ -613,6 +616,9 @@ PrimExpr floor(PrimExpr x) {
}
PrimExpr ceil(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::ceil(fx->value));
......@@ -620,6 +626,9 @@ PrimExpr ceil(PrimExpr x) {
}
PrimExpr round(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
......@@ -627,6 +636,9 @@ PrimExpr round(PrimExpr x) {
}
PrimExpr nearbyint(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
......@@ -634,6 +646,9 @@ PrimExpr nearbyint(PrimExpr x) {
}
PrimExpr trunc(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
......
......@@ -187,14 +187,14 @@ def test_bitwise():
assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2"
def test_float_bitwise():
t = tvm.tir.const(1.5,dtype='float32')
for test in [lambda lhs, rhs : lhs << rhs,
lambda lhs, rhs : lhs >> rhs,
lambda lhs, rhs : lhs | rhs,
lambda lhs, rhs : lhs ^ rhs,
lambda lhs, rhs : lhs & rhs
]:
lambda lhs, rhs : lhs & rhs]:
try:
test(t,10.0)
assert False
......@@ -206,6 +206,20 @@ def test_float_bitwise():
except RuntimeError:
pass
def test_divide_by_zero():
for test in [lambda lhs, rhs : tvm.tir.floormod(lhs,rhs),
lambda lhs, rhs : tvm.tir.floordiv(lhs,rhs),
lambda lhs, rhs : tvm.tir.truncmod(lhs,rhs),
lambda lhs, rhs : tvm.tir.truncdiv(lhs,rhs),
lambda lhs, rhs : tvm.tir.div(lhs,rhs)]:
try:
test(tvm.tir.const(5,'int32'), tvm.tir.const(0,'int32'))
assert False
except tvm.TVMError:
pass
def test_isnan():
x = te.var('x', 'float32')
assert str(tvm.tir.isnan(x)) == 'isnan(x)'
......@@ -250,6 +264,7 @@ if __name__ == "__main__":
test_all()
test_bitwise()
test_float_bitwise()
test_divide_by_zero()
test_isnan()
test_equality()
test_equality_string_imm()
......@@ -44,6 +44,16 @@ def test_nearbyint():
tvm.testing.assert_allclose(
a_rounded.asnumpy(), np.rint(a.asnumpy()))
def test_round_intrinsics_on_int():
i = tvm.te.var("i", 'int32')
for op in [tvm.tir.round, tvm.tir.trunc, tvm.tir.ceil,
tvm.tir.floor, tvm.tir.nearbyint]:
assert op(tvm.tir.const(10,'int32')).value == 10
assert op(tvm.tir.const(True,'bool')).value == True
assert op(i).same_as(i)
assert tvm.tir.isnan(tvm.tir.const(10, 'int32')).value == False
def test_unary_intrin():
test_funcs = [
......@@ -75,3 +85,4 @@ def test_unary_intrin():
if __name__ == "__main__":
test_nearbyint()
test_unary_intrin()
test_round_intrinsics_on_int()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment