@@ -80,3 +80,96 @@ mlir::Value CIRGenFunction::emitNVPTXBuiltinExpr(unsigned builtinId,
8080 llvm_unreachable (" NYI" );
8181 }
8282}
83+
84+ // vprintf takes two args: A format string, and a pointer to a buffer containing
85+ // the varargs.
86+ //
87+ // For example, the call
88+ //
89+ // printf("format string", arg1, arg2, arg3);
90+ //
91+ // is converted into something resembling
92+ //
93+ // struct Tmp {
94+ // Arg1 a1;
95+ // Arg2 a2;
96+ // Arg3 a3;
97+ // };
98+ // char* buf = alloca(sizeof(Tmp));
99+ // *(Tmp*)buf = {a1, a2, a3};
100+ // vprintf("format string", buf);
101+ //
102+ // `buf` is aligned to the max of {alignof(Arg1), ...}. Furthermore, each of
103+ // the args is itself aligned to its preferred alignment.
104+ //
105+ // Note that by the time this function runs, the arguments have already
106+ // undergone the standard C vararg promotion (short -> int, float -> double
107+ // etc). In this function we pack the arguments into the buffer described above.
108+ mlir::Value packArgsIntoNVPTXFormatBuffer (CIRGenFunction &cgf,
109+ const CallArgList &args,
110+ mlir::Location loc) {
111+ const CIRDataLayout &dataLayout = cgf.CGM .getDataLayout ();
112+ CIRGenBuilderTy &builder = cgf.getBuilder ();
113+
114+ if (args.size () <= 1 )
115+ // If there are no arguments other than the format string,
116+ // pass a nullptr to vprintf.
117+ return builder.getNullPtr (cgf.VoidPtrTy , loc);
118+
119+ llvm::SmallVector<mlir::Type, 8 > argTypes;
120+ for (auto arg : llvm::drop_begin (args))
121+ argTypes.push_back (arg.getRValue (cgf, loc).getScalarVal ().getType ());
122+
123+ // We can directly store the arguments into a struct, and the alignment
124+ // would automatically be correct. That's because vprintf does not
125+ // accept aggregates.
126+ mlir::Type allocaTy =
127+ cir::StructType::get (&cgf.getMLIRContext (), argTypes, /* packed=*/ false ,
128+ /* padded=*/ false , StructType::Struct);
129+ mlir::Value alloca =
130+ cgf.CreateTempAlloca (allocaTy, loc, " printf_args" , nullptr );
131+
132+ for (auto [i, arg] : llvm::enumerate (llvm::drop_begin (args))) {
133+ mlir::Value member =
134+ builder.createGetMember (loc, cir::PointerType::get (argTypes[i]), alloca,
135+ /* name=*/ " " , /* index=*/ i);
136+ auto preferredAlign = clang::CharUnits::fromQuantity (
137+ dataLayout.getPrefTypeAlign (argTypes[i]).value ());
138+ builder.createAlignedStore (loc, arg.getRValue (cgf, loc).getScalarVal (),
139+ member, preferredAlign);
140+ }
141+
142+ return builder.createBitcast (alloca, cgf.VoidPtrTy );
143+ }
144+
145+ mlir::Value
146+ CIRGenFunction::emitNVPTXDevicePrintfCallExpr (const CallExpr *expr) {
147+ assert (CGM.getTriple ().isNVPTX ());
148+ CallArgList args;
149+ emitCallArgs (args,
150+ expr->getDirectCallee ()->getType ()->getAs <FunctionProtoType>(),
151+ expr->arguments (), expr->getDirectCallee ());
152+
153+ mlir::Location loc = getLoc (expr->getBeginLoc ());
154+
155+ // Except the format string, no non-scalar arguments are allowed for
156+ // device-side printf.
157+ bool hasNonScalar =
158+ llvm::any_of (llvm::drop_begin (args), [&](const CallArg &A) {
159+ return !A.getRValue (*this , loc).isScalar ();
160+ });
161+ if (hasNonScalar) {
162+ CGM.ErrorUnsupported (expr, " non-scalar args to printf" );
163+ return builder.getConstInt (loc, SInt32Ty, 0 );
164+ }
165+
166+ mlir::Value packedData = packArgsIntoNVPTXFormatBuffer (*this , args, loc);
167+
168+ // int vprintf(char *format, void *packedData);
169+ auto vprintf = CGM.createRuntimeFunction (
170+ FuncType::get ({cir::PointerType::get (SInt8Ty), VoidPtrTy}, SInt32Ty),
171+ " vprintf" );
172+ auto formatString = args[0 ].getRValue (*this , loc).getScalarVal ();
173+ return builder.createCallOp (loc, vprintf, {formatString, packedData})
174+ .getResult ();
175+ }
0 commit comments