async_lock/semaphore.rs
1use core::fmt;
2use core::marker::PhantomPinned;
3use core::mem;
4use core::pin::Pin;
5use core::task::Poll;
6
7use crate::sync::atomic::{AtomicUsize, Ordering};
8
9use alloc::sync::Arc;
10
11use event_listener::{Event, EventListener};
12use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy};
13
14/// A counter for limiting the number of concurrent operations.
15#[derive(Debug)]
16pub struct Semaphore {
17 count: AtomicUsize,
18 event: Event,
19}
20
21impl Semaphore {
22 const_fn! {
23 const_if: #[cfg(not(loom))];
24 /// Creates a new semaphore with a limit of `n` concurrent operations.
25 ///
26 /// # Examples
27 ///
28 /// ```
29 /// use async_lock::Semaphore;
30 ///
31 /// let s = Semaphore::new(5);
32 /// ```
33 pub const fn new(n: usize) -> Semaphore {
34 Semaphore {
35 count: AtomicUsize::new(n),
36 event: Event::new(),
37 }
38 }
39 }
40
41 /// Attempts to get a permit for a concurrent operation.
42 ///
43 /// If the permit could not be acquired at this time, then [`None`] is returned. Otherwise, a
44 /// guard is returned that releases the mutex when dropped.
45 ///
46 /// # Examples
47 ///
48 /// ```
49 /// use async_lock::Semaphore;
50 ///
51 /// let s = Semaphore::new(2);
52 ///
53 /// let g1 = s.try_acquire().unwrap();
54 /// let g2 = s.try_acquire().unwrap();
55 ///
56 /// assert!(s.try_acquire().is_none());
57 /// drop(g2);
58 /// assert!(s.try_acquire().is_some());
59 /// ```
60 pub fn try_acquire(&self) -> Option<SemaphoreGuard<'_>> {
61 let mut count = self.count.load(Ordering::Acquire);
62 loop {
63 if count == 0 {
64 return None;
65 }
66
67 match self.count.compare_exchange_weak(
68 count,
69 count - 1,
70 Ordering::AcqRel,
71 Ordering::Acquire,
72 ) {
73 Ok(_) => return Some(SemaphoreGuard(self)),
74 Err(c) => count = c,
75 }
76 }
77 }
78
79 /// Waits for a permit for a concurrent operation.
80 ///
81 /// Returns a guard that releases the permit when dropped.
82 ///
83 /// # Examples
84 ///
85 /// ```
86 /// # futures_lite::future::block_on(async {
87 /// use async_lock::Semaphore;
88 ///
89 /// let s = Semaphore::new(2);
90 /// let guard = s.acquire().await;
91 /// # });
92 /// ```
93 pub fn acquire(&self) -> Acquire<'_> {
94 Acquire::_new(AcquireInner {
95 semaphore: self,
96 listener: None,
97 _pin: PhantomPinned,
98 })
99 }
100
101 /// Waits for a permit for a concurrent operation.
102 ///
103 /// Returns a guard that releases the permit when dropped.
104 ///
105 /// # Blocking
106 ///
107 /// Rather than using asynchronous waiting, like the [`acquire`][Semaphore::acquire] method,
108 /// this method will block the current thread until the permit is acquired.
109 ///
110 /// This method should not be used in an asynchronous context. It is intended to be
111 /// used in a way that a semaphore can be used in both asynchronous and synchronous contexts.
112 /// Calling this method in an asynchronous context may result in a deadlock.
113 ///
114 /// # Examples
115 ///
116 /// ```
117 /// use async_lock::Semaphore;
118 ///
119 /// let s = Semaphore::new(2);
120 /// let guard = s.acquire_blocking();
121 /// ```
122 #[cfg(all(feature = "std", not(target_family = "wasm")))]
123 #[inline]
124 pub fn acquire_blocking(&self) -> SemaphoreGuard<'_> {
125 self.acquire().wait()
126 }
127
128 /// Attempts to get an owned permit for a concurrent operation.
129 ///
130 /// If the permit could not be acquired at this time, then [`None`] is returned. Otherwise, an
131 /// owned guard is returned that releases the mutex when dropped.
132 ///
133 /// # Examples
134 ///
135 /// ```
136 /// use async_lock::Semaphore;
137 /// use std::sync::Arc;
138 ///
139 /// let s = Arc::new(Semaphore::new(2));
140 ///
141 /// let g1 = s.try_acquire_arc().unwrap();
142 /// let g2 = s.try_acquire_arc().unwrap();
143 ///
144 /// assert!(s.try_acquire_arc().is_none());
145 /// drop(g2);
146 /// assert!(s.try_acquire_arc().is_some());
147 /// ```
148 pub fn try_acquire_arc(self: &Arc<Self>) -> Option<SemaphoreGuardArc> {
149 let mut count = self.count.load(Ordering::Acquire);
150 loop {
151 if count == 0 {
152 return None;
153 }
154
155 match self.count.compare_exchange_weak(
156 count,
157 count - 1,
158 Ordering::AcqRel,
159 Ordering::Acquire,
160 ) {
161 Ok(_) => return Some(SemaphoreGuardArc(Some(self.clone()))),
162 Err(c) => count = c,
163 }
164 }
165 }
166
167 /// Waits for an owned permit for a concurrent operation.
168 ///
169 /// Returns a guard that releases the permit when dropped.
170 ///
171 /// # Examples
172 ///
173 /// ```
174 /// # futures_lite::future::block_on(async {
175 /// use async_lock::Semaphore;
176 /// use std::sync::Arc;
177 ///
178 /// let s = Arc::new(Semaphore::new(2));
179 /// let guard = s.acquire_arc().await;
180 /// # });
181 /// ```
182 pub fn acquire_arc(self: &Arc<Self>) -> AcquireArc {
183 AcquireArc::_new(AcquireArcInner {
184 semaphore: self.clone(),
185 listener: None,
186 _pin: PhantomPinned,
187 })
188 }
189
190 /// Waits for an owned permit for a concurrent operation.
191 ///
192 /// Returns a guard that releases the permit when dropped.
193 ///
194 /// # Blocking
195 ///
196 /// Rather than using asynchronous waiting, like the [`acquire_arc`][Semaphore::acquire_arc] method,
197 /// this method will block the current thread until the permit is acquired.
198 ///
199 /// This method should not be used in an asynchronous context. It is intended to be
200 /// used in a way that a semaphore can be used in both asynchronous and synchronous contexts.
201 /// Calling this method in an asynchronous context may result in a deadlock.
202 ///
203 /// # Examples
204 ///
205 /// ```
206 /// use std::sync::Arc;
207 /// use async_lock::Semaphore;
208 ///
209 /// let s = Arc::new(Semaphore::new(2));
210 /// let guard = s.acquire_arc_blocking();
211 /// ```
212 #[cfg(all(feature = "std", not(target_family = "wasm")))]
213 #[inline]
214 pub fn acquire_arc_blocking(self: &Arc<Self>) -> SemaphoreGuardArc {
215 self.acquire_arc().wait()
216 }
217
218 /// Adds `n` additional permits to the semaphore.
219 ///
220 /// # Examples
221 ///
222 /// ```
223 /// use async_lock::Semaphore;
224 ///
225 /// # futures_lite::future::block_on(async {
226 /// let s = Semaphore::new(1);
227 ///
228 /// let _guard = s.acquire().await;
229 /// assert!(s.try_acquire().is_none());
230 ///
231 /// s.add_permits(2);
232 ///
233 /// let _guard = s.acquire().await;
234 /// let _guard = s.acquire().await;
235 /// # });
236 /// ```
237 pub fn add_permits(&self, n: usize) {
238 self.count.fetch_add(n, Ordering::AcqRel);
239 self.event.notify(n);
240 }
241}
242
243easy_wrapper! {
244 /// The future returned by [`Semaphore::acquire`].
245 pub struct Acquire<'a>(AcquireInner<'a> => SemaphoreGuard<'a>);
246 #[cfg(all(feature = "std", not(target_family = "wasm")))]
247 pub(crate) wait();
248}
249
250pin_project_lite::pin_project! {
251 struct AcquireInner<'a> {
252 // The semaphore being acquired.
253 semaphore: &'a Semaphore,
254
255 // The listener waiting on the semaphore.
256 listener: Option<EventListener>,
257
258 // Keeping this future `!Unpin` enables future optimizations.
259 #[pin]
260 _pin: PhantomPinned
261 }
262}
263
264impl fmt::Debug for Acquire<'_> {
265 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266 f.write_str("Acquire { .. }")
267 }
268}
269
270impl<'a> EventListenerFuture for AcquireInner<'a> {
271 type Output = SemaphoreGuard<'a>;
272
273 fn poll_with_strategy<'x, S: Strategy<'x>>(
274 self: Pin<&mut Self>,
275 strategy: &mut S,
276 cx: &mut S::Context,
277 ) -> Poll<Self::Output> {
278 let this = self.project();
279
280 loop {
281 match this.semaphore.try_acquire() {
282 Some(guard) => return Poll::Ready(guard),
283 None => {
284 // Wait on the listener.
285 if this.listener.is_none() {
286 *this.listener = Some(this.semaphore.event.listen());
287 } else {
288 ready!(strategy.poll(this.listener, cx));
289 }
290 }
291 }
292 }
293 }
294}
295
296easy_wrapper! {
297 /// The future returned by [`Semaphore::acquire_arc`].
298 pub struct AcquireArc(AcquireArcInner => SemaphoreGuardArc);
299 #[cfg(all(feature = "std", not(target_family = "wasm")))]
300 pub(crate) wait();
301}
302
303pin_project_lite::pin_project! {
304 struct AcquireArcInner {
305 // The semaphore being acquired.
306 semaphore: Arc<Semaphore>,
307
308 // The listener waiting on the semaphore.
309 listener: Option<EventListener>,
310
311 // Keeping this future `!Unpin` enables future optimizations.
312 #[pin]
313 _pin: PhantomPinned
314 }
315}
316
317impl fmt::Debug for AcquireArc {
318 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
319 f.write_str("AcquireArc { .. }")
320 }
321}
322
323impl EventListenerFuture for AcquireArcInner {
324 type Output = SemaphoreGuardArc;
325
326 fn poll_with_strategy<'x, S: Strategy<'x>>(
327 self: Pin<&mut Self>,
328 strategy: &mut S,
329 cx: &mut S::Context,
330 ) -> Poll<Self::Output> {
331 let this = self.project();
332
333 loop {
334 match this.semaphore.try_acquire_arc() {
335 Some(guard) => return Poll::Ready(guard),
336 None => {
337 // Wait on the listener.
338 if this.listener.is_none() {
339 *this.listener = Some(this.semaphore.event.listen());
340 } else {
341 ready!(strategy.poll(this.listener, cx));
342 }
343 }
344 }
345 }
346 }
347}
348
349/// A guard that releases the acquired permit.
350#[clippy::has_significant_drop]
351#[derive(Debug)]
352pub struct SemaphoreGuard<'a>(&'a Semaphore);
353
354impl SemaphoreGuard<'_> {
355 /// Drops the guard _without_ releasing the acquired permit.
356 #[inline]
357 pub fn forget(self) {
358 mem::forget(self);
359 }
360}
361
362impl Drop for SemaphoreGuard<'_> {
363 fn drop(&mut self) {
364 self.0.count.fetch_add(1, Ordering::AcqRel);
365 self.0.event.notify(1);
366 }
367}
368
369/// An owned guard that releases the acquired permit.
370#[clippy::has_significant_drop]
371#[derive(Debug)]
372pub struct SemaphoreGuardArc(Option<Arc<Semaphore>>);
373
374impl SemaphoreGuardArc {
375 /// Drops the guard _without_ releasing the acquired permit.
376 /// (Will still decrement the `Arc` reference count.)
377 #[inline]
378 pub fn forget(mut self) {
379 // Drop the inner `Arc` in order to decrement the reference count.
380 // FIXME: get rid of the `Option` once RFC 3466 or equivalent becomes available.
381 drop(self.0.take());
382 mem::forget(self);
383 }
384}
385
386impl Drop for SemaphoreGuardArc {
387 fn drop(&mut self) {
388 let opt = self.0.take().unwrap();
389 opt.count.fetch_add(1, Ordering::AcqRel);
390 opt.event.notify(1);
391 }
392}