1
- use crate :: sync:: Mutex ;
2
1
use broadcaster:: BroadcastChannel ;
3
2
4
- /// A barrier enables multiple threads to synchronize the beginning of some computation.
3
+ use crate :: sync:: Mutex ;
4
+
5
+ /// A barrier enables multiple tasks to synchronize the beginning
6
+ /// of some computation.
5
7
///
6
8
/// ```
7
9
/// # fn main() { async_std::task::block_on(async {
8
10
/// #
9
11
/// use std::sync::Arc;
10
12
/// use async_std::sync::Barrier;
11
- /// use futures_util::future::join_all ;
13
+ /// use async_std::task ;
12
14
///
13
15
/// let mut handles = Vec::with_capacity(10);
14
16
/// let barrier = Arc::new(Barrier::new(10));
15
17
/// for _ in 0..10 {
16
18
/// let c = barrier.clone();
17
19
/// // The same messages will be printed together.
18
20
/// // You will NOT see any interleaving.
19
- /// handles.push(async move {
21
+ /// handles.push(task::spawn( async move {
20
22
/// println!("before wait");
21
23
/// let wr = c.wait().await;
22
24
/// println!("after wait");
23
25
/// wr
24
- /// });
26
+ /// }));
27
+ /// }
28
+ /// // Wait for the other futures to finish.
29
+ /// for handle in handles {
30
+ /// handle.await;
25
31
/// }
26
- /// // Will not resolve until all "before wait" messages have been printed
27
- /// let wrs = join_all(handles).await;
28
- /// // Exactly one barrier will resolve as the "leader"
29
- /// assert_eq!(wrs.into_iter().filter(|wr| wr.is_leader()).count(), 1);
30
32
/// # });
31
33
/// # }
32
34
/// ```
33
- // #[derive(Debug)]
35
+ #[ derive( Debug ) ]
34
36
pub struct Barrier {
35
37
state : Mutex < BarrierState > ,
36
- wait : BroadcastChannel < usize > ,
38
+ wait : BroadcastChannel < ( usize , usize ) > ,
37
39
n : usize ,
38
40
}
39
41
40
- // #[derive(Debug)]
42
+ // The inner state of a double barrier
43
+ #[ derive( Debug ) ]
41
44
struct BarrierState {
42
- waker : BroadcastChannel < usize > ,
43
- arrived : usize ,
44
- generation : usize ,
45
+ waker : BroadcastChannel < ( usize , usize ) > ,
46
+ count : usize ,
47
+ generation_id : usize ,
45
48
}
46
49
50
+ /// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused.
51
+ ///
52
+ /// [`wait`]: struct.Barrier.html#method.wait
53
+ /// [`Barrier`]: struct.Barrier.html
54
+ ///
55
+ /// # Examples
56
+ ///
57
+ /// ```
58
+ /// use async_std::sync::Barrier;
59
+ ///
60
+ /// let barrier = Barrier::new(1);
61
+ /// let barrier_wait_result = barrier.wait();
62
+ /// ```
63
+ #[ derive( Debug , Clone ) ]
64
+ pub struct BarrierWaitResult ( bool ) ;
65
+
47
66
impl Barrier {
48
- /// Creates a new barrier that can block a given number of threads.
67
+ /// Creates a new barrier that can block a given number of tasks.
68
+ ///
69
+ /// A barrier will block `n`-1 tasks which call [`wait`] and then wake up
70
+ /// all tasks at once when the `n`th task calls [`wait`].
71
+ ///
72
+ /// [`wait`]: #method.wait
73
+ ///
74
+ /// # Examples
75
+ ///
76
+ /// ```
77
+ /// use std::sync::Barrier;
49
78
///
50
- /// A barrier will block `n`-1 threads which call [` Barrier::wait`] and then wake up all
51
- /// threads at once when the `n`th thread calls `wait`.
79
+ /// let barrier = Barrier::new(10);
80
+ /// ```
52
81
pub fn new ( mut n : usize ) -> Barrier {
53
82
let waker = BroadcastChannel :: new ( ) ;
54
83
let wait = waker. clone ( ) ;
@@ -63,67 +92,82 @@ impl Barrier {
63
92
Barrier {
64
93
state : Mutex :: new ( BarrierState {
65
94
waker,
66
- arrived : 0 ,
67
- generation : 1 ,
95
+ count : 0 ,
96
+ generation_id : 1 ,
68
97
} ) ,
69
98
n,
70
99
wait,
71
100
}
72
101
}
73
102
74
- /// Does not resolve until all tasks have rendezvoused here.
103
+ /// Blocks the current task until all tasks have rendezvoused here.
75
104
///
76
- /// Barriers are re-usable after all threads have rendezvoused once, and can
105
+ /// Barriers are re-usable after all tasks have rendezvoused once, and can
77
106
/// be used continuously.
78
107
///
79
- /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from
80
- /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other threads
81
- /// will receive a result that will return `false` from `is_leader`.
108
+ /// A single (arbitrary) task will receive a [`BarrierWaitResult`] that
109
+ /// returns `true` from [`is_leader`] when returning from this function, and
110
+ /// all other tasks will receive a result that will return `false` from
111
+ /// [`is_leader`].
112
+ ///
113
+ /// [`BarrierWaitResult`]: struct.BarrierWaitResult.html
114
+ /// [`is_leader`]: struct.BarrierWaitResult.html#method.is_leader
82
115
pub async fn wait ( & self ) -> BarrierWaitResult {
83
- // NOTE: The implementation in tokio use a sync mutex, decide if we agree with their reasoning or
84
- // not.
85
- let mut state = self . state . lock ( ) . await ;
86
- let generation = state. generation ;
87
- state. arrived += 1 ;
88
- if state. arrived == self . n {
89
- // we are the leader for this generation
90
- // wake everyone, increment the generation, and return
91
- state
92
- . waker
93
- . send ( & state. generation )
94
- . await
95
- . expect ( "there is at least one receiver" ) ;
96
- state. arrived = 0 ;
97
- state. generation += 1 ;
98
- return BarrierWaitResult ( true ) ;
99
- }
116
+ let mut lock = self . state . lock ( ) . await ;
117
+ let local_gen = lock. generation_id ;
118
+
119
+ lock. count += 1 ;
120
+
121
+ if lock. count < self . n {
122
+ let mut wait = self . wait . clone ( ) ;
100
123
101
- drop ( state) ;
124
+ let mut generation_id = lock. generation_id ;
125
+ let mut count = lock. count ;
102
126
103
- // we're going to have to wait for the last of the generation to arrive
104
- let mut wait = self . wait . clone ( ) ;
127
+ drop ( lock) ;
105
128
106
- loop {
107
- // note that the first time through the loop, this _will_ yield a generation
108
- // immediately, since we cloned a receiver that has never seen any values.
109
- if wait. recv ( ) . await . expect ( "sender hasn't been closed" ) >= generation {
110
- break ;
129
+ while local_gen == generation_id && count < self . n {
130
+ let ( g, c) = wait. recv ( ) . await . expect ( "sender hasn not been closed" ) ;
131
+ generation_id = g;
132
+ count = c;
111
133
}
112
- }
113
134
114
- BarrierWaitResult ( false )
135
+ BarrierWaitResult ( false )
136
+ } else {
137
+ lock. count = 0 ;
138
+ lock. generation_id = lock. generation_id . wrapping_add ( 1 ) ;
139
+
140
+ lock. waker
141
+ . send ( & ( lock. generation_id , lock. count ) )
142
+ . await
143
+ . expect ( "there should be at least one receiver" ) ;
144
+
145
+ BarrierWaitResult ( true )
146
+ }
115
147
}
116
148
}
117
149
118
- /// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused.
119
- #[ derive( Debug , Clone ) ]
120
- pub struct BarrierWaitResult ( bool ) ;
121
-
122
150
impl BarrierWaitResult {
123
- /// Returns true if this thread from wait is the "leader thread".
151
+ /// Returns `true` if this task from [`wait`] is the "leader task".
152
+ ///
153
+ /// Only one task will have `true` returned from their result, all other
154
+ /// tasks will have `false` returned.
155
+ ///
156
+ /// [`wait`]: struct.Barrier.html#method.wait
157
+ ///
158
+ /// # Examples
159
+ ///
160
+ /// ```
161
+ /// # fn main() { async_std::task::block_on(async {
162
+ /// #
163
+ /// use async_std::sync::Barrier;
124
164
///
125
- /// Only one thread will have `true` returned from their result, all other threads will have
126
- /// `false` returned.
165
+ /// let barrier = Barrier::new(1);
166
+ /// let barrier_wait_result = barrier.wait().await;
167
+ /// println!("{:?}", barrier_wait_result.is_leader());
168
+ /// # });
169
+ /// # }
170
+ /// ```
127
171
pub fn is_leader ( & self ) -> bool {
128
172
self . 0
129
173
}
0 commit comments