Unverified Commit 976c08ad by pankratz Committed by GitHub

Fixed bugs that occured when using bitwise operators on floating point type…

Fixed bugs that occured when using bitwise operators on floating point type expressions. Further crash when using ops <<, >>, %. Finally added regression tests for both types of bug. (#4892)
parent 08338dd5
......@@ -52,6 +52,11 @@ def _dtype_is_int(value):
return (isinstance(value, ExprOp) and
DataType(value.dtype).type_code == TypeCode.INT)
def _dtype_is_float(value):
if isinstance(value, float):
return True
return (isinstance(value, ExprOp) and
DataType(value.dtype).type_code == TypeCode.FLOAT)
class ExprOp(object):
"""Operator overloading for Expr like expressions."""
......@@ -102,6 +107,9 @@ class ExprOp(object):
def __mod__(self, other):
return _ffi_api._OpFloorMod(self, other)
def __rmod__(self, other):
return _ffi_api._OpFloorMod(other, self)
def __neg__(self):
neg_one = const(-1, self.dtype)
return self.__mul__(neg_one)
......@@ -109,9 +117,15 @@ class ExprOp(object):
def __lshift__(self, other):
return _ffi_api.left_shift(self, other)
def __rlshift__(self, other):
return _ffi_api.left_shift(other, self)
def __rshift__(self, other):
return _ffi_api.right_shift(self, other)
def __rrshift__(self, other):
return _ffi_api.right_shift(other, self)
def __and__(self, other):
return _ffi_api.bitwise_and(self, other)
......@@ -131,6 +145,8 @@ class ExprOp(object):
return _ffi_api.bitwise_xor(other, self)
def __invert__(self):
if _dtype_is_float(self):
raise RuntimeError("Cannot use ~ operator on float type Expr.")
return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)
def __lt__(self, other):
......
......@@ -417,6 +417,8 @@ PrimExpr operator!(PrimExpr a) {
}
PrimExpr operator>>(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
......@@ -430,6 +432,8 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) {
}
PrimExpr operator<<(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
......@@ -443,6 +447,8 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) {
}
PrimExpr operator&(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
......@@ -453,6 +459,8 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) {
}
PrimExpr operator|(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
......@@ -463,6 +471,8 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) {
}
PrimExpr operator^(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
......
......@@ -178,11 +178,32 @@ def test_bitwise():
assert str(10 & x) == 'bitwise_and(10, x)'
assert str(10 | x) == 'bitwise_or(10, x)'
assert str(10 ^ x) == 'bitwise_xor(10, x)'
assert str(10 >> x) == 'shift_right(10, x)'
assert str(10 << x) == 'shift_left(10, x)'
assert str(10 % x) == 'floormod(10, x)'
assert str(~x) == 'bitwise_not(x)'
assert(tvm.const(1, "int8x2") >> 1).dtype == "int8x2"
assert(x >> tvm.const(1, "int32x2")).dtype == "int32x2"
assert(tvm.var("z", "int8x2") << tvm.const(1, "int8x2")).dtype == "int8x2"
def test_float_bitwise():
t = tvm.const(1.5,dtype='float32')
for test in [lambda lhs, rhs : lhs << rhs,
lambda lhs, rhs : lhs >> rhs,
lambda lhs, rhs : lhs | rhs,
lambda lhs, rhs : lhs ^ rhs,
lambda lhs, rhs : lhs & rhs
]:
try:
test(t,10.0)
assert False
except tvm.TVMError:
pass
try:
~t
assert False
except RuntimeError:
pass
def test_isnan():
x = tvm.var('x', 'float32')
......@@ -227,6 +248,7 @@ if __name__ == "__main__":
test_any()
test_all()
test_bitwise()
test_float_bitwise()
test_isnan()
test_equality()
test_equality_string_imm()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment