Commit cb465e7e by PENGUINLIONG Committed by Tianqi Chen

Fix boolean type (#862)

parent 79d503fd
......@@ -311,17 +311,27 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) {
}
CHECK_LE(dtype.type.bits(), 64);
Value ret = NewValue(dtype, kConstant);
ib_.Begin(spv::OpConstant).AddSeq(dtype, ret);
uint64_t mask = 0xFFFFFFFFUL;
ib_.Add(static_cast<uint32_t>(pvalue[0] & mask));
if (dtype.type.bits() > 32) {
if (dtype.type.is_int()) {
int64_t sign_mask = 0xFFFFFFFFL;
const int64_t* sign_ptr =
reinterpret_cast<const int64_t*>(pvalue);
ib_.Add(static_cast<uint32_t>((sign_ptr[0] >> 32L) & sign_mask));
if (1 == dtype.type.bits() && dtype.is_uint()) {
// Boolean types.
if (*pvalue) {
ib_.Begin(spv::OpConstantTrue).AddSeq(ret);
} else {
ib_.Add(static_cast<uint32_t>((pvalue[0] >> 32UL) & mask));
ib_.Begin(spv::OpConstantFalse).AddSeq(ret);
}
} else {
// Integral/floating-point types.
ib_.Begin(spv::OpConstant).AddSeq(dtype, ret);
uint64_t mask = 0xFFFFFFFFUL;
ib_.Add(static_cast<uint32_t>(pvalue[0] & mask));
if (dtype.type.bits() > 32) {
if (dtype.type.is_int()) {
int64_t sign_mask = 0xFFFFFFFFL;
const int64_t* sign_ptr =
reinterpret_cast<const int64_t*>(pvalue);
ib_.Add(static_cast<uint32_t>((sign_ptr[0] >> 32L) & sign_mask));
} else {
ib_.Add(static_cast<uint32_t>((pvalue[0] >> 32UL) & mask));
}
}
}
ib_.Commit(&global_);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment