Skip to content

Safer implementation of RepeatN #130887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 43 additions & 118 deletions library/core/src/iter/sources/repeat_n.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use crate::fmt;
use crate::iter::{FusedIterator, TrustedLen, UncheckedIterator};
use crate::mem::MaybeUninit;
use crate::num::NonZero;
use crate::ops::{NeverShortCircuit, Try};
use crate::ops::Try;

/// Creates a new iterator that repeats a single element a given number of times.
///
Expand Down Expand Up @@ -58,78 +57,48 @@ use crate::ops::{NeverShortCircuit, Try};
#[inline]
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
pub fn repeat_n<T: Clone>(element: T, count: usize) -> RepeatN<T> {
let element = if count == 0 {
// `element` gets dropped eagerly.
MaybeUninit::uninit()
} else {
MaybeUninit::new(element)
};
RepeatN { inner: RepeatNInner::new(element, count) }
}

RepeatN { element, count }
#[derive(Clone, Copy)]
struct RepeatNInner<T> {
count: NonZero<usize>,
element: T,
}

impl<T> RepeatNInner<T> {
fn new(element: T, count: usize) -> Option<Self> {
let count = NonZero::<usize>::new(count)?;
Some(Self { element, count })
}
}

/// An iterator that repeats an element an exact number of times.
///
/// This `struct` is created by the [`repeat_n()`] function.
/// See its documentation for more.
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
#[derive(Clone)]
pub struct RepeatN<A> {
count: usize,
// Invariant: uninit iff count == 0.
element: MaybeUninit<A>,
inner: Option<RepeatNInner<A>>,
}

impl<A> RepeatN<A> {
/// Returns the element if it hasn't been dropped already.
fn element_ref(&self) -> Option<&A> {
if self.count > 0 {
// SAFETY: The count is non-zero, so it must be initialized.
Some(unsafe { self.element.assume_init_ref() })
} else {
None
}
}
/// If we haven't already dropped the element, return it in an option.
///
/// Clears the count so it won't be dropped again later.
#[inline]
fn take_element(&mut self) -> Option<A> {
if self.count > 0 {
self.count = 0;
// SAFETY: We just set count to zero so it won't be dropped again,
// and it used to be non-zero so it hasn't already been dropped.
let element = unsafe { self.element.assume_init_read() };
Some(element)
} else {
None
}
}
}

#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: Clone> Clone for RepeatN<A> {
fn clone(&self) -> RepeatN<A> {
RepeatN {
count: self.count,
element: self.element_ref().cloned().map_or_else(MaybeUninit::uninit, MaybeUninit::new),
}
self.inner.take().map(|inner| inner.element)
}
}

#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: fmt::Debug> fmt::Debug for RepeatN<A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RepeatN")
.field("count", &self.count)
.field("element", &self.element_ref())
.finish()
}
}

#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A> Drop for RepeatN<A> {
fn drop(&mut self) {
self.take_element();
let (count, element) = match self.inner.as_ref() {
Some(inner) => (inner.count.get(), Some(&inner.element)),
None => (0, None),
};
f.debug_struct("RepeatN").field("count", &count).field("element", &element).finish()
}
}

Expand All @@ -139,12 +108,17 @@ impl<A: Clone> Iterator for RepeatN<A> {

#[inline]
fn next(&mut self) -> Option<A> {
if self.count > 0 {
// SAFETY: Just checked it's not empty
unsafe { Some(self.next_unchecked()) }
} else {
None
let inner = self.inner.as_mut()?;
let count = inner.count.get();

if let Some(decremented) = NonZero::<usize>::new(count - 1) {
// Order of these is important for optimization
let tmp = inner.element.clone();
inner.count = decremented;
return Some(tmp);
}

return self.take_element();
}

#[inline]
Expand All @@ -155,52 +129,19 @@ impl<A: Clone> Iterator for RepeatN<A> {

#[inline]
fn advance_by(&mut self, skip: usize) -> Result<(), NonZero<usize>> {
let len = self.count;
let Some(inner) = self.inner.as_mut() else {
return NonZero::<usize>::new(skip).map(Err).unwrap_or(Ok(()));
};

if skip >= len {
self.take_element();
}
let len = inner.count.get();

if skip > len {
// SAFETY: we just checked that the difference is positive
Err(unsafe { NonZero::new_unchecked(skip - len) })
} else {
self.count = len - skip;
Ok(())
if let Some(new_len) = len.checked_sub(skip).and_then(NonZero::<usize>::new) {
inner.count = new_len;
return Ok(());
}
}

fn try_fold<B, F, R>(&mut self, mut acc: B, mut f: F) -> R
where
F: FnMut(B, A) -> R,
R: Try<Output = B>,
{
if self.count > 0 {
while self.count > 1 {
self.count -= 1;
// SAFETY: the count was larger than 1, so the element is
// initialized and hasn't been dropped.
acc = f(acc, unsafe { self.element.assume_init_ref().clone() })?;
}

// We could just set the count to zero directly, but doing it this
// way should make it easier for the optimizer to fold this tail
// into the loop when `clone()` is equivalent to copying.
self.count -= 1;
// SAFETY: we just set the count to zero from one, so the element
// is still initialized, has not been dropped yet and will not be
// accessed by future calls.
f(acc, unsafe { self.element.assume_init_read() })
} else {
try { acc }
}
}

fn fold<B, F>(mut self, init: B, f: F) -> B
where
F: FnMut(B, A) -> B,
{
self.try_fold(init, NeverShortCircuit::wrap_mut_2(f)).0
self.inner = None;
return NonZero::<usize>::new(skip - len).map(Err).unwrap_or(Ok(()));
}

#[inline]
Expand All @@ -217,7 +158,7 @@ impl<A: Clone> Iterator for RepeatN<A> {
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: Clone> ExactSizeIterator for RepeatN<A> {
fn len(&self) -> usize {
self.count
self.inner.as_ref().map(|inner| inner.count.get()).unwrap_or(0)
}
}

Expand Down Expand Up @@ -262,20 +203,4 @@ impl<A: Clone> FusedIterator for RepeatN<A> {}
#[unstable(feature = "trusted_len", issue = "37572")]
unsafe impl<A: Clone> TrustedLen for RepeatN<A> {}
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: Clone> UncheckedIterator for RepeatN<A> {
#[inline]
unsafe fn next_unchecked(&mut self) -> Self::Item {
// SAFETY: The caller promised the iterator isn't empty
self.count = unsafe { self.count.unchecked_sub(1) };
if self.count == 0 {
// SAFETY: the check above ensured that the count used to be non-zero,
// so element hasn't been dropped yet, and we just lowered the count to
// zero so it won't be dropped later, and thus it's okay to take it here.
unsafe { self.element.assume_init_read() }
} else {
// SAFETY: the count is non-zero, so it must have not been dropped yet.
let element = unsafe { self.element.assume_init_ref() };
A::clone(element)
}
}
}
impl<A: Clone> UncheckedIterator for RepeatN<A> {}
3 changes: 2 additions & 1 deletion tests/codegen/iter-repeat-n-trivial-drop.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//@ compile-flags: -C opt-level=3
//@ only-x86_64
//@ needs-deterministic-layouts

#![crate_type = "lib"]
#![feature(iter_repeat_n)]
Expand All @@ -25,7 +26,7 @@ pub fn iter_repeat_n_next(it: &mut std::iter::RepeatN<NotCopy>) -> Option<NotCop
// CHECK-NEXT: br i1 %[[COUNT_ZERO]], label %[[EMPTY:.+]], label %[[NOT_EMPTY:.+]]

// CHECK: [[NOT_EMPTY]]:
// CHECK-NEXT: %[[DEC:.+]] = add i64 %[[COUNT]], -1
// CHECK: %[[DEC:.+]] = add i64 %[[COUNT]], -1
// CHECK-NEXT: store i64 %[[DEC]]
// CHECK-NOT: br
// CHECK: %[[VAL:.+]] = load i16
Expand Down
Loading