From 90833a4b4adc2f54c1110740cc545675c8f829cd Mon Sep 17 00:00:00 2001 From: Morris Hafner Date: Thu, 17 Apr 2025 19:04:45 +0200 Subject: [PATCH 1/8] [CIR] Always zero-extend shift amounts (#1568) Negative shift amounts are undefined behavior in C and C++. Because of that we can always zero-extend the shift amount which is slightly faster on certain architectures (e. g. x86). This also matches the behavior of the original clang Codegen. Backported from https://github.com/llvm/llvm-project/pull/133405 Co-authored-by: Morris Hafner --- clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 5 +++-- clang/test/CIR/Lowering/shift.cir | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index bb1a43893e55..4840637171ed 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -2967,10 +2967,11 @@ mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite( // behavior might occur in the casts below as per [C99 6.5.7.3]. // Vector type shift amount needs no cast as type consistency is expected to // be already be enforced at CIRGen. + // Negative shift amounts are undefined behavior so we can always zero extend + // the integer here. if (cirAmtTy) amt = getLLVMIntCast(rewriter, amt, mlir::cast(llvmTy), - !cirAmtTy.isSigned(), cirAmtTy.getWidth(), - cirValTy.getWidth()); + true, cirAmtTy.getWidth(), cirValTy.getWidth()); // Lower to the proper LLVM shift operation. if (op.getIsShiftleft()) diff --git a/clang/test/CIR/Lowering/shift.cir b/clang/test/CIR/Lowering/shift.cir index 78a7f89e13d0..f47d5955dcee 100644 --- a/clang/test/CIR/Lowering/shift.cir +++ b/clang/test/CIR/Lowering/shift.cir @@ -16,7 +16,7 @@ module { // Should allow shift with signed smaller amount type. %2 = cir.shift(left, %arg1 : !s32i, %arg0 : !s16i) -> !s32i - // CHECK: %[[#CAST:]] = llvm.sext %{{.+}} : i16 to i32 + // CHECK: %[[#CAST:]] = llvm.zext %{{.+}} : i16 to i32 // CHECK: llvm.shl %{{.+}}, %[[#CAST]] : i32 // Should allow shift with unsigned smaller amount type. From 4c4a7628275fbe3a07fa0c3b1bc3428d99b5b639 Mon Sep 17 00:00:00 2001 From: terapines open source contributor 2 Date: Fri, 18 Apr 2025 02:17:23 +0800 Subject: [PATCH 2/8] [CIR][ThroughMLIR] Lower TrapOp (#1561) `cir.trap` corresponds to two operations, `call @llvm.trap` and `unreachable`. See the test case `Lowering/intrinsics.cir`. Co-authored-by: Yue Huang --- .../Lowering/ThroughMLIR/LowerCIRToMLIR.cpp | 22 ++++++++++++++++--- .../CIR/Lowering/ThroughMLIR/unreachable.cir | 9 ++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index ee0f33ae4428..326152783980 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -1248,12 +1248,28 @@ class CIRUnreachableOpLowering mlir::LogicalResult matchAndRewrite(cir::UnreachableOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - // match and rewrite. rewriter.replaceOpWithNewOp(op); return mlir::success(); } }; +class CIRTrapOpLowering : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::TrapOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.setInsertionPointAfter(op); + auto trapIntrinsicName = rewriter.getStringAttr("llvm.trap"); + rewriter.create(op.getLoc(), trapIntrinsicName, + /*args=*/mlir::ValueRange()); + rewriter.create(op.getLoc()); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter) { patterns.add(patterns.getContext()); @@ -1274,8 +1290,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering, - CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering>( - converter, patterns.getContext()); + CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering, + CIRTrapOpLowering>(converter, patterns.getContext()); } static mlir::TypeConverter prepareTypeConverter() { diff --git a/clang/test/CIR/Lowering/ThroughMLIR/unreachable.cir b/clang/test/CIR/Lowering/ThroughMLIR/unreachable.cir index 1e94300b1b60..843f9ce41607 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/unreachable.cir +++ b/clang/test/CIR/Lowering/ThroughMLIR/unreachable.cir @@ -7,4 +7,13 @@ module { // MLIR: func.func @test_unreachable() // MLIR-NEXT: llvm.unreachable + + cir.func @test_trap() { + cir.trap + } + + // MLIR: func.func @test_trap() { + // MLIR-NEXT: llvm.call_intrinsic "llvm.trap"() : () -> () + // MLIR-NEXT: llvm.unreachable + // MLIR-NEXT: } } From 9f6742f5a8fe5ded48f8dae8551de08829aa4767 Mon Sep 17 00:00:00 2001 From: "Chibuoyim (Wilson) Ogbonna" Date: Thu, 17 Apr 2025 21:19:45 +0300 Subject: [PATCH 3/8] [CIR][CodeGen] Fix crash during exception cleanup (#1566) Currently, the following code snippet fails with a crash during CodeGen ``` class C { public: ~C(); void operator=(C); }; void d() { C a, b; a = b; } ``` with error: ``` mlir::Block* clang::CIRGen::CIRGenFunction::getEHResumeBlock(bool, cir::TryOp): Assertion `tryOp && "expected available cir.try"' failed. ``` in CIRGenCleanup [these lines](https://github.com/llvm/clangir/blob/204c03efbe898c9f64e477937d869767fdfb1310/clang/lib/CIR/CodeGen/CIRGenCleanup.cpp#L615C1-L617C6) don't check if there is a TryOp when at the end of the scope chain before [getEHResumeBlock](https://github.com/llvm/clangir/blob/204c03efbe898c9f64e477937d869767fdfb1310/clang/lib/CIR/CodeGen/CIRGenException.cpp#L764) is called causing the crash, because it contains an assertion. This PR fixes this and adds a simple test for a case like this. --- clang/lib/CIR/CodeGen/CIRGenCleanup.cpp | 4 +++ clang/test/CIR/CodeGen/try-catch-dtors.cpp | 34 ++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/clang/lib/CIR/CodeGen/CIRGenCleanup.cpp b/clang/lib/CIR/CodeGen/CIRGenCleanup.cpp index 5d4199fb42bb..242aee079f22 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCleanup.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCleanup.cpp @@ -613,6 +613,10 @@ void CIRGenFunction::PopCleanupBlock(bool FallthroughIsBranchThrough) { // Emit the EH cleanup if required. if (RequiresEHCleanup) { cir::TryOp tryOp = ehEntry->getParentOp()->getParentOfType(); + + if (EHParent == EHStack.stable_end() && !tryOp) + return; + auto *nextAction = getEHDispatchBlock(EHParent, tryOp); (void)nextAction; diff --git a/clang/test/CIR/CodeGen/try-catch-dtors.cpp b/clang/test/CIR/CodeGen/try-catch-dtors.cpp index c24240b72605..36d4430e764c 100644 --- a/clang/test/CIR/CodeGen/try-catch-dtors.cpp +++ b/clang/test/CIR/CodeGen/try-catch-dtors.cpp @@ -339,3 +339,37 @@ void bar() { // CIR: cir.store %[[V3]], %[[V1]] : !s32i, !cir.ptr // CIR: cir.call @_ZN1AD2Ev(%[[V0]]) : (!cir.ptr) -> () extra(#fn_attr) // CIR: cir.return + +class C { +public: + ~C(); + void operator=(C); +}; + +void d() { + C a, b; + a = b; +} + +// CIR: %[[V0:.*]] = cir.alloca !ty_C, !cir.ptr, ["a"] {alignment = 1 : i64} +// CIR: %[[V1:.*]] = cir.alloca !ty_C, !cir.ptr, ["b"] {alignment = 1 : i64} +// CIR: cir.scope { +// CIR: %[[V2:.*]] = cir.alloca !ty_C, !cir.ptr, ["agg.tmp0"] {alignment = 1 : i64} +// CIR: cir.call @_ZN1CC2ERKS_(%[[V2]], %[[V1]]) : (!cir.ptr, !cir.ptr) -> () extra(#fn_attr) +// CIR: %[[V3:.*]] = cir.load %[[V2]] : !cir.ptr, !ty_C +// CIR: cir.try synthetic cleanup { +// CIR: cir.call exception @_ZN1CaSES_(%[[V0]], %[[V3]]) : (!cir.ptr, !ty_C) -> () cleanup { +// CIR: cir.call @_ZN1CD1Ev(%[[V2]]) : (!cir.ptr) -> () extra(#fn_attr) +// CIR: cir.call @_ZN1CD1Ev(%[[V1]]) : (!cir.ptr) -> () extra(#fn_attr) +// CIR: cir.yield +// CIR: } +// CIR: cir.yield +// CIR: } catch [#cir.unwind { +// CIR: cir.resume +// CIR: }] +// CIR: cir.call @_ZN1CD1Ev(%[[V2]]) : (!cir.ptr) -> () extra(#fn_attr) +// CIR: cir.call @_ZN1CD1Ev(%[[V1]]) : (!cir.ptr) -> () extra(#fn_attr) +// CIR: } +// CIR: cir.call @_ZN1CD1Ev(%[[V1]]) : (!cir.ptr) -> () extra(#fn_attr) +// CIR: cir.call @_ZN1CD1Ev(%[[V0]]) : (!cir.ptr) -> () extra(#fn_attr) +// CIR: cir.return From 49914211d7b379b9e07d52dfa7164aabf3419238 Mon Sep 17 00:00:00 2001 From: gitoleg Date: Fri, 18 Apr 2025 20:39:56 +0300 Subject: [PATCH 4/8] [CIR][CodeGen] Supports const array user in the globals replacement (#1567) This is a just small fix that cover the case when the global union is declared with `static` keyword and one of the its users is an array --- clang/lib/CIR/CodeGen/CIRGenModule.cpp | 14 ++++++++++++-- clang/test/CIR/CodeGen/union-array.c | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.cpp b/clang/lib/CIR/CodeGen/CIRGenModule.cpp index 7d7f9e636f1b..a63efce20967 100644 --- a/clang/lib/CIR/CodeGen/CIRGenModule.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenModule.cpp @@ -1009,8 +1009,9 @@ void CIRGenModule::replaceGlobal(cir::GlobalOp oldSym, cir::GlobalOp newSym) { if (oldSymUses.has_value()) { for (auto use : *oldSymUses) { auto *userOp = use.getUser(); - assert((isa(userOp) || isa(userOp)) && - "GlobalOp symbol user is neither a GetGlobalOp nor a GlobalOp"); + assert( + (isa(userOp)) && + "GlobalOp symbol user is neither a GetGlobalOp nor a GlobalOp"); if (auto ggo = dyn_cast(use.getUser())) { auto useOpResultValue = ggo.getAddr(); @@ -1028,6 +1029,15 @@ void CIRGenModule::replaceGlobal(cir::GlobalOp oldSym, cir::GlobalOp newSym) { auto nw = getNewInitValue(*this, newSym, oldTy, glob, init.value()); glob.setInitialValueAttr(nw); } + } else if (auto c = dyn_cast(userOp)) { + mlir::Attribute init = + getNewInitValue(*this, newSym, oldTy, glob, c.getValue()); + auto ar = cast(init); + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(c); + auto newUser = + builder.create(c.getLoc(), ar.getType(), ar); + c.replaceAllUsesWith(newUser.getOperation()); } } } diff --git a/clang/test/CIR/CodeGen/union-array.c b/clang/test/CIR/CodeGen/union-array.c index 92ec655cf019..3788b6e23741 100644 --- a/clang/test/CIR/CodeGen/union-array.c +++ b/clang/test/CIR/CodeGen/union-array.c @@ -16,7 +16,24 @@ typedef union { S_2 b; } U; +typedef union { + int f0; + int f1; +} U1; + +static U1 g = {5}; +// LLVM: @__const.bar.x = private constant [2 x ptr] [ptr @g, ptr @g] +// LLVM: @g = internal global { i32 } { i32 5 } +// FIXME: LLVM output should be: @g = internal global %union.U { i32 5 } + + void foo() { U arr[2] = {{.b = {1, 2}}, {.a = {1}}}; } // CIR: cir.const #cir.const_record<{#cir.const_record<{#cir.const_record<{#cir.int<1> : !s64i, #cir.int<2> : !s64i}> : {{.*}}}> : {{.*}}, #cir.const_record<{#cir.const_record<{#cir.int<1> : !s8i}> : {{.*}}, #cir.const_array<[#cir.zero : !u8i, #cir.zero : !u8i, #cir.zero : !u8i, #cir.zero : !u8i, #cir.zero : !u8i, #cir.zero : !u8i, #cir.zero : !u8i, #cir.zero : !u8i, #cir.zero : !u8i, #cir.zero : !u8i, #cir.zero : !u8i, #cir.zero : !u8i, #cir.zero : !u8i, #cir.zero : !u8i, #cir.zero : !u8i]> : !cir.array}> // LLVM: store { { %struct.S_2 }, { %struct.S_1, [15 x i8] } } { { %struct.S_2 } { %struct.S_2 { i64 1, i64 2 } }, { %struct.S_1, [15 x i8] } { %struct.S_1 { i8 1 }, [15 x i8] zeroinitializer } } + +void bar(void) { + int *x[2] = { &g.f0, &g.f0 }; +} +// CIR: cir.global "private" internal dsolocal @g = #cir.const_record<{#cir.int<5> : !s32i}> : !ty_anon_struct +// CIR: cir.const #cir.const_array<[#cir.global_view<@g> : !cir.ptr, #cir.global_view<@g> : !cir.ptr]> : !cir.array x 2> From 0bae1fd51953d4b22d8d8df7109a6a5ebd102d6f Mon Sep 17 00:00:00 2001 From: Henrich Lauko Date: Fri, 18 Apr 2025 19:40:38 +0200 Subject: [PATCH 5/8] [CIR] Infer MLIR context in type builders when possible (#1570) Add `TypeBuilderWithInferredContext` to each CIR type that supports MLIR context inference from its parameters. --- .../CIR/Dialect/Builder/CIRBaseBuilder.h | 2 +- .../include/clang/CIR/Dialect/IR/CIRTypes.td | 38 ++++++++ clang/lib/CIR/CodeGen/CIRGenBuilder.h | 13 +-- .../lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp | 91 +++++++------------ clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp | 3 +- clang/lib/CIR/CodeGen/CIRGenExprConst.cpp | 11 +-- clang/lib/CIR/CodeGen/CIRGenTypes.cpp | 16 ++-- clang/lib/CIR/CodeGen/CIRGenVTables.cpp | 9 +- .../CIR/CodeGen/CIRRecordLayoutBuilder.cpp | 6 +- clang/lib/CIR/CodeGen/ConstantInitBuilder.cpp | 5 +- clang/lib/CIR/CodeGen/TargetInfo.cpp | 2 +- clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 2 +- .../TargetLowering/Targets/AArch64.cpp | 6 +- 13 files changed, 96 insertions(+), 108 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h index 224f4863554d..4744183b583e 100644 --- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h +++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h @@ -302,7 +302,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder { mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real, mlir::Value imag) { - auto resultComplexTy = cir::ComplexType::get(getContext(), real.getType()); + auto resultComplexTy = cir::ComplexType::get(real.getType()); return create(loc, resultComplexTy, real, imag); } diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index 194fb5360ee3..a66d522128a2 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -228,6 +228,12 @@ def CIR_ComplexType : CIR_Type<"Complex", "complex", let parameters = (ins "mlir::Type":$elementTy); + let builders = [ + TypeBuilderWithInferredContext<(ins "mlir::Type":$elementTy), [{ + return $_get(elementTy.getContext(), elementTy); + }]>, + ]; + let assemblyFormat = [{ `<` $elementTy `>` }]; @@ -301,6 +307,14 @@ def CIR_DataMemberType : CIR_Type<"DataMember", "data_member", let parameters = (ins "mlir::Type":$memberTy, "cir::RecordType":$clsTy); + let builders = [ + TypeBuilderWithInferredContext<(ins + "mlir::Type":$memberTy, "cir::RecordType":$clsTy + ), [{ + return $_get(memberTy.getContext(), memberTy, clsTy); + }]>, + ]; + let assemblyFormat = [{ `<` $memberTy `in` $clsTy `>` }]; @@ -338,6 +352,14 @@ def CIR_ArrayType : CIR_Type<"Array", "array", let parameters = (ins "mlir::Type":$eltType, "uint64_t":$size); + let builders = [ + TypeBuilderWithInferredContext<(ins + "mlir::Type":$eltType, "uint64_t":$size + ), [{ + return $_get(eltType.getContext(), eltType, size); + }]>, + ]; + let assemblyFormat = [{ `<` $eltType `x` $size `>` }]; @@ -358,6 +380,14 @@ def CIR_VectorType : CIR_Type<"Vector", "vector", let parameters = (ins "mlir::Type":$eltType, "uint64_t":$size); + let builders = [ + TypeBuilderWithInferredContext<(ins + "mlir::Type":$eltType, "uint64_t":$size + ), [{ + return $_get(eltType.getContext(), eltType, size); + }]>, + ]; + let assemblyFormat = [{ `<` $eltType `x` $size `>` }]; @@ -452,6 +482,14 @@ def CIR_MethodType : CIR_Type<"Method", "method", let parameters = (ins "cir::FuncType":$memberFuncTy, "cir::RecordType":$clsTy); + let builders = [ + TypeBuilderWithInferredContext<(ins + "cir::FuncType":$memberFuncTy, "cir::RecordType":$clsTy + ), [{ + return $_get(memberFuncTy.getContext(), memberFuncTy, clsTy); + }]>, + ]; + let assemblyFormat = [{ `<` qualified($memberFuncTy) `in` $clsTy `>` }]; diff --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.h b/clang/lib/CIR/CodeGen/CIRGenBuilder.h index 39ca13bcdf26..a694e6a46f80 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuilder.h +++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.h @@ -168,7 +168,7 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy { // If the string is full of null bytes, emit a #cir.zero rather than // a #cir.const_array. if (lastNonZeroPos == llvm::StringRef::npos) { - auto arrayTy = cir::ArrayType::get(getContext(), eltTy, finalSize); + auto arrayTy = cir::ArrayType::get(eltTy, finalSize); return getZeroAttr(arrayTy); } // We will use trailing zeros only if there are more than one zero @@ -176,8 +176,8 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy { int trailingZerosNum = finalSize > lastNonZeroPos + 2 ? finalSize - lastNonZeroPos - 1 : 0; auto truncatedArrayTy = - cir::ArrayType::get(getContext(), eltTy, finalSize - trailingZerosNum); - auto fullArrayTy = cir::ArrayType::get(getContext(), eltTy, finalSize); + cir::ArrayType::get(eltTy, finalSize - trailingZerosNum); + auto fullArrayTy = cir::ArrayType::get(eltTy, finalSize); return cir::ConstArrayAttr::get( getContext(), fullArrayTy, mlir::StringAttr::get(str.drop_back(trailingZerosNum), @@ -407,8 +407,7 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy { bool isSigned = false) { auto elementTy = mlir::dyn_cast_or_null(vt.getEltType()); assert(elementTy && "expected int vector"); - return cir::VectorType::get(getContext(), - isExtended + return cir::VectorType::get(isExtended ? getExtendedIntTy(elementTy, isSigned) : getTruncatedIntTy(elementTy, isSigned), vt.getSize()); @@ -530,10 +529,6 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy { return getCompleteRecordTy(members, name, packed, padded, ast); } - cir::ArrayType getArrayType(mlir::Type eltType, unsigned size) { - return cir::ArrayType::get(getContext(), eltType, size); - } - bool isSized(mlir::Type ty) { if (mlir::isa(ty)) diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp index f40ca94aae13..422872fd8d99 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp @@ -1863,14 +1863,12 @@ static cir::VectorType GetNeonType(CIRGenFunction *CGF, NeonTypeFlags TypeFlags, switch (TypeFlags.getEltType()) { case NeonTypeFlags::Int8: case NeonTypeFlags::Poly8: - return cir::VectorType::get(CGF->getBuilder().getContext(), - TypeFlags.isUnsigned() ? CGF->UInt8Ty + return cir::VectorType::get(TypeFlags.isUnsigned() ? CGF->UInt8Ty : CGF->SInt8Ty, V1Ty ? 1 : (8 << IsQuad)); case NeonTypeFlags::Int16: case NeonTypeFlags::Poly16: - return cir::VectorType::get(CGF->getBuilder().getContext(), - TypeFlags.isUnsigned() ? CGF->UInt16Ty + return cir::VectorType::get(TypeFlags.isUnsigned() ? CGF->UInt16Ty : CGF->SInt16Ty, V1Ty ? 1 : (4 << IsQuad)); case NeonTypeFlags::BFloat16: @@ -1884,14 +1882,12 @@ static cir::VectorType GetNeonType(CIRGenFunction *CGF, NeonTypeFlags TypeFlags, else llvm_unreachable("NeonTypeFlags::Float16 NYI"); case NeonTypeFlags::Int32: - return cir::VectorType::get(CGF->getBuilder().getContext(), - TypeFlags.isUnsigned() ? CGF->UInt32Ty + return cir::VectorType::get(TypeFlags.isUnsigned() ? CGF->UInt32Ty : CGF->SInt32Ty, V1Ty ? 1 : (2 << IsQuad)); case NeonTypeFlags::Int64: case NeonTypeFlags::Poly64: - return cir::VectorType::get(CGF->getBuilder().getContext(), - TypeFlags.isUnsigned() ? CGF->UInt64Ty + return cir::VectorType::get(TypeFlags.isUnsigned() ? CGF->UInt64Ty : CGF->SInt64Ty, V1Ty ? 1 : (1 << IsQuad)); case NeonTypeFlags::Poly128: @@ -1900,12 +1896,10 @@ static cir::VectorType GetNeonType(CIRGenFunction *CGF, NeonTypeFlags TypeFlags, // so we use v16i8 to represent poly128 and get pattern matched. llvm_unreachable("NeonTypeFlags::Poly128 NYI"); case NeonTypeFlags::Float32: - return cir::VectorType::get(CGF->getBuilder().getContext(), - CGF->getCIRGenModule().FloatTy, + return cir::VectorType::get(CGF->getCIRGenModule().FloatTy, V1Ty ? 1 : (2 << IsQuad)); case NeonTypeFlags::Float64: - return cir::VectorType::get(CGF->getBuilder().getContext(), - CGF->getCIRGenModule().DoubleTy, + return cir::VectorType::get(CGF->getCIRGenModule().DoubleTy, V1Ty ? 1 : (1 << IsQuad)); } llvm_unreachable("Unknown vector element type!"); @@ -2102,7 +2096,7 @@ static cir::VectorType getSignChangedVectorType(CIRGenBuilderTy &builder, auto elemTy = mlir::cast(vecTy.getEltType()); elemTy = elemTy.isSigned() ? builder.getUIntNTy(elemTy.getWidth()) : builder.getSIntNTy(elemTy.getWidth()); - return cir::VectorType::get(builder.getContext(), elemTy, vecTy.getSize()); + return cir::VectorType::get(elemTy, vecTy.getSize()); } static cir::VectorType @@ -2111,19 +2105,16 @@ getHalfEltSizeTwiceNumElemsVecType(CIRGenBuilderTy &builder, auto elemTy = mlir::cast(vecTy.getEltType()); elemTy = elemTy.isSigned() ? builder.getSIntNTy(elemTy.getWidth() / 2) : builder.getUIntNTy(elemTy.getWidth() / 2); - return cir::VectorType::get(builder.getContext(), elemTy, - vecTy.getSize() * 2); + return cir::VectorType::get(elemTy, vecTy.getSize() * 2); } static cir::VectorType castVecOfFPTypeToVecOfIntWithSameWidth(CIRGenBuilderTy &builder, cir::VectorType vecTy) { if (mlir::isa(vecTy.getEltType())) - return cir::VectorType::get(builder.getContext(), builder.getSInt32Ty(), - vecTy.getSize()); + return cir::VectorType::get(builder.getSInt32Ty(), vecTy.getSize()); if (mlir::isa(vecTy.getEltType())) - return cir::VectorType::get(builder.getContext(), builder.getSInt64Ty(), - vecTy.getSize()); + return cir::VectorType::get(builder.getSInt64Ty(), vecTy.getSize()); llvm_unreachable( "Unsupported element type in getVecOfIntTypeWithSameEltWidth"); } @@ -2315,8 +2306,7 @@ static mlir::Value emitCommonNeonVecAcrossCall(CIRGenFunction &cgf, const clang::CallExpr *e) { CIRGenBuilderTy &builder = cgf.getBuilder(); mlir::Value op = cgf.emitScalarExpr(e->getArg(0)); - cir::VectorType vTy = - cir::VectorType::get(&cgf.getMLIRContext(), eltTy, vecLen); + cir::VectorType vTy = cir::VectorType::get(eltTy, vecLen); llvm::SmallVector args{op}; return emitNeonCall(builder, {vTy}, args, intrincsName, eltTy, cgf.getLoc(e->getExprLoc())); @@ -2447,8 +2437,7 @@ mlir::Value CIRGenFunction::emitCommonNeonBuiltinExpr( cir::VectorType resTy = (builtinID == NEON::BI__builtin_neon_vqdmulhq_lane_v || builtinID == NEON::BI__builtin_neon_vqrdmulhq_lane_v) - ? cir::VectorType::get(&getMLIRContext(), vTy.getEltType(), - vTy.getSize() * 2) + ? cir::VectorType::get(vTy.getEltType(), vTy.getSize() * 2) : vTy; cir::VectorType mulVecT = GetNeonType(this, NeonTypeFlags(neonType.getEltType(), false, @@ -2888,10 +2877,8 @@ static mlir::Value emitCommonNeonSISDBuiltinExpr( llvm_unreachable(" neon_vqmovnh_u16 NYI "); case NEON::BI__builtin_neon_vqmovns_s32: { mlir::Location loc = cgf.getLoc(expr->getExprLoc()); - cir::VectorType argVecTy = - cir::VectorType::get(&(cgf.getMLIRContext()), cgf.SInt32Ty, 4); - cir::VectorType resVecTy = - cir::VectorType::get(&(cgf.getMLIRContext()), cgf.SInt16Ty, 4); + cir::VectorType argVecTy = cir::VectorType::get(cgf.SInt32Ty, 4); + cir::VectorType resVecTy = cir::VectorType::get(cgf.SInt16Ty, 4); vecExtendIntValue(cgf, argVecTy, ops[0], loc); mlir::Value result = emitNeonCall(builder, {argVecTy}, ops, "aarch64.neon.sqxtn", resVecTy, loc); @@ -3706,88 +3693,74 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E, case NEON::BI__builtin_neon_vset_lane_f64: { Ops.push_back(emitScalarExpr(E->getArg(2))); - Ops[1] = builder.createBitcast( - Ops[1], cir::VectorType::get(&getMLIRContext(), DoubleTy, 1)); + Ops[1] = builder.createBitcast(Ops[1], cir::VectorType::get(DoubleTy, 1)); return builder.create(getLoc(E->getExprLoc()), Ops[1], Ops[0], Ops[2]); } case NEON::BI__builtin_neon_vsetq_lane_f64: { Ops.push_back(emitScalarExpr(E->getArg(2))); - Ops[1] = builder.createBitcast( - Ops[1], cir::VectorType::get(&getMLIRContext(), DoubleTy, 2)); + Ops[1] = builder.createBitcast(Ops[1], cir::VectorType::get(DoubleTy, 2)); return builder.create(getLoc(E->getExprLoc()), Ops[1], Ops[0], Ops[2]); } case NEON::BI__builtin_neon_vget_lane_i8: case NEON::BI__builtin_neon_vdupb_lane_i8: - Ops[0] = builder.createBitcast( - Ops[0], cir::VectorType::get(&getMLIRContext(), UInt8Ty, 8)); + Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt8Ty, 8)); return builder.create(getLoc(E->getExprLoc()), Ops[0], emitScalarExpr(E->getArg(1))); case NEON::BI__builtin_neon_vgetq_lane_i8: case NEON::BI__builtin_neon_vdupb_laneq_i8: - Ops[0] = builder.createBitcast( - Ops[0], cir::VectorType::get(&getMLIRContext(), UInt8Ty, 16)); + Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt8Ty, 16)); return builder.create(getLoc(E->getExprLoc()), Ops[0], emitScalarExpr(E->getArg(1))); case NEON::BI__builtin_neon_vget_lane_i16: case NEON::BI__builtin_neon_vduph_lane_i16: - Ops[0] = builder.createBitcast( - Ops[0], cir::VectorType::get(&getMLIRContext(), UInt16Ty, 4)); + Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt16Ty, 4)); return builder.create(getLoc(E->getExprLoc()), Ops[0], emitScalarExpr(E->getArg(1))); case NEON::BI__builtin_neon_vgetq_lane_i16: case NEON::BI__builtin_neon_vduph_laneq_i16: - Ops[0] = builder.createBitcast( - Ops[0], cir::VectorType::get(&getMLIRContext(), UInt16Ty, 8)); + Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt16Ty, 8)); return builder.create(getLoc(E->getExprLoc()), Ops[0], emitScalarExpr(E->getArg(1))); case NEON::BI__builtin_neon_vget_lane_i32: case NEON::BI__builtin_neon_vdups_lane_i32: - Ops[0] = builder.createBitcast( - Ops[0], cir::VectorType::get(&getMLIRContext(), UInt32Ty, 2)); + Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt32Ty, 2)); return builder.create(getLoc(E->getExprLoc()), Ops[0], emitScalarExpr(E->getArg(1))); case NEON::BI__builtin_neon_vget_lane_f32: case NEON::BI__builtin_neon_vdups_lane_f32: - Ops[0] = builder.createBitcast( - Ops[0], cir::VectorType::get(&getMLIRContext(), FloatTy, 2)); + Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(FloatTy, 2)); return builder.create(getLoc(E->getExprLoc()), Ops[0], emitScalarExpr(E->getArg(1))); case NEON::BI__builtin_neon_vgetq_lane_i32: case NEON::BI__builtin_neon_vdups_laneq_i32: - Ops[0] = builder.createBitcast( - Ops[0], cir::VectorType::get(&getMLIRContext(), UInt32Ty, 4)); + Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt32Ty, 4)); return builder.create(getLoc(E->getExprLoc()), Ops[0], emitScalarExpr(E->getArg(1))); case NEON::BI__builtin_neon_vget_lane_i64: case NEON::BI__builtin_neon_vdupd_lane_i64: - Ops[0] = builder.createBitcast( - Ops[0], cir::VectorType::get(&getMLIRContext(), UInt64Ty, 1)); + Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt64Ty, 1)); return builder.create(getLoc(E->getExprLoc()), Ops[0], emitScalarExpr(E->getArg(1))); case NEON::BI__builtin_neon_vdupd_lane_f64: case NEON::BI__builtin_neon_vget_lane_f64: - Ops[0] = builder.createBitcast( - Ops[0], cir::VectorType::get(&getMLIRContext(), DoubleTy, 1)); + Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(DoubleTy, 1)); return builder.create(getLoc(E->getExprLoc()), Ops[0], emitScalarExpr(E->getArg(1))); case NEON::BI__builtin_neon_vgetq_lane_i64: case NEON::BI__builtin_neon_vdupd_laneq_i64: - Ops[0] = builder.createBitcast( - Ops[0], cir::VectorType::get(&getMLIRContext(), UInt64Ty, 2)); + Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt64Ty, 2)); return builder.create(getLoc(E->getExprLoc()), Ops[0], emitScalarExpr(E->getArg(1))); case NEON::BI__builtin_neon_vgetq_lane_f32: case NEON::BI__builtin_neon_vdups_laneq_f32: - Ops[0] = builder.createBitcast( - Ops[0], cir::VectorType::get(&getMLIRContext(), FloatTy, 4)); + Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(FloatTy, 4)); return builder.create(getLoc(E->getExprLoc()), Ops[0], emitScalarExpr(E->getArg(1))); case NEON::BI__builtin_neon_vgetq_lane_f64: case NEON::BI__builtin_neon_vdupd_laneq_f64: - Ops[0] = builder.createBitcast( - Ops[0], cir::VectorType::get(&getMLIRContext(), DoubleTy, 2)); + Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(DoubleTy, 2)); return builder.create(getLoc(E->getExprLoc()), Ops[0], emitScalarExpr(E->getArg(1))); case NEON::BI__builtin_neon_vaddh_f16: { @@ -4318,7 +4291,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E, [[fallthrough]]; case NEON::BI__builtin_neon_vaddv_s16: { cir::IntType eltTy = usgn ? UInt16Ty : SInt16Ty; - cir::VectorType vTy = cir::VectorType::get(builder.getContext(), eltTy, 4); + cir::VectorType vTy = cir::VectorType::get(eltTy, 4); Ops.push_back(emitScalarExpr(E->getArg(0))); // This is to add across the vector elements, so wider result type needed. Ops[0] = emitNeonCall(builder, {vTy}, Ops, @@ -4427,8 +4400,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E, usgn = true; [[fallthrough]]; case NEON::BI__builtin_neon_vaddlvq_s16: { - mlir::Type argTy = cir::VectorType::get(builder.getContext(), - usgn ? UInt16Ty : SInt16Ty, 8); + mlir::Type argTy = cir::VectorType::get(usgn ? UInt16Ty : SInt16Ty, 8); llvm::SmallVector argOps = {emitScalarExpr(E->getArg(0))}; return emitNeonCall(builder, {argTy}, argOps, usgn ? "aarch64.neon.uaddlv" : "aarch64.neon.saddlv", @@ -4441,8 +4413,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E, usgn = true; [[fallthrough]]; case NEON::BI__builtin_neon_vaddlv_s16: { - mlir::Type argTy = cir::VectorType::get(builder.getContext(), - usgn ? UInt16Ty : SInt16Ty, 4); + mlir::Type argTy = cir::VectorType::get(usgn ? UInt16Ty : SInt16Ty, 4); llvm::SmallVector argOps = {emitScalarExpr(E->getArg(0))}; return emitNeonCall(builder, {argTy}, argOps, usgn ? "aarch64.neon.uaddlv" : "aarch64.neon.saddlv", diff --git a/clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp b/clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp index e76f20026bf6..5cb59f1690c9 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp @@ -129,8 +129,7 @@ void CIRGenNVCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf, // we need to pass it as `void *args[2] = { &a, &b }`. auto loc = fn.getLoc(); - auto voidPtrArrayTy = - cir::ArrayType::get(&cgm.getMLIRContext(), cgm.VoidPtrTy, args.size()); + auto voidPtrArrayTy = cir::ArrayType::get(cgm.VoidPtrTy, args.size()); mlir::Value kernelArgs = builder.createAlloca( loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy, "kernel_args", CharUnits::fromQuantity(16)); diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConst.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConst.cpp index d54a19cbbe6a..c05a9099c95b 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprConst.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprConst.cpp @@ -50,7 +50,7 @@ static mlir::TypedAttr computePadding(CIRGenModule &CGM, CharUnits size) { if (size > CharUnits::One()) { SmallVector elts(arSize, bld.getZeroAttr(eltTy)); return bld.getConstArray(mlir::ArrayAttr::get(bld.getContext(), elts), - bld.getArrayType(eltTy, arSize)); + cir::ArrayType::get(eltTy, arSize)); } else { return cir::ZeroAttr::get(bld.getContext(), eltTy); } @@ -1303,8 +1303,7 @@ emitArrayConstant(CIRGenModule &CGM, mlir::Type DesiredType, return builder.getConstArray( mlir::ArrayAttr::get(builder.getContext(), Eles), - cir::ArrayType::get(builder.getContext(), CommonElementType, - ArrayBound)); + cir::ArrayType::get(CommonElementType, ArrayBound)); // TODO(cir): If all the elements had the same type up to the trailing // zeroes, emit a record of two arrays (the nonzero data and the // zeroinitializer). Use DesiredType to get the element type. @@ -1324,8 +1323,7 @@ emitArrayConstant(CIRGenModule &CGM, mlir::Type DesiredType, return builder.getConstArray( mlir::ArrayAttr::get(builder.getContext(), Eles), - cir::ArrayType::get(builder.getContext(), CommonElementType, - ArrayBound)); + cir::ArrayType::get(CommonElementType, ArrayBound)); } SmallVector Eles; @@ -1831,8 +1829,7 @@ mlir::Attribute ConstantEmitter::emitForMemory(CIRGenModule &CGM, assert(innerSize < outerSize && "emitted over-large constant for atomic"); auto &builder = CGM.getBuilder(); auto zeroArray = builder.getZeroInitAttr( - cir::ArrayType::get(builder.getContext(), builder.getUInt8Ty(), - (outerSize - innerSize) / 8)); + cir::ArrayType::get(builder.getUInt8Ty(), (outerSize - innerSize) / 8)); SmallVector anonElts = {C, zeroArray}; auto arrAttr = mlir::ArrayAttr::get(builder.getContext(), anonElts); return builder.getAnonConstRecord(arrAttr, false); diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp index f9d3569ab0aa..cb4741fb8df0 100644 --- a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp @@ -602,7 +602,7 @@ mlir::Type CIRGenTypes::convertType(QualType T) { case Type::Complex: { const ComplexType *CT = cast(Ty); auto ElementTy = convertType(CT->getElementType()); - ResultType = cir::ComplexType::get(Builder.getContext(), ElementTy); + ResultType = cir::ComplexType::get(ElementTy); break; } case Type::LValueReference: @@ -650,7 +650,7 @@ mlir::Type CIRGenTypes::convertType(QualType T) { SkippedLayout = true; ResultType = Builder.getUInt8Ty(); } - ResultType = Builder.getArrayType(ResultType, 0); + ResultType = cir::ArrayType::get(ResultType, 0); break; } case Type::ConstantArray: { @@ -660,16 +660,14 @@ mlir::Type CIRGenTypes::convertType(QualType T) { // FIXME: In LLVM, "lower arrays of undefined struct type to arrays of // i8 just to have a concrete type". Not sure this makes sense in CIR yet. assert(Builder.isSized(EltTy) && "not implemented"); - ResultType = cir::ArrayType::get(Builder.getContext(), EltTy, - A->getSize().getZExtValue()); + ResultType = cir::ArrayType::get(EltTy, A->getSize().getZExtValue()); break; } case Type::ExtVector: case Type::Vector: { const VectorType *V = cast(Ty); auto ElementType = convertTypeForMem(V->getElementType()); - ResultType = cir::VectorType::get(Builder.getContext(), ElementType, - V->getNumElements()); + ResultType = cir::VectorType::get(ElementType, V->getNumElements()); break; } case Type::ConstantMatrix: { @@ -717,12 +715,10 @@ mlir::Type CIRGenTypes::convertType(QualType T) { auto clsTy = mlir::cast(convertType(QualType(MPT->getClass(), 0))); if (MPT->isMemberDataPointer()) - ResultType = - cir::DataMemberType::get(Builder.getContext(), memberTy, clsTy); + ResultType = cir::DataMemberType::get(memberTy, clsTy); else { auto memberFuncTy = mlir::cast(memberTy); - ResultType = - cir::MethodType::get(Builder.getContext(), memberFuncTy, clsTy); + ResultType = cir::MethodType::get(memberFuncTy, clsTy); } break; } diff --git a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp index 6a4e534c09f3..53da50204f0a 100644 --- a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp @@ -58,8 +58,7 @@ mlir::Type CIRGenVTables::getVTableType(const VTableLayout &layout) { mlir::MLIRContext *mlirContext = CGM.getBuilder().getContext(); auto componentType = getVTableComponentType(); for (unsigned i = 0, e = layout.getNumVTables(); i != e; ++i) - tys.push_back(cir::ArrayType::get(mlirContext, componentType, - layout.getVTableSize(i))); + tys.push_back(cir::ArrayType::get(componentType, layout.getVTableSize(i))); // FIXME(cir): should VTableLayout be encoded like we do for some // AST nodes? @@ -519,8 +518,7 @@ cir::GlobalOp CIRGenVTables::getAddrOfVTT(const CXXRecordDecl *RD) { VTTBuilder Builder(CGM.getASTContext(), RD, /*GenerateDefinition=*/false); - auto ArrayType = cir::ArrayType::get(CGM.getBuilder().getContext(), - CGM.getBuilder().getUInt8PtrTy(), + auto ArrayType = cir::ArrayType::get(CGM.getBuilder().getUInt8PtrTy(), Builder.getVTTComponents().size()); auto Align = CGM.getDataLayout().getABITypeAlign(CGM.getBuilder().getUInt8PtrTy()); @@ -590,8 +588,7 @@ void CIRGenVTables::emitVTTDefinition(cir::GlobalOp VTT, const CXXRecordDecl *RD) { VTTBuilder Builder(CGM.getASTContext(), RD, /*GenerateDefinition=*/true); - auto ArrayType = cir::ArrayType::get(CGM.getBuilder().getContext(), - CGM.getBuilder().getUInt8PtrTy(), + auto ArrayType = cir::ArrayType::get(CGM.getBuilder().getUInt8PtrTy(), Builder.getVTTComponents().size()); SmallVector VTables; diff --git a/clang/lib/CIR/CodeGen/CIRRecordLayoutBuilder.cpp b/clang/lib/CIR/CodeGen/CIRRecordLayoutBuilder.cpp index 231051042502..578b22f6b3be 100644 --- a/clang/lib/CIR/CodeGen/CIRRecordLayoutBuilder.cpp +++ b/clang/lib/CIR/CodeGen/CIRRecordLayoutBuilder.cpp @@ -152,8 +152,7 @@ struct CIRRecordLowering final { mlir::Type type = getCharType(); return numberOfChars == CharUnits::One() ? type - : cir::ArrayType::get(type.getContext(), type, - numberOfChars.getQuantity()); + : cir::ArrayType::get(type, numberOfChars.getQuantity()); } // This is different from LLVM traditional codegen because CIRGen uses arrays @@ -165,8 +164,7 @@ struct CIRRecordLowering final { return builder.getUIntNTy(alignedBits); } else { mlir::Type type = getCharType(); - return cir::ArrayType::get(type.getContext(), type, - alignedBits / astContext.getCharWidth()); + return cir::ArrayType::get(type, alignedBits / astContext.getCharWidth()); } } diff --git a/clang/lib/CIR/CodeGen/ConstantInitBuilder.cpp b/clang/lib/CIR/CodeGen/ConstantInitBuilder.cpp index 1ef96e73d113..8eafe58259f6 100644 --- a/clang/lib/CIR/CodeGen/ConstantInitBuilder.cpp +++ b/clang/lib/CIR/CodeGen/ConstantInitBuilder.cpp @@ -292,9 +292,8 @@ mlir::Attribute ConstantAggregateBuilderBase::finishArray(mlir::Type eltTy) { // eltTy = tAttr.getType(); } - auto constant = getConstArray( - mlir::ArrayAttr::get(eltTy.getContext(), elts), - cir::ArrayType::get(eltTy.getContext(), eltTy, elts.size())); + auto constant = getConstArray(mlir::ArrayAttr::get(eltTy.getContext(), elts), + cir::ArrayType::get(eltTy, elts.size())); buffer.erase(buffer.begin() + Begin, buffer.end()); return constant; } diff --git a/clang/lib/CIR/CodeGen/TargetInfo.cpp b/clang/lib/CIR/CodeGen/TargetInfo.cpp index d8a3d9cc1abe..43d3835ca9f3 100644 --- a/clang/lib/CIR/CodeGen/TargetInfo.cpp +++ b/clang/lib/CIR/CodeGen/TargetInfo.cpp @@ -440,7 +440,7 @@ cir::VectorType ABIInfo::getOptimalVectorMemoryType(cir::VectorType T, const clang::LangOptions &Opt) const { if (T.getSize() == 3 && !Opt.PreserveVec3Type) { - return cir::VectorType::get(&CGT.getMLIRContext(), T.getEltType(), 4); + return cir::VectorType::get(T.getEltType(), 4); } return T; } diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 948a955141d6..d157048c504b 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -936,7 +936,7 @@ OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) { assert(realAttr.getType() == imagAttr.getType() && "real part and imag part should be of the same type"); - auto complexTy = cir::ComplexType::get(getContext(), realAttr.getType()); + auto complexTy = cir::ComplexType::get(realAttr.getType()); return cir::ComplexAttr::get(complexTy, realAttr, imagAttr); } diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/AArch64.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/AArch64.cpp index cac197fae1bc..6751bd7d99a0 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/AArch64.cpp +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/AArch64.cpp @@ -128,8 +128,7 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(mlir::Type RetTy, // For aggregates with 16-byte alignment, we use i128. if (Alignment < 128 && Size == 128) { mlir::Type baseTy = cir::IntType::get(LT.getMLIRContext(), 64, false); - return ABIArgInfo::getDirect( - cir::ArrayType::get(LT.getMLIRContext(), baseTy, Size / 64)); + return ABIArgInfo::getDirect(cir::ArrayType::get(baseTy, Size / 64)); } return ABIArgInfo::getDirect( @@ -180,8 +179,7 @@ AArch64ABIInfo::classifyArgumentType(mlir::Type Ty, bool IsVariadic, cir::IntType::get(LT.getMLIRContext(), Alignment, false); auto argTy = Size == Alignment ? baseTy - : cir::ArrayType::get(LT.getMLIRContext(), baseTy, - Size / Alignment); + : cir::ArrayType::get(baseTy, Size / Alignment); return ABIArgInfo::getDirect(argTy); } From 390cf2516a599652aedefdc7c6d10410393e51c2 Mon Sep 17 00:00:00 2001 From: Andy Kaylor Date: Fri, 18 Apr 2025 10:43:50 -0700 Subject: [PATCH 6/8] [CIR] Replace RecordType data layout calculations (#1569) We have been using RecordLayoutAttr to "cache" data layout information calculated for records. Unfortunately, it wasn't actually caching the information, and because each call was calculating more information than it needed, it was doing extra work. This replaces the previous implementation with a set of functions that compute only the information needed. Ideally, we would like to have a mechanism to properly cache this information, but until such a mechanism is implemented, these new functions should be a small step forward. --- .../include/clang/CIR/Dialect/IR/CIRAttrs.td | 39 ---- .../include/clang/CIR/Dialect/IR/CIRTypes.td | 11 +- clang/lib/CIR/Dialect/IR/CIRAttrs.cpp | 12 - clang/lib/CIR/Dialect/IR/CIRTypes.cpp | 211 +++++++++++------- 4 files changed, 133 insertions(+), 140 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td index 0b1316fc4012..e78baa37bccd 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td @@ -838,45 +838,6 @@ def VTableAttr : CIR_Attr<"VTable", "vtable", [TypedAttrInterface]> { }]; } -//===----------------------------------------------------------------------===// -// RecordLayoutAttr -//===----------------------------------------------------------------------===// - -// Used to decouple layout information from the record type. RecordType's -// uses this attribute to cache that information. - -def RecordLayoutAttr : CIR_Attr<"RecordLayout", "record_layout"> { - let summary = "ABI specific information about a record layout"; - let description = [{ - Holds layout information often queried by !cir.record users - during lowering passes and optimizations. - }]; - - let parameters = (ins "unsigned":$size, - "unsigned":$alignment, - "bool":$padded, - "mlir::Type":$largest_member, - "mlir::ArrayAttr":$offsets); - - let builders = [ - AttrBuilderWithInferredContext<(ins "unsigned":$size, - "unsigned":$alignment, - "bool":$padded, - "mlir::Type":$largest_member, - "mlir::ArrayAttr":$offsets), [{ - return $_get(largest_member.getContext(), size, alignment, padded, - largest_member, offsets); - }]>, - ]; - - let genVerifyDecl = 1; - let assemblyFormat = [{ - `<` - struct($size, $alignment, $padded, $largest_member, $offsets) - `>` - }]; -} - //===----------------------------------------------------------------------===// // DynamicCastInfoAttr //===----------------------------------------------------------------------===// diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index a66d522128a2..8d1dd609e820 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -781,13 +781,12 @@ def CIR_RecordType : CIR_Type<"Record", "record", bool isLayoutIdentical(const RecordType &other); - // Utilities for lazily computing and cacheing data layout info. - // FIXME: currently opaque because there's a cycle if CIRTypes.types include - // from CIRAttrs.h. The implementation operates in terms of RecordLayoutAttr - // instead. + // Utilities for computing data layout info private: - mutable mlir::Attribute layoutInfo; - void computeSizeAndAlignment(const mlir::DataLayout &dataLayout) const; + unsigned computeStructSize(const mlir::DataLayout &dataLayout) const; + unsigned computeUnionSize(const mlir::DataLayout &dataLayout) const; + uint64_t computeStructAlignment(const mlir::DataLayout &dataLayout) const; + uint64_t computeUnionAlignment(const mlir::DataLayout &dataLayout) const; public: }]; diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp index 19da8f32519c..88d5bebd8e7e 100644 --- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp @@ -170,18 +170,6 @@ LogicalResult ConstRecordAttr::verify( return success(); } -LogicalResult RecordLayoutAttr::verify( - ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned size, - unsigned alignment, bool padded, mlir::Type largest_member, - mlir::ArrayAttr offsets) { - if (not std::all_of(offsets.begin(), offsets.end(), [](mlir::Attribute attr) { - return mlir::isa(attr); - })) { - return emitError() << "all index values must be integers"; - } - return success(); -} - //===----------------------------------------------------------------------===// // LangAttr definitions //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp index 4955edfde4de..125c9fe99f73 100644 --- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp @@ -112,9 +112,34 @@ void CIRDialect::printType(Type type, DialectAsmPrinter &os) const { /// /// Recurses into union members never returning a union as the largest member. Type RecordType::getLargestMember(const ::mlir::DataLayout &dataLayout) const { - if (!layoutInfo) - computeSizeAndAlignment(dataLayout); - return mlir::cast(layoutInfo).getLargestMember(); + assert(isUnion() && "getLargetMember called for non-union record"); + + // This is a similar algorithm to LLVM's StructLayout. + unsigned numElements = getNumElements(); + auto members = getMembers(); + mlir::Type largestMember; + unsigned largestMemberSize = 0; + + // Ignore the last member if this is a padded union. + if (getPadded()) + --numElements; + + for (unsigned i = 0, e = numElements; i != e; ++i) { + auto ty = members[i]; + + // Found a nested union: recurse into it to fetch its largest member. + if (!largestMember || + dataLayout.getTypeABIAlignment(ty) > + dataLayout.getTypeABIAlignment(largestMember) || + (dataLayout.getTypeABIAlignment(ty) == + dataLayout.getTypeABIAlignment(largestMember) && + dataLayout.getTypeSize(ty) > largestMemberSize)) { + largestMember = ty; + largestMemberSize = dataLayout.getTypeSize(largestMember); + } + } + + return largestMember; } Type RecordType::parse(mlir::AsmParser &parser) { @@ -385,117 +410,137 @@ cir::VectorType::getABIAlignment(const ::mlir::DataLayout &dataLayout, return llvm::NextPowerOf2(dataLayout.getTypeSizeInBits(*this)); } +// TODO(cir): Implement a way to cache the datalayout info calculated below. + llvm::TypeSize -RecordType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout, - ::mlir::DataLayoutEntryListRef params) const { - if (!layoutInfo) - computeSizeAndAlignment(dataLayout); - return llvm::TypeSize::getFixed( - mlir::cast(layoutInfo).getSize() * 8); +RecordType::getTypeSizeInBits(const mlir::DataLayout &dataLayout, + mlir::DataLayoutEntryListRef params) const { + if (isUnion()) + return llvm::TypeSize::getFixed(computeUnionSize(dataLayout) * 8); + + return llvm::TypeSize::getFixed(computeStructSize(dataLayout) * 8); } uint64_t RecordType::getABIAlignment(const ::mlir::DataLayout &dataLayout, ::mlir::DataLayoutEntryListRef params) const { - if (!layoutInfo) - computeSizeAndAlignment(dataLayout); - return mlir::cast(layoutInfo).getAlignment(); -} + // Packed structures always have an ABI alignment of 1. + if (getPacked()) + return 1; -uint64_t RecordType::getElementOffset(const ::mlir::DataLayout &dataLayout, - unsigned idx) const { - assert(idx < getMembers().size() && "access not valid"); - if (!layoutInfo) - computeSizeAndAlignment(dataLayout); - auto offsets = mlir::cast(layoutInfo).getOffsets(); - auto intAttr = mlir::cast(offsets[idx]); - return intAttr.getInt(); + if (isUnion()) + return computeUnionAlignment(dataLayout); + return computeStructAlignment(dataLayout); } -void RecordType::computeSizeAndAlignment( - const ::mlir::DataLayout &dataLayout) const { +unsigned +RecordType::computeUnionSize(const mlir::DataLayout &dataLayout) const { assert(isComplete() && "Cannot get layout of incomplete records"); - // Do not recompute. - if (layoutInfo) - return; + assert(isUnion() && "computeUnionSize called for non-union record"); // This is a similar algorithm to LLVM's StructLayout. unsigned recordSize = 0; llvm::Align recordAlignment{1}; - bool isPadded = false; unsigned numElements = getNumElements(); auto members = getMembers(); - mlir::Type largestMember; unsigned largestMemberSize = 0; - llvm::SmallVector memberOffsets; - bool dontCountLastElt = isUnion() && getPadded(); - if (dontCountLastElt) - numElements--; + auto largestMember = getLargestMember(dataLayout); + recordSize = dataLayout.getTypeSize(largestMember); - // Loop over each of the elements, placing them in memory. - memberOffsets.reserve(numElements); + // If the union is padded, add the padding to the size. + if (getPadded()) { + auto ty = getMembers()[numElements - 1]; + recordSize += dataLayout.getTypeSize(ty); + } - for (unsigned i = 0, e = numElements; i != e; ++i) { - auto ty = members[i]; + return recordSize; +} - // Found a nested union: recurse into it to fetch its largest member. - if (!largestMember || - dataLayout.getTypeABIAlignment(ty) > - dataLayout.getTypeABIAlignment(largestMember) || - (dataLayout.getTypeABIAlignment(ty) == - dataLayout.getTypeABIAlignment(largestMember) && - dataLayout.getTypeSize(ty) > largestMemberSize)) { - largestMember = ty; - largestMemberSize = dataLayout.getTypeSize(largestMember); - } +unsigned +RecordType::computeStructSize(const mlir::DataLayout &dataLayout) const { + assert(isComplete() && "Cannot get layout of incomplete records"); + + // This is a similar algorithm to LLVM's StructLayout. + unsigned recordSize = 0; + uint64_t recordAlignment = 1; + + // We can't use a range-based for loop here because we might be ignoring the + // last element. + for (mlir::Type ty : getMembers()) { + // This assumes that we're calculating size based on the ABI alignment, not + // the preferred alignment for each type. + const uint64_t tyAlign = + (getPacked() ? 1 : dataLayout.getTypeABIAlignment(ty)); + + // Add padding to the struct size to align it to the abi alignment of the + // element type before than adding the size of the element. + recordSize = llvm::alignTo(recordSize, tyAlign); + recordSize += dataLayout.getTypeSize(ty); + + // The alignment requirement of a struct is equal to the strictest alignment + // requirement of its elements. + recordAlignment = std::max(tyAlign, recordAlignment); + } + + // At the end, add padding to the struct to satisfy its own alignment + // requirement. Otherwise structs inside of arrays would be misaligned. + recordSize = llvm::alignTo(recordSize, recordAlignment); + return recordSize; +} + +// We also compute the alignment as part of computeStructSize, but this is more +// efficient. Ideally, we'd like to compute both at once and cache the result, +// but that's implemented yet. +// TODO(CIR): Implement a way to cache the result. +uint64_t +RecordType::computeStructAlignment(const mlir::DataLayout &dataLayout) const { + assert(isComplete() && "Cannot get layout of incomplete records"); + + // This is a similar algorithm to LLVM's StructLayout. + uint64_t recordAlignment = 1; + for (mlir::Type ty : getMembers()) + recordAlignment = + std::max(dataLayout.getTypeABIAlignment(ty), recordAlignment); + + return recordAlignment; +} + +uint64_t +RecordType::computeUnionAlignment(const mlir::DataLayout &dataLayout) const { + auto largestMember = getLargestMember(dataLayout); + return dataLayout.getTypeABIAlignment(largestMember); +} + +uint64_t RecordType::getElementOffset(const ::mlir::DataLayout &dataLayout, + unsigned idx) const { + assert(idx < getMembers().size() && "access not valid"); + + // All union elements are at offset zero. + if (isUnion() || idx == 0) + return 0; + + assert(isComplete() && "Cannot get layout of incomplete records"); + assert(idx < getNumElements()); + auto members = getMembers(); + + unsigned offset = 0; + + for (unsigned i = 0, e = idx; i != e; ++i) { + auto ty = members[i]; // This matches LLVM since it uses the ABI instead of preferred alignment. const llvm::Align tyAlign = llvm::Align(getPacked() ? 1 : dataLayout.getTypeABIAlignment(ty)); // Add padding if necessary to align the data element properly. - if (!llvm::isAligned(tyAlign, recordSize)) { - isPadded = true; - recordSize = llvm::alignTo(recordSize, tyAlign); - } - - // Keep track of maximum alignment constraint. - recordAlignment = std::max(tyAlign, recordAlignment); - - // Record size up to each element is the element offset. - memberOffsets.push_back(mlir::IntegerAttr::get( - mlir::IntegerType::get(getContext(), 32), isUnion() ? 0 : recordSize)); + offset = llvm::alignTo(offset, tyAlign); // Consume space for this data item - recordSize += dataLayout.getTypeSize(ty); - } - - // For unions, the size and aligment is that of the largest element. - if (isUnion()) { - recordSize = largestMemberSize; - if (getPadded()) { - memberOffsets.push_back(mlir::IntegerAttr::get( - mlir::IntegerType::get(getContext(), 32), recordSize)); - auto ty = getMembers()[numElements]; - recordSize += dataLayout.getTypeSize(ty); - isPadded = true; - } else { - isPadded = false; - } - } else { - // Add padding to the end of the record so that it could be put in an array - // and all array elements would be aligned correctly. - if (!llvm::isAligned(recordAlignment, recordSize)) { - isPadded = true; - recordSize = llvm::alignTo(recordSize, recordAlignment); - } + offset += dataLayout.getTypeSize(ty); } - auto offsets = mlir::ArrayAttr::get(getContext(), memberOffsets); - layoutInfo = cir::RecordLayoutAttr::get(getContext(), recordSize, - recordAlignment.value(), isPadded, - largestMember, offsets); + return offset; } //===----------------------------------------------------------------------===// From 6c34681236186670add513e6afefc89e4bc380fc Mon Sep 17 00:00:00 2001 From: Henrich Lauko Date: Fri, 18 Apr 2025 21:40:50 +0200 Subject: [PATCH 7/8] [CIR][NFC] Simplify BoolAttr builders (#1572) --- .../CIR/Dialect/Builder/CIRBaseBuilder.h | 9 ++++-- .../include/clang/CIR/Dialect/IR/CIRAttrs.td | 6 ++++ clang/include/clang/CIR/Dialect/IR/CIROps.td | 6 ++++ clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp | 32 +++++-------------- clang/lib/CIR/CodeGen/CIRGenStmt.cpp | 4 +-- .../TargetLowering/ItaniumCXXABI.cpp | 8 ++--- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 4 +-- 7 files changed, 32 insertions(+), 37 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h index 4744183b583e..c1875f091571 100644 --- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h +++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h @@ -41,6 +41,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder { public: CIRBaseBuilderTy(mlir::MLIRContext &C) : mlir::OpBuilder(&C) {} + CIRBaseBuilderTy(mlir::OpBuilder &B) : mlir::OpBuilder(B) {} mlir::Value getConstAPSInt(mlir::Location loc, const llvm::APSInt &val) { auto ty = @@ -125,9 +126,13 @@ class CIRBaseBuilderTy : public mlir::OpBuilder { } cir::BoolAttr getCIRBoolAttr(bool state) { - return cir::BoolAttr::get(getContext(), getBoolTy(), state); + return cir::BoolAttr::get(getContext(), state); } + cir::BoolAttr getTrueAttr() { return getCIRBoolAttr(true); } + + cir::BoolAttr getFalseAttr() { return getCIRBoolAttr(false); } + mlir::TypedAttr getZeroAttr(mlir::Type t) { return cir::ZeroAttr::get(getContext(), t); } @@ -148,7 +153,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder { if (auto methodTy = mlir::dyn_cast(ty)) return getNullMethodAttr(methodTy); if (mlir::isa(ty)) { - return getCIRBoolAttr(false); + return getFalseAttr(); } llvm_unreachable("Zero initializer for given type is NYI"); } diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td index e78baa37bccd..e295f7fc57bc 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td @@ -209,6 +209,12 @@ def CIR_BoolAttr : CIR_Attr<"Bool", "bool", [TypedAttrInterface]> { "", "cir::BoolType">:$type, "bool":$value); + let builders = [ + AttrBuilder<(ins "bool":$value), [{ + return $_get($_ctxt, cir::BoolType::get($_ctxt), value); + }]>, + ]; + let assemblyFormat = [{ `<` $value `>` }]; diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index f4c455a1f039..7b34d4e8c29f 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -407,6 +407,12 @@ def ConstantOp : CIR_Op<"const", // The constant operation returns a single value of CIR_AnyType. let results = (outs CIR_AnyType:$res); + let builders = [ + OpBuilder<(ins "cir::BoolAttr":$value), [{ + build($_builder, $_state, value.getType(), value); + }]> + ]; + let assemblyFormat = "attr-dict $value"; let hasVerifier = 1; diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 7e77bb10e5b2..c31e4bfcceab 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -199,9 +199,7 @@ class ScalarExprEmitter : public StmtVisitor { llvm_unreachable("NYI"); } mlir::Value VisitCXXBoolLiteralExpr(const CXXBoolLiteralExpr *E) { - mlir::Type Ty = CGF.convertType(E->getType()); - return Builder.create( - CGF.getLoc(E->getExprLoc()), Ty, Builder.getCIRBoolAttr(E->getValue())); + return Builder.getBool(E->getValue(), CGF.getLoc(E->getExprLoc())); } mlir::Value VisitCXXScalarValueInitExpr(const CXXScalarValueInitExpr *E) { @@ -419,9 +417,7 @@ class ScalarExprEmitter : public StmtVisitor { // An interesting aspect of this is that increment is always true. // Decrement does not have this property. if (isInc && type->isBooleanType()) { - value = Builder.create(CGF.getLoc(E->getExprLoc()), - CGF.convertType(type), - Builder.getCIRBoolAttr(true)); + value = Builder.getTrue(CGF.getLoc(E->getExprLoc())); } else if (type->isIntegerType()) { QualType promotedType; bool canPerformLossyDemotionCheck = false; @@ -2669,9 +2665,7 @@ mlir::Value ScalarExprEmitter::VisitBinLAnd(const clang::BinaryOperator *E) { CIRGenFunction::LexicalScope lexScope{CGF, Loc, B.getInsertionBlock()}; CGF.currLexScope->setAsTernary(); - auto res = B.create( - Loc, Builder.getBoolTy(), - Builder.getAttr(Builder.getBoolTy(), true)); + auto res = B.create(Loc, Builder.getTrueAttr()); B.create(Loc, res.getRes()); }, /*falseBuilder*/ @@ -2679,9 +2673,7 @@ mlir::Value ScalarExprEmitter::VisitBinLAnd(const clang::BinaryOperator *E) { CIRGenFunction::LexicalScope lexScope{CGF, Loc, b.getInsertionBlock()}; CGF.currLexScope->setAsTernary(); - auto res = b.create( - Loc, Builder.getBoolTy(), - Builder.getAttr(Builder.getBoolTy(), false)); + auto res = b.create(Loc, Builder.getFalseAttr()); b.create(Loc, res.getRes()); }); B.create(Loc, res.getResult()); @@ -2690,9 +2682,7 @@ mlir::Value ScalarExprEmitter::VisitBinLAnd(const clang::BinaryOperator *E) { [&](mlir::OpBuilder &B, mlir::Location Loc) { CIRGenFunction::LexicalScope lexScope{CGF, Loc, B.getInsertionBlock()}; CGF.currLexScope->setAsTernary(); - auto res = B.create( - Loc, Builder.getBoolTy(), - Builder.getAttr(Builder.getBoolTy(), false)); + auto res = B.create(Loc, Builder.getFalseAttr()); B.create(Loc, res.getRes()); }); return Builder.createZExtOrBitCast(ResOp.getLoc(), ResOp.getResult(), ResTy); @@ -2738,9 +2728,7 @@ mlir::Value ScalarExprEmitter::VisitBinLOr(const clang::BinaryOperator *E) { [&](mlir::OpBuilder &B, mlir::Location Loc) { CIRGenFunction::LexicalScope lexScope{CGF, Loc, B.getInsertionBlock()}; CGF.currLexScope->setAsTernary(); - auto res = B.create( - Loc, Builder.getBoolTy(), - Builder.getAttr(Builder.getBoolTy(), true)); + auto res = B.create(Loc, Builder.getTrueAttr()); B.create(Loc, res.getRes()); }, /*falseBuilder*/ @@ -2763,9 +2751,7 @@ mlir::Value ScalarExprEmitter::VisitBinLOr(const clang::BinaryOperator *E) { CIRGenFunction::LexicalScope lexScope{CGF, Loc, B.getInsertionBlock()}; CGF.currLexScope->setAsTernary(); - auto res = B.create( - Loc, Builder.getBoolTy(), - Builder.getAttr(Builder.getBoolTy(), true)); + auto res = B.create(Loc, Builder.getTrueAttr()); B.create(Loc, res.getRes()); }, /*falseBuilder*/ @@ -2782,9 +2768,7 @@ mlir::Value ScalarExprEmitter::VisitBinLOr(const clang::BinaryOperator *E) { CIRGenFunction::LexicalScope lexScope{CGF, Loc, B.getInsertionBlock()}; CGF.currLexScope->setAsTernary(); - auto res = b.create( - Loc, Builder.getBoolTy(), - Builder.getAttr(Builder.getBoolTy(), false)); + auto res = b.create(Loc, Builder.getFalseAttr()); b.create(Loc, res.getRes()); }); B.create(Loc, res.getResult()); diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp index 450fd8647468..37b7656e2acc 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp @@ -947,9 +947,7 @@ mlir::LogicalResult CIRGenFunction::emitForStmt(const ForStmt &S) { // scalar type. condVal = evaluateExprAsBool(S.getCond()); } else { - auto boolTy = cir::BoolType::get(b.getContext()); - condVal = b.create( - loc, boolTy, cir::BoolAttr::get(b.getContext(), boolTy, true)); + condVal = b.create(loc, builder.getTrueAttr()); } builder.createCondition(condVal); }, diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp index edf895ef05fb..5d7e0b49aab8 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp @@ -564,11 +564,9 @@ mlir::Value ItaniumCXXABI::lowerMethodCmp(cir::CmpOp op, mlir::Value loweredLhs, // - cir.select if %a then true else %b => %a || %b // TODO: Do we need to invent dedicated "cir.logical_or" and "cir.logical_and" // operations for this? - auto boolTy = cir::BoolType::get(op.getContext()); - mlir::Value trueValue = builder.create( - op.getLoc(), boolTy, cir::BoolAttr::get(op.getContext(), boolTy, true)); - mlir::Value falseValue = builder.create( - op.getLoc(), boolTy, cir::BoolAttr::get(op.getContext(), boolTy, false)); + CIRBaseBuilderTy cirBuilder(builder); + mlir::Value trueValue = cirBuilder.getTrue(op.getLoc()); + mlir::Value falseValue = cirBuilder.getFalse(op.getLoc()); auto create_and = [&](mlir::Value lhs, mlir::Value rhs) { return builder.create(op.getLoc(), lhs, rhs, falseValue); }; diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 4840637171ed..609da6eff24c 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1760,9 +1760,7 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( // during a pass as long as they don't live past the end of the pass. attr = op.getValue(); } else if (mlir::isa(op.getType())) { - int value = (op.getValue() == - cir::BoolAttr::get(getContext(), - cir::BoolType::get(getContext()), true)); + int value = (op.getValue() == cir::BoolAttr::get(getContext(), true)); attr = rewriter.getIntegerAttr(typeConverter->convertType(op.getType()), value); } else if (mlir::isa(op.getType())) { From f00ff52afc72580f4a6a4f6a1ac621f9e8d0a164 Mon Sep 17 00:00:00 2001 From: Yue Huang Date: Fri, 11 Apr 2025 17:29:35 +0800 Subject: [PATCH 8/8] [CIR][ThroughMLIR] Lower structs and GetMemberOp Structs are implemented as `memref`. It is not feasible to represent them as tuples, for tuples can't be put in memref (i.e. pointers to structs would break if we did). We use `memref::ViewOp` for this. Unlike `PtrStrideOp`, the reinterpret cast operation doesn't work here, as the result type is potentially different from i8. --- .../Lowering/ThroughMLIR/LowerCIRToMLIR.cpp | 69 ++++++++++++++++--- .../test/CIR/Lowering/ThroughMLIR/struct.cir | 25 +++++++ 2 files changed, 83 insertions(+), 11 deletions(-) create mode 100644 clang/test/CIR/Lowering/ThroughMLIR/struct.cir diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index 326152783980..f1a100ab85f6 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "LowerToMLIRHelpers.h" +#include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -32,8 +33,10 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Region.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" @@ -163,17 +166,17 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern { matchAndRewrite(cir::AllocaOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Type mlirType = - convertTypeForMemory(*getTypeConverter(), adaptor.getAllocaType()); + mlir::Type allocaType = adaptor.getAllocaType(); + mlir::Type mlirType = convertTypeForMemory(*getTypeConverter(), allocaType); // FIXME: Some types can not be converted yet (e.g. struct) if (!mlirType) return mlir::LogicalResult::failure(); auto memreftype = mlir::dyn_cast(mlirType); - if (memreftype && mlir::isa(adaptor.getAllocaType())) { - // if the type is an array, - // we don't need to wrap with memref. + if (memreftype && (mlir::isa(allocaType) || + mlir::isa(allocaType))) { + // Arrays and structs are already memref. No need to wrap another one. } else { memreftype = mlir::MemRefType::get({}, mlirType); } @@ -1240,6 +1243,36 @@ class CIRPtrStrideOpLowering } }; +class CIRGetMemberOpLowering + : public mlir::OpConversionPattern { +public: + CIRGetMemberOpLowering(mlir::TypeConverter &converter, mlir::MLIRContext *ctx, + const mlir::DataLayout &layout) + : OpConversionPattern(converter, ctx), layout(layout) {} + + mlir::LogicalResult + matchAndRewrite(cir::GetMemberOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto baseAddr = op.getAddr(); + auto structType = + mlir::cast(baseAddr.getType().getPointee()); + uint64_t byteOffset = structType.getElementOffset(layout, op.getIndex()); + + auto fieldType = op.getResult().getType(); + auto resultType = mlir::cast( + getTypeConverter()->convertType(fieldType)); + + mlir::Value offsetValue = + rewriter.create(op.getLoc(), byteOffset); + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getAddr(), offsetValue, mlir::ValueRange{}); + return mlir::success(); + } + +private: + const mlir::DataLayout &layout; +}; + class CIRUnreachableOpLowering : public mlir::OpConversionPattern { public: @@ -1271,7 +1304,8 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern { }; void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, - mlir::TypeConverter &converter) { + mlir::TypeConverter &converter, + mlir::DataLayout layout) { patterns.add(patterns.getContext()); patterns @@ -1292,16 +1326,20 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering, CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering, CIRTrapOpLowering>(converter, patterns.getContext()); + + patterns.add(converter, patterns.getContext(), + layout); } -static mlir::TypeConverter prepareTypeConverter() { +static mlir::TypeConverter prepareTypeConverter(mlir::DataLayout layout) { mlir::TypeConverter converter; converter.addConversion([&](cir::PointerType type) -> mlir::Type { - auto ty = convertTypeForMemory(converter, type.getPointee()); + auto pointee = type.getPointee(); + auto ty = convertTypeForMemory(converter, pointee); // FIXME: The pointee type might not be converted (e.g. struct) if (!ty) return nullptr; - if (isa(type.getPointee())) + if (isa(pointee) || isa(pointee)) return ty; return mlir::MemRefType::get({}, ty); }); @@ -1353,6 +1391,13 @@ static mlir::TypeConverter prepareTypeConverter() { return nullptr; return mlir::MemRefType::get(shape, elementType); }); + converter.addConversion([&](cir::RecordType type) -> mlir::Type { + // Reinterpret structs as raw bytes. Don't use tuples as they can't be put + // in memref. + auto size = type.getTypeSize(layout, {}); + auto i8 = mlir::IntegerType::get(type.getContext(), /*width=*/8); + return mlir::MemRefType::get(size.getFixedValue(), i8); + }); converter.addConversion([&](cir::VectorType type) -> mlir::Type { auto ty = converter.convertType(type.getEltType()); return mlir::VectorType::get(type.getSize(), ty); @@ -1363,13 +1408,15 @@ static mlir::TypeConverter prepareTypeConverter() { void ConvertCIRToMLIRPass::runOnOperation() { auto module = getOperation(); + mlir::DataLayoutAnalysis layoutAnalysis(module); + const mlir::DataLayout &layout = layoutAnalysis.getAtOrAbove(module); - auto converter = prepareTypeConverter(); + auto converter = prepareTypeConverter(layout); mlir::RewritePatternSet patterns(&getContext()); populateCIRLoopToSCFConversionPatterns(patterns, converter); - populateCIRToMLIRConversionPatterns(patterns, converter); + populateCIRToMLIRConversionPatterns(patterns, converter, layout); mlir::ConversionTarget target(getContext()); target.addLegalOp(); diff --git a/clang/test/CIR/Lowering/ThroughMLIR/struct.cir b/clang/test/CIR/Lowering/ThroughMLIR/struct.cir new file mode 100644 index 000000000000..dd599753cfc6 --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/struct.cir @@ -0,0 +1,25 @@ +// RUN: cir-opt %s -cir-to-mlir -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s + +!s32i = !cir.int +!u8i = !cir.int +!u32i = !cir.int +!ty_S = !cir.record + +module { + cir.func @test() { + %1 = cir.alloca !ty_S, !cir.ptr, ["x"] {alignment = 4 : i64} + %3 = cir.get_member %1[0] {name = "c"} : !cir.ptr -> !cir.ptr + %5 = cir.get_member %1[1] {name = "i"} : !cir.ptr -> !cir.ptr + cir.return + } + + // CHECK: func.func @test() { + // CHECK: %[[alloca:[a-z0-9]+]] = memref.alloca() {alignment = 4 : i64} : memref<8xi8> + // CHECK: %[[zero:[a-z0-9]+]] = arith.constant 0 : index + // CHECK: memref.view %[[alloca]][%[[zero]]][] : memref<8xi8> to memref + // CHECK: %[[four:[a-z0-9]+]] = arith.constant 4 : index + // CHECK: %view_0 = memref.view %[[alloca]][%[[four]]][] : memref<8xi8> to memref + // CHECK: return + // CHECK: } +}