1
1
use std:: pin:: Pin ;
2
- use std:: sync:: atomic:: { AtomicBool , Ordering } ;
2
+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
3
+ use std:: sync:: Arc ;
3
4
use std:: time:: Duration ;
4
5
5
6
use futures_timer:: Delay ;
@@ -60,8 +61,19 @@ impl WaitTimeoutResult {
60
61
/// ```
61
62
#[ derive( Debug ) ]
62
63
pub struct Condvar {
63
- has_blocked : AtomicBool ,
64
- blocked : std:: sync:: Mutex < Slab < Option < Waker > > > ,
64
+ blocked : std:: sync:: Mutex < Slab < WaitEntry > > ,
65
+ }
66
+
67
+ /// Flag to mark if the task was notified
68
+ const NOTIFIED : usize = 1 ;
69
+ /// State if the task was notified with `notify_once`
70
+ /// so it should notify another task if the future is dropped without waking.
71
+ const NOTIFIED_ONCE : usize = 0b11 ;
72
+
73
+ #[ derive( Debug ) ]
74
+ struct WaitEntry {
75
+ state : Arc < AtomicUsize > ,
76
+ waker : Option < Waker > ,
65
77
}
66
78
67
79
impl Condvar {
@@ -76,7 +88,6 @@ impl Condvar {
76
88
/// ```
77
89
pub fn new ( ) -> Self {
78
90
Condvar {
79
- has_blocked : AtomicBool :: new ( false ) ,
80
91
blocked : std:: sync:: Mutex :: new ( Slab :: new ( ) ) ,
81
92
}
82
93
}
@@ -126,6 +137,7 @@ impl Condvar {
126
137
AwaitNotify {
127
138
cond : self ,
128
139
guard : Some ( guard) ,
140
+ state : Arc :: new ( AtomicUsize :: new ( 0 ) ) ,
129
141
key : None ,
130
142
}
131
143
}
@@ -261,14 +273,8 @@ impl Condvar {
261
273
/// # }) }
262
274
/// ```
263
275
pub fn notify_one ( & self ) {
264
- if self . has_blocked . load ( Ordering :: Acquire ) {
265
- let mut blocked = self . blocked . lock ( ) . unwrap ( ) ;
266
- if let Some ( ( _, opt_waker) ) = blocked. iter_mut ( ) . next ( ) {
267
- if let Some ( w) = opt_waker. take ( ) {
268
- w. wake ( ) ;
269
- }
270
- }
271
- }
276
+ let blocked = self . blocked . lock ( ) . unwrap ( ) ;
277
+ notify ( blocked, false ) ;
272
278
}
273
279
274
280
/// Wakes up all blocked tasks on this condvar.
@@ -304,12 +310,20 @@ impl Condvar {
304
310
/// # }) }
305
311
/// ```
306
312
pub fn notify_all ( & self ) {
307
- if self . has_blocked . load ( Ordering :: Acquire ) {
308
- let mut blocked = self . blocked . lock ( ) . unwrap ( ) ;
309
- for ( _, opt_waker) in blocked. iter_mut ( ) {
310
- if let Some ( w) = opt_waker. take ( ) {
311
- w. wake ( ) ;
312
- }
313
+ let blocked = self . blocked . lock ( ) . unwrap ( ) ;
314
+ notify ( blocked, true ) ;
315
+ }
316
+ }
317
+
318
+ #[ inline]
319
+ fn notify ( mut blocked : std:: sync:: MutexGuard < ' _ , Slab < WaitEntry > > , all : bool ) {
320
+ let state = if all { NOTIFIED } else { NOTIFIED_ONCE } ;
321
+ for ( _, entry) in blocked. iter_mut ( ) {
322
+ if let Some ( w) = entry. waker . take ( ) {
323
+ entry. state . store ( state, Ordering :: Release ) ;
324
+ w. wake ( ) ;
325
+ if !all {
326
+ return ;
313
327
}
314
328
}
315
329
}
@@ -318,6 +332,7 @@ impl Condvar {
318
332
struct AwaitNotify < ' a , ' b , T > {
319
333
cond : & ' a Condvar ,
320
334
guard : Option < MutexGuard < ' b , T > > ,
335
+ state : Arc < AtomicUsize > ,
321
336
key : Option < usize > ,
322
337
}
323
338
@@ -329,15 +344,21 @@ impl<'a, 'b, T> Future for AwaitNotify<'a, 'b, T> {
329
344
Some ( _) => {
330
345
let mut blocked = self . cond . blocked . lock ( ) . unwrap ( ) ;
331
346
let w = cx. waker ( ) . clone ( ) ;
332
- self . key = Some ( blocked. insert ( Some ( w) ) ) ;
347
+ self . key = Some ( blocked. insert ( WaitEntry {
348
+ state : self . state . clone ( ) ,
349
+ waker : Some ( w) ,
350
+ } ) ) ;
333
351
334
- if blocked. len ( ) == 1 {
335
- self . cond . has_blocked . store ( true , Ordering :: Relaxed ) ;
336
- }
337
352
// the guard is dropped when we return, which frees the lock
338
353
Poll :: Pending
339
354
}
340
- None => Poll :: Ready ( ( ) ) ,
355
+ None => {
356
+ if self . state . fetch_and ( !NOTIFIED , Ordering :: AcqRel ) & NOTIFIED != 0 {
357
+ Poll :: Ready ( ( ) )
358
+ } else {
359
+ Poll :: Pending
360
+ }
361
+ }
341
362
}
342
363
}
343
364
}
@@ -348,8 +369,10 @@ impl<'a, 'b, T> Drop for AwaitNotify<'a, 'b, T> {
348
369
let mut blocked = self . cond . blocked . lock ( ) . unwrap ( ) ;
349
370
blocked. remove ( key) ;
350
371
351
- if blocked. is_empty ( ) {
352
- self . cond . has_blocked . store ( false , Ordering :: Relaxed ) ;
372
+ if !blocked. is_empty ( ) && self . state . load ( Ordering :: Acquire ) == NOTIFIED_ONCE {
373
+ // we got a notification form notify_once but didn't handle it,
374
+ // so send it to a different task
375
+ notify ( blocked, false ) ;
353
376
}
354
377
}
355
378
}
@@ -369,12 +392,12 @@ impl<'a, 'b, T> Future for TimeoutWaitFuture<'a, 'b, T> {
369
392
type Output = WaitTimeoutResult ;
370
393
371
394
fn poll ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
372
- match self . as_mut ( ) . await_notify ( ) . poll ( cx) {
373
- Poll :: Ready ( _) => Poll :: Ready ( WaitTimeoutResult ( false ) ) ,
374
- Poll :: Pending => match self . delay ( ) . poll ( cx) {
375
- Poll :: Ready ( _) => Poll :: Ready ( WaitTimeoutResult ( true ) ) ,
395
+ match self . as_mut ( ) . delay ( ) . poll ( cx) {
396
+ Poll :: Pending => match self . await_notify ( ) . poll ( cx) {
397
+ Poll :: Ready ( _) => Poll :: Ready ( WaitTimeoutResult ( false ) ) ,
376
398
Poll :: Pending => Poll :: Pending ,
377
399
} ,
400
+ Poll :: Ready ( _) => Poll :: Ready ( WaitTimeoutResult ( true ) ) ,
378
401
}
379
402
}
380
403
}
0 commit comments