| use clippy_utils::diagnostics::span_lint_and_then; |
| use clippy_utils::res::MaybeResPath; |
| use clippy_utils::source::{IntoSpan, SpanRangeExt, first_line_of_span, indent_of, reindent_multiline, snippet}; |
| use clippy_utils::ty::needs_ordered_drop; |
| use clippy_utils::visitors::for_each_expr_without_closures; |
| use clippy_utils::{ |
| ContainsName, HirEqInterExpr, SpanlessEq, capture_local_usage, get_enclosing_block, hash_expr, hash_stmt, |
| }; |
| use core::iter; |
| use core::ops::ControlFlow; |
| use rustc_errors::Applicability; |
| use rustc_hir::{Block, Expr, ExprKind, HirId, HirIdSet, LetStmt, Node, Stmt, StmtKind, intravisit}; |
| use rustc_lint::LateContext; |
| use rustc_span::hygiene::walk_chain; |
| use rustc_span::source_map::SourceMap; |
| use rustc_span::{Span, Symbol}; |
| |
| use super::BRANCHES_SHARING_CODE; |
| |
| pub(super) fn check<'tcx>( |
| cx: &LateContext<'tcx>, |
| conds: &[&'tcx Expr<'_>], |
| blocks: &[&'tcx Block<'_>], |
| expr: &'tcx Expr<'_>, |
| ) { |
| // We only lint ifs with multiple blocks |
| let &[first_block, ref blocks @ ..] = blocks else { |
| return; |
| }; |
| let &[.., last_block] = blocks else { |
| return; |
| }; |
| |
| let res = scan_block_for_eq(cx, conds, first_block, blocks); |
| let sm = cx.tcx.sess.source_map(); |
| let start_suggestion = res.start_span(first_block, sm).map(|span| { |
| let first_line_span = first_line_of_span(cx, expr.span); |
| let replace_span = first_line_span.with_hi(span.hi()); |
| let cond_span = first_line_span.until(first_block.span); |
| let cond_snippet = reindent_multiline(&snippet(cx, cond_span, "_"), false, None); |
| let cond_indent = indent_of(cx, cond_span); |
| let moved_snippet = reindent_multiline(&snippet(cx, span, "_"), true, None); |
| let suggestion = moved_snippet + "\n" + &cond_snippet + "{"; |
| let suggestion = reindent_multiline(&suggestion, true, cond_indent); |
| (replace_span, suggestion) |
| }); |
| let end_suggestion = res.end_span(last_block, sm).map(|span| { |
| let moved_snipped = reindent_multiline(&snippet(cx, span, "_"), true, None); |
| let indent = indent_of(cx, expr.span.shrink_to_hi()); |
| let suggestion = "}\n".to_string() + &moved_snipped; |
| let suggestion = reindent_multiline(&suggestion, true, indent); |
| |
| let span = span.with_hi(last_block.span.hi()); |
| // Improve formatting if the inner block has indentation (i.e. normal Rust formatting) |
| let span = span |
| .map_range(cx, |_, src, range| { |
| (range.start > 4 && src.get(range.start - 4..range.start)? == " ") |
| .then_some(range.start - 4..range.end) |
| }) |
| .map_or(span, |range| range.with_ctxt(span.ctxt())); |
| (span, suggestion.clone()) |
| }); |
| |
| let (span, msg, end_span) = match (&start_suggestion, &end_suggestion) { |
| (&Some((span, _)), &Some((end_span, _))) => ( |
| span, |
| "all if blocks contain the same code at both the start and the end", |
| Some(end_span), |
| ), |
| (&Some((span, _)), None) => (span, "all if blocks contain the same code at the start", None), |
| (None, &Some((span, _))) => (span, "all if blocks contain the same code at the end", None), |
| (None, None) => return, |
| }; |
| span_lint_and_then(cx, BRANCHES_SHARING_CODE, span, msg, |diag| { |
| if let Some(span) = end_span { |
| diag.span_note(span, "this code is shared at the end"); |
| } |
| if let Some((span, sugg)) = start_suggestion { |
| diag.span_suggestion( |
| span, |
| "consider moving these statements before the if", |
| sugg, |
| Applicability::Unspecified, |
| ); |
| } |
| if let Some((span, sugg)) = end_suggestion { |
| diag.span_suggestion( |
| span, |
| "consider moving these statements after the if", |
| sugg, |
| Applicability::Unspecified, |
| ); |
| if is_expr_parent_assignment(cx, expr) || !cx.typeck_results().expr_ty(expr).is_unit() { |
| diag.note("the end suggestion probably needs some adjustments to use the expression result correctly"); |
| } |
| } |
| if check_for_warn_of_moved_symbol(cx, &res.moved_locals, expr) { |
| diag.warn("some moved values might need to be renamed to avoid wrong references"); |
| } |
| }); |
| } |
| |
| struct BlockEq { |
| /// The end of the range of equal stmts at the start. |
| start_end_eq: usize, |
| /// The start of the range of equal stmts at the end. |
| end_begin_eq: Option<usize>, |
| /// The name and id of every local which can be moved at the beginning and the end. |
| moved_locals: Vec<(HirId, Symbol)>, |
| } |
| impl BlockEq { |
| fn start_span(&self, b: &Block<'_>, sm: &SourceMap) -> Option<Span> { |
| match &b.stmts[..self.start_end_eq] { |
| [first, .., last] => Some(sm.stmt_span(first.span, b.span).to(sm.stmt_span(last.span, b.span))), |
| [s] => Some(sm.stmt_span(s.span, b.span)), |
| [] => None, |
| } |
| } |
| |
| fn end_span(&self, b: &Block<'_>, sm: &SourceMap) -> Option<Span> { |
| match (&b.stmts[b.stmts.len() - self.end_begin_eq?..], b.expr) { |
| ([first, .., last], None) => Some(sm.stmt_span(first.span, b.span).to(sm.stmt_span(last.span, b.span))), |
| ([first, ..], Some(last)) => Some(sm.stmt_span(first.span, b.span).to(sm.stmt_span(last.span, b.span))), |
| ([s], None) => Some(sm.stmt_span(s.span, b.span)), |
| ([], Some(e)) => Some(walk_chain(e.span, b.span.ctxt())), |
| ([], None) => None, |
| } |
| } |
| } |
| |
| /// If the statement is a local, checks if the bound names match the expected list of names. |
| fn eq_binding_names(s: &Stmt<'_>, names: &[(HirId, Symbol)]) -> bool { |
| if let StmtKind::Let(l) = s.kind { |
| let mut i = 0usize; |
| let mut res = true; |
| l.pat.each_binding_or_first(&mut |_, _, _, name| { |
| if names.get(i).is_some_and(|&(_, n)| n == name.name) { |
| i += 1; |
| } else { |
| res = false; |
| } |
| }); |
| res && i == names.len() |
| } else { |
| false |
| } |
| } |
| |
| /// Checks if the statement modifies or moves any of the given locals. |
| fn modifies_any_local<'tcx>(cx: &LateContext<'tcx>, s: &'tcx Stmt<'_>, locals: &HirIdSet) -> bool { |
| for_each_expr_without_closures(s, |e| { |
| if let Some(id) = e.res_local_id() |
| && locals.contains(&id) |
| && !capture_local_usage(cx, e).is_imm_ref() |
| { |
| ControlFlow::Break(()) |
| } else { |
| ControlFlow::Continue(()) |
| } |
| }) |
| .is_some() |
| } |
| |
| /// Checks if the given statement should be considered equal to the statement in the same |
| /// position for each block. |
| fn eq_stmts( |
| stmt: &Stmt<'_>, |
| blocks: &[&Block<'_>], |
| get_stmt: impl for<'a> Fn(&'a Block<'a>) -> Option<&'a Stmt<'a>>, |
| eq: &mut HirEqInterExpr<'_, '_, '_>, |
| moved_bindings: &mut Vec<(HirId, Symbol)>, |
| ) -> bool { |
| (if let StmtKind::Let(l) = stmt.kind { |
| let old_count = moved_bindings.len(); |
| l.pat.each_binding_or_first(&mut |_, id, _, name| { |
| moved_bindings.push((id, name.name)); |
| }); |
| let new_bindings = &moved_bindings[old_count..]; |
| blocks |
| .iter() |
| .all(|b| get_stmt(b).is_some_and(|s| eq_binding_names(s, new_bindings))) |
| } else { |
| true |
| }) && blocks.iter().all(|b| get_stmt(b).is_some_and(|s| eq.eq_stmt(s, stmt))) |
| } |
| |
| #[expect(clippy::too_many_lines)] |
| fn scan_block_for_eq<'tcx>( |
| cx: &LateContext<'tcx>, |
| conds: &[&'tcx Expr<'_>], |
| block: &'tcx Block<'_>, |
| blocks: &[&'tcx Block<'_>], |
| ) -> BlockEq { |
| let mut eq = SpanlessEq::new(cx); |
| let mut eq = eq.inter_expr(); |
| let mut moved_locals = Vec::new(); |
| |
| let mut cond_locals = HirIdSet::default(); |
| for &cond in conds { |
| let _: Option<!> = for_each_expr_without_closures(cond, |e| { |
| if let Some(id) = e.res_local_id() { |
| cond_locals.insert(id); |
| } |
| ControlFlow::Continue(()) |
| }); |
| } |
| |
| let mut local_needs_ordered_drop = false; |
| let start_end_eq = block |
| .stmts |
| .iter() |
| .enumerate() |
| .find(|&(i, stmt)| { |
| if let StmtKind::Let(l) = stmt.kind |
| && needs_ordered_drop(cx, cx.typeck_results().node_type(l.hir_id)) |
| { |
| local_needs_ordered_drop = true; |
| return true; |
| } |
| modifies_any_local(cx, stmt, &cond_locals) |
| || !eq_stmts(stmt, blocks, |b| b.stmts.get(i), &mut eq, &mut moved_locals) |
| }) |
| .map_or(block.stmts.len(), |(i, stmt)| { |
| adjust_by_closest_callsite(i, stmt, block.stmts[..i].iter().enumerate().rev()) |
| }); |
| |
| if local_needs_ordered_drop { |
| return BlockEq { |
| start_end_eq, |
| end_begin_eq: None, |
| moved_locals, |
| }; |
| } |
| |
| // Walk backwards through the final expression/statements so long as their hashes are equal. Note |
| // `SpanlessHash` treats all local references as equal allowing locals declared earlier in the block |
| // to match those in other blocks. e.g. If each block ends with the following the hash value will be |
| // the same even though each `x` binding will have a different `HirId`: |
| // let x = foo(); |
| // x + 50 |
| let expr_hash_eq = if let Some(e) = block.expr { |
| let hash = hash_expr(cx, e); |
| blocks.iter().all(|b| b.expr.is_some_and(|e| hash_expr(cx, e) == hash)) |
| } else { |
| blocks.iter().all(|b| b.expr.is_none()) |
| }; |
| if !expr_hash_eq { |
| return BlockEq { |
| start_end_eq, |
| end_begin_eq: None, |
| moved_locals, |
| }; |
| } |
| let end_search_start = block.stmts[start_end_eq..] |
| .iter() |
| .rev() |
| .enumerate() |
| .find(|&(offset, stmt)| { |
| let hash = hash_stmt(cx, stmt); |
| blocks.iter().any(|b| { |
| b.stmts |
| // the bounds check will catch the underflow |
| .get(b.stmts.len().wrapping_sub(offset + 1)) |
| .is_none_or(|s| hash != hash_stmt(cx, s)) |
| }) |
| }) |
| .map_or(block.stmts.len() - start_end_eq, |(i, stmt)| { |
| adjust_by_closest_callsite(i, stmt, (0..i).rev().zip(block.stmts[(block.stmts.len() - i)..].iter())) |
| }); |
| |
| let moved_locals_at_start = moved_locals.len(); |
| let mut i = end_search_start; |
| let end_begin_eq = block.stmts[block.stmts.len() - end_search_start..] |
| .iter() |
| .zip(iter::repeat_with(move || { |
| let x = i; |
| i -= 1; |
| x |
| })) |
| .fold(end_search_start, |init, (stmt, offset)| { |
| if eq_stmts( |
| stmt, |
| blocks, |
| |b| b.stmts.get(b.stmts.len() - offset), |
| &mut eq, |
| &mut moved_locals, |
| ) { |
| init |
| } else { |
| // Clear out all locals seen at the end so far. None of them can be moved. |
| let stmts = &blocks[0].stmts; |
| for stmt in &stmts[stmts.len() - init..=stmts.len() - offset] { |
| if let StmtKind::Let(l) = stmt.kind { |
| l.pat.each_binding_or_first(&mut |_, id, _, _| { |
| // FIXME(rust/#120456) - is `swap_remove` correct? |
| eq.locals.swap_remove(&id); |
| }); |
| } |
| } |
| moved_locals.truncate(moved_locals_at_start); |
| offset - 1 |
| } |
| }); |
| if let Some(e) = block.expr { |
| for block in blocks { |
| if block.expr.is_some_and(|expr| !eq.eq_expr(expr, e)) { |
| moved_locals.truncate(moved_locals_at_start); |
| return BlockEq { |
| start_end_eq, |
| end_begin_eq: None, |
| moved_locals, |
| }; |
| } |
| } |
| } |
| |
| BlockEq { |
| start_end_eq, |
| end_begin_eq: Some(end_begin_eq), |
| moved_locals, |
| } |
| } |
| |
| /// Adjusts the index for which the statements begin to differ to the closest macro callsite. |
| /// This avoids giving suggestions that requires splitting a macro call in half, when only a |
| /// part of the macro expansion is equal. |
| /// |
| /// For example, for the following macro: |
| /// ```rust,ignore |
| /// macro_rules! foo { |
| /// ($x:expr) => { |
| /// let y = 42; |
| /// $x; |
| /// }; |
| /// } |
| /// ``` |
| /// If the macro is called like this: |
| /// ```rust,ignore |
| /// if false { |
| /// let z = 42; |
| /// foo!(println!("Hello")); |
| /// } else { |
| /// let z = 42; |
| /// foo!(println!("World")); |
| /// } |
| /// ``` |
| /// Although the expanded `let y = 42;` is equal, the macro call should not be included in the |
| /// suggestion. |
| fn adjust_by_closest_callsite<'tcx>( |
| i: usize, |
| stmt: &'tcx Stmt<'tcx>, |
| mut iter: impl Iterator<Item = (usize, &'tcx Stmt<'tcx>)>, |
| ) -> usize { |
| let Some((_, first)) = iter.next() else { |
| return 0; |
| }; |
| |
| // If it is already at the boundary of a macro call, then just return. |
| if first.span.source_callsite() != stmt.span.source_callsite() { |
| return i; |
| } |
| |
| iter.find(|(_, stmt)| stmt.span.source_callsite() != first.span.source_callsite()) |
| .map_or(0, |(i, _)| i + 1) |
| } |
| |
| fn check_for_warn_of_moved_symbol(cx: &LateContext<'_>, symbols: &[(HirId, Symbol)], if_expr: &Expr<'_>) -> bool { |
| get_enclosing_block(cx, if_expr.hir_id).is_some_and(|block| { |
| let ignore_span = block.span.shrink_to_lo().to(if_expr.span); |
| |
| symbols |
| .iter() |
| .filter(|&&(_, name)| !name.as_str().starts_with('_')) |
| .any(|&(_, name)| { |
| let mut walker = ContainsName { name, cx }; |
| |
| // Scan block |
| let mut res = block |
| .stmts |
| .iter() |
| .filter(|stmt| !ignore_span.overlaps(stmt.span)) |
| .try_for_each(|stmt| intravisit::walk_stmt(&mut walker, stmt)); |
| |
| if let Some(expr) = block.expr |
| && res.is_continue() |
| { |
| res = intravisit::walk_expr(&mut walker, expr); |
| } |
| |
| res.is_break() |
| }) |
| }) |
| } |
| |
| fn is_expr_parent_assignment(cx: &LateContext<'_>, expr: &Expr<'_>) -> bool { |
| let parent = cx.tcx.parent_hir_node(expr.hir_id); |
| if let Node::LetStmt(LetStmt { init: Some(e), .. }) |
| | Node::Expr(Expr { |
| kind: ExprKind::Assign(_, e, _), |
| .. |
| }) = parent |
| { |
| return e.hir_id == expr.hir_id; |
| } |
| |
| false |
| } |