Commit d4fb0a2d by Haichen Shen Committed by Tianqi Chen

[BugFix] Fix bug in cast to bool (#3207)

parent 1b359035
......@@ -537,6 +537,14 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
if (value->getType() == target) return value;
if (to.is_handle()) {
return builder_->CreateBitCast(value, target);
} else if (to.is_uint() && to.bits() == 1) {
if (from.is_float()) {
llvm::Constant* zero = llvm::ConstantFP::get(LLVMType(from), 0.);
return builder_->CreateFCmpONE(value, zero);
} else {
llvm::Constant* zero = llvm::ConstantInt::get(LLVMType(from), 0);
return builder_->CreateICmpNE(value, zero);
}
} else if (!from.is_float() && !to.is_float()) {
return builder_->CreateIntCast(value, target, from.is_int());
} else if (from.is_float() && to.is_int()) {
......
......@@ -19,6 +19,7 @@ import tvm
import topi
import topi.testing
from topi import util
from common import get_all_backend
def test_util():
......@@ -59,8 +60,7 @@ def test_ewise():
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm', 'nvptx', 'sdaccel',
'aocl_sw_emu']:
for device in get_all_backend():
check_device(device)
......@@ -77,6 +77,46 @@ def test_ewise():
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
test_apply(topi.rsqrt, "rsqrt", lambda x:np.ones_like(x)/np.sqrt(x), 0, 100, skip_name_check=True)
def test_cast():
def verify(from_dtype, to_dtype, low=-100, high=100):
shape = (5, 4)
A = tvm.placeholder(shape, dtype=from_dtype, name="A")
B = topi.cast(A, to_dtype)
if from_dtype == "bool":
a_np = np.random.choice([True, False], size=shape)
else:
a_np = np.random.uniform(low, high, size=shape).astype(from_dtype)
if to_dtype == "bool":
a_np = a_np - a_np[2, 3]
b_np = a_np.astype(to_dtype)
for device in get_all_backend():
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
foo = tvm.build(s, [A, B], device)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.empty(shape=shape, dtype=to_dtype, ctx=ctx)
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np)
verify("int32", "float32")
verify("int32", "float64")
verify("int32", "bool")
verify("float32", "int32")
verify("float32", "float64")
verify("float32", "bool")
verify("bool", "float32")
verify("bool", "int32")
if __name__ == "__main__":
test_util()
test_ewise()
test_cast()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment