broker_tokio/sync/
barrier.rs

1use crate::sync::watch;
2
3use std::sync::Mutex;
4
5/// A barrier enables multiple threads to synchronize the beginning of some computation.
6///
7/// ```
8/// # #[tokio::main]
9/// # async fn main() {
10/// use tokio::sync::Barrier;
11///
12/// use futures::future::join_all;
13/// use std::sync::Arc;
14///
15/// let mut handles = Vec::with_capacity(10);
16/// let barrier = Arc::new(Barrier::new(10));
17/// for _ in 0..10 {
18///     let c = barrier.clone();
19///     // The same messages will be printed together.
20///     // You will NOT see any interleaving.
21///     handles.push(async move {
22///         println!("before wait");
23///         let wr = c.wait().await;
24///         println!("after wait");
25///         wr
26///     });
27/// }
28/// // Will not resolve until all "before wait" messages have been printed
29/// let wrs = join_all(handles).await;
30/// // Exactly one barrier will resolve as the "leader"
31/// assert_eq!(wrs.into_iter().filter(|wr| wr.is_leader()).count(), 1);
32/// # }
33/// ```
34#[derive(Debug)]
35pub struct Barrier {
36    state: Mutex<BarrierState>,
37    wait: watch::Receiver<usize>,
38    n: usize,
39}
40
41#[derive(Debug)]
42struct BarrierState {
43    waker: watch::Sender<usize>,
44    arrived: usize,
45    generation: usize,
46}
47
48impl Barrier {
49    /// Creates a new barrier that can block a given number of threads.
50    ///
51    /// A barrier will block `n`-1 threads which call [`Barrier::wait`] and then wake up all
52    /// threads at once when the `n`th thread calls `wait`.
53    pub fn new(mut n: usize) -> Barrier {
54        let (waker, wait) = crate::sync::watch::channel(0);
55
56        if n == 0 {
57            // if n is 0, it's not clear what behavior the user wants.
58            // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
59            // .wait() immediately unblocks, so we adopt that here as well.
60            n = 1;
61        }
62
63        Barrier {
64            state: Mutex::new(BarrierState {
65                waker,
66                arrived: 0,
67                generation: 1,
68            }),
69            n,
70            wait,
71        }
72    }
73
74    /// Does not resolve until all tasks have rendezvoused here.
75    ///
76    /// Barriers are re-usable after all threads have rendezvoused once, and can
77    /// be used continuously.
78    ///
79    /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from
80    /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other threads
81    /// will receive a result that will return `false` from `is_leader`.
82    pub async fn wait(&self) -> BarrierWaitResult {
83        // NOTE: we are taking a _synchronous_ lock here.
84        // It is okay to do so because the critical section is fast and never yields, so it cannot
85        // deadlock even if another future is concurrently holding the lock.
86        // It is _desireable_ to do so as synchronous Mutexes are, at least in theory, faster than
87        // the asynchronous counter-parts, so we should use them where possible [citation needed].
88        // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across
89        // a yield point, and thus marks the returned future as !Send.
90        let generation = {
91            let mut state = self.state.lock().unwrap();
92            let generation = state.generation;
93            state.arrived += 1;
94            if state.arrived == self.n {
95                // we are the leader for this generation
96                // wake everyone, increment the generation, and return
97                state
98                    .waker
99                    .broadcast(state.generation)
100                    .expect("there is at least one receiver");
101                state.arrived = 0;
102                state.generation += 1;
103                return BarrierWaitResult(true);
104            }
105
106            generation
107        };
108
109        // we're going to have to wait for the last of the generation to arrive
110        let mut wait = self.wait.clone();
111
112        loop {
113            // note that the first time through the loop, this _will_ yield a generation
114            // immediately, since we cloned a receiver that has never seen any values.
115            if wait.recv().await.expect("sender hasn't been closed") >= generation {
116                break;
117            }
118        }
119
120        BarrierWaitResult(false)
121    }
122}
123
124/// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused.
125#[derive(Debug, Clone)]
126pub struct BarrierWaitResult(bool);
127
128impl BarrierWaitResult {
129    /// Returns true if this thread from wait is the "leader thread".
130    ///
131    /// Only one thread will have `true` returned from their result, all other threads will have
132    /// `false` returned.
133    pub fn is_leader(&self) -> bool {
134        self.0
135    }
136}