| use rustc_abi::Integer; |
| use rustc_const_eval::const_eval::mk_eval_cx_for_const_val; |
| use rustc_middle::mir::*; |
| use rustc_middle::ty::layout::{IntegerExt, TyAndLayout}; |
| use rustc_middle::ty::util::Discr; |
| use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt}; |
| |
| use super::simplify::simplify_cfg; |
| use crate::patch::MirPatch; |
| use crate::unreachable_prop::remove_successors_from_switch; |
| |
| /// Unifies all targets into one basic block if each statement can have the same statement. |
| pub(super) struct MatchBranchSimplification; |
| |
| impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification { |
| fn is_enabled(&self, sess: &rustc_session::Session) -> bool { |
| // Enable only under -Zmir-opt-level=2 as this can make programs less debuggable. |
| sess.mir_opt_level() >= 2 |
| } |
| |
| fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
| let typing_env = body.typing_env(tcx); |
| let mut changed = false; |
| for bb in body.basic_blocks.indices() { |
| if !candidate_match(body, bb) { |
| continue; |
| }; |
| changed |= simplify_match(tcx, typing_env, body, bb) |
| } |
| |
| if changed { |
| simplify_cfg(tcx, body); |
| } |
| } |
| |
| fn is_required(&self) -> bool { |
| false |
| } |
| } |
| |
| struct SimplifyMatch<'tcx, 'a> { |
| tcx: TyCtxt<'tcx>, |
| typing_env: ty::TypingEnv<'tcx>, |
| patch: MirPatch<'tcx>, |
| body: &'a Body<'tcx>, |
| switch_bb: BasicBlock, |
| discr: &'a Operand<'tcx>, |
| discr_local: Option<Local>, |
| discr_ty: Ty<'tcx>, |
| } |
| |
| impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> { |
| fn discr_local(&mut self) -> Local { |
| *self.discr_local.get_or_insert_with(|| { |
| // Introduce a temporary for the discriminant value. |
| let source_info = self.body.basic_blocks[self.switch_bb].terminator().source_info; |
| self.patch.new_temp(self.discr_ty, source_info.span) |
| }) |
| } |
| |
| /// Unifies the assignments if all rvalues are constants and equal. |
| fn unify_if_equal_const( |
| &self, |
| dest: Place<'tcx>, |
| consts: &[(u128, &ConstOperand<'tcx>)], |
| otherwise: Option<&ConstOperand<'tcx>>, |
| ) -> Option<StatementKind<'tcx>> { |
| let (_, first_const, mut others) = split_first_case(consts, otherwise); |
| let first_scalar_int = first_const.const_.try_eval_scalar_int(self.tcx, self.typing_env)?; |
| if others.all(|const_| { |
| const_.const_.try_eval_scalar_int(self.tcx, self.typing_env) == Some(first_scalar_int) |
| }) { |
| Some(StatementKind::Assign(Box::new(( |
| dest, |
| Rvalue::Use(Operand::Constant(Box::new(first_const.clone()))), |
| )))) |
| } else { |
| None |
| } |
| } |
| |
| /// If a source block is found that switches between two blocks that are exactly |
| /// the same modulo const bool assignments (e.g., one assigns true another false |
| /// to the same place), unify a target block statements into the source block, |
| /// using Eq / Ne comparison with switch value where const bools value differ. |
| /// |
| /// For example: |
| /// |
| /// ```ignore (MIR) |
| /// bb0: { |
| /// switchInt(move _3) -> [42_isize: bb1, otherwise: bb2]; |
| /// } |
| /// |
| /// bb1: { |
| /// _2 = const true; |
| /// goto -> bb3; |
| /// } |
| /// |
| /// bb2: { |
| /// _2 = const false; |
| /// goto -> bb3; |
| /// } |
| /// ``` |
| /// |
| /// into: |
| /// |
| /// ```ignore (MIR) |
| /// bb0: { |
| /// _2 = Eq(move _3, const 42_isize); |
| /// goto -> bb3; |
| /// } |
| /// ``` |
| fn unify_by_eq_op( |
| &mut self, |
| dest: Place<'tcx>, |
| consts: &[(u128, &ConstOperand<'tcx>)], |
| otherwise: Option<&ConstOperand<'tcx>>, |
| ) -> Option<StatementKind<'tcx>> { |
| // FIXME: extend to any case. |
| let (first_case, first_const, mut others) = split_first_case(consts, otherwise); |
| if !first_const.ty().is_bool() { |
| return None; |
| } |
| let first_bool = first_const.const_.try_eval_bool(self.tcx, self.typing_env)?; |
| if others.all(|const_| { |
| const_.const_.try_eval_bool(self.tcx, self.typing_env) == Some(!first_bool) |
| }) { |
| // Make value conditional on switch condition. |
| let size = |
| self.tcx.layout_of(self.typing_env.as_query_input(self.discr_ty)).unwrap().size; |
| let const_cmp = Operand::const_from_scalar( |
| self.tcx, |
| self.discr_ty, |
| rustc_const_eval::interpret::Scalar::from_uint(first_case, size), |
| rustc_span::DUMMY_SP, |
| ); |
| let op = if first_bool { BinOp::Eq } else { BinOp::Ne }; |
| let rval = Rvalue::BinaryOp( |
| op, |
| Box::new((Operand::Copy(Place::from(self.discr_local())), const_cmp)), |
| ); |
| Some(StatementKind::Assign(Box::new((dest, rval)))) |
| } else { |
| None |
| } |
| } |
| |
| /// Unifies the assignments if all rvalues can be cast from the discriminant value by IntToInt. |
| /// |
| /// For example: |
| /// |
| /// ```ignore (MIR) |
| /// bb0: { |
| /// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1]; |
| /// } |
| /// |
| /// bb1: { |
| /// unreachable; |
| /// } |
| /// |
| /// bb2: { |
| /// _0 = const 1_i16; |
| /// goto -> bb5; |
| /// } |
| /// |
| /// bb3: { |
| /// _0 = const 2_i16; |
| /// goto -> bb5; |
| /// } |
| /// |
| /// bb4: { |
| /// _0 = const 3_i16; |
| /// goto -> bb5; |
| /// } |
| /// ``` |
| /// |
| /// into: |
| /// |
| /// ```ignore (MIR) |
| /// bb0: { |
| /// _0 = _1 as i16 (IntToInt); |
| /// goto -> bb5; |
| /// } |
| /// ``` |
| fn unify_by_int_to_int( |
| &mut self, |
| dest: Place<'tcx>, |
| consts: &[(u128, &ConstOperand<'tcx>)], |
| ) -> Option<StatementKind<'tcx>> { |
| let (_, first_const) = consts[0]; |
| if !first_const.ty().is_integral() { |
| return None; |
| } |
| let discr_layout = |
| self.tcx.layout_of(self.typing_env.as_query_input(self.discr_ty)).unwrap(); |
| if consts.iter().all(|&(case, const_)| { |
| let Some(scalar_int) = const_.const_.try_eval_scalar_int(self.tcx, self.typing_env) |
| else { |
| return false; |
| }; |
| can_cast(self.tcx, case, discr_layout, const_.ty(), scalar_int) |
| }) { |
| let operand = Operand::Copy(Place::from(self.discr_local())); |
| let rval = if first_const.ty() == self.discr_ty { |
| Rvalue::Use(operand) |
| } else { |
| Rvalue::Cast(CastKind::IntToInt, operand, first_const.ty()) |
| }; |
| Some(StatementKind::Assign(Box::new((dest, rval)))) |
| } else { |
| None |
| } |
| } |
| |
| /// This is primarily used to unify these copy statements that simplified the canonical enum clone method by GVN. |
| /// The GVN simplified |
| /// ```ignore (syntax-highlighting-only) |
| /// match a { |
| /// Foo::A(x) => Foo::A(*x), |
| /// Foo::B => Foo::B |
| /// } |
| /// ``` |
| /// to |
| /// ```ignore (syntax-highlighting-only) |
| /// match a { |
| /// Foo::A(_x) => a, // copy a |
| /// Foo::B => Foo::B |
| /// } |
| /// ``` |
| /// This will simplify into a copy statement. |
| fn unify_by_copy( |
| &self, |
| dest: Place<'tcx>, |
| rvals: &[(u128, &Rvalue<'tcx>)], |
| ) -> Option<StatementKind<'tcx>> { |
| let bbs = &self.body.basic_blocks; |
| // Check if the copy source matches the following pattern. |
| // _2 = discriminant(*_1); // "*_1" is the expected the copy source. |
| // switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1]; |
| let &Statement { |
| kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(copy_src_place))), |
| .. |
| } = bbs[self.switch_bb].statements.last()? |
| else { |
| return None; |
| }; |
| if self.discr.place() != Some(discr_place) { |
| return None; |
| } |
| let src_ty = copy_src_place.ty(self.body.local_decls(), self.tcx); |
| if !src_ty.ty.is_enum() || src_ty.variant_index.is_some() { |
| return None; |
| } |
| let dest_ty = dest.ty(self.body.local_decls(), self.tcx); |
| if dest_ty.ty != src_ty.ty || dest_ty.variant_index.is_some() { |
| return None; |
| } |
| let ty::Adt(def, _) = dest_ty.ty.kind() else { |
| return None; |
| }; |
| |
| for &(case, rvalue) in rvals.iter() { |
| match rvalue { |
| // Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`. |
| Rvalue::Use(Operand::Constant(box constant)) |
| if let Const::Val(const_, ty) = constant.const_ => |
| { |
| let (ecx, op) = mk_eval_cx_for_const_val( |
| self.tcx.at(constant.span), |
| self.typing_env, |
| const_, |
| ty, |
| )?; |
| let variant = ecx.read_discriminant(&op).discard_err()?; |
| if !def.variants()[variant].fields.is_empty() { |
| return None; |
| } |
| let Discr { val, .. } = ty.discriminant_for_variant(self.tcx, variant)?; |
| if val != case { |
| return None; |
| } |
| } |
| Rvalue::Use(Operand::Copy(src_place)) if *src_place == copy_src_place => {} |
| // Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`. |
| Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields) |
| if fields.is_empty() |
| && let Some(Discr { val, .. }) = |
| src_ty.ty.discriminant_for_variant(self.tcx, *variant_index) |
| && val == case => {} |
| _ => return None, |
| } |
| } |
| Some(StatementKind::Assign(Box::new((dest, Rvalue::Use(Operand::Copy(copy_src_place)))))) |
| } |
| |
| /// Returns a new statement if we can use the statement replace all statements. |
| fn try_unify_stmts( |
| &mut self, |
| index: usize, |
| stmts: &[(u128, &StatementKind<'tcx>)], |
| otherwise: Option<&StatementKind<'tcx>>, |
| ) -> Option<StatementKind<'tcx>> { |
| if let Some(new_stmt) = identical_stmts(stmts, otherwise) { |
| return Some(new_stmt); |
| } |
| |
| let (dest, rvals, otherwise) = candidate_assign(stmts, otherwise)?; |
| if let Some((consts, otherwise)) = candidate_const(&rvals, otherwise) { |
| if let Some(new_stmt) = self.unify_if_equal_const(dest, &consts, otherwise) { |
| return Some(new_stmt); |
| } |
| if let Some(new_stmt) = self.unify_by_eq_op(dest, &consts, otherwise) { |
| return Some(new_stmt); |
| } |
| // Requires the otherwise is unreachable. |
| if otherwise.is_none() |
| && let Some(new_stmt) = self.unify_by_int_to_int(dest, &consts) |
| { |
| return Some(new_stmt); |
| } |
| } |
| |
| // We only know the first statement is safe to introduce new dereferences. |
| if index == 0 |
| // We cannot create overlapping assignments. |
| && dest.is_stable_offset() |
| // Requires the otherwise is unreachable. |
| && otherwise.is_none() |
| && let Some(new_stmt) = self.unify_by_copy(dest, &rvals) |
| { |
| return Some(new_stmt); |
| } |
| None |
| } |
| } |
| |
| /// Returns the first case target if all targets have an equal number of statements and identical destination. |
| fn candidate_match<'tcx>(body: &Body<'tcx>, switch_bb: BasicBlock) -> bool { |
| use itertools::Itertools; |
| let targets = match &body.basic_blocks[switch_bb].terminator().kind { |
| TerminatorKind::SwitchInt { |
| discr: Operand::Copy(_) | Operand::Move(_), targets, .. |
| } => targets, |
| // Only optimize switch int statements |
| _ => return false, |
| }; |
| // We require that the possible target blocks don't contain this block. |
| if targets.all_targets().contains(&switch_bb) { |
| return false; |
| } |
| // We require that the possible target blocks all be distinct. |
| if !targets.is_distinct() { |
| return false; |
| } |
| // Check that destinations are identical, and if not, then don't optimize this block |
| targets |
| .all_targets() |
| .iter() |
| .map(|&bb| &body.basic_blocks[bb]) |
| .filter(|bb| !bb.is_empty_unreachable()) |
| .map(|bb| (bb.statements.len(), &bb.terminator().kind)) |
| .all_equal() |
| } |
| |
| fn simplify_match<'tcx>( |
| tcx: TyCtxt<'tcx>, |
| typing_env: ty::TypingEnv<'tcx>, |
| body: &mut Body<'tcx>, |
| switch_bb: BasicBlock, |
| ) -> bool { |
| let (discr, targets) = match &body.basic_blocks[switch_bb].terminator().kind { |
| TerminatorKind::SwitchInt { discr, targets, .. } => (discr, targets), |
| _ => unreachable!(), |
| }; |
| let mut simplify_match = SimplifyMatch { |
| tcx, |
| typing_env, |
| patch: MirPatch::new(body), |
| body, |
| switch_bb, |
| discr, |
| discr_local: None, |
| discr_ty: discr.ty(body.local_decls(), tcx), |
| }; |
| let reachable_cases: Vec<_> = |
| targets.iter().filter(|&(_, bb)| !body.basic_blocks[bb].is_empty_unreachable()).collect(); |
| let mut new_stmts = Vec::new(); |
| let otherwise = if body.basic_blocks[targets.otherwise()].is_empty_unreachable() { |
| None |
| } else { |
| Some(targets.otherwise()) |
| }; |
| // We can patch the terminator to goto because there is a single target. |
| match (reachable_cases.len(), otherwise.is_none()) { |
| (1, true) | (0, false) => { |
| let mut patch = simplify_match.patch; |
| remove_successors_from_switch(tcx, switch_bb, body, &mut patch, |bb| { |
| body.basic_blocks[bb].is_empty_unreachable() |
| }); |
| patch.apply(body); |
| return true; |
| } |
| _ => {} |
| } |
| let Some(&(_, first_case_bb)) = reachable_cases.first() else { |
| return false; |
| }; |
| let stmt_len = body.basic_blocks[first_case_bb].statements.len(); |
| let mut cases = Vec::with_capacity(stmt_len); |
| // Check at each position in the basic blocks whether these statements can be unified. |
| for index in 0..stmt_len { |
| cases.clear(); |
| let otherwise = otherwise.map(|bb| &body.basic_blocks[bb].statements[index].kind); |
| for &(case, bb) in &reachable_cases { |
| cases.push((case, &body.basic_blocks[bb].statements[index].kind)); |
| } |
| let Some(new_stmt) = simplify_match.try_unify_stmts(index, &cases, otherwise) else { |
| return false; |
| }; |
| new_stmts.push(new_stmt); |
| } |
| // Take ownership of items now that we know we can optimize. |
| let discr = discr.clone(); |
| |
| let statement_index = body.basic_blocks[switch_bb].statements.len(); |
| let parent_end = Location { block: switch_bb, statement_index }; |
| let mut patch = simplify_match.patch; |
| if let Some(discr_local) = simplify_match.discr_local { |
| patch.add_statement(parent_end, StatementKind::StorageLive(discr_local)); |
| patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr)); |
| } |
| for new_stmt in new_stmts { |
| patch.add_statement(parent_end, new_stmt); |
| } |
| if let Some(discr_local) = simplify_match.discr_local { |
| patch.add_statement(parent_end, StatementKind::StorageDead(discr_local)); |
| } |
| patch.patch_terminator(switch_bb, body.basic_blocks[first_case_bb].terminator().kind.clone()); |
| patch.apply(body); |
| true |
| } |
| |
| /// Check if the cast constant using `IntToInt` is equal to the target constant. |
| fn can_cast( |
| tcx: TyCtxt<'_>, |
| src_val: impl Into<u128>, |
| src_layout: TyAndLayout<'_>, |
| cast_ty: Ty<'_>, |
| target_scalar: ScalarInt, |
| ) -> bool { |
| let from_scalar = ScalarInt::try_from_uint(src_val.into(), src_layout.size).unwrap(); |
| let v = match src_layout.ty.kind() { |
| ty::Uint(_) => from_scalar.to_uint(src_layout.size), |
| ty::Int(_) => from_scalar.to_int(src_layout.size) as u128, |
| // We can also transform the values of other integer representations (such as char), |
| // although this may not be practical in real-world scenarios. |
| _ => return false, |
| }; |
| let size = match *cast_ty.kind() { |
| ty::Int(t) => Integer::from_int_ty(&tcx, t).size(), |
| ty::Uint(t) => Integer::from_uint_ty(&tcx, t).size(), |
| _ => return false, |
| }; |
| let v = size.truncate(v); |
| let cast_scalar = ScalarInt::try_from_uint(v, size).unwrap(); |
| cast_scalar == target_scalar |
| } |
| |
| fn candidate_assign<'tcx, 'a>( |
| stmts: &'a [(u128, &'a StatementKind<'tcx>)], |
| otherwise: Option<&'a StatementKind<'tcx>>, |
| ) -> Option<(Place<'tcx>, Vec<(u128, &'a Rvalue<'tcx>)>, Option<&'a Rvalue<'tcx>>)> { |
| let (_, first_stmt) = stmts[0]; |
| let (dest, _) = first_stmt.as_assign()?; |
| let otherwise = if let Some(otherwise) = otherwise { |
| let Some((otherwise_dest, rval)) = otherwise.as_assign() else { |
| return None; |
| }; |
| if otherwise_dest != dest { |
| return None; |
| } |
| Some(rval) |
| } else { |
| None |
| }; |
| let rvals = stmts |
| .into_iter() |
| .map(|&(case, stmt)| { |
| let (other_dest, rval) = stmt.as_assign()?; |
| if other_dest != dest { |
| return None; |
| } |
| Some((case, rval)) |
| }) |
| .try_collect()?; |
| Some((*dest, rvals, otherwise)) |
| } |
| |
| // Returns all ConstOperands if all Rvalues are ConstOperands. |
| fn candidate_const<'tcx, 'a>( |
| rvals: &'a [(u128, &'a Rvalue<'tcx>)], |
| otherwise: Option<&'a Rvalue<'tcx>>, |
| ) -> Option<(Vec<(u128, &'a ConstOperand<'tcx>)>, Option<&'a ConstOperand<'tcx>>)> { |
| let otherwise = if let Some(otherwise) = otherwise { |
| let Rvalue::Use(Operand::Constant(box const_)) = otherwise else { |
| return None; |
| }; |
| Some(const_) |
| } else { |
| None |
| }; |
| let consts = rvals |
| .into_iter() |
| .map(|&(case, rval)| { |
| let Rvalue::Use(Operand::Constant(box const_)) = rval else { return None }; |
| Some((case, const_)) |
| }) |
| .try_collect()?; |
| Some((consts, otherwise)) |
| } |
| |
| // Returns the first case and others (including otherwise if present). |
| fn split_first_case<'a, T>( |
| stmts: &'a [(u128, &'a T)], |
| otherwise: Option<&'a T>, |
| ) -> (u128, &'a T, impl Iterator<Item = &'a T>) { |
| let (first_case, first) = stmts[0]; |
| (first_case, first, stmts[1..].into_iter().map(|&(_, val)| val).chain(otherwise)) |
| } |
| |
| // If all statements are identical, we can optimize. |
| fn identical_stmts<'tcx>( |
| stmts: &[(u128, &StatementKind<'tcx>)], |
| otherwise: Option<&StatementKind<'tcx>>, |
| ) -> Option<StatementKind<'tcx>> { |
| use itertools::Itertools; |
| let (_, first_stmt, others) = split_first_case(stmts, otherwise); |
| if std::iter::once(first_stmt).chain(others).all_equal() { |
| return Some(first_stmt.clone()); |
| } |
| None |
| } |