diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index 137ad75c6b10885dc019f359b98738e0e7fb5f5d..26f851031e5a5a68f1e86eb7a1c33509dd4c98da 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -311,17 +311,27 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { } CHECK_LE(dtype.type.bits(), 64); Value ret = NewValue(dtype, kConstant); - ib_.Begin(spv::OpConstant).AddSeq(dtype, ret); - uint64_t mask = 0xFFFFFFFFUL; - ib_.Add(static_cast<uint32_t>(pvalue[0] & mask)); - if (dtype.type.bits() > 32) { - if (dtype.type.is_int()) { - int64_t sign_mask = 0xFFFFFFFFL; - const int64_t* sign_ptr = - reinterpret_cast<const int64_t*>(pvalue); - ib_.Add(static_cast<uint32_t>((sign_ptr[0] >> 32L) & sign_mask)); + if (1 == dtype.type.bits() && dtype.is_uint()) { + // Boolean types. + if (*pvalue) { + ib_.Begin(spv::OpConstantTrue).AddSeq(ret); } else { - ib_.Add(static_cast<uint32_t>((pvalue[0] >> 32UL) & mask)); + ib_.Begin(spv::OpConstantFalse).AddSeq(ret); + } + } else { + // Integral/floating-point types. + ib_.Begin(spv::OpConstant).AddSeq(dtype, ret); + uint64_t mask = 0xFFFFFFFFUL; + ib_.Add(static_cast<uint32_t>(pvalue[0] & mask)); + if (dtype.type.bits() > 32) { + if (dtype.type.is_int()) { + int64_t sign_mask = 0xFFFFFFFFL; + const int64_t* sign_ptr = + reinterpret_cast<const int64_t*>(pvalue); + ib_.Add(static_cast<uint32_t>((sign_ptr[0] >> 32L) & sign_mask)); + } else { + ib_.Add(static_cast<uint32_t>((pvalue[0] >> 32UL) & mask)); + } } } ib_.Commit(&global_);