@@ -8010,14 +8010,29 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI,
80108010 }
80118011 auto *SPI = SPIRVInstTemplateBase::create (OC);
80128012 std::vector<SPIRVWord> SPArgs;
8013+ std::vector<SPIRVValue *> SPArgsValues;
80138014 for (size_t I = 0 , E = Args.size (); I != E; ++I) {
80148015 assert ((!isFunctionPointerType (Args[I]->getType ()) ||
80158016 isa<Function>(Args[I])) &&
80168017 " Invalid function pointer argument" );
8017- SPArgs.push_back (SPI->isOperandLiteral (I)
8018- ? cast<ConstantInt>(Args[I])->getZExtValue ()
8019- : transValue (Args[I], BB)->getId ());
8018+ SPArgsValues.push_back (
8019+ !SPI->isOperandLiteral (I) ? transValue (Args[I], BB) : nullptr );
8020+ SPArgs.push_back (!SPI->isOperandLiteral (I)
8021+ ? SPArgsValues.back ()->getId ()
8022+ : cast<ConstantInt>(Args[I])->getZExtValue ());
80208023 }
8024+
8025+ // fix up potential int <-> uint argument type mismatch in atomics
8026+ if (isAtomicOpCode (OC)) {
8027+ const auto last_arg_idx = SPArgs.size () - 1 ;
8028+ const auto &last_arg = SPArgsValues[last_arg_idx];
8029+ if (last_arg != nullptr && last_arg->getType () != SPRetTy) {
8030+ const auto bc_arg =
8031+ BM->addUnaryInst (spv::OpBitcast, SPRetTy, last_arg, BB);
8032+ SPArgs[last_arg_idx] = bc_arg->getId ();
8033+ }
8034+ }
8035+
80218036 BM->addInstTemplate (SPI, SPArgs, BB, SPRetTy);
80228037 if (!SPRetTy || !SPRetTy->isTypeStruct ())
80238038 return SPI;
0 commit comments