Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1395,11 +1395,20 @@ void TernaryOp::build(OpBuilder &builder, OperationState &result, Value cond,

OpFoldResult SelectOp::fold(FoldAdaptor adaptor) {
auto condition = adaptor.getCondition();
if (!condition)
return nullptr;
if (condition) {
auto conditionValue = mlir::cast<mlir::cir::BoolAttr>(condition).getValue();
return conditionValue ? getTrueValue() : getFalseValue();
}

auto conditionValue = mlir::cast<mlir::cir::BoolAttr>(condition).getValue();
return conditionValue ? getTrueValue() : getFalseValue();
// cir.select if %0 then x else x -> x
auto trueValue = adaptor.getTrueValue();
auto falseValue = adaptor.getFalseValue();
if (trueValue && trueValue == falseValue)
return trueValue;
if (getTrueValue() == getFalseValue())
return getTrueValue();

return nullptr;
}

//===----------------------------------------------------------------------===//
Expand Down
42 changes: 41 additions & 1 deletion clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,45 @@ struct RemoveTrivialTry : public OpRewritePattern<TryOp> {
}
};

struct SimplifySelect : public OpRewritePattern<SelectOp> {
using OpRewritePattern<SelectOp>::OpRewritePattern;

LogicalResult matchAndRewrite(SelectOp op,
PatternRewriter &rewriter) const final {
mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
auto trueValueConstOp =
mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(trueValueOp);
auto falseValueConstOp =
mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(falseValueOp);
if (!trueValueConstOp || !falseValueConstOp)
return mlir::failure();

auto trueValue =
mlir::dyn_cast<mlir::cir::BoolAttr>(trueValueConstOp.getValue());
auto falseValue =
mlir::dyn_cast<mlir::cir::BoolAttr>(falseValueConstOp.getValue());
if (!trueValue || !falseValue)
return mlir::failure();

// cir.select if %0 then #true else #false -> %0
if (trueValue.getValue() && !falseValue.getValue()) {
rewriter.replaceAllUsesWith(op, op.getCondition());
rewriter.eraseOp(op);
return mlir::success();
}

// cir.seleft if %0 then #false else #true -> cir.unary not %0
if (!trueValue.getValue() && falseValue.getValue()) {
rewriter.replaceOpWithNewOp<mlir::cir::UnaryOp>(
op, mlir::cir::UnaryOpKind::Not, op.getCondition());
return mlir::success();
}

return mlir::failure();
}
};

//===----------------------------------------------------------------------===//
// CIRSimplifyPass
//===----------------------------------------------------------------------===//
Expand All @@ -131,7 +170,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
RemoveRedundantBranches,
RemoveEmptyScope,
RemoveEmptySwitch,
RemoveTrivialTry
RemoveTrivialTry,
SimplifySelect
>(patterns.getContext());
// clang-format on
}
Expand Down
36 changes: 35 additions & 1 deletion clang/test/CIR/Transforms/select.cir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: cir-opt --canonicalize -o %t.cir %s
// RUN: cir-opt -cir-simplify -o %t.cir %s
// RUN: FileCheck --input-file=%t.cir %s

!s32i = !cir.int<s, 32>
Expand All @@ -23,4 +23,38 @@ module {
// CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
// CHECK-NEXT: cir.return %[[ARG1]] : !s32i
// CHECK-NEXT: }

cir.func @fold_to_const(%arg0 : !cir.bool) -> !s32i {
%0 = cir.const #cir.int<42> : !s32i
%1 = cir.select if %arg0 then %0 else %0 : (!cir.bool, !s32i, !s32i) -> !s32i
cir.return %1 : !s32i
}

// CHECK: cir.func @fold_to_const(%{{.+}}: !cir.bool) -> !s32i {
// CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i
// CHECK-NEXT: cir.return %[[#A]] : !s32i
// CHECK-NEXT: }

cir.func @simplify_1(%arg0 : !cir.bool) -> !cir.bool {
%0 = cir.const #cir.bool<true> : !cir.bool
%1 = cir.const #cir.bool<false> : !cir.bool
%2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
cir.return %2 : !cir.bool
}

// CHECK: cir.func @simplify_1(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
// CHECK-NEXT: cir.return %[[ARG0]] : !cir.bool
// CHECK-NEXT: }

cir.func @simplify_2(%arg0 : !cir.bool) -> !cir.bool {
%0 = cir.const #cir.bool<false> : !cir.bool
%1 = cir.const #cir.bool<true> : !cir.bool
%2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
cir.return %2 : !cir.bool
}

// CHECK: cir.func @simplify_2(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
// CHECK-NEXT: %[[#A:]] = cir.unary(not, %[[ARG0]]) : !cir.bool, !cir.bool
// CHECK-NEXT: cir.return %[[#A]] : !cir.bool
// CHECK-NEXT: }
}