async_lock/
barrier.rs

1use event_listener::{Event, EventListener};
2use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy};
3
4use core::fmt;
5use core::pin::Pin;
6use core::task::Poll;
7
8use crate::futures::Lock;
9use crate::Mutex;
10
11/// A counter to synchronize multiple tasks at the same time.
12#[derive(Debug)]
13pub struct Barrier {
14    n: usize,
15    state: Mutex<State>,
16    event: Event,
17}
18
19#[derive(Debug)]
20struct State {
21    count: usize,
22    generation_id: u64,
23}
24
25impl Barrier {
26    const_fn! {
27        const_if: #[cfg(not(loom))];
28        /// Creates a barrier that can block the given number of tasks.
29        ///
30        /// A barrier will block `n`-1 tasks which call [`wait()`] and then wake up all tasks
31        /// at once when the `n`th task calls [`wait()`].
32        ///
33        /// [`wait()`]: `Barrier::wait()`
34        ///
35        /// # Examples
36        ///
37        /// ```
38        /// use async_lock::Barrier;
39        ///
40        /// let barrier = Barrier::new(5);
41        /// ```
42        pub const fn new(n: usize) -> Barrier {
43            Barrier {
44                n,
45                state: Mutex::new(State {
46                    count: 0,
47                    generation_id: 0,
48                }),
49                event: Event::new(),
50            }
51        }
52    }
53
54    /// Blocks the current task until all tasks reach this point.
55    ///
56    /// Barriers are reusable after all tasks have synchronized, and can be used continuously.
57    ///
58    /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the
59    /// last task to call this method.
60    ///
61    /// # Examples
62    ///
63    /// ```
64    /// use async_lock::Barrier;
65    /// use futures_lite::future;
66    /// use std::sync::Arc;
67    /// use std::thread;
68    ///
69    /// let barrier = Arc::new(Barrier::new(5));
70    ///
71    /// for _ in 0..5 {
72    ///     let b = barrier.clone();
73    ///     thread::spawn(move || {
74    ///         future::block_on(async {
75    ///             // The same messages will be printed together.
76    ///             // There will NOT be interleaving of "before" and "after".
77    ///             println!("before wait");
78    ///             b.wait().await;
79    ///             println!("after wait");
80    ///         });
81    ///     });
82    /// }
83    /// ```
84    pub fn wait(&self) -> BarrierWait<'_> {
85        BarrierWait::_new(BarrierWaitInner {
86            barrier: self,
87            lock: Some(self.state.lock()),
88            evl: None,
89            state: WaitState::Initial,
90        })
91    }
92
93    /// Blocks the current thread until all tasks reach this point.
94    ///
95    /// Barriers are reusable after all tasks have synchronized, and can be used continuously.
96    ///
97    /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the
98    /// last task to call this method.
99    ///
100    /// # Blocking
101    ///
102    /// Rather than using asynchronous waiting, like the [`wait`][`Barrier::wait`] method,
103    /// this method will block the current thread until the wait is complete.
104    ///
105    /// This method should not be used in an asynchronous context. It is intended to be
106    /// used in a way that a barrier can be used in both asynchronous and synchronous contexts.
107    /// Calling this method in an asynchronous context may result in a deadlock.
108    ///
109    /// # Examples
110    ///
111    /// ```
112    /// use async_lock::Barrier;
113    /// use futures_lite::future;
114    /// use std::sync::Arc;
115    /// use std::thread;
116    ///
117    /// let barrier = Arc::new(Barrier::new(5));
118    ///
119    /// for _ in 0..5 {
120    ///     let b = barrier.clone();
121    ///     thread::spawn(move || {
122    ///         // The same messages will be printed together.
123    ///         // There will NOT be interleaving of "before" and "after".
124    ///         println!("before wait");
125    ///         b.wait_blocking();
126    ///         println!("after wait");
127    ///     });
128    /// }
129    /// # // Wait for threads to stop.
130    /// # std::thread::sleep(std::time::Duration::from_secs(1));
131    /// ```
132    #[cfg(all(feature = "std", not(target_family = "wasm")))]
133    pub fn wait_blocking(&self) -> BarrierWaitResult {
134        self.wait().wait()
135    }
136}
137
138easy_wrapper! {
139    /// The future returned by [`Barrier::wait()`].
140    pub struct BarrierWait<'a>(BarrierWaitInner<'a> => BarrierWaitResult);
141    #[cfg(all(feature = "std", not(target_family = "wasm")))]
142    pub(crate) wait();
143}
144
145pin_project_lite::pin_project! {
146    /// The future returned by [`Barrier::wait()`].
147    struct BarrierWaitInner<'a> {
148        // The barrier to wait on.
149        barrier: &'a Barrier,
150
151        // The ongoing mutex lock operation we are blocking on.
152        #[pin]
153        lock: Option<Lock<'a, State>>,
154
155        // An event listener for the `barrier.event` event.
156        evl: Option<EventListener>,
157
158        // The current state of the future.
159        state: WaitState,
160    }
161}
162
163impl fmt::Debug for BarrierWait<'_> {
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        f.write_str("BarrierWait { .. }")
166    }
167}
168
169enum WaitState {
170    /// We are getting the original values of the state.
171    Initial,
172
173    /// We are waiting for the listener to complete.
174    Waiting { local_gen: u64 },
175
176    /// Waiting to re-acquire the lock to check the state again.
177    Reacquiring { local_gen: u64 },
178}
179
180impl EventListenerFuture for BarrierWaitInner<'_> {
181    type Output = BarrierWaitResult;
182
183    fn poll_with_strategy<'a, S: Strategy<'a>>(
184        self: Pin<&mut Self>,
185        strategy: &mut S,
186        cx: &mut S::Context,
187    ) -> Poll<Self::Output> {
188        let mut this = self.project();
189
190        loop {
191            match this.state {
192                WaitState::Initial => {
193                    // See if the lock is ready yet.
194                    let mut state = ready!(this
195                        .lock
196                        .as_mut()
197                        .as_pin_mut()
198                        .unwrap()
199                        .poll_with_strategy(strategy, cx));
200                    this.lock.as_mut().set(None);
201
202                    let local_gen = state.generation_id;
203                    state.count += 1;
204
205                    if state.count < this.barrier.n {
206                        // We need to wait for the event.
207                        *this.evl = Some(this.barrier.event.listen());
208                        *this.state = WaitState::Waiting { local_gen };
209                    } else {
210                        // We are the last one.
211                        state.count = 0;
212                        state.generation_id = state.generation_id.wrapping_add(1);
213                        this.barrier.event.notify(core::usize::MAX);
214                        return Poll::Ready(BarrierWaitResult { is_leader: true });
215                    }
216                }
217
218                WaitState::Waiting { local_gen } => {
219                    ready!(strategy.poll(this.evl, cx));
220
221                    // We are now re-acquiring the mutex.
222                    this.lock.as_mut().set(Some(this.barrier.state.lock()));
223                    *this.state = WaitState::Reacquiring {
224                        local_gen: *local_gen,
225                    };
226                }
227
228                WaitState::Reacquiring { local_gen } => {
229                    // Acquire the local state again.
230                    let state = ready!(this
231                        .lock
232                        .as_mut()
233                        .as_pin_mut()
234                        .unwrap()
235                        .poll_with_strategy(strategy, cx));
236                    this.lock.set(None);
237
238                    if *local_gen == state.generation_id && state.count < this.barrier.n {
239                        // We need to wait for the event again.
240                        *this.evl = Some(this.barrier.event.listen());
241                        *this.state = WaitState::Waiting {
242                            local_gen: *local_gen,
243                        };
244                    } else {
245                        // We are ready, but not the leader.
246                        return Poll::Ready(BarrierWaitResult { is_leader: false });
247                    }
248                }
249            }
250        }
251    }
252}
253
254/// Returned by [`Barrier::wait()`] when all tasks have called it.
255///
256/// # Examples
257///
258/// ```
259/// # futures_lite::future::block_on(async {
260/// use async_lock::Barrier;
261///
262/// let barrier = Barrier::new(1);
263/// let barrier_wait_result = barrier.wait().await;
264/// # });
265/// ```
266#[derive(Debug, Clone)]
267pub struct BarrierWaitResult {
268    is_leader: bool,
269}
270
271impl BarrierWaitResult {
272    /// Returns `true` if this task was the last to call to [`Barrier::wait()`].
273    ///
274    /// # Examples
275    ///
276    /// ```
277    /// # futures_lite::future::block_on(async {
278    /// use async_lock::Barrier;
279    /// use futures_lite::future;
280    ///
281    /// let barrier = Barrier::new(2);
282    /// let (a, b) = future::zip(barrier.wait(), barrier.wait()).await;
283    /// assert_eq!(a.is_leader(), false);
284    /// assert_eq!(b.is_leader(), true);
285    /// # });
286    /// ```
287    pub fn is_leader(&self) -> bool {
288        self.is_leader
289    }
290}