Skip to content

Commit c8d9389

Browse files
committed
100% safe implementation of RepeatN
1 parent fc6bfe0 commit c8d9389

File tree

2 files changed

+45
-119
lines changed

2 files changed

+45
-119
lines changed
Lines changed: 43 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
use crate::fmt;
22
use crate::iter::{FusedIterator, TrustedLen, UncheckedIterator};
3-
use crate::mem::MaybeUninit;
43
use crate::num::NonZero;
5-
use crate::ops::{NeverShortCircuit, Try};
4+
use crate::ops::Try;
65

76
/// Creates a new iterator that repeats a single element a given number of times.
87
///
@@ -58,78 +57,48 @@ use crate::ops::{NeverShortCircuit, Try};
5857
#[inline]
5958
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
6059
pub fn repeat_n<T: Clone>(element: T, count: usize) -> RepeatN<T> {
61-
let element = if count == 0 {
62-
// `element` gets dropped eagerly.
63-
MaybeUninit::uninit()
64-
} else {
65-
MaybeUninit::new(element)
66-
};
60+
RepeatN { inner: RepeatNInner::new(element, count) }
61+
}
6762

68-
RepeatN { element, count }
63+
#[derive(Clone, Copy)]
64+
struct RepeatNInner<T> {
65+
count: NonZero<usize>,
66+
element: T,
67+
}
68+
69+
impl<T> RepeatNInner<T> {
70+
fn new(element: T, count: usize) -> Option<Self> {
71+
let count = NonZero::<usize>::new(count)?;
72+
Some(Self { element, count })
73+
}
6974
}
7075

7176
/// An iterator that repeats an element an exact number of times.
7277
///
7378
/// This `struct` is created by the [`repeat_n()`] function.
7479
/// See its documentation for more.
7580
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
81+
#[derive(Clone)]
7682
pub struct RepeatN<A> {
77-
count: usize,
78-
// Invariant: uninit iff count == 0.
79-
element: MaybeUninit<A>,
83+
inner: Option<RepeatNInner<A>>,
8084
}
8185

8286
impl<A> RepeatN<A> {
83-
/// Returns the element if it hasn't been dropped already.
84-
fn element_ref(&self) -> Option<&A> {
85-
if self.count > 0 {
86-
// SAFETY: The count is non-zero, so it must be initialized.
87-
Some(unsafe { self.element.assume_init_ref() })
88-
} else {
89-
None
90-
}
91-
}
9287
/// If we haven't already dropped the element, return it in an option.
93-
///
94-
/// Clears the count so it won't be dropped again later.
9588
#[inline]
9689
fn take_element(&mut self) -> Option<A> {
97-
if self.count > 0 {
98-
self.count = 0;
99-
// SAFETY: We just set count to zero so it won't be dropped again,
100-
// and it used to be non-zero so it hasn't already been dropped.
101-
let element = unsafe { self.element.assume_init_read() };
102-
Some(element)
103-
} else {
104-
None
105-
}
106-
}
107-
}
108-
109-
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
110-
impl<A: Clone> Clone for RepeatN<A> {
111-
fn clone(&self) -> RepeatN<A> {
112-
RepeatN {
113-
count: self.count,
114-
element: self.element_ref().cloned().map_or_else(MaybeUninit::uninit, MaybeUninit::new),
115-
}
90+
self.inner.take().map(|inner| inner.element)
11691
}
11792
}
11893

11994
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
12095
impl<A: fmt::Debug> fmt::Debug for RepeatN<A> {
12196
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122-
f.debug_struct("RepeatN")
123-
.field("count", &self.count)
124-
.field("element", &self.element_ref())
125-
.finish()
126-
}
127-
}
128-
129-
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
130-
impl<A> Drop for RepeatN<A> {
131-
fn drop(&mut self) {
132-
self.take_element();
97+
let (count, element) = match self.inner.as_ref() {
98+
Some(inner) => (inner.count.get(), Some(&inner.element)),
99+
None => (0, None),
100+
};
101+
f.debug_struct("RepeatN").field("count", &count).field("element", &element).finish()
133102
}
134103
}
135104

@@ -139,12 +108,17 @@ impl<A: Clone> Iterator for RepeatN<A> {
139108

140109
#[inline]
141110
fn next(&mut self) -> Option<A> {
142-
if self.count > 0 {
143-
// SAFETY: Just checked it's not empty
144-
unsafe { Some(self.next_unchecked()) }
145-
} else {
146-
None
111+
let inner = self.inner.as_mut()?;
112+
let count = inner.count.get();
113+
114+
if let Some(decremented) = NonZero::<usize>::new(count - 1) {
115+
// Order of these is important for optimization
116+
let tmp = inner.element.clone();
117+
inner.count = decremented;
118+
return Some(tmp);
147119
}
120+
121+
return self.take_element();
148122
}
149123

150124
#[inline]
@@ -155,52 +129,19 @@ impl<A: Clone> Iterator for RepeatN<A> {
155129

156130
#[inline]
157131
fn advance_by(&mut self, skip: usize) -> Result<(), NonZero<usize>> {
158-
let len = self.count;
132+
let Some(inner) = self.inner.as_mut() else {
133+
return NonZero::<usize>::new(skip).map(Err).unwrap_or(Ok(()));
134+
};
159135

160-
if skip >= len {
161-
self.take_element();
162-
}
136+
let len = inner.count.get();
163137

164-
if skip > len {
165-
// SAFETY: we just checked that the difference is positive
166-
Err(unsafe { NonZero::new_unchecked(skip - len) })
167-
} else {
168-
self.count = len - skip;
169-
Ok(())
138+
if let Some(new_len) = len.checked_sub(skip).and_then(NonZero::<usize>::new) {
139+
inner.count = new_len;
140+
return Ok(());
170141
}
171-
}
172142

173-
fn try_fold<B, F, R>(&mut self, mut acc: B, mut f: F) -> R
174-
where
175-
F: FnMut(B, A) -> R,
176-
R: Try<Output = B>,
177-
{
178-
if self.count > 0 {
179-
while self.count > 1 {
180-
self.count -= 1;
181-
// SAFETY: the count was larger than 1, so the element is
182-
// initialized and hasn't been dropped.
183-
acc = f(acc, unsafe { self.element.assume_init_ref().clone() })?;
184-
}
185-
186-
// We could just set the count to zero directly, but doing it this
187-
// way should make it easier for the optimizer to fold this tail
188-
// into the loop when `clone()` is equivalent to copying.
189-
self.count -= 1;
190-
// SAFETY: we just set the count to zero from one, so the element
191-
// is still initialized, has not been dropped yet and will not be
192-
// accessed by future calls.
193-
f(acc, unsafe { self.element.assume_init_read() })
194-
} else {
195-
try { acc }
196-
}
197-
}
198-
199-
fn fold<B, F>(mut self, init: B, f: F) -> B
200-
where
201-
F: FnMut(B, A) -> B,
202-
{
203-
self.try_fold(init, NeverShortCircuit::wrap_mut_2(f)).0
143+
self.inner = None;
144+
return NonZero::<usize>::new(skip - len).map(Err).unwrap_or(Ok(()));
204145
}
205146

206147
#[inline]
@@ -217,7 +158,7 @@ impl<A: Clone> Iterator for RepeatN<A> {
217158
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
218159
impl<A: Clone> ExactSizeIterator for RepeatN<A> {
219160
fn len(&self) -> usize {
220-
self.count
161+
self.inner.as_ref().map(|inner| inner.count.get()).unwrap_or(0)
221162
}
222163
}
223164

@@ -262,20 +203,4 @@ impl<A: Clone> FusedIterator for RepeatN<A> {}
262203
#[unstable(feature = "trusted_len", issue = "37572")]
263204
unsafe impl<A: Clone> TrustedLen for RepeatN<A> {}
264205
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
265-
impl<A: Clone> UncheckedIterator for RepeatN<A> {
266-
#[inline]
267-
unsafe fn next_unchecked(&mut self) -> Self::Item {
268-
// SAFETY: The caller promised the iterator isn't empty
269-
self.count = unsafe { self.count.unchecked_sub(1) };
270-
if self.count == 0 {
271-
// SAFETY: the check above ensured that the count used to be non-zero,
272-
// so element hasn't been dropped yet, and we just lowered the count to
273-
// zero so it won't be dropped later, and thus it's okay to take it here.
274-
unsafe { self.element.assume_init_read() }
275-
} else {
276-
// SAFETY: the count is non-zero, so it must have not been dropped yet.
277-
let element = unsafe { self.element.assume_init_ref() };
278-
A::clone(element)
279-
}
280-
}
281-
}
206+
impl<A: Clone> UncheckedIterator for RepeatN<A> {}

tests/codegen/iter-repeat-n-trivial-drop.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//@ compile-flags: -C opt-level=3
22
//@ only-x86_64
3+
//@ needs-deterministic-layouts
34

45
#![crate_type = "lib"]
56
#![feature(iter_repeat_n)]
@@ -25,7 +26,7 @@ pub fn iter_repeat_n_next(it: &mut std::iter::RepeatN<NotCopy>) -> Option<NotCop
2526
// CHECK-NEXT: br i1 %[[COUNT_ZERO]], label %[[EMPTY:.+]], label %[[NOT_EMPTY:.+]]
2627

2728
// CHECK: [[NOT_EMPTY]]:
28-
// CHECK-NEXT: %[[DEC:.+]] = add i64 %[[COUNT]], -1
29+
// CHECK: %[[DEC:.+]] = add i64 %[[COUNT]], -1
2930
// CHECK-NEXT: store i64 %[[DEC]]
3031
// CHECK-NOT: br
3132
// CHECK: %[[VAL:.+]] = load i16

0 commit comments

Comments
 (0)