Unverified Commit a422589c by pankratz Committed by GitHub

[Bugfix] Fixed bug where shifting by out-of-bounds value results in no compute…

[Bugfix] Fixed bug where shifting by out-of-bounds value results in no compute code being emitted. (#5115)

* Fixed bug where shifting by out-of-bounds RHS values results in LLVM to codegen nothing. Added regression testcase

* Updated testcase to be more precise.

* Fixed testcase
parent 9037f4ec
...@@ -469,6 +469,9 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { ...@@ -469,6 +469,9 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) <<
"Shift amount must be non-negative and less than " << rtype.bits()
<< " for type " << rtype;
if (pa && pb) return IntImm(rtype, (pa->value >> pb->value)); if (pa && pb) return IntImm(rtype, (pa->value >> pb->value));
if (pb) { if (pb) {
if (pb->value == 0) return a; if (pb->value == 0) return a;
...@@ -484,6 +487,9 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { ...@@ -484,6 +487,9 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) <<
"Shift amount must be non-negative and less than " << rtype.bits()
<< " for type " << rtype;
if (pa && pb) return IntImm(rtype, (pa->value << pb->value)); if (pa && pb) return IntImm(rtype, (pa->value << pb->value));
if (pb) { if (pb) {
if (pb->value == 0) return a; if (pb->value == 0) return a;
......
...@@ -207,6 +207,23 @@ def test_float_bitwise(): ...@@ -207,6 +207,23 @@ def test_float_bitwise():
pass pass
def test_shift_bounds():
x = te.var('x')
for test in [lambda lhs, rhs : lhs << rhs,
lambda lhs, rhs : lhs >> rhs]:
#negative case
for testcase in [(x,-1), (x,32)]:
try:
test(*testcase)
assert False
except tvm.TVMError:
pass
#positive case
for testcase in [(x,0), (x,16), (x,31)]:
test(*testcase)
def test_divide_by_zero(): def test_divide_by_zero():
for test in [lambda lhs, rhs : tvm.tir.floormod(lhs,rhs), for test in [lambda lhs, rhs : tvm.tir.floormod(lhs,rhs),
lambda lhs, rhs : tvm.tir.floordiv(lhs,rhs), lambda lhs, rhs : tvm.tir.floordiv(lhs,rhs),
...@@ -293,6 +310,7 @@ if __name__ == "__main__": ...@@ -293,6 +310,7 @@ if __name__ == "__main__":
test_all() test_all()
test_bitwise() test_bitwise()
test_float_bitwise() test_float_bitwise()
test_shift_bounds()
test_divide_by_zero() test_divide_by_zero()
test_isnan() test_isnan()
test_equality() test_equality()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment