Skip to content

Commit 408ca37

Browse files
committed
[CIR][CUDA] Support device-side printf
1 parent fd3f5f8 commit 408ca37

File tree

5 files changed

+140
-3
lines changed

5 files changed

+140
-3
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#include "CIRGenFunction.h"
2+
3+
using namespace cir;
4+
using namespace clang;
5+
using namespace clang::CIRGen;
6+
7+
// vprintf takes two args: A format string, and a pointer to a buffer containing
8+
// the varargs.
9+
//
10+
// For example, the call
11+
//
12+
// printf("format string", arg1, arg2, arg3);
13+
//
14+
// is converted into something resembling
15+
//
16+
// struct Tmp {
17+
// Arg1 a1;
18+
// Arg2 a2;
19+
// Arg3 a3;
20+
// };
21+
// char* buf = alloca(sizeof(Tmp));
22+
// *(Tmp*)buf = {a1, a2, a3};
23+
// vprintf("format string", buf);
24+
//
25+
// `buf` is aligned to the max of {alignof(Arg1), ...}. Furthermore, each of
26+
// the args is itself aligned to its preferred alignment.
27+
//
28+
// Note that by the time this function runs, the arguments have already
29+
// undergone the standard C vararg promotion (short -> int, float -> double
30+
// etc). In this function we pack the arguments into the buffer described above.
31+
mlir::Value packArgsIntoNVPTXFormatBuffer(CIRGenFunction &cgf,
32+
const CallArgList &args,
33+
mlir::Location loc) {
34+
const CIRDataLayout &dataLayout = cgf.CGM.getDataLayout();
35+
CIRGenBuilderTy &builder = cgf.getBuilder();
36+
37+
if (args.size() <= 1)
38+
// If there are no arguments other than the format string,
39+
// pass a nullptr to vprintf.
40+
return builder.getNullPtr(cgf.VoidPtrTy, loc);
41+
42+
llvm::SmallVector<mlir::Type, 8> argTypes;
43+
for (auto arg : llvm::drop_begin(args))
44+
argTypes.push_back(arg.getRValue(cgf, loc).getScalarVal().getType());
45+
46+
// We can directly store the arguments into a struct, and the alignment
47+
// would automatically be correct. That's because vprintf does not
48+
// accept aggregates.
49+
mlir::Type allocaTy =
50+
cir::StructType::get(&cgf.getMLIRContext(), argTypes, /*packed=*/false,
51+
/*padded=*/false, StructType::Struct);
52+
mlir::Value alloca =
53+
cgf.CreateTempAlloca(allocaTy, loc, "printf_args", nullptr);
54+
55+
for (auto [i, arg] : llvm::enumerate(llvm::drop_begin(args))) {
56+
mlir::Value member =
57+
builder.createGetMember(loc, cir::PointerType::get(argTypes[i]), alloca,
58+
/*name=*/"", /*index=*/i);
59+
auto preferredAlign = clang::CharUnits::fromQuantity(
60+
dataLayout.getPrefTypeAlign(argTypes[i]).value());
61+
builder.createAlignedStore(loc, arg.getRValue(cgf, loc).getScalarVal(),
62+
member, preferredAlign);
63+
}
64+
65+
return builder.createBitcast(alloca, cgf.VoidPtrTy);
66+
}
67+
68+
mlir::Value
69+
CIRGenFunction::emitNVPTXDevicePrintfCallExpr(const CallExpr *expr) {
70+
assert(CGM.getTriple().isNVPTX());
71+
CallArgList args;
72+
emitCallArgs(args,
73+
expr->getDirectCallee()->getType()->getAs<FunctionProtoType>(),
74+
expr->arguments(), expr->getDirectCallee());
75+
76+
mlir::Location loc = getLoc(expr->getBeginLoc());
77+
78+
// Except the format string, no non-scalar arguments are allowed for
79+
// device-side printf.
80+
bool hasNonScalar =
81+
llvm::any_of(llvm::drop_begin(args), [&](const CallArg &A) {
82+
return !A.getRValue(*this, loc).isScalar();
83+
});
84+
if (hasNonScalar) {
85+
CGM.ErrorUnsupported(expr, "non-scalar args to printf");
86+
return builder.getConstInt(loc, SInt32Ty, 0);
87+
}
88+
89+
mlir::Value packedData = packArgsIntoNVPTXFormatBuffer(*this, args, loc);
90+
91+
// int vprintf(char *format, void *packedData);
92+
auto vprintf = CGM.createRuntimeFunction(
93+
FuncType::get({cir::PointerType::get(SInt8Ty), VoidPtrTy}, SInt32Ty),
94+
"vprintf");
95+
auto formatString = args[0].getRValue(*this, loc).getScalarVal();
96+
return builder.createCallOp(loc, vprintf, {formatString, packedData})
97+
.getResult();
98+
}

clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2340,12 +2340,14 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
23402340
llvm_unreachable("BI__builtin_load_halff NYI");
23412341

23422342
case Builtin::BI__builtin_printf:
2343-
llvm_unreachable("BI__builtin_printf NYI");
23442343
case Builtin::BIprintf:
2345-
if (getTarget().getTriple().isNVPTX() ||
2346-
getTarget().getTriple().isAMDGCN()) {
2344+
assert(E->getNumArgs() >= 1);
2345+
if (getTarget().getTriple().isAMDGCN()) {
23472346
llvm_unreachable("BIprintf NYI");
23482347
}
2348+
if (getTarget().getTriple().isNVPTX()) {
2349+
return RValue::get(emitNVPTXDevicePrintfCallExpr(E));
2350+
}
23492351
break;
23502352

23512353
case Builtin::BI__builtin_canonicalize:

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,6 +1478,8 @@ class CIRGenFunction : public CIRGenTypeCache {
14781478
mlir::Value emitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
14791479
mlir::Value emitNVPTXBuiltinExpr(unsigned builtinID, const CallExpr *expr);
14801480

1481+
mlir::Value emitNVPTXDevicePrintfCallExpr(const CallExpr *expr);
1482+
14811483
/// Given an expression with a pointer type, emit the value and compute our
14821484
/// best estimate of the alignment of the pointee.
14831485
///

clang/lib/CIR/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ add_clang_library(clangCIR
4343
CIRGenTypes.cpp
4444
CIRGenVTables.cpp
4545
CIRGenerator.cpp
46+
CIRGPUBuiltin.cpp
4647
CIRPasses.cpp
4748
CIRRecordLayoutBuilder.cpp
4849
ConstantInitBuilder.cpp
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include "../Inputs/cuda.h"
2+
3+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
4+
// RUN: -fcuda-is-device -emit-cir -target-sdk-version=12.3 \
5+
// RUN: %s -o %t.cir
6+
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s
7+
8+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
9+
// RUN: -fcuda-is-device -emit-llvm -target-sdk-version=12.3 \
10+
// RUN: %s -o %t.ll
11+
// RUN: FileCheck --check-prefix=LLVM-DEVICE --input-file=%t.ll %s
12+
13+
14+
__device__ void printer() {
15+
printf("%d", 0);
16+
}
17+
18+
// CIR-DEVICE: cir.func @_Z7printerv() extra({{.*}}) {
19+
// CIR-DEVICE: %[[#Packed:]] = cir.alloca !ty_anon_struct
20+
// CIR-DEVICE: %[[#Zero:]] = cir.const #cir.int<0> : !s32i loc(#loc5)
21+
// CIR-DEVICE: %[[#Field0:]] = cir.get_member %0[0]
22+
// CIR-DEVICE: cir.store align(4) %[[#Zero]], %[[#Field0]]
23+
// CIR-DEVICE: %[[#Output:]] = cir.cast(bitcast, %[[#Packed]] : !cir.ptr<!ty_anon_struct>)
24+
// CIR-DEVICE: cir.call @vprintf(%{{.+}}, %[[#Output]])
25+
// CIR-DEVICE: cir.return
26+
// CIR-DEVICE: }
27+
28+
// LLVM-DEVICE: define dso_local void @_Z7printerv() {{.*}} {
29+
// LLVM-DEVICE: %[[#LLVMPacked:]] = alloca { i32 }, i64 1, align 8
30+
// LLVM-DEVICE: %[[#LLVMField0:]] = getelementptr { i32 }, ptr %[[#LLVMPacked]], i32 0, i32 0
31+
// LLVM-DEVICE: store i32 0, ptr %[[#LLVMField0]], align 4
32+
// LLVM-DEVICE: call i32 @vprintf(ptr @.str, ptr %[[#LLVMPacked]])
33+
// LLVM-DEVICE: ret void
34+
// LLVM-DEVICE: }

0 commit comments

Comments
 (0)