Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
mir-opt: ignore dead store stmts in MatchBranchSimplification
  • Loading branch information
dianqk committed Apr 9, 2025
commit 42118389d387e65ab6dbac780c3a48383dc05e45
46 changes: 40 additions & 6 deletions compiler/rustc_mir_transform/src/dead_store_elimination.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//! will still not cause any further changes.
//!

use rustc_index::IndexVec;
use rustc_middle::bug;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::*;
Expand All @@ -24,17 +25,49 @@ use rustc_mir_dataflow::impls::{

use crate::util::is_within_packed;

pub(super) enum ModifyBasicBlocks<'tcx, 'a> {
Direct(&'a mut Body<'tcx>),
BasicBlocks(&'a Body<'tcx>, &'a mut IndexVec<BasicBlock, BasicBlockData<'tcx>>),
}

impl<'tcx, 'a> ModifyBasicBlocks<'tcx, 'a> {
pub(super) fn body(&self) -> &Body<'tcx> {
match self {
ModifyBasicBlocks::Direct(body) => body,
ModifyBasicBlocks::BasicBlocks(body, _) => body,
}
}

pub(super) fn bbs(&mut self) -> &mut IndexVec<BasicBlock, BasicBlockData<'tcx>> {
match self {
ModifyBasicBlocks::Direct(body) => body.basic_blocks.as_mut_preserves_cfg(),
ModifyBasicBlocks::BasicBlocks(_, bbs) => bbs,
}
}
}

/// Performs the optimization on the body
///
/// The `borrowed` set must be a `DenseBitSet` of all the locals that are ever borrowed in this
/// body. It can be generated via the [`borrowed_locals`] function.
fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
pub(super) fn eliminate<'tcx>(
tcx: TyCtxt<'tcx>,
mut modify_basic_blocks: ModifyBasicBlocks<'tcx, '_>,
ignore_debuginfo: bool,
arg_copy_to_move: bool,
) {
let body = modify_basic_blocks.body();
let borrowed_locals = borrowed_locals(body);

// If the user requests complete debuginfo, mark the locals that appear in it as live, so
// we don't remove assignments to them.
let mut always_live = debuginfo_locals(body);
always_live.union(&borrowed_locals);
let always_live = if ignore_debuginfo {
borrowed_locals.clone()
} else {
let mut always_live = debuginfo_locals(body);
always_live.union(&borrowed_locals);
always_live
};

let mut live = MaybeTransitiveLiveLocals::new(&always_live)
.iterate_to_fixpoint(tcx, body, None)
Expand All @@ -46,7 +79,8 @@ fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let mut patch = Vec::new();

for (bb, bb_data) in traversal::preorder(body) {
if let TerminatorKind::Call { ref args, .. } = bb_data.terminator().kind {
if arg_copy_to_move && let TerminatorKind::Call { ref args, .. } = bb_data.terminator().kind
{
let loc = Location { block: bb, statement_index: bb_data.statements.len() };

// Position ourselves between the evaluation of `args` and the write to `destination`.
Expand Down Expand Up @@ -113,7 +147,7 @@ fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
return;
}

let bbs = body.basic_blocks.as_mut_preserves_cfg();
let bbs = modify_basic_blocks.bbs();
for Location { block, statement_index } in patch {
bbs[block].statements[statement_index].make_nop();
}
Expand Down Expand Up @@ -145,7 +179,7 @@ impl<'tcx> crate::MirPass<'tcx> for DeadStoreElimination {
}

fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
eliminate(tcx, body);
eliminate(tcx, ModifyBasicBlocks::Direct(body), false, true);
}

fn is_required(&self) -> bool {
Expand Down
70 changes: 65 additions & 5 deletions compiler/rustc_mir_transform/src/match_branches.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use std::iter;

use rustc_abi::Integer;
use rustc_index::IndexSlice;
use rustc_index::{IndexSlice, IndexVec};
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::*;
use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
use tracing::instrument;

use super::simplify::simplify_cfg;
use crate::dead_store_elimination::{self, ModifyBasicBlocks};
use crate::patch::MirPatch;
use crate::simplify::strip_nops;

pub(super) struct MatchBranchSimplification;

Expand All @@ -20,6 +23,17 @@ impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let typing_env = body.typing_env(tcx);
let mut apply_patch = false;
let mut bbs = body.basic_blocks.clone();
let bbs = bbs.as_mut_preserves_cfg();
// We can ignore the dead store statements when merging branches.
dead_store_elimination::eliminate(
tcx,
ModifyBasicBlocks::BasicBlocks(body, bbs),
true,
false,
);
eliminate_unused_storage_mark(body, bbs);
strip_nops(bbs.as_mut_slice());
let mut patch = MirPatch::new(body);
for (bb, data) in body.basic_blocks.iter_enumerated() {
match data.terminator().kind {
Expand All @@ -33,11 +47,17 @@ impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
_ => continue,
};

if SimplifyToIf.simplify(tcx, body, &mut patch, bb, typing_env).is_some() {
if SimplifyToIf
.simplify(tcx, body, bbs.as_slice(), &mut patch, bb, typing_env)
.is_some()
{
apply_patch = true;
continue;
}
if SimplifyToExp::default().simplify(tcx, body, &mut patch, bb, typing_env).is_some() {
if SimplifyToExp::default()
.simplify(tcx, body, bbs.as_slice(), &mut patch, bb, typing_env)
.is_some()
{
apply_patch = true;
continue;
}
Expand All @@ -62,11 +82,11 @@ trait SimplifyMatch<'tcx> {
&mut self,
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
patch: &mut MirPatch<'tcx>,
switch_bb_idx: BasicBlock,
typing_env: ty::TypingEnv<'tcx>,
) -> Option<()> {
let bbs = &body.basic_blocks;
let (discr, targets) = match bbs[switch_bb_idx].terminator().kind {
TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets),
_ => unreachable!(),
Expand All @@ -83,7 +103,7 @@ trait SimplifyMatch<'tcx> {
let discr_local = patch.new_temp(discr_ty, source_info.span);

let (_, first) = targets.iter().next().unwrap();
let statement_index = bbs[switch_bb_idx].statements.len();
let statement_index = body.basic_blocks[switch_bb_idx].statements.len();
let parent_end = Location { block: switch_bb_idx, statement_index };
patch.add_statement(parent_end, StatementKind::StorageLive(discr_local));
patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr));
Expand Down Expand Up @@ -526,3 +546,43 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
}
}
}

struct EliminateUnusedStorageMark {
storage_live_locals: IndexVec<Local, Option<usize>>,
}

impl<'tcx> Visitor<'tcx> for EliminateUnusedStorageMark {
fn visit_local(&mut self, local: Local, _: visit::PlaceContext, _: Location) {
self.storage_live_locals[local] = None;
}
}

fn eliminate_unused_storage_mark<'tcx>(
body: &Body<'tcx>,
basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
) {
for (bb, data) in basic_blocks.iter_enumerated_mut() {
let mut unused_storage_mark = EliminateUnusedStorageMark {
storage_live_locals: IndexVec::from_elem_n(None, body.local_decls.len()),
};
for stmt_index in 0..data.statements.len() {
let loc = Location { block: bb, statement_index: stmt_index };
match data.statements[stmt_index].kind {
StatementKind::StorageLive(local) => {
unused_storage_mark.storage_live_locals[local] = Some(stmt_index);
}
StatementKind::StorageDead(local)
if let Some(live_stmt_index) =
unused_storage_mark.storage_live_locals[local] =>
{
data.statements[live_stmt_index].make_nop();
data.statements[stmt_index].make_nop();
unused_storage_mark.storage_live_locals[local] = None;
}
_ => {
unused_storage_mark.visit_statement(&data.statements[stmt_index], loc);
}
}
}
}
}
10 changes: 5 additions & 5 deletions compiler/rustc_mir_transform/src/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl<'a, 'tcx> CfgSimplifier<'a, 'tcx> {
}

fn simplify(mut self) {
self.strip_nops();
strip_nops(self.basic_blocks);

// Vec of the blocks that should be merged. We store the indices here, instead of the
// statements itself to avoid moving the (relatively) large statements twice.
Expand Down Expand Up @@ -276,11 +276,11 @@ impl<'a, 'tcx> CfgSimplifier<'a, 'tcx> {
terminator.kind = TerminatorKind::Goto { target: first_succ };
true
}
}

fn strip_nops(&mut self) {
for blk in self.basic_blocks.iter_mut() {
blk.statements.retain(|stmt| !matches!(stmt.kind, StatementKind::Nop))
}
pub(super) fn strip_nops(basic_blocks: &mut IndexSlice<BasicBlock, BasicBlockData<'_>>) {
for blk in basic_blocks.iter_mut() {
blk.statements.retain(|stmt| !matches!(stmt.kind, StatementKind::Nop))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
+ _3 = Eq(copy _11, const 7_i32);
_4 = const false;
_5 = const true;
_6 = ();
- _6 = ();
- goto -> bb3;
- }
-
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@
- bb2: {
- _2 = const -1_i8;
- _3 = const -1_i8;
+ StorageLive(_8);
+ _8 = move _5;
+ _2 = copy _8 as i8 (IntToInt);
+ _3 = copy _8 as i8 (IntToInt);
_4 = ();
- _4 = ();
- goto -> bb6;
- }
-
Expand All @@ -63,6 +59,10 @@
- }
-
- bb6: {
+ StorageLive(_8);
+ _8 = move _5;
+ _2 = copy _8 as i8 (IntToInt);
+ _3 = copy _8 as i8 (IntToInt);
+ StorageDead(_8);
StorageDead(_4);
StorageLive(_6);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
- // MIR for `match_u8_i8_dead_store` before MatchBranchSimplification
+ // MIR for `match_u8_i8_dead_store` after MatchBranchSimplification

fn match_u8_i8_dead_store(_1: EnumAu8) -> i8 {
let mut _0: i8;
let mut _2: u8;
+ let mut _3: u8;

bb0: {
_2 = discriminant(_1);
- switchInt(copy _2) -> [0: bb1, 127: bb2, 128: bb3, 255: bb4, otherwise: bb5];
- }
-
- bb1: {
- _0 = const 0_i8;
- goto -> bb6;
- }
-
- bb2: {
- _0 = const 1_i8;
- _0 = const i8::MAX;
- goto -> bb6;
- }
-
- bb3: {
- _0 = const i8::MIN;
- goto -> bb6;
- }
-
- bb4: {
- _0 = const -1_i8;
- goto -> bb6;
- }
-
- bb5: {
- unreachable;
- }
-
- bb6: {
+ StorageLive(_3);
+ _3 = copy _2;
+ _0 = copy _3 as i8 (IntToInt);
+ StorageDead(_3);
return;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

bb2: {
_0 = const i8::MAX;
_0 = const i8::MAX;
_0 = Add(copy _0, const 0_i8);
goto -> bb6;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

bb4: {
_0 = const -1_i8;
_0 = const -1_i8;
_0 = Add(copy _0, const 0_i8);
goto -> bb6;
}

Expand Down
46 changes: 45 additions & 1 deletion tests/mir-opt/matches_reduce_branches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ fn match_u8_i8_failed_len_1(i: EnumAu8) -> i8 {
}
bb2 = {
RET = 127;
RET = 127;
RET = RET + 0;
Goto(ret)
}
bb3 = {
Expand Down Expand Up @@ -289,6 +289,50 @@ fn match_u8_i8_failed_len_2(i: EnumAu8) -> i8 {
}
bb4 = {
RET = -1;
RET = RET + 0;
Goto(ret)
}
unreachable_bb = {
Unreachable()
}
ret = {
Return()
}
}
}

// EMIT_MIR matches_reduce_branches.match_u8_i8_dead_store.MatchBranchSimplification.diff
#[custom_mir(dialect = "built")]
fn match_u8_i8_dead_store(i: EnumAu8) -> i8 {
// CHECK-LABEL: fn match_u8_i8_dead_store(
// CHECK-NOT: switchInt
// CHECK: IntToInt
// CHECK: return
mir! {
{
let a = Discriminant(i);
match a {
0 => bb1,
127 => bb2,
128 => bb3,
255 => bb4,
_ => unreachable_bb,
}
}
bb1 = {
RET = 0;
Goto(ret)
}
bb2 = {
RET = 1; // This a dead store statement.
RET = 127;
Goto(ret)
}
bb3 = {
RET = -128;
Goto(ret)
}
bb4 = {
RET = -1;
Goto(ret)
}
Expand Down
Loading