Assume the returned value in `.filter(…).count()`
Similar to how this helps in `slice::Iter::position`, LLVM sometimes loses track of how high this can get, so for `TrustedLen` iterators tell it what the upper bound is.
diff --git a/library/core/src/iter/adapters/filter.rs b/library/core/src/iter/adapters/filter.rs
index dd08cd6..b22419c 100644
--- a/library/core/src/iter/adapters/filter.rs
+++ b/library/core/src/iter/adapters/filter.rs
@@ -4,7 +4,7 @@
use crate::fmt;
use crate::iter::adapters::SourceIter;
-use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused};
+use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused, TrustedLen};
use crate::num::NonZero;
use crate::ops::Try;
@@ -138,7 +138,13 @@ fn to_usize<T>(mut predicate: impl FnMut(&T) -> bool) -> impl FnMut(T) -> usize
move |x| predicate(&x) as usize
}
- self.iter.map(to_usize(self.predicate)).sum()
+ let before = self.iter.size_hint().1.unwrap_or(usize::MAX);
+ let total = self.iter.map(to_usize(self.predicate)).sum();
+ // SAFETY: `total` and `before` came from the same iterator of type `I`
+ unsafe {
+ <I as SpecAssumeCount>::assume_count_le_upper_bound(total, before);
+ }
+ total
}
#[inline]
@@ -214,3 +220,34 @@ unsafe impl<I: InPlaceIterable, P> InPlaceIterable for Filter<I, P> {
const EXPAND_BY: Option<NonZero<usize>> = I::EXPAND_BY;
const MERGE_BY: Option<NonZero<usize>> = I::MERGE_BY;
}
+
+trait SpecAssumeCount {
+ /// # Safety
+ ///
+ /// `count` must be an number of items actually read from the iterator.
+ ///
+ /// `upper` must either:
+ /// - have come from `size_hint().1` on the iterator, or
+ /// - be `usize::MAX` which will vacuously do nothing.
+ unsafe fn assume_count_le_upper_bound(count: usize, upper: usize);
+}
+
+impl<I: Iterator> SpecAssumeCount for I {
+ #[inline]
+ #[rustc_inherit_overflow_checks]
+ default unsafe fn assume_count_le_upper_bound(count: usize, upper: usize) {
+ // In the default we can't trust the `upper` for soundness
+ // because it came from an untrusted `size_hint`.
+
+ // In debug mode we might as well check that the size_hint wasn't too small
+ let _ = upper - count;
+ }
+}
+
+impl<I: TrustedLen> SpecAssumeCount for I {
+ #[inline]
+ unsafe fn assume_count_le_upper_bound(count: usize, upper: usize) {
+ // SAFETY: The `upper` is trusted because it came from a `TrustedLen` iterator.
+ unsafe { crate::hint::assert_unchecked(count <= upper) }
+ }
+}
diff --git a/tests/codegen-llvm/lib-optimizations/iter-filter-count-assume.rs b/tests/codegen-llvm/lib-optimizations/iter-filter-count-assume.rs
new file mode 100644
index 0000000..7588b23
--- /dev/null
+++ b/tests/codegen-llvm/lib-optimizations/iter-filter-count-assume.rs
@@ -0,0 +1,34 @@
+//@ compile-flags: -Copt-level=3
+//@ edition: 2024
+
+#![crate_type = "lib"]
+
+// Similar to how we `assume` that `slice::Iter::position` is within the length,
+// check that `count` also does that for `TrustedLen` iterators.
+// See https://rust-lang.zulipchat.com/#narrow/channel/122651-general/topic/Overflow-chk.20removed.20for.20array.20of.2059.2C.20but.20not.2060.2C.20elems/with/561070780
+
+// CHECK-LABEL: @filter_count_untrusted
+#[unsafe(no_mangle)]
+pub fn filter_count_untrusted(bar: &[u8; 1234]) -> u16 {
+ // CHECK-NOT: llvm.assume
+ // CHECK: call void @{{.+}}unwrap_failed
+ // CHECK-NOT: llvm.assume
+ let mut iter = bar.iter();
+ let iter = std::iter::from_fn(|| iter.next()); // Make it not TrustedLen
+ u16::try_from(iter.filter(|v| **v == 0).count()).unwrap()
+}
+
+// CHECK-LABEL: @filter_count_trusted
+#[unsafe(no_mangle)]
+pub fn filter_count_trusted(bar: &[u8; 1234]) -> u16 {
+ // CHECK-NOT: unwrap_failed
+ // CHECK: %[[ASSUME:.+]] = icmp ult {{i64|i32|i16}} %{{.+}}, 1235
+ // CHECK-NEXT: tail call void @llvm.assume(i1 %[[ASSUME]])
+ // CHECK-NOT: unwrap_failed
+ let iter = bar.iter();
+ u16::try_from(iter.filter(|v| **v == 0).count()).unwrap()
+}
+
+// CHECK: ; core::result::unwrap_failed
+// CHECK-NEXT: Function Attrs
+// CHECK-NEXT: declare{{.+}}void @{{.+}}unwrap_failed
diff --git a/tests/ui/iterators/iter-filter-count-debug-check.rs b/tests/ui/iterators/iter-filter-count-debug-check.rs
new file mode 100644
index 0000000..6e3a3f7
--- /dev/null
+++ b/tests/ui/iterators/iter-filter-count-debug-check.rs
@@ -0,0 +1,34 @@
+//@ run-pass
+//@ needs-unwind
+//@ ignore-backends: gcc
+//@ compile-flags: -C overflow-checks
+
+use std::panic;
+
+struct Lies(usize);
+
+impl Iterator for Lies {
+ type Item = usize;
+
+ fn next(&mut self) -> Option<usize> {
+ if self.0 == 0 {
+ None
+ } else {
+ self.0 -= 1;
+ Some(self.0)
+ }
+ }
+
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ (0, Some(2))
+ }
+}
+
+fn main() {
+ let r = panic::catch_unwind(|| {
+ // This returns more items than its `size_hint` said was possible,
+ // which `Filter::count` detects via `overflow-checks`.
+ let _ = Lies(10).filter(|&x| x > 3).count();
+ });
+ assert!(r.is_err());
+}