Unverified Commit 173b4fc4 by pankratz Committed by GitHub

Fixed div by zero core dump. Fixed rounding intrinsics on int crash (#5026)

parent ec86d7f1
...@@ -181,6 +181,7 @@ inline PrimExpr TryConstFold<tir::ModNode>(PrimExpr a, PrimExpr b) { ...@@ -181,6 +181,7 @@ inline PrimExpr TryConstFold<tir::ModNode>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pa && pb) { if (pa && pb) {
CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, pa->value % pb->value); return IntImm(rtype, pa->value % pb->value);
} }
if (pa) { if (pa) {
...@@ -226,6 +227,7 @@ inline PrimExpr TryConstFold<tir::FloorModNode>(PrimExpr a, PrimExpr b) { ...@@ -226,6 +227,7 @@ inline PrimExpr TryConstFold<tir::FloorModNode>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pa && pb) { if (pa && pb) {
CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, floormod(pa->value, pb->value)); return IntImm(rtype, floormod(pa->value, pb->value));
} }
if (pa) { if (pa) {
......
...@@ -606,6 +606,9 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) { ...@@ -606,6 +606,9 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) {
} }
PrimExpr floor(PrimExpr x) { PrimExpr floor(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode; using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>(); const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::floor(fx->value)); if (fx) return FloatImm(x.dtype(), std::floor(fx->value));
...@@ -613,6 +616,9 @@ PrimExpr floor(PrimExpr x) { ...@@ -613,6 +616,9 @@ PrimExpr floor(PrimExpr x) {
} }
PrimExpr ceil(PrimExpr x) { PrimExpr ceil(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode; using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>(); const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::ceil(fx->value)); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value));
...@@ -620,6 +626,9 @@ PrimExpr ceil(PrimExpr x) { ...@@ -620,6 +626,9 @@ PrimExpr ceil(PrimExpr x) {
} }
PrimExpr round(PrimExpr x) { PrimExpr round(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode; using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>(); const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
...@@ -627,6 +636,9 @@ PrimExpr round(PrimExpr x) { ...@@ -627,6 +636,9 @@ PrimExpr round(PrimExpr x) {
} }
PrimExpr nearbyint(PrimExpr x) { PrimExpr nearbyint(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode; using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>(); const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
...@@ -634,6 +646,9 @@ PrimExpr nearbyint(PrimExpr x) { ...@@ -634,6 +646,9 @@ PrimExpr nearbyint(PrimExpr x) {
} }
PrimExpr trunc(PrimExpr x) { PrimExpr trunc(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode; using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>(); const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) { if (fx) {
......
...@@ -187,14 +187,14 @@ def test_bitwise(): ...@@ -187,14 +187,14 @@ def test_bitwise():
assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2" assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2" assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2"
def test_float_bitwise(): def test_float_bitwise():
t = tvm.tir.const(1.5,dtype='float32') t = tvm.tir.const(1.5,dtype='float32')
for test in [lambda lhs, rhs : lhs << rhs, for test in [lambda lhs, rhs : lhs << rhs,
lambda lhs, rhs : lhs >> rhs, lambda lhs, rhs : lhs >> rhs,
lambda lhs, rhs : lhs | rhs, lambda lhs, rhs : lhs | rhs,
lambda lhs, rhs : lhs ^ rhs, lambda lhs, rhs : lhs ^ rhs,
lambda lhs, rhs : lhs & rhs lambda lhs, rhs : lhs & rhs]:
]:
try: try:
test(t,10.0) test(t,10.0)
assert False assert False
...@@ -206,6 +206,20 @@ def test_float_bitwise(): ...@@ -206,6 +206,20 @@ def test_float_bitwise():
except RuntimeError: except RuntimeError:
pass pass
def test_divide_by_zero():
for test in [lambda lhs, rhs : tvm.tir.floormod(lhs,rhs),
lambda lhs, rhs : tvm.tir.floordiv(lhs,rhs),
lambda lhs, rhs : tvm.tir.truncmod(lhs,rhs),
lambda lhs, rhs : tvm.tir.truncdiv(lhs,rhs),
lambda lhs, rhs : tvm.tir.div(lhs,rhs)]:
try:
test(tvm.tir.const(5,'int32'), tvm.tir.const(0,'int32'))
assert False
except tvm.TVMError:
pass
def test_isnan(): def test_isnan():
x = te.var('x', 'float32') x = te.var('x', 'float32')
assert str(tvm.tir.isnan(x)) == 'isnan(x)' assert str(tvm.tir.isnan(x)) == 'isnan(x)'
...@@ -250,6 +264,7 @@ if __name__ == "__main__": ...@@ -250,6 +264,7 @@ if __name__ == "__main__":
test_all() test_all()
test_bitwise() test_bitwise()
test_float_bitwise() test_float_bitwise()
test_divide_by_zero()
test_isnan() test_isnan()
test_equality() test_equality()
test_equality_string_imm() test_equality_string_imm()
...@@ -44,6 +44,16 @@ def test_nearbyint(): ...@@ -44,6 +44,16 @@ def test_nearbyint():
tvm.testing.assert_allclose( tvm.testing.assert_allclose(
a_rounded.asnumpy(), np.rint(a.asnumpy())) a_rounded.asnumpy(), np.rint(a.asnumpy()))
def test_round_intrinsics_on_int():
i = tvm.te.var("i", 'int32')
for op in [tvm.tir.round, tvm.tir.trunc, tvm.tir.ceil,
tvm.tir.floor, tvm.tir.nearbyint]:
assert op(tvm.tir.const(10,'int32')).value == 10
assert op(tvm.tir.const(True,'bool')).value == True
assert op(i).same_as(i)
assert tvm.tir.isnan(tvm.tir.const(10, 'int32')).value == False
def test_unary_intrin(): def test_unary_intrin():
test_funcs = [ test_funcs = [
...@@ -75,3 +85,4 @@ def test_unary_intrin(): ...@@ -75,3 +85,4 @@ def test_unary_intrin():
if __name__ == "__main__": if __name__ == "__main__":
test_nearbyint() test_nearbyint()
test_unary_intrin() test_unary_intrin()
test_round_intrinsics_on_int()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment