Skip to content

Commit 36af98a

Browse files
refactor, and fix subtle bug
1 parent 7ec5b1e commit 36af98a

File tree

2 files changed

+140
-141
lines changed

2 files changed

+140
-141
lines changed

src/sync/barrier.rs

Lines changed: 103 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,83 @@
1-
use crate::sync::Mutex;
21
use broadcaster::BroadcastChannel;
32

4-
/// A barrier enables multiple threads to synchronize the beginning of some computation.
3+
use crate::sync::Mutex;
4+
5+
/// A barrier enables multiple tasks to synchronize the beginning
6+
/// of some computation.
57
///
68
/// ```
79
/// # fn main() { async_std::task::block_on(async {
810
/// #
911
/// use std::sync::Arc;
1012
/// use async_std::sync::Barrier;
11-
/// use futures_util::future::join_all;
13+
/// use async_std::task;
1214
///
1315
/// let mut handles = Vec::with_capacity(10);
1416
/// let barrier = Arc::new(Barrier::new(10));
1517
/// for _ in 0..10 {
1618
/// let c = barrier.clone();
1719
/// // The same messages will be printed together.
1820
/// // You will NOT see any interleaving.
19-
/// handles.push(async move {
21+
/// handles.push(task::spawn(async move {
2022
/// println!("before wait");
2123
/// let wr = c.wait().await;
2224
/// println!("after wait");
2325
/// wr
24-
/// });
26+
/// }));
27+
/// }
28+
/// // Wait for the other futures to finish.
29+
/// for handle in handles {
30+
/// handle.await;
2531
/// }
26-
/// // Will not resolve until all "before wait" messages have been printed
27-
/// let wrs = join_all(handles).await;
28-
/// // Exactly one barrier will resolve as the "leader"
29-
/// assert_eq!(wrs.into_iter().filter(|wr| wr.is_leader()).count(), 1);
3032
/// # });
3133
/// # }
3234
/// ```
33-
// #[derive(Debug)]
35+
#[derive(Debug)]
3436
pub struct Barrier {
3537
state: Mutex<BarrierState>,
36-
wait: BroadcastChannel<usize>,
38+
wait: BroadcastChannel<(usize, usize)>,
3739
n: usize,
3840
}
3941

40-
// #[derive(Debug)]
42+
// The inner state of a double barrier
43+
#[derive(Debug)]
4144
struct BarrierState {
42-
waker: BroadcastChannel<usize>,
43-
arrived: usize,
44-
generation: usize,
45+
waker: BroadcastChannel<(usize, usize)>,
46+
count: usize,
47+
generation_id: usize,
4548
}
4649

50+
/// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused.
51+
///
52+
/// [`wait`]: struct.Barrier.html#method.wait
53+
/// [`Barrier`]: struct.Barrier.html
54+
///
55+
/// # Examples
56+
///
57+
/// ```
58+
/// use async_std::sync::Barrier;
59+
///
60+
/// let barrier = Barrier::new(1);
61+
/// let barrier_wait_result = barrier.wait();
62+
/// ```
63+
#[derive(Debug, Clone)]
64+
pub struct BarrierWaitResult(bool);
65+
4766
impl Barrier {
48-
/// Creates a new barrier that can block a given number of threads.
67+
/// Creates a new barrier that can block a given number of tasks.
68+
///
69+
/// A barrier will block `n`-1 tasks which call [`wait`] and then wake up
70+
/// all tasks at once when the `n`th task calls [`wait`].
71+
///
72+
/// [`wait`]: #method.wait
73+
///
74+
/// # Examples
75+
///
76+
/// ```
77+
/// use std::sync::Barrier;
4978
///
50-
/// A barrier will block `n`-1 threads which call [`Barrier::wait`] and then wake up all
51-
/// threads at once when the `n`th thread calls `wait`.
79+
/// let barrier = Barrier::new(10);
80+
/// ```
5281
pub fn new(mut n: usize) -> Barrier {
5382
let waker = BroadcastChannel::new();
5483
let wait = waker.clone();
@@ -63,67 +92,82 @@ impl Barrier {
6392
Barrier {
6493
state: Mutex::new(BarrierState {
6594
waker,
66-
arrived: 0,
67-
generation: 1,
95+
count: 0,
96+
generation_id: 1,
6897
}),
6998
n,
7099
wait,
71100
}
72101
}
73102

74-
/// Does not resolve until all tasks have rendezvoused here.
103+
/// Blocks the current task until all tasks have rendezvoused here.
75104
///
76-
/// Barriers are re-usable after all threads have rendezvoused once, and can
105+
/// Barriers are re-usable after all tasks have rendezvoused once, and can
77106
/// be used continuously.
78107
///
79-
/// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from
80-
/// [`BarrierWaitResult::is_leader`] when returning from this function, and all other threads
81-
/// will receive a result that will return `false` from `is_leader`.
108+
/// A single (arbitrary) task will receive a [`BarrierWaitResult`] that
109+
/// returns `true` from [`is_leader`] when returning from this function, and
110+
/// all other tasks will receive a result that will return `false` from
111+
/// [`is_leader`].
112+
///
113+
/// [`BarrierWaitResult`]: struct.BarrierWaitResult.html
114+
/// [`is_leader`]: struct.BarrierWaitResult.html#method.is_leader
82115
pub async fn wait(&self) -> BarrierWaitResult {
83-
// NOTE: The implementation in tokio use a sync mutex, decide if we agree with their reasoning or
84-
// not.
85-
let mut state = self.state.lock().await;
86-
let generation = state.generation;
87-
state.arrived += 1;
88-
if state.arrived == self.n {
89-
// we are the leader for this generation
90-
// wake everyone, increment the generation, and return
91-
state
92-
.waker
93-
.send(&state.generation)
94-
.await
95-
.expect("there is at least one receiver");
96-
state.arrived = 0;
97-
state.generation += 1;
98-
return BarrierWaitResult(true);
99-
}
116+
let mut lock = self.state.lock().await;
117+
let local_gen = lock.generation_id;
118+
119+
lock.count += 1;
120+
121+
if lock.count < self.n {
122+
let mut wait = self.wait.clone();
100123

101-
drop(state);
124+
let mut generation_id = lock.generation_id;
125+
let mut count = lock.count;
102126

103-
// we're going to have to wait for the last of the generation to arrive
104-
let mut wait = self.wait.clone();
127+
drop(lock);
105128

106-
loop {
107-
// note that the first time through the loop, this _will_ yield a generation
108-
// immediately, since we cloned a receiver that has never seen any values.
109-
if wait.recv().await.expect("sender hasn't been closed") >= generation {
110-
break;
129+
while local_gen == generation_id && count < self.n {
130+
let (g, c) = wait.recv().await.expect("sender hasn not been closed");
131+
generation_id = g;
132+
count = c;
111133
}
112-
}
113134

114-
BarrierWaitResult(false)
135+
BarrierWaitResult(false)
136+
} else {
137+
lock.count = 0;
138+
lock.generation_id = lock.generation_id.wrapping_add(1);
139+
140+
lock.waker
141+
.send(&(lock.generation_id, lock.count))
142+
.await
143+
.expect("there should be at least one receiver");
144+
145+
BarrierWaitResult(true)
146+
}
115147
}
116148
}
117149

118-
/// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused.
119-
#[derive(Debug, Clone)]
120-
pub struct BarrierWaitResult(bool);
121-
122150
impl BarrierWaitResult {
123-
/// Returns true if this thread from wait is the "leader thread".
151+
/// Returns `true` if this task from [`wait`] is the "leader task".
152+
///
153+
/// Only one task will have `true` returned from their result, all other
154+
/// tasks will have `false` returned.
155+
///
156+
/// [`wait`]: struct.Barrier.html#method.wait
157+
///
158+
/// # Examples
159+
///
160+
/// ```
161+
/// # fn main() { async_std::task::block_on(async {
162+
/// #
163+
/// use async_std::sync::Barrier;
124164
///
125-
/// Only one thread will have `true` returned from their result, all other threads will have
126-
/// `false` returned.
165+
/// let barrier = Barrier::new(1);
166+
/// let barrier_wait_result = barrier.wait().await;
167+
/// println!("{:?}", barrier_wait_result.is_leader());
168+
/// # });
169+
/// # }
170+
/// ```
127171
pub fn is_leader(&self) -> bool {
128172
self.0
129173
}

tests/barrier.rs

Lines changed: 37 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,52 @@
11
use std::sync::Arc;
22

3+
use futures_channel::mpsc::unbounded;
4+
use futures_util::sink::SinkExt;
5+
use futures_util::stream::StreamExt;
6+
37
use async_std::sync::Barrier;
48
use async_std::task;
59

610
#[test]
7-
fn barrier_zero_does_not_block() {
8-
let b = Arc::new(Barrier::new(0));
9-
10-
let c = b.clone();
11-
task::block_on(async move {
12-
let w = task::spawn(async move { c.wait().await });
13-
let wr = w.await;
14-
assert!(wr.is_leader());
15-
});
16-
17-
let c = b.clone();
18-
task::block_on(async move {
19-
let w = task::spawn(async move { c.wait().await });
20-
let wr = w.await;
21-
assert!(wr.is_leader());
22-
});
23-
}
24-
25-
#[test]
26-
fn barrier_single() {
27-
let b = Arc::new(Barrier::new(1));
11+
fn test_barrier() {
12+
// Based on the test in std, I was seeing some race conditions, so running it in a loop to make sure
13+
// things are solid.
2814

29-
let c = b.clone();
30-
task::block_on(async move {
31-
let w = task::spawn(async move { c.wait().await });
32-
let wr = w.await;
33-
assert!(wr.is_leader());
34-
});
35-
let c = b.clone();
36-
task::block_on(async move {
37-
let w = task::spawn(async move { c.wait().await });
38-
let wr = w.await;
39-
assert!(wr.is_leader());
40-
});
41-
let c = b.clone();
42-
task::block_on(async move {
43-
let w = task::spawn(async move { c.wait().await });
44-
let wr = w.await;
45-
assert!(wr.is_leader());
46-
});
47-
}
15+
for _ in 0..1_000 {
16+
task::block_on(async move {
17+
const N: usize = 10;
4818

49-
#[test]
50-
fn barrier_tango() {
51-
let b = Arc::new(Barrier::new(2));
19+
let barrier = Arc::new(Barrier::new(N));
20+
let (tx, mut rx) = unbounded();
5221

53-
task::block_on(async move {
54-
let c = b.clone();
55-
let w1 = task::spawn(async move { c.wait().await });
22+
for _ in 0..N - 1 {
23+
let c = barrier.clone();
24+
let mut tx = tx.clone();
25+
task::spawn(async move {
26+
let res = c.wait().await;
5627

57-
let c = b.clone();
58-
let w2 = task::spawn(async move { c.wait().await });
59-
60-
let ws = futures_util::future::join_all(vec![w1, w2]).await;
61-
let wr1 = &ws[0];
62-
let wr2 = &ws[1];
63-
64-
assert!(wr1.is_leader() || wr2.is_leader());
65-
assert!(!(wr1.is_leader() && wr2.is_leader()));
66-
});
67-
}
68-
69-
#[test]
70-
fn barrier_lots() {
71-
let b = Arc::new(Barrier::new(100));
72-
73-
task::block_on(async move {
74-
for _ in 0..10 {
75-
let mut wait = Vec::new();
76-
for _ in 0..99 {
77-
let c = b.clone();
78-
let w = task::spawn(async move { c.wait().await });
79-
wait.push(w);
28+
tx.send(res.is_leader()).await.unwrap();
29+
});
8030
}
8131

82-
// pass the barrier
83-
let c = b.clone();
84-
let w = task::spawn(async move { c.wait().await });
85-
86-
let mut found_leader = w.await.is_leader();
87-
for w in wait {
88-
let wr = w.await;
89-
if wr.is_leader() {
90-
assert!(!found_leader);
91-
found_leader = true;
32+
// At this point, all spawned threads should be blocked,
33+
// so we shouldn't get anything from the port
34+
let res = rx.try_next();
35+
assert!(match res {
36+
Err(_err) => true,
37+
_ => false,
38+
});
39+
40+
let mut leader_found = barrier.wait().await.is_leader();
41+
42+
// Now, the barrier is cleared and we should get data.
43+
for _ in 0..N - 1 {
44+
if rx.next().await.unwrap() {
45+
assert!(!leader_found);
46+
leader_found = true;
9247
}
9348
}
94-
assert!(found_leader);
95-
}
96-
});
49+
assert!(leader_found);
50+
});
51+
}
9752
}

0 commit comments

Comments
 (0)