async_lock/
semaphore.rs

1use core::fmt;
2use core::marker::PhantomPinned;
3use core::mem;
4use core::pin::Pin;
5use core::task::Poll;
6
7use crate::sync::atomic::{AtomicUsize, Ordering};
8
9use alloc::sync::Arc;
10
11use event_listener::{Event, EventListener};
12use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy};
13
14/// A counter for limiting the number of concurrent operations.
15#[derive(Debug)]
16pub struct Semaphore {
17    count: AtomicUsize,
18    event: Event,
19}
20
21impl Semaphore {
22    const_fn! {
23        const_if: #[cfg(not(loom))];
24        /// Creates a new semaphore with a limit of `n` concurrent operations.
25        ///
26        /// # Examples
27        ///
28        /// ```
29        /// use async_lock::Semaphore;
30        ///
31        /// let s = Semaphore::new(5);
32        /// ```
33        pub const fn new(n: usize) -> Semaphore {
34            Semaphore {
35                count: AtomicUsize::new(n),
36                event: Event::new(),
37            }
38        }
39    }
40
41    /// Attempts to get a permit for a concurrent operation.
42    ///
43    /// If the permit could not be acquired at this time, then [`None`] is returned. Otherwise, a
44    /// guard is returned that releases the mutex when dropped.
45    ///
46    /// # Examples
47    ///
48    /// ```
49    /// use async_lock::Semaphore;
50    ///
51    /// let s = Semaphore::new(2);
52    ///
53    /// let g1 = s.try_acquire().unwrap();
54    /// let g2 = s.try_acquire().unwrap();
55    ///
56    /// assert!(s.try_acquire().is_none());
57    /// drop(g2);
58    /// assert!(s.try_acquire().is_some());
59    /// ```
60    pub fn try_acquire(&self) -> Option<SemaphoreGuard<'_>> {
61        let mut count = self.count.load(Ordering::Acquire);
62        loop {
63            if count == 0 {
64                return None;
65            }
66
67            match self.count.compare_exchange_weak(
68                count,
69                count - 1,
70                Ordering::AcqRel,
71                Ordering::Acquire,
72            ) {
73                Ok(_) => return Some(SemaphoreGuard(self)),
74                Err(c) => count = c,
75            }
76        }
77    }
78
79    /// Waits for a permit for a concurrent operation.
80    ///
81    /// Returns a guard that releases the permit when dropped.
82    ///
83    /// # Examples
84    ///
85    /// ```
86    /// # futures_lite::future::block_on(async {
87    /// use async_lock::Semaphore;
88    ///
89    /// let s = Semaphore::new(2);
90    /// let guard = s.acquire().await;
91    /// # });
92    /// ```
93    pub fn acquire(&self) -> Acquire<'_> {
94        Acquire::_new(AcquireInner {
95            semaphore: self,
96            listener: None,
97            _pin: PhantomPinned,
98        })
99    }
100
101    /// Waits for a permit for a concurrent operation.
102    ///
103    /// Returns a guard that releases the permit when dropped.
104    ///
105    /// # Blocking
106    ///
107    /// Rather than using asynchronous waiting, like the [`acquire`][Semaphore::acquire] method,
108    /// this method will block the current thread until the permit is acquired.
109    ///
110    /// This method should not be used in an asynchronous context. It is intended to be
111    /// used in a way that a semaphore can be used in both asynchronous and synchronous contexts.
112    /// Calling this method in an asynchronous context may result in a deadlock.
113    ///
114    /// # Examples
115    ///
116    /// ```
117    /// use async_lock::Semaphore;
118    ///
119    /// let s = Semaphore::new(2);
120    /// let guard = s.acquire_blocking();
121    /// ```
122    #[cfg(all(feature = "std", not(target_family = "wasm")))]
123    #[inline]
124    pub fn acquire_blocking(&self) -> SemaphoreGuard<'_> {
125        self.acquire().wait()
126    }
127
128    /// Attempts to get an owned permit for a concurrent operation.
129    ///
130    /// If the permit could not be acquired at this time, then [`None`] is returned. Otherwise, an
131    /// owned guard is returned that releases the mutex when dropped.
132    ///
133    /// # Examples
134    ///
135    /// ```
136    /// use async_lock::Semaphore;
137    /// use std::sync::Arc;
138    ///
139    /// let s = Arc::new(Semaphore::new(2));
140    ///
141    /// let g1 = s.try_acquire_arc().unwrap();
142    /// let g2 = s.try_acquire_arc().unwrap();
143    ///
144    /// assert!(s.try_acquire_arc().is_none());
145    /// drop(g2);
146    /// assert!(s.try_acquire_arc().is_some());
147    /// ```
148    pub fn try_acquire_arc(self: &Arc<Self>) -> Option<SemaphoreGuardArc> {
149        let mut count = self.count.load(Ordering::Acquire);
150        loop {
151            if count == 0 {
152                return None;
153            }
154
155            match self.count.compare_exchange_weak(
156                count,
157                count - 1,
158                Ordering::AcqRel,
159                Ordering::Acquire,
160            ) {
161                Ok(_) => return Some(SemaphoreGuardArc(Some(self.clone()))),
162                Err(c) => count = c,
163            }
164        }
165    }
166
167    /// Waits for an owned permit for a concurrent operation.
168    ///
169    /// Returns a guard that releases the permit when dropped.
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// # futures_lite::future::block_on(async {
175    /// use async_lock::Semaphore;
176    /// use std::sync::Arc;
177    ///
178    /// let s = Arc::new(Semaphore::new(2));
179    /// let guard = s.acquire_arc().await;
180    /// # });
181    /// ```
182    pub fn acquire_arc(self: &Arc<Self>) -> AcquireArc {
183        AcquireArc::_new(AcquireArcInner {
184            semaphore: self.clone(),
185            listener: None,
186            _pin: PhantomPinned,
187        })
188    }
189
190    /// Waits for an owned permit for a concurrent operation.
191    ///
192    /// Returns a guard that releases the permit when dropped.
193    ///
194    /// # Blocking
195    ///
196    /// Rather than using asynchronous waiting, like the [`acquire_arc`][Semaphore::acquire_arc] method,
197    /// this method will block the current thread until the permit is acquired.
198    ///
199    /// This method should not be used in an asynchronous context. It is intended to be
200    /// used in a way that a semaphore can be used in both asynchronous and synchronous contexts.
201    /// Calling this method in an asynchronous context may result in a deadlock.
202    ///
203    /// # Examples
204    ///
205    /// ```
206    /// use std::sync::Arc;
207    /// use async_lock::Semaphore;
208    ///
209    /// let s = Arc::new(Semaphore::new(2));
210    /// let guard = s.acquire_arc_blocking();
211    /// ```
212    #[cfg(all(feature = "std", not(target_family = "wasm")))]
213    #[inline]
214    pub fn acquire_arc_blocking(self: &Arc<Self>) -> SemaphoreGuardArc {
215        self.acquire_arc().wait()
216    }
217
218    /// Adds `n` additional permits to the semaphore.
219    ///
220    /// # Examples
221    ///
222    /// ```
223    /// use async_lock::Semaphore;
224    ///
225    /// # futures_lite::future::block_on(async {
226    /// let s = Semaphore::new(1);
227    ///
228    /// let _guard = s.acquire().await;
229    /// assert!(s.try_acquire().is_none());
230    ///
231    /// s.add_permits(2);
232    ///
233    /// let _guard = s.acquire().await;
234    /// let _guard = s.acquire().await;
235    /// # });
236    /// ```
237    pub fn add_permits(&self, n: usize) {
238        self.count.fetch_add(n, Ordering::AcqRel);
239        self.event.notify(n);
240    }
241}
242
243easy_wrapper! {
244    /// The future returned by [`Semaphore::acquire`].
245    pub struct Acquire<'a>(AcquireInner<'a> => SemaphoreGuard<'a>);
246    #[cfg(all(feature = "std", not(target_family = "wasm")))]
247    pub(crate) wait();
248}
249
250pin_project_lite::pin_project! {
251    struct AcquireInner<'a> {
252        // The semaphore being acquired.
253        semaphore: &'a Semaphore,
254
255        // The listener waiting on the semaphore.
256        listener: Option<EventListener>,
257
258        // Keeping this future `!Unpin` enables future optimizations.
259        #[pin]
260        _pin: PhantomPinned
261    }
262}
263
264impl fmt::Debug for Acquire<'_> {
265    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266        f.write_str("Acquire { .. }")
267    }
268}
269
270impl<'a> EventListenerFuture for AcquireInner<'a> {
271    type Output = SemaphoreGuard<'a>;
272
273    fn poll_with_strategy<'x, S: Strategy<'x>>(
274        self: Pin<&mut Self>,
275        strategy: &mut S,
276        cx: &mut S::Context,
277    ) -> Poll<Self::Output> {
278        let this = self.project();
279
280        loop {
281            match this.semaphore.try_acquire() {
282                Some(guard) => return Poll::Ready(guard),
283                None => {
284                    // Wait on the listener.
285                    if this.listener.is_none() {
286                        *this.listener = Some(this.semaphore.event.listen());
287                    } else {
288                        ready!(strategy.poll(this.listener, cx));
289                    }
290                }
291            }
292        }
293    }
294}
295
296easy_wrapper! {
297    /// The future returned by [`Semaphore::acquire_arc`].
298    pub struct AcquireArc(AcquireArcInner => SemaphoreGuardArc);
299    #[cfg(all(feature = "std", not(target_family = "wasm")))]
300    pub(crate) wait();
301}
302
303pin_project_lite::pin_project! {
304    struct AcquireArcInner {
305        // The semaphore being acquired.
306        semaphore: Arc<Semaphore>,
307
308        // The listener waiting on the semaphore.
309        listener: Option<EventListener>,
310
311        // Keeping this future `!Unpin` enables future optimizations.
312        #[pin]
313        _pin: PhantomPinned
314    }
315}
316
317impl fmt::Debug for AcquireArc {
318    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
319        f.write_str("AcquireArc { .. }")
320    }
321}
322
323impl EventListenerFuture for AcquireArcInner {
324    type Output = SemaphoreGuardArc;
325
326    fn poll_with_strategy<'x, S: Strategy<'x>>(
327        self: Pin<&mut Self>,
328        strategy: &mut S,
329        cx: &mut S::Context,
330    ) -> Poll<Self::Output> {
331        let this = self.project();
332
333        loop {
334            match this.semaphore.try_acquire_arc() {
335                Some(guard) => return Poll::Ready(guard),
336                None => {
337                    // Wait on the listener.
338                    if this.listener.is_none() {
339                        *this.listener = Some(this.semaphore.event.listen());
340                    } else {
341                        ready!(strategy.poll(this.listener, cx));
342                    }
343                }
344            }
345        }
346    }
347}
348
349/// A guard that releases the acquired permit.
350#[clippy::has_significant_drop]
351#[derive(Debug)]
352pub struct SemaphoreGuard<'a>(&'a Semaphore);
353
354impl SemaphoreGuard<'_> {
355    /// Drops the guard _without_ releasing the acquired permit.
356    #[inline]
357    pub fn forget(self) {
358        mem::forget(self);
359    }
360}
361
362impl Drop for SemaphoreGuard<'_> {
363    fn drop(&mut self) {
364        self.0.count.fetch_add(1, Ordering::AcqRel);
365        self.0.event.notify(1);
366    }
367}
368
369/// An owned guard that releases the acquired permit.
370#[clippy::has_significant_drop]
371#[derive(Debug)]
372pub struct SemaphoreGuardArc(Option<Arc<Semaphore>>);
373
374impl SemaphoreGuardArc {
375    /// Drops the guard _without_ releasing the acquired permit.
376    /// (Will still decrement the `Arc` reference count.)
377    #[inline]
378    pub fn forget(mut self) {
379        // Drop the inner `Arc` in order to decrement the reference count.
380        // FIXME: get rid of the `Option` once RFC 3466 or equivalent becomes available.
381        drop(self.0.take());
382        mem::forget(self);
383    }
384}
385
386impl Drop for SemaphoreGuardArc {
387    fn drop(&mut self) {
388        let opt = self.0.take().unwrap();
389        opt.count.fetch_add(1, Ordering::AcqRel);
390        opt.event.notify(1);
391    }
392}