Skip to content

Commit 7ec5b1e

Browse files
feat: implement sync::Barrier
Based on the implementation in tokio-rs/tokio#1571
1 parent 785371c commit 7ec5b1e

File tree

4 files changed

+230
-0
lines changed

4 files changed

+230
-0
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ num_cpus = "1.10.1"
4242
pin-utils = "0.1.0-alpha.4"
4343
slab = "0.4.2"
4444
kv-log-macro = "1.0.4"
45+
broadcaster = "0.2.4"
4546

4647
[dev-dependencies]
4748
femme = "1.2.0"

src/sync/barrier.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
use crate::sync::Mutex;
2+
use broadcaster::BroadcastChannel;
3+
4+
/// A barrier enables multiple threads to synchronize the beginning of some computation.
5+
///
6+
/// ```
7+
/// # fn main() { async_std::task::block_on(async {
8+
/// #
9+
/// use std::sync::Arc;
10+
/// use async_std::sync::Barrier;
11+
/// use futures_util::future::join_all;
12+
///
13+
/// let mut handles = Vec::with_capacity(10);
14+
/// let barrier = Arc::new(Barrier::new(10));
15+
/// for _ in 0..10 {
16+
/// let c = barrier.clone();
17+
/// // The same messages will be printed together.
18+
/// // You will NOT see any interleaving.
19+
/// handles.push(async move {
20+
/// println!("before wait");
21+
/// let wr = c.wait().await;
22+
/// println!("after wait");
23+
/// wr
24+
/// });
25+
/// }
26+
/// // Will not resolve until all "before wait" messages have been printed
27+
/// let wrs = join_all(handles).await;
28+
/// // Exactly one barrier will resolve as the "leader"
29+
/// assert_eq!(wrs.into_iter().filter(|wr| wr.is_leader()).count(), 1);
30+
/// # });
31+
/// # }
32+
/// ```
33+
// #[derive(Debug)]
34+
pub struct Barrier {
35+
state: Mutex<BarrierState>,
36+
wait: BroadcastChannel<usize>,
37+
n: usize,
38+
}
39+
40+
// #[derive(Debug)]
41+
struct BarrierState {
42+
waker: BroadcastChannel<usize>,
43+
arrived: usize,
44+
generation: usize,
45+
}
46+
47+
impl Barrier {
48+
/// Creates a new barrier that can block a given number of threads.
49+
///
50+
/// A barrier will block `n`-1 threads which call [`Barrier::wait`] and then wake up all
51+
/// threads at once when the `n`th thread calls `wait`.
52+
pub fn new(mut n: usize) -> Barrier {
53+
let waker = BroadcastChannel::new();
54+
let wait = waker.clone();
55+
56+
if n == 0 {
57+
// if n is 0, it's not clear what behavior the user wants.
58+
// in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
59+
// .wait() immediately unblocks, so we adopt that here as well.
60+
n = 1;
61+
}
62+
63+
Barrier {
64+
state: Mutex::new(BarrierState {
65+
waker,
66+
arrived: 0,
67+
generation: 1,
68+
}),
69+
n,
70+
wait,
71+
}
72+
}
73+
74+
/// Does not resolve until all tasks have rendezvoused here.
75+
///
76+
/// Barriers are re-usable after all threads have rendezvoused once, and can
77+
/// be used continuously.
78+
///
79+
/// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from
80+
/// [`BarrierWaitResult::is_leader`] when returning from this function, and all other threads
81+
/// will receive a result that will return `false` from `is_leader`.
82+
pub async fn wait(&self) -> BarrierWaitResult {
83+
// NOTE: The implementation in tokio use a sync mutex, decide if we agree with their reasoning or
84+
// not.
85+
let mut state = self.state.lock().await;
86+
let generation = state.generation;
87+
state.arrived += 1;
88+
if state.arrived == self.n {
89+
// we are the leader for this generation
90+
// wake everyone, increment the generation, and return
91+
state
92+
.waker
93+
.send(&state.generation)
94+
.await
95+
.expect("there is at least one receiver");
96+
state.arrived = 0;
97+
state.generation += 1;
98+
return BarrierWaitResult(true);
99+
}
100+
101+
drop(state);
102+
103+
// we're going to have to wait for the last of the generation to arrive
104+
let mut wait = self.wait.clone();
105+
106+
loop {
107+
// note that the first time through the loop, this _will_ yield a generation
108+
// immediately, since we cloned a receiver that has never seen any values.
109+
if wait.recv().await.expect("sender hasn't been closed") >= generation {
110+
break;
111+
}
112+
}
113+
114+
BarrierWaitResult(false)
115+
}
116+
}
117+
118+
/// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused.
119+
#[derive(Debug, Clone)]
120+
pub struct BarrierWaitResult(bool);
121+
122+
impl BarrierWaitResult {
123+
/// Returns true if this thread from wait is the "leader thread".
124+
///
125+
/// Only one thread will have `true` returned from their result, all other threads will have
126+
/// `false` returned.
127+
pub fn is_leader(&self) -> bool {
128+
self.0
129+
}
130+
}

src/sync/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@
3232
#[doc(inline)]
3333
pub use std::sync::{Arc, Weak};
3434

35+
pub use barrier::{Barrier, BarrierWaitResult};
3536
pub use mutex::{Mutex, MutexGuard};
3637
pub use rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard};
3738

39+
mod barrier;
3840
mod mutex;
3941
mod rwlock;

tests/barrier.rs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
use std::sync::Arc;
2+
3+
use async_std::sync::Barrier;
4+
use async_std::task;
5+
6+
#[test]
7+
fn barrier_zero_does_not_block() {
8+
let b = Arc::new(Barrier::new(0));
9+
10+
let c = b.clone();
11+
task::block_on(async move {
12+
let w = task::spawn(async move { c.wait().await });
13+
let wr = w.await;
14+
assert!(wr.is_leader());
15+
});
16+
17+
let c = b.clone();
18+
task::block_on(async move {
19+
let w = task::spawn(async move { c.wait().await });
20+
let wr = w.await;
21+
assert!(wr.is_leader());
22+
});
23+
}
24+
25+
#[test]
26+
fn barrier_single() {
27+
let b = Arc::new(Barrier::new(1));
28+
29+
let c = b.clone();
30+
task::block_on(async move {
31+
let w = task::spawn(async move { c.wait().await });
32+
let wr = w.await;
33+
assert!(wr.is_leader());
34+
});
35+
let c = b.clone();
36+
task::block_on(async move {
37+
let w = task::spawn(async move { c.wait().await });
38+
let wr = w.await;
39+
assert!(wr.is_leader());
40+
});
41+
let c = b.clone();
42+
task::block_on(async move {
43+
let w = task::spawn(async move { c.wait().await });
44+
let wr = w.await;
45+
assert!(wr.is_leader());
46+
});
47+
}
48+
49+
#[test]
50+
fn barrier_tango() {
51+
let b = Arc::new(Barrier::new(2));
52+
53+
task::block_on(async move {
54+
let c = b.clone();
55+
let w1 = task::spawn(async move { c.wait().await });
56+
57+
let c = b.clone();
58+
let w2 = task::spawn(async move { c.wait().await });
59+
60+
let ws = futures_util::future::join_all(vec![w1, w2]).await;
61+
let wr1 = &ws[0];
62+
let wr2 = &ws[1];
63+
64+
assert!(wr1.is_leader() || wr2.is_leader());
65+
assert!(!(wr1.is_leader() && wr2.is_leader()));
66+
});
67+
}
68+
69+
#[test]
70+
fn barrier_lots() {
71+
let b = Arc::new(Barrier::new(100));
72+
73+
task::block_on(async move {
74+
for _ in 0..10 {
75+
let mut wait = Vec::new();
76+
for _ in 0..99 {
77+
let c = b.clone();
78+
let w = task::spawn(async move { c.wait().await });
79+
wait.push(w);
80+
}
81+
82+
// pass the barrier
83+
let c = b.clone();
84+
let w = task::spawn(async move { c.wait().await });
85+
86+
let mut found_leader = w.await.is_leader();
87+
for w in wait {
88+
let wr = w.await;
89+
if wr.is_leader() {
90+
assert!(!found_leader);
91+
found_leader = true;
92+
}
93+
}
94+
assert!(found_leader);
95+
}
96+
});
97+
}

0 commit comments

Comments
 (0)