Skip to content

Commit cf6bb54

Browse files
committed
More rigourous detection of notification for condvar
1 parent 83688e2 commit cf6bb54

File tree

2 files changed

+69
-29
lines changed

2 files changed

+69
-29
lines changed

src/sync/condvar.rs

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::pin::Pin;
2-
use std::sync::atomic::{AtomicBool, Ordering};
2+
use std::sync::atomic::{AtomicUsize, Ordering};
3+
use std::sync::Arc;
34
use std::time::Duration;
45

56
use futures_timer::Delay;
@@ -60,8 +61,19 @@ impl WaitTimeoutResult {
6061
/// ```
6162
#[derive(Debug)]
6263
pub struct Condvar {
63-
has_blocked: AtomicBool,
64-
blocked: std::sync::Mutex<Slab<Option<Waker>>>,
64+
blocked: std::sync::Mutex<Slab<WaitEntry>>,
65+
}
66+
67+
/// Flag to mark if the task was notified
68+
const NOTIFIED: usize = 1;
69+
/// State if the task was notified with `notify_once`
70+
/// so it should notify another task if the future is dropped without waking.
71+
const NOTIFIED_ONCE: usize = 0b11;
72+
73+
#[derive(Debug)]
74+
struct WaitEntry {
75+
state: Arc<AtomicUsize>,
76+
waker: Option<Waker>,
6577
}
6678

6779
impl Condvar {
@@ -76,7 +88,6 @@ impl Condvar {
7688
/// ```
7789
pub fn new() -> Self {
7890
Condvar {
79-
has_blocked: AtomicBool::new(false),
8091
blocked: std::sync::Mutex::new(Slab::new()),
8192
}
8293
}
@@ -126,6 +137,7 @@ impl Condvar {
126137
AwaitNotify {
127138
cond: self,
128139
guard: Some(guard),
140+
state: Arc::new(AtomicUsize::new(0)),
129141
key: None,
130142
}
131143
}
@@ -261,14 +273,8 @@ impl Condvar {
261273
/// # }) }
262274
/// ```
263275
pub fn notify_one(&self) {
264-
if self.has_blocked.load(Ordering::Acquire) {
265-
let mut blocked = self.blocked.lock().unwrap();
266-
if let Some((_, opt_waker)) = blocked.iter_mut().next() {
267-
if let Some(w) = opt_waker.take() {
268-
w.wake();
269-
}
270-
}
271-
}
276+
let blocked = self.blocked.lock().unwrap();
277+
notify(blocked, false);
272278
}
273279

274280
/// Wakes up all blocked tasks on this condvar.
@@ -304,12 +310,20 @@ impl Condvar {
304310
/// # }) }
305311
/// ```
306312
pub fn notify_all(&self) {
307-
if self.has_blocked.load(Ordering::Acquire) {
308-
let mut blocked = self.blocked.lock().unwrap();
309-
for (_, opt_waker) in blocked.iter_mut() {
310-
if let Some(w) = opt_waker.take() {
311-
w.wake();
312-
}
313+
let blocked = self.blocked.lock().unwrap();
314+
notify(blocked, true);
315+
}
316+
}
317+
318+
#[inline]
319+
fn notify(mut blocked: std::sync::MutexGuard<'_, Slab<WaitEntry>>, all: bool) {
320+
let state = if all { NOTIFIED } else { NOTIFIED_ONCE };
321+
for (_, entry) in blocked.iter_mut() {
322+
if let Some(w) = entry.waker.take() {
323+
entry.state.store(state, Ordering::Release);
324+
w.wake();
325+
if !all {
326+
return;
313327
}
314328
}
315329
}
@@ -318,6 +332,7 @@ impl Condvar {
318332
struct AwaitNotify<'a, 'b, T> {
319333
cond: &'a Condvar,
320334
guard: Option<MutexGuard<'b, T>>,
335+
state: Arc<AtomicUsize>,
321336
key: Option<usize>,
322337
}
323338

@@ -329,15 +344,21 @@ impl<'a, 'b, T> Future for AwaitNotify<'a, 'b, T> {
329344
Some(_) => {
330345
let mut blocked = self.cond.blocked.lock().unwrap();
331346
let w = cx.waker().clone();
332-
self.key = Some(blocked.insert(Some(w)));
347+
self.key = Some(blocked.insert(WaitEntry {
348+
state: self.state.clone(),
349+
waker: Some(w),
350+
}));
333351

334-
if blocked.len() == 1 {
335-
self.cond.has_blocked.store(true, Ordering::Relaxed);
336-
}
337352
// the guard is dropped when we return, which frees the lock
338353
Poll::Pending
339354
}
340-
None => Poll::Ready(()),
355+
None => {
356+
if self.state.fetch_and(!NOTIFIED, Ordering::AcqRel) & NOTIFIED != 0 {
357+
Poll::Ready(())
358+
} else {
359+
Poll::Pending
360+
}
361+
}
341362
}
342363
}
343364
}
@@ -348,8 +369,10 @@ impl<'a, 'b, T> Drop for AwaitNotify<'a, 'b, T> {
348369
let mut blocked = self.cond.blocked.lock().unwrap();
349370
blocked.remove(key);
350371

351-
if blocked.is_empty() {
352-
self.cond.has_blocked.store(false, Ordering::Relaxed);
372+
if !blocked.is_empty() && self.state.load(Ordering::Acquire) == NOTIFIED_ONCE {
373+
// we got a notification form notify_once but didn't handle it,
374+
// so send it to a different task
375+
notify(blocked, false);
353376
}
354377
}
355378
}
@@ -369,12 +392,12 @@ impl<'a, 'b, T> Future for TimeoutWaitFuture<'a, 'b, T> {
369392
type Output = WaitTimeoutResult;
370393

371394
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
372-
match self.as_mut().await_notify().poll(cx) {
373-
Poll::Ready(_) => Poll::Ready(WaitTimeoutResult(false)),
374-
Poll::Pending => match self.delay().poll(cx) {
375-
Poll::Ready(_) => Poll::Ready(WaitTimeoutResult(true)),
395+
match self.as_mut().delay().poll(cx) {
396+
Poll::Pending => match self.await_notify().poll(cx) {
397+
Poll::Ready(_) => Poll::Ready(WaitTimeoutResult(false)),
376398
Poll::Pending => Poll::Pending,
377399
},
400+
Poll::Ready(_) => Poll::Ready(WaitTimeoutResult(true)),
378401
}
379402
}
380403
}

tests/condvar.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
use std::sync::Arc;
2+
use std::time::Duration;
3+
4+
use async_std::sync::{Condvar, Mutex};
5+
use async_std::task;
6+
7+
#[test]
8+
fn wait_timeout() {
9+
task::block_on(async {
10+
let m = Mutex::new(());
11+
let c = Condvar::new();
12+
let (_, wait_result) = c
13+
.wait_timeout(m.lock().await, Duration::from_millis(10))
14+
.await;
15+
assert!(wait_result.timed_out());
16+
})
17+
}

0 commit comments

Comments
 (0)