Skip to content

Commit b43809a

Browse files
committed
Use state of Waker instead of AtomicUsize to keep track of if task was
notified.
1 parent cf6bb54 commit b43809a

File tree

2 files changed

+26
-36
lines changed

2 files changed

+26
-36
lines changed

src/sync/condvar.rs

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
use std::pin::Pin;
2-
use std::sync::atomic::{AtomicUsize, Ordering};
3-
use std::sync::Arc;
42
use std::time::Duration;
53

64
use futures_timer::Delay;
@@ -61,19 +59,7 @@ impl WaitTimeoutResult {
6159
/// ```
6260
#[derive(Debug)]
6361
pub struct Condvar {
64-
blocked: std::sync::Mutex<Slab<WaitEntry>>,
65-
}
66-
67-
/// Flag to mark if the task was notified
68-
const NOTIFIED: usize = 1;
69-
/// State if the task was notified with `notify_once`
70-
/// so it should notify another task if the future is dropped without waking.
71-
const NOTIFIED_ONCE: usize = 0b11;
72-
73-
#[derive(Debug)]
74-
struct WaitEntry {
75-
state: Arc<AtomicUsize>,
76-
waker: Option<Waker>,
62+
blocked: std::sync::Mutex<Slab<Option<Waker>>>,
7763
}
7864

7965
impl Condvar {
@@ -137,8 +123,8 @@ impl Condvar {
137123
AwaitNotify {
138124
cond: self,
139125
guard: Some(guard),
140-
state: Arc::new(AtomicUsize::new(0)),
141126
key: None,
127+
notified: false,
142128
}
143129
}
144130

@@ -316,11 +302,9 @@ impl Condvar {
316302
}
317303

318304
#[inline]
319-
fn notify(mut blocked: std::sync::MutexGuard<'_, Slab<WaitEntry>>, all: bool) {
320-
let state = if all { NOTIFIED } else { NOTIFIED_ONCE };
305+
fn notify(mut blocked: std::sync::MutexGuard<'_, Slab<Option<Waker>>>, all: bool) {
321306
for (_, entry) in blocked.iter_mut() {
322-
if let Some(w) = entry.waker.take() {
323-
entry.state.store(state, Ordering::Release);
307+
if let Some(w) = entry.take() {
324308
w.wake();
325309
if !all {
326310
return;
@@ -332,8 +316,8 @@ fn notify(mut blocked: std::sync::MutexGuard<'_, Slab<WaitEntry>>, all: bool) {
332316
struct AwaitNotify<'a, 'b, T> {
333317
cond: &'a Condvar,
334318
guard: Option<MutexGuard<'b, T>>,
335-
state: Arc<AtomicUsize>,
336319
key: Option<usize>,
320+
notified: bool,
337321
}
338322

339323
impl<'a, 'b, T> Future for AwaitNotify<'a, 'b, T> {
@@ -344,20 +328,14 @@ impl<'a, 'b, T> Future for AwaitNotify<'a, 'b, T> {
344328
Some(_) => {
345329
let mut blocked = self.cond.blocked.lock().unwrap();
346330
let w = cx.waker().clone();
347-
self.key = Some(blocked.insert(WaitEntry {
348-
state: self.state.clone(),
349-
waker: Some(w),
350-
}));
331+
self.key = Some(blocked.insert(Some(w)));
351332

352333
// the guard is dropped when we return, which frees the lock
353334
Poll::Pending
354335
}
355336
None => {
356-
if self.state.fetch_and(!NOTIFIED, Ordering::AcqRel) & NOTIFIED != 0 {
357-
Poll::Ready(())
358-
} else {
359-
Poll::Pending
360-
}
337+
self.notified = true;
338+
Poll::Ready(())
361339
}
362340
}
363341
}
@@ -367,12 +345,14 @@ impl<'a, 'b, T> Drop for AwaitNotify<'a, 'b, T> {
367345
fn drop(&mut self) {
368346
if let Some(key) = self.key {
369347
let mut blocked = self.cond.blocked.lock().unwrap();
370-
blocked.remove(key);
348+
let opt_waker = blocked.remove(key);
371349

372-
if !blocked.is_empty() && self.state.load(Ordering::Acquire) == NOTIFIED_ONCE {
373-
// we got a notification form notify_once but didn't handle it,
374-
// so send it to a different task
350+
if opt_waker.is_none() && !self.notified {
351+
// wake up the next task, because this task was notified, but
352+
// we are dropping it before it can finished.
353+
// This may result in a spurious wake-up, but that's ok.
375354
notify(blocked, false);
355+
376356
}
377357
}
378358
}

tests/condvar.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,21 @@ use async_std::task;
77
#[test]
88
fn wait_timeout() {
99
task::block_on(async {
10-
let m = Mutex::new(());
11-
let c = Condvar::new();
10+
let pair = Arc::new((Mutex::new(false), Condvar::new()));
11+
let pair2 = pair.clone();
12+
13+
task::spawn(async move {
14+
let (m, c) = &*pair2;
15+
let _g = m.lock().await;
16+
task::sleep(Duration::from_millis(20)).await;
17+
c.notify_one();
18+
});
19+
20+
let (m, c) = &*pair;
1221
let (_, wait_result) = c
1322
.wait_timeout(m.lock().await, Duration::from_millis(10))
1423
.await;
1524
assert!(wait_result.timed_out());
1625
})
1726
}
27+

0 commit comments

Comments
 (0)