futures_bounded/
futures_map.rs

1use std::future::Future;
2use std::hash::Hash;
3use std::pin::Pin;
4use std::task::{Context, Poll, Waker};
5use std::time::Duration;
6use std::{future, mem};
7
8use futures_timer::Delay;
9use futures_util::future::BoxFuture;
10use futures_util::stream::FuturesUnordered;
11use futures_util::{FutureExt, StreamExt};
12
13use crate::{PushError, Timeout};
14
15/// Represents a map of [`Future`]s.
16///
17/// Each future must finish within the specified time and the map never outgrows its capacity.
18pub struct FuturesMap<ID, O> {
19    timeout: Duration,
20    capacity: usize,
21    inner: FuturesUnordered<TaggedFuture<ID, TimeoutFuture<BoxFuture<'static, O>>>>,
22    empty_waker: Option<Waker>,
23    full_waker: Option<Waker>,
24}
25
26impl<ID, O> FuturesMap<ID, O> {
27    pub fn new(timeout: Duration, capacity: usize) -> Self {
28        Self {
29            timeout,
30            capacity,
31            inner: Default::default(),
32            empty_waker: None,
33            full_waker: None,
34        }
35    }
36}
37
38impl<ID, O> FuturesMap<ID, O>
39where
40    ID: Clone + Hash + Eq + Send + Unpin + 'static,
41    O: 'static,
42{
43    /// Push a future into the map.
44    ///
45    /// This method inserts the given future with defined `future_id` to the set.
46    /// If the length of the map is equal to the capacity, this method returns [PushError::BeyondCapacity],
47    /// that contains the passed future. In that case, the future is not inserted to the map.
48    /// If a future with the given `future_id` already exists, then the old future will be replaced by a new one.
49    /// In that case, the returned error [PushError::Replaced] contains the old future.
50    pub fn try_push<F>(&mut self, future_id: ID, future: F) -> Result<(), PushError<BoxFuture<O>>>
51    where
52        F: Future<Output = O> + Send + 'static,
53    {
54        if self.inner.len() >= self.capacity {
55            return Err(PushError::BeyondCapacity(future.boxed()));
56        }
57
58        if let Some(waker) = self.empty_waker.take() {
59            waker.wake();
60        }
61
62        let old = self.remove(future_id.clone());
63        self.inner.push(TaggedFuture {
64            tag: future_id,
65            inner: TimeoutFuture {
66                inner: future.boxed(),
67                timeout: Delay::new(self.timeout),
68                cancelled: false,
69            },
70        });
71        match old {
72            None => Ok(()),
73            Some(old) => Err(PushError::Replaced(old)),
74        }
75    }
76
77    pub fn remove(&mut self, id: ID) -> Option<BoxFuture<'static, O>> {
78        let tagged = self.inner.iter_mut().find(|s| s.tag == id)?;
79
80        let inner = mem::replace(&mut tagged.inner.inner, future::pending().boxed());
81        tagged.inner.cancelled = true;
82
83        Some(inner)
84    }
85
86    pub fn contains(&self, id: ID) -> bool {
87        self.inner.iter().any(|f| f.tag == id && !f.inner.cancelled)
88    }
89
90    pub fn len(&self) -> usize {
91        self.inner.len()
92    }
93
94    pub fn is_empty(&self) -> bool {
95        self.inner.is_empty()
96    }
97
98    #[allow(unknown_lints, clippy::needless_pass_by_ref_mut)] // &mut Context is idiomatic.
99    pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> {
100        if self.inner.len() < self.capacity {
101            return Poll::Ready(());
102        }
103
104        self.full_waker = Some(cx.waker().clone());
105
106        Poll::Pending
107    }
108
109    pub fn poll_unpin(&mut self, cx: &mut Context<'_>) -> Poll<(ID, Result<O, Timeout>)> {
110        loop {
111            let maybe_result = futures_util::ready!(self.inner.poll_next_unpin(cx));
112
113            match maybe_result {
114                None => {
115                    self.empty_waker = Some(cx.waker().clone());
116                    return Poll::Pending;
117                }
118                Some((id, Ok(output))) => return Poll::Ready((id, Ok(output))),
119                Some((id, Err(TimeoutError::Timeout))) => {
120                    return Poll::Ready((id, Err(Timeout::new(self.timeout))))
121                }
122                Some((_, Err(TimeoutError::Cancelled))) => continue,
123            }
124        }
125    }
126}
127
128struct TimeoutFuture<F> {
129    inner: F,
130    timeout: Delay,
131
132    cancelled: bool,
133}
134
135impl<F> Future for TimeoutFuture<F>
136where
137    F: Future + Unpin,
138{
139    type Output = Result<F::Output, TimeoutError>;
140
141    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
142        if self.cancelled {
143            return Poll::Ready(Err(TimeoutError::Cancelled));
144        }
145
146        if self.timeout.poll_unpin(cx).is_ready() {
147            return Poll::Ready(Err(TimeoutError::Timeout));
148        }
149
150        self.inner.poll_unpin(cx).map(Ok)
151    }
152}
153
154enum TimeoutError {
155    Timeout,
156    Cancelled,
157}
158
159struct TaggedFuture<T, F> {
160    tag: T,
161    inner: F,
162}
163
164impl<T, F> Future for TaggedFuture<T, F>
165where
166    T: Clone + Unpin,
167    F: Future + Unpin,
168{
169    type Output = (T, F::Output);
170
171    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
172        let output = futures_util::ready!(self.inner.poll_unpin(cx));
173
174        Poll::Ready((self.tag.clone(), output))
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use futures::channel::oneshot;
181    use futures_util::task::noop_waker_ref;
182    use std::future::{pending, poll_fn, ready};
183    use std::pin::Pin;
184    use std::time::Instant;
185
186    use super::*;
187
188    #[test]
189    fn cannot_push_more_than_capacity_tasks() {
190        let mut futures = FuturesMap::new(Duration::from_secs(10), 1);
191
192        assert!(futures.try_push("ID_1", ready(())).is_ok());
193        matches!(
194            futures.try_push("ID_2", ready(())),
195            Err(PushError::BeyondCapacity(_))
196        );
197    }
198
199    #[test]
200    fn cannot_push_the_same_id_few_times() {
201        let mut futures = FuturesMap::new(Duration::from_secs(10), 5);
202
203        assert!(futures.try_push("ID", ready(())).is_ok());
204        matches!(
205            futures.try_push("ID", ready(())),
206            Err(PushError::Replaced(_))
207        );
208    }
209
210    #[tokio::test]
211    async fn futures_timeout() {
212        let mut futures = FuturesMap::new(Duration::from_millis(100), 1);
213
214        let _ = futures.try_push("ID", pending::<()>());
215        Delay::new(Duration::from_millis(150)).await;
216        let (_, result) = poll_fn(|cx| futures.poll_unpin(cx)).await;
217
218        assert!(result.is_err())
219    }
220
221    #[test]
222    fn resources_of_removed_future_are_cleaned_up() {
223        let mut futures = FuturesMap::new(Duration::from_millis(100), 1);
224
225        let _ = futures.try_push("ID", pending::<()>());
226        futures.remove("ID");
227
228        let poll = futures.poll_unpin(&mut Context::from_waker(noop_waker_ref()));
229        assert!(poll.is_pending());
230
231        assert_eq!(futures.len(), 0);
232    }
233
234    #[tokio::test]
235    async fn replaced_pending_future_is_polled() {
236        let mut streams = FuturesMap::new(Duration::from_millis(100), 3);
237
238        let (_tx1, rx1) = oneshot::channel();
239        let (tx2, rx2) = oneshot::channel();
240
241        let _ = streams.try_push("ID1", rx1);
242        let _ = streams.try_push("ID2", rx2);
243
244        let _ = tx2.send(2);
245        let (id, res) = poll_fn(|cx| streams.poll_unpin(cx)).await;
246        assert_eq!(id, "ID2");
247        assert_eq!(res.unwrap().unwrap(), 2);
248
249        let (new_tx1, new_rx1) = oneshot::channel();
250        let replaced = streams.try_push("ID1", new_rx1);
251        assert!(matches!(replaced.unwrap_err(), PushError::Replaced(_)));
252
253        let _ = new_tx1.send(4);
254        let (id, res) = poll_fn(|cx| streams.poll_unpin(cx)).await;
255
256        assert_eq!(id, "ID1");
257        assert_eq!(res.unwrap().unwrap(), 4);
258    }
259
260    // Each future causes a delay, `Task` only has a capacity of 1, meaning they must be processed in sequence.
261    // We stop after NUM_FUTURES tasks, meaning the overall execution must at least take DELAY * NUM_FUTURES.
262    #[tokio::test]
263    async fn backpressure() {
264        const DELAY: Duration = Duration::from_millis(100);
265        const NUM_FUTURES: u32 = 10;
266
267        let start = Instant::now();
268        Task::new(DELAY, NUM_FUTURES, 1).await;
269        let duration = start.elapsed();
270
271        assert!(duration >= DELAY * NUM_FUTURES);
272    }
273
274    #[test]
275    fn contains() {
276        let mut futures = FuturesMap::new(Duration::from_secs(10), 1);
277        _ = futures.try_push("ID", pending::<()>());
278        assert!(futures.contains("ID"));
279        _ = futures.remove("ID");
280        assert!(!futures.contains("ID"));
281    }
282
283    struct Task {
284        future: Duration,
285        num_futures: usize,
286        num_processed: usize,
287        inner: FuturesMap<u8, ()>,
288    }
289
290    impl Task {
291        fn new(future: Duration, num_futures: u32, capacity: usize) -> Self {
292            Self {
293                future,
294                num_futures: num_futures as usize,
295                num_processed: 0,
296                inner: FuturesMap::new(Duration::from_secs(60), capacity),
297            }
298        }
299    }
300
301    impl Future for Task {
302        type Output = ();
303
304        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
305            let this = self.get_mut();
306
307            while this.num_processed < this.num_futures {
308                if let Poll::Ready((_, result)) = this.inner.poll_unpin(cx) {
309                    if result.is_err() {
310                        panic!("Timeout is great than future delay")
311                    }
312
313                    this.num_processed += 1;
314                    continue;
315                }
316
317                if let Poll::Ready(()) = this.inner.poll_ready_unpin(cx) {
318                    // We push the constant future's ID to prove that user can use the same ID
319                    // if the future was finished
320                    let maybe_future = this.inner.try_push(1u8, Delay::new(this.future));
321                    assert!(maybe_future.is_ok(), "we polled for readiness");
322
323                    continue;
324                }
325
326                return Poll::Pending;
327            }
328
329            Poll::Ready(())
330        }
331    }
332}