|
| 1 | +#include "CIRGenFunction.h" |
| 2 | + |
| 3 | +using namespace cir; |
| 4 | +using namespace clang; |
| 5 | +using namespace clang::CIRGen; |
| 6 | + |
| 7 | +// vprintf takes two args: A format string, and a pointer to a buffer containing |
| 8 | +// the varargs. |
| 9 | +// |
| 10 | +// For example, the call |
| 11 | +// |
| 12 | +// printf("format string", arg1, arg2, arg3); |
| 13 | +// |
| 14 | +// is converted into something resembling |
| 15 | +// |
| 16 | +// struct Tmp { |
| 17 | +// Arg1 a1; |
| 18 | +// Arg2 a2; |
| 19 | +// Arg3 a3; |
| 20 | +// }; |
| 21 | +// char* buf = alloca(sizeof(Tmp)); |
| 22 | +// *(Tmp*)buf = {a1, a2, a3}; |
| 23 | +// vprintf("format string", buf); |
| 24 | +// |
| 25 | +// `buf` is aligned to the max of {alignof(Arg1), ...}. Furthermore, each of |
| 26 | +// the args is itself aligned to its preferred alignment. |
| 27 | +// |
| 28 | +// Note that by the time this function runs, the arguments have already |
| 29 | +// undergone the standard C vararg promotion (short -> int, float -> double |
| 30 | +// etc). In this function we pack the arguments into the buffer described above. |
| 31 | +mlir::Value packArgsIntoNVPTXFormatBuffer(CIRGenFunction &cgf, |
| 32 | + const CallArgList &args, |
| 33 | + mlir::Location loc) { |
| 34 | + const CIRDataLayout &dataLayout = cgf.CGM.getDataLayout(); |
| 35 | + CIRGenBuilderTy &builder = cgf.getBuilder(); |
| 36 | + |
| 37 | + if (args.size() <= 1) |
| 38 | + // If there are no arguments other than the format string, |
| 39 | + // pass a nullptr to vprintf. |
| 40 | + return builder.getNullPtr(cgf.VoidPtrTy, loc); |
| 41 | + |
| 42 | + llvm::SmallVector<mlir::Type, 8> argTypes; |
| 43 | + for (auto arg : llvm::drop_begin(args)) |
| 44 | + argTypes.push_back(arg.getRValue(cgf, loc).getScalarVal().getType()); |
| 45 | + |
| 46 | + // We can directly store the arguments into a struct, and the alignment |
| 47 | + // would automatically be correct. That's because vprintf does not |
| 48 | + // accept aggregates. |
| 49 | + mlir::Type allocaTy = |
| 50 | + cir::StructType::get(&cgf.getMLIRContext(), argTypes, /*packed=*/false, |
| 51 | + /*padded=*/false, StructType::Struct); |
| 52 | + mlir::Value alloca = |
| 53 | + cgf.CreateTempAlloca(allocaTy, loc, "printf_args", nullptr); |
| 54 | + |
| 55 | + for (auto [i, arg] : llvm::enumerate(llvm::drop_begin(args))) { |
| 56 | + mlir::Value member = |
| 57 | + builder.createGetMember(loc, cir::PointerType::get(argTypes[i]), alloca, |
| 58 | + /*name=*/"", /*index=*/i); |
| 59 | + auto preferredAlign = clang::CharUnits::fromQuantity( |
| 60 | + dataLayout.getPrefTypeAlign(argTypes[i]).value()); |
| 61 | + builder.createAlignedStore(loc, arg.getRValue(cgf, loc).getScalarVal(), |
| 62 | + member, preferredAlign); |
| 63 | + } |
| 64 | + |
| 65 | + return builder.createBitcast(alloca, cgf.VoidPtrTy); |
| 66 | +} |
| 67 | + |
| 68 | +mlir::Value |
| 69 | +CIRGenFunction::emitNVPTXDevicePrintfCallExpr(const CallExpr *expr) { |
| 70 | + assert(CGM.getTriple().isNVPTX()); |
| 71 | + CallArgList args; |
| 72 | + emitCallArgs(args, |
| 73 | + expr->getDirectCallee()->getType()->getAs<FunctionProtoType>(), |
| 74 | + expr->arguments(), expr->getDirectCallee()); |
| 75 | + |
| 76 | + mlir::Location loc = getLoc(expr->getBeginLoc()); |
| 77 | + |
| 78 | + // Except the format string, no non-scalar arguments are allowed for |
| 79 | + // device-side printf. |
| 80 | + bool hasNonScalar = |
| 81 | + llvm::any_of(llvm::drop_begin(args), [&](const CallArg &A) { |
| 82 | + return !A.getRValue(*this, loc).isScalar(); |
| 83 | + }); |
| 84 | + if (hasNonScalar) { |
| 85 | + CGM.ErrorUnsupported(expr, "non-scalar args to printf"); |
| 86 | + return builder.getConstInt(loc, SInt32Ty, 0); |
| 87 | + } |
| 88 | + |
| 89 | + mlir::Value packedData = packArgsIntoNVPTXFormatBuffer(*this, args, loc); |
| 90 | + |
| 91 | + // int vprintf(char *format, void *packedData); |
| 92 | + auto vprintf = CGM.createRuntimeFunction( |
| 93 | + FuncType::get({cir::PointerType::get(SInt8Ty), VoidPtrTy}, SInt32Ty), |
| 94 | + "vprintf"); |
| 95 | + auto formatString = args[0].getRValue(*this, loc).getScalarVal(); |
| 96 | + return builder.createCallOp(loc, vprintf, {formatString, packedData}) |
| 97 | + .getResult(); |
| 98 | +} |
0 commit comments