Skip to content

Commit 4ddd7e2

Browse files
[CIR][ThoughMLIR] Support ContinueOp in nested whiles (#1694)
As we need to preserve the ContinueOp for inner loops when we convert for outer while-loops, we must not mark cir dialect as illegal. Otherwise, MLIR rejects this kind of preservation and considers it as a pass failure. It seems we need another way to check whether the CIR is fully lowered. Co-authored-by: Yue Huang <[email protected]>
1 parent f74a380 commit 4ddd7e2

File tree

3 files changed

+56
-4
lines changed

3 files changed

+56
-4
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "clang/CIR/Dialect/IR/CIRTypes.h"
2626
#include "clang/CIR/LowerToMLIR.h"
2727
#include "llvm/ADT/TypeSwitch.h"
28-
#include "llvm/IR/Module.h"
2928

3029
using namespace cir;
3130
using namespace llvm;
@@ -483,6 +482,19 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
483482
return;
484483

485484
for (auto continueOp : continues) {
485+
bool nested = false;
486+
// When there is another loop between this WhileOp and the ContinueOp,
487+
// we shouldn't change that loop instead.
488+
for (mlir::Operation *parent = continueOp->getParentOp();
489+
parent != whileOp; parent = parent->getParentOp()) {
490+
if (isa<WhileOp>(parent)) {
491+
nested = true;
492+
break;
493+
}
494+
}
495+
if (nested)
496+
continue;
497+
486498
// When the ContinueOp is under an IfOp, a direct replacement of
487499
// `scf.yield` won't work: the yield would jump out of that IfOp instead.
488500
// We might need to change the WhileOp itself to achieve the same effect.

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,7 +1552,11 @@ void ConvertCIRToMLIRPass::runOnOperation() {
15521552
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
15531553
mlir::math::MathDialect, mlir::vector::VectorDialect,
15541554
mlir::LLVM::LLVMDialect>();
1555-
target.addIllegalDialect<cir::CIRDialect>();
1555+
// We cannot mark cir dialect as illegal before conversion.
1556+
// The conversion of WhileOp relies on partially preserving operations from
1557+
// cir dialect, for example the `cir.continue`. If we marked cir as illegal
1558+
// here, then MLIR would think any remaining `cir.continue` indicates a
1559+
// failure, which is not what we want.
15561560

15571561
if (failed(applyPartialConversion(module, target, std::move(patterns))))
15581562
signalPassFailure();
@@ -1616,8 +1620,9 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule,
16161620

16171621
auto result = !mlir::failed(pm.run(theModule));
16181622
if (!result)
1619-
report_fatal_error(
1620-
"The pass manager failed to lower CIR to MLIR standard dialects!");
1623+
theModule.dump(),
1624+
report_fatal_error(
1625+
"The pass manager failed to lower CIR to MLIR standard dialects!");
16211626
// Now that we ran all the lowering passes, verify the final output.
16221627
if (theModule.verify().failed())
16231628
report_fatal_error(

clang/test/CIR/Lowering/ThroughMLIR/while-with-continue.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,38 @@ void while_continue_2() {
6868
// CHECK: scf.yield
6969
// CHECK: }
7070
}
71+
72+
void while_continue_nested() {
73+
int i = 0;
74+
while (i < 10) {
75+
while (true) {
76+
continue;
77+
i--;
78+
}
79+
i++;
80+
}
81+
// The continue will only work on the inner while.
82+
83+
// CHECK: scf.while : () -> () {
84+
// CHECK: %[[IV:.+]] = memref.load %alloca[]
85+
// CHECK: %[[TEN:.+]] = arith.constant 10
86+
// CHECK: %[[LT:.+]] = arith.cmpi slt, %[[IV]], %[[TEN]]
87+
// CHECK: scf.condition(%[[LT]])
88+
// CHECK: } do {
89+
// CHECK: memref.alloca_scope {
90+
// CHECK: memref.alloca_scope {
91+
// CHECK: scf.while : () -> () {
92+
// CHECK: %[[TRUE:.+]] = arith.constant true
93+
// CHECK: scf.condition(%[[TRUE]])
94+
// CHECK: } do {
95+
// CHECK: scf.yield
96+
// CHECK: }
97+
// CHECK: }
98+
// CHECK: %[[IV2:.+]] = memref.load %alloca[]
99+
// CHECK: %[[ONE:.+]] = arith.constant 1
100+
// CHECK: %[[ADD:.+]] = arith.addi %[[IV2]], %[[ONE]]
101+
// CHECK: memref.store %[[ADD]], %alloca[]
102+
// CHECK: }
103+
// CHECK: scf.yield
104+
// CHECK: }
105+
}

0 commit comments

Comments
 (0)