async_lock/barrier.rs
1use event_listener::{Event, EventListener};
2use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy};
3
4use core::fmt;
5use core::pin::Pin;
6use core::task::Poll;
7
8use crate::futures::Lock;
9use crate::Mutex;
10
11/// A counter to synchronize multiple tasks at the same time.
12#[derive(Debug)]
13pub struct Barrier {
14 n: usize,
15 state: Mutex<State>,
16 event: Event,
17}
18
19#[derive(Debug)]
20struct State {
21 count: usize,
22 generation_id: u64,
23}
24
25impl Barrier {
26 const_fn! {
27 const_if: #[cfg(not(loom))];
28 /// Creates a barrier that can block the given number of tasks.
29 ///
30 /// A barrier will block `n`-1 tasks which call [`wait()`] and then wake up all tasks
31 /// at once when the `n`th task calls [`wait()`].
32 ///
33 /// [`wait()`]: `Barrier::wait()`
34 ///
35 /// # Examples
36 ///
37 /// ```
38 /// use async_lock::Barrier;
39 ///
40 /// let barrier = Barrier::new(5);
41 /// ```
42 pub const fn new(n: usize) -> Barrier {
43 Barrier {
44 n,
45 state: Mutex::new(State {
46 count: 0,
47 generation_id: 0,
48 }),
49 event: Event::new(),
50 }
51 }
52 }
53
54 /// Blocks the current task until all tasks reach this point.
55 ///
56 /// Barriers are reusable after all tasks have synchronized, and can be used continuously.
57 ///
58 /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the
59 /// last task to call this method.
60 ///
61 /// # Examples
62 ///
63 /// ```
64 /// use async_lock::Barrier;
65 /// use futures_lite::future;
66 /// use std::sync::Arc;
67 /// use std::thread;
68 ///
69 /// let barrier = Arc::new(Barrier::new(5));
70 ///
71 /// for _ in 0..5 {
72 /// let b = barrier.clone();
73 /// thread::spawn(move || {
74 /// future::block_on(async {
75 /// // The same messages will be printed together.
76 /// // There will NOT be interleaving of "before" and "after".
77 /// println!("before wait");
78 /// b.wait().await;
79 /// println!("after wait");
80 /// });
81 /// });
82 /// }
83 /// ```
84 pub fn wait(&self) -> BarrierWait<'_> {
85 BarrierWait::_new(BarrierWaitInner {
86 barrier: self,
87 lock: Some(self.state.lock()),
88 evl: None,
89 state: WaitState::Initial,
90 })
91 }
92
93 /// Blocks the current thread until all tasks reach this point.
94 ///
95 /// Barriers are reusable after all tasks have synchronized, and can be used continuously.
96 ///
97 /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the
98 /// last task to call this method.
99 ///
100 /// # Blocking
101 ///
102 /// Rather than using asynchronous waiting, like the [`wait`][`Barrier::wait`] method,
103 /// this method will block the current thread until the wait is complete.
104 ///
105 /// This method should not be used in an asynchronous context. It is intended to be
106 /// used in a way that a barrier can be used in both asynchronous and synchronous contexts.
107 /// Calling this method in an asynchronous context may result in a deadlock.
108 ///
109 /// # Examples
110 ///
111 /// ```
112 /// use async_lock::Barrier;
113 /// use futures_lite::future;
114 /// use std::sync::Arc;
115 /// use std::thread;
116 ///
117 /// let barrier = Arc::new(Barrier::new(5));
118 ///
119 /// for _ in 0..5 {
120 /// let b = barrier.clone();
121 /// thread::spawn(move || {
122 /// // The same messages will be printed together.
123 /// // There will NOT be interleaving of "before" and "after".
124 /// println!("before wait");
125 /// b.wait_blocking();
126 /// println!("after wait");
127 /// });
128 /// }
129 /// # // Wait for threads to stop.
130 /// # std::thread::sleep(std::time::Duration::from_secs(1));
131 /// ```
132 #[cfg(all(feature = "std", not(target_family = "wasm")))]
133 pub fn wait_blocking(&self) -> BarrierWaitResult {
134 self.wait().wait()
135 }
136}
137
138easy_wrapper! {
139 /// The future returned by [`Barrier::wait()`].
140 pub struct BarrierWait<'a>(BarrierWaitInner<'a> => BarrierWaitResult);
141 #[cfg(all(feature = "std", not(target_family = "wasm")))]
142 pub(crate) wait();
143}
144
145pin_project_lite::pin_project! {
146 /// The future returned by [`Barrier::wait()`].
147 struct BarrierWaitInner<'a> {
148 // The barrier to wait on.
149 barrier: &'a Barrier,
150
151 // The ongoing mutex lock operation we are blocking on.
152 #[pin]
153 lock: Option<Lock<'a, State>>,
154
155 // An event listener for the `barrier.event` event.
156 evl: Option<EventListener>,
157
158 // The current state of the future.
159 state: WaitState,
160 }
161}
162
163impl fmt::Debug for BarrierWait<'_> {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 f.write_str("BarrierWait { .. }")
166 }
167}
168
169enum WaitState {
170 /// We are getting the original values of the state.
171 Initial,
172
173 /// We are waiting for the listener to complete.
174 Waiting { local_gen: u64 },
175
176 /// Waiting to re-acquire the lock to check the state again.
177 Reacquiring { local_gen: u64 },
178}
179
180impl EventListenerFuture for BarrierWaitInner<'_> {
181 type Output = BarrierWaitResult;
182
183 fn poll_with_strategy<'a, S: Strategy<'a>>(
184 self: Pin<&mut Self>,
185 strategy: &mut S,
186 cx: &mut S::Context,
187 ) -> Poll<Self::Output> {
188 let mut this = self.project();
189
190 loop {
191 match this.state {
192 WaitState::Initial => {
193 // See if the lock is ready yet.
194 let mut state = ready!(this
195 .lock
196 .as_mut()
197 .as_pin_mut()
198 .unwrap()
199 .poll_with_strategy(strategy, cx));
200 this.lock.as_mut().set(None);
201
202 let local_gen = state.generation_id;
203 state.count += 1;
204
205 if state.count < this.barrier.n {
206 // We need to wait for the event.
207 *this.evl = Some(this.barrier.event.listen());
208 *this.state = WaitState::Waiting { local_gen };
209 } else {
210 // We are the last one.
211 state.count = 0;
212 state.generation_id = state.generation_id.wrapping_add(1);
213 this.barrier.event.notify(core::usize::MAX);
214 return Poll::Ready(BarrierWaitResult { is_leader: true });
215 }
216 }
217
218 WaitState::Waiting { local_gen } => {
219 ready!(strategy.poll(this.evl, cx));
220
221 // We are now re-acquiring the mutex.
222 this.lock.as_mut().set(Some(this.barrier.state.lock()));
223 *this.state = WaitState::Reacquiring {
224 local_gen: *local_gen,
225 };
226 }
227
228 WaitState::Reacquiring { local_gen } => {
229 // Acquire the local state again.
230 let state = ready!(this
231 .lock
232 .as_mut()
233 .as_pin_mut()
234 .unwrap()
235 .poll_with_strategy(strategy, cx));
236 this.lock.set(None);
237
238 if *local_gen == state.generation_id && state.count < this.barrier.n {
239 // We need to wait for the event again.
240 *this.evl = Some(this.barrier.event.listen());
241 *this.state = WaitState::Waiting {
242 local_gen: *local_gen,
243 };
244 } else {
245 // We are ready, but not the leader.
246 return Poll::Ready(BarrierWaitResult { is_leader: false });
247 }
248 }
249 }
250 }
251 }
252}
253
254/// Returned by [`Barrier::wait()`] when all tasks have called it.
255///
256/// # Examples
257///
258/// ```
259/// # futures_lite::future::block_on(async {
260/// use async_lock::Barrier;
261///
262/// let barrier = Barrier::new(1);
263/// let barrier_wait_result = barrier.wait().await;
264/// # });
265/// ```
266#[derive(Debug, Clone)]
267pub struct BarrierWaitResult {
268 is_leader: bool,
269}
270
271impl BarrierWaitResult {
272 /// Returns `true` if this task was the last to call to [`Barrier::wait()`].
273 ///
274 /// # Examples
275 ///
276 /// ```
277 /// # futures_lite::future::block_on(async {
278 /// use async_lock::Barrier;
279 /// use futures_lite::future;
280 ///
281 /// let barrier = Barrier::new(2);
282 /// let (a, b) = future::zip(barrier.wait(), barrier.wait()).await;
283 /// assert_eq!(a.is_leader(), false);
284 /// assert_eq!(b.is_leader(), true);
285 /// # });
286 /// ```
287 pub fn is_leader(&self) -> bool {
288 self.is_leader
289 }
290}