tokio_util/sync/
poll_semaphore.rs

1use futures_core::Stream;
2use std::fmt;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{ready, Context, Poll};
6use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError};
7
8use super::ReusableBoxFuture;
9
10/// A wrapper around [`Semaphore`] that provides a `poll_acquire` method.
11///
12/// [`Semaphore`]: tokio::sync::Semaphore
13pub struct PollSemaphore {
14    semaphore: Arc<Semaphore>,
15    permit_fut: Option<(
16        u32, // The number of permits requested.
17        ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>,
18    )>,
19}
20
21impl PollSemaphore {
22    /// Create a new `PollSemaphore`.
23    pub fn new(semaphore: Arc<Semaphore>) -> Self {
24        Self {
25            semaphore,
26            permit_fut: None,
27        }
28    }
29
30    /// Closes the semaphore.
31    pub fn close(&self) {
32        self.semaphore.close();
33    }
34
35    /// Obtain a clone of the inner semaphore.
36    pub fn clone_inner(&self) -> Arc<Semaphore> {
37        self.semaphore.clone()
38    }
39
40    /// Get back the inner semaphore.
41    pub fn into_inner(self) -> Arc<Semaphore> {
42        self.semaphore
43    }
44
45    /// Poll to acquire a permit from the semaphore.
46    ///
47    /// This can return the following values:
48    ///
49    ///  - `Poll::Pending` if a permit is not currently available.
50    ///  - `Poll::Ready(Some(permit))` if a permit was acquired.
51    ///  - `Poll::Ready(None)` if the semaphore has been closed.
52    ///
53    /// When this method returns `Poll::Pending`, the current task is scheduled
54    /// to receive a wakeup when a permit becomes available, or when the
55    /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
56    /// the `Waker` from the `Context` passed to the most recent call is
57    /// scheduled to receive a wakeup.
58    pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
59        self.poll_acquire_many(cx, 1)
60    }
61
62    /// Poll to acquire many permits from the semaphore.
63    ///
64    /// This can return the following values:
65    ///
66    ///  - `Poll::Pending` if a permit is not currently available.
67    ///  - `Poll::Ready(Some(permit))` if a permit was acquired.
68    ///  - `Poll::Ready(None)` if the semaphore has been closed.
69    ///
70    /// When this method returns `Poll::Pending`, the current task is scheduled
71    /// to receive a wakeup when the permits become available, or when the
72    /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
73    /// the `Waker` from the `Context` passed to the most recent call is
74    /// scheduled to receive a wakeup.
75    pub fn poll_acquire_many(
76        &mut self,
77        cx: &mut Context<'_>,
78        permits: u32,
79    ) -> Poll<Option<OwnedSemaphorePermit>> {
80        let permit_future = match self.permit_fut.as_mut() {
81            Some((prev_permits, fut)) if *prev_permits == permits => fut,
82            Some((old_permits, fut_box)) => {
83                // We're requesting a different number of permits, so replace the future
84                // and record the new amount.
85                let fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
86                fut_box.set(fut);
87                *old_permits = permits;
88                fut_box
89            }
90            None => {
91                // avoid allocations completely if we can grab a permit immediately
92                match Arc::clone(&self.semaphore).try_acquire_many_owned(permits) {
93                    Ok(permit) => return Poll::Ready(Some(permit)),
94                    Err(TryAcquireError::Closed) => return Poll::Ready(None),
95                    Err(TryAcquireError::NoPermits) => {}
96                }
97
98                let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
99                &mut self
100                    .permit_fut
101                    .get_or_insert((permits, ReusableBoxFuture::new(next_fut)))
102                    .1
103            }
104        };
105
106        let result = ready!(permit_future.poll(cx));
107
108        // Assume we'll request the same amount of permits in a subsequent call.
109        let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
110        permit_future.set(next_fut);
111
112        match result {
113            Ok(permit) => Poll::Ready(Some(permit)),
114            Err(_closed) => {
115                self.permit_fut = None;
116                Poll::Ready(None)
117            }
118        }
119    }
120
121    /// Returns the current number of available permits.
122    ///
123    /// This is equivalent to the [`Semaphore::available_permits`] method on the
124    /// `tokio::sync::Semaphore` type.
125    ///
126    /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits
127    pub fn available_permits(&self) -> usize {
128        self.semaphore.available_permits()
129    }
130
131    /// Adds `n` new permits to the semaphore.
132    ///
133    /// The maximum number of permits is [`Semaphore::MAX_PERMITS`], and this function
134    /// will panic if the limit is exceeded.
135    ///
136    /// This is equivalent to the [`Semaphore::add_permits`] method on the
137    /// `tokio::sync::Semaphore` type.
138    ///
139    /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits
140    pub fn add_permits(&self, n: usize) {
141        self.semaphore.add_permits(n);
142    }
143}
144
145impl Stream for PollSemaphore {
146    type Item = OwnedSemaphorePermit;
147
148    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
149        Pin::into_inner(self).poll_acquire(cx)
150    }
151}
152
153impl Clone for PollSemaphore {
154    fn clone(&self) -> PollSemaphore {
155        PollSemaphore::new(self.clone_inner())
156    }
157}
158
159impl fmt::Debug for PollSemaphore {
160    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161        f.debug_struct("PollSemaphore")
162            .field("semaphore", &self.semaphore)
163            .finish()
164    }
165}
166
167impl AsRef<Semaphore> for PollSemaphore {
168    fn as_ref(&self) -> &Semaphore {
169        &self.semaphore
170    }
171}