Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 69 additions & 7 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,62 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
};

class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
void rewriteContinueInIf(cir::IfOp ifOp, cir::ContinueOp continueOp,
mlir::scf::WhileOp whileOp,
mlir::ConversionPatternRewriter &rewriter) const {
auto loc = ifOp->getLoc();

rewriter.setInsertionPointToStart(whileOp.getAfterBody());
auto boolTy = rewriter.getType<BoolType>();
auto boolPtrTy = rewriter.getType<PointerType>(boolTy);
auto alignment = rewriter.getI64IntegerAttr(4);
auto condAlloca = rewriter.create<AllocaOp>(loc, boolPtrTy, boolTy,
"condition", alignment);

rewriter.setInsertionPoint(ifOp);
auto negated = rewriter.create<UnaryOp>(loc, boolTy, UnaryOpKind::Not,
ifOp.getCondition());
rewriter.create<StoreOp>(loc, negated, condAlloca);

// On each layer, surround everything after runner in its parent with a
// guard: `if (!condAlloca)`.
for (mlir::Operation *runner = ifOp; runner != whileOp;
runner = runner->getParentOp()) {
rewriter.setInsertionPointAfter(runner);
auto cond = rewriter.create<LoadOp>(
loc, boolTy, condAlloca, /*isDeref=*/false,
/*volatile=*/false, /*nontemporal=*/false, alignment,
/*memorder=*/cir::MemOrderAttr{}, /*tbaa=*/cir::TBAAAttr{});
auto ifnot =
rewriter.create<IfOp>(loc, cond, /*withElseRegion=*/false,
[&](mlir::OpBuilder &, mlir::Location) {
/* Intentionally left empty */
});

auto &region = ifnot.getThenRegion();
rewriter.setInsertionPointToEnd(&region.back());
auto terminator = rewriter.create<YieldOp>(loc);

bool inserted = false;
for (mlir::Operation *op = ifnot->getNextNode(); op;) {
// Don't move terminators in.
if (isa<YieldOp>(op) || isa<ReturnOp>(op))
break;

mlir::Operation *next = op->getNextNode();
op->moveBefore(terminator);
op = next;
inserted = true;
}
// Don't retain `if (!condAlloca)` when it's empty.
if (!inserted)
rewriter.eraseOp(ifnot);
}
rewriter.setInsertionPoint(continueOp);
rewriter.create<mlir::scf::YieldOp>(continueOp->getLoc());
rewriter.eraseOp(continueOp);
}

void rewriteContinue(mlir::scf::WhileOp whileOp,
mlir::ConversionPatternRewriter &rewriter) const {
// Collect all ContinueOp inside this while.
Expand All @@ -427,23 +483,29 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
return;

for (auto continueOp : continues) {
// When the break is under an IfOp, a direct replacement of `scf.yield`
// won't work: the yield would jump out of that IfOp instead. We might
// need to change the whileOp itself to achieve the same effect.
// When the ContinueOp is under an IfOp, a direct replacement of
// `scf.yield` won't work: the yield would jump out of that IfOp instead.
// We might need to change the WhileOp itself to achieve the same effect.
bool rewritten = false;
for (mlir::Operation *parent = continueOp->getParentOp();
parent != whileOp; parent = parent->getParentOp()) {
if (isa<mlir::scf::IfOp>(parent) || isa<cir::IfOp>(parent))
llvm_unreachable("NYI");
if (auto ifOp = dyn_cast<cir::IfOp>(parent)) {
rewriteContinueInIf(ifOp, continueOp, whileOp, rewriter);
rewritten = true;
break;
}
}
if (rewritten)
continue;

// Operations after this break has to be removed.
// Operations after this ContinueOp has to be removed.
for (mlir::Operation *runner = continueOp->getNextNode(); runner;) {
mlir::Operation *next = runner->getNextNode();
runner->erase();
runner = next;
}

// Blocks after this break also has to be removed.
// Blocks after this ContinueOp also has to be removed.
for (mlir::Block *block = continueOp->getBlock()->getNextNode(); block;) {
mlir::Block *next = block->getNextNode();
block->erase();
Expand Down
45 changes: 44 additions & 1 deletion clang/test/CIR/Lowering/ThroughMLIR/while-with-continue.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
// RUN: FileCheck --input-file=%t.mlir %s

void for_with_break() {
void while_continue() {
int i = 0;
while (i < 100) {
i++;
Expand All @@ -25,3 +25,46 @@ void for_with_break() {
// CHECK: scf.yield
// CHECK: }
}

void while_continue_2() {
int i = 0;
while (i < 10) {
if (i == 5) {
i += 3;
continue;
}

i++;
}
// The final i++ will have a `if (!(i == 5))` guarded against it.

// CHECK: do {
// CHECK: %[[NOTALLOCA:.+]] = memref.alloca
// CHECK: memref.alloca_scope {
// CHECK: memref.alloca_scope {
// CHECK: %[[IV:.+]] = memref.load %[[IVADDR:.+]][]
// CHECK: %[[FIVE:.+]] = arith.constant 5
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[IV]], %[[FIVE]]
// CHECK: %true = arith.constant true
// CHECK: %[[NOT:.+]] = arith.xori %true, %[[COND]]
// CHECK: %[[EXT:.+]] = arith.extui %[[NOT]] : i1 to i8
// CHECK: memref.store %[[EXT]], %[[NOTALLOCA]]
// CHECK: scf.if %[[COND]] {
// CHECK: %[[THREE:.+]] = arith.constant 3
// CHECK: %[[IV2:.+]] = memref.load %[[IVADDR]]
// CHECK: %[[TMP:.+]] = arith.addi %[[IV2]], %[[THREE]]
// CHECK: memref.store %[[TMP]], %[[IVADDR]]
// CHECK: }
// CHECK: }
// CHECK: %[[NOTCOND:.+]] = memref.load %[[NOTALLOCA]]
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[NOTCOND]] : i8 to i1
// CHECK: scf.if %[[TRUNC]] {
// CHECK: %[[IV3:.+]] = memref.load %[[IVADDR]]
// CHECK: %[[ONE:.+]] = arith.constant 1
// CHECK: %[[TMP2:.+]] = arith.addi %[[IV3]], %[[ONE]]
// CHECK: memref.store %[[TMP2]], %[[IVADDR]]
// CHECK: }
// CHECK: }
// CHECK: scf.yield
// CHECK: }
}
Loading