futures_bounded/
stream_map.rs

1use std::mem;
2use std::pin::Pin;
3use std::task::{Context, Poll, Waker};
4use std::time::Duration;
5
6use futures_timer::Delay;
7use futures_util::stream::{BoxStream, SelectAll};
8use futures_util::{stream, FutureExt, Stream, StreamExt};
9
10use crate::{PushError, Timeout};
11
12/// Represents a map of [`Stream`]s.
13///
14/// Each stream must finish within the specified time and the map never outgrows its capacity.
15pub struct StreamMap<ID, O> {
16    timeout: Duration,
17    capacity: usize,
18    inner: SelectAll<TaggedStream<ID, TimeoutStream<BoxStream<'static, O>>>>,
19    empty_waker: Option<Waker>,
20    full_waker: Option<Waker>,
21}
22
23impl<ID, O> StreamMap<ID, O>
24where
25    ID: Clone + Unpin,
26{
27    pub fn new(timeout: Duration, capacity: usize) -> Self {
28        Self {
29            timeout,
30            capacity,
31            inner: Default::default(),
32            empty_waker: None,
33            full_waker: None,
34        }
35    }
36}
37
38impl<ID, O> StreamMap<ID, O>
39where
40    ID: Clone + PartialEq + Send + Unpin + 'static,
41    O: Send + 'static,
42{
43    /// Push a stream into the map.
44    pub fn try_push<F>(&mut self, id: ID, stream: F) -> Result<(), PushError<BoxStream<O>>>
45    where
46        F: Stream<Item = O> + Send + 'static,
47    {
48        if self.inner.len() >= self.capacity {
49            return Err(PushError::BeyondCapacity(stream.boxed()));
50        }
51
52        if let Some(waker) = self.empty_waker.take() {
53            waker.wake();
54        }
55
56        let old = self.remove(id.clone());
57        self.inner.push(TaggedStream::new(
58            id,
59            TimeoutStream {
60                inner: stream.boxed(),
61                timeout: Delay::new(self.timeout),
62            },
63        ));
64
65        match old {
66            None => Ok(()),
67            Some(old) => Err(PushError::Replaced(old)),
68        }
69    }
70
71    pub fn remove(&mut self, id: ID) -> Option<BoxStream<'static, O>> {
72        let tagged = self.inner.iter_mut().find(|s| s.key == id)?;
73
74        let inner = mem::replace(&mut tagged.inner.inner, stream::pending().boxed());
75        tagged.exhausted = true; // Setting this will emit `None` on the next poll and ensure `SelectAll` cleans up the resources.
76
77        Some(inner)
78    }
79
80    pub fn len(&self) -> usize {
81        self.inner.len()
82    }
83
84    pub fn is_empty(&self) -> bool {
85        self.inner.is_empty()
86    }
87
88    #[allow(unknown_lints, clippy::needless_pass_by_ref_mut)] // &mut Context is idiomatic.
89    pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> {
90        if self.inner.len() < self.capacity {
91            return Poll::Ready(());
92        }
93
94        self.full_waker = Some(cx.waker().clone());
95
96        Poll::Pending
97    }
98
99    pub fn poll_next_unpin(
100        &mut self,
101        cx: &mut Context<'_>,
102    ) -> Poll<(ID, Option<Result<O, Timeout>>)> {
103        match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
104            None => {
105                self.empty_waker = Some(cx.waker().clone());
106                Poll::Pending
107            }
108            Some((id, Some(Ok(output)))) => Poll::Ready((id, Some(Ok(output)))),
109            Some((id, Some(Err(())))) => {
110                self.remove(id.clone()); // Remove stream, otherwise we keep reporting the timeout.
111
112                Poll::Ready((id, Some(Err(Timeout::new(self.timeout)))))
113            }
114            Some((id, None)) => Poll::Ready((id, None)),
115        }
116    }
117}
118
119struct TimeoutStream<S> {
120    inner: S,
121    timeout: Delay,
122}
123
124impl<F> Stream for TimeoutStream<F>
125where
126    F: Stream + Unpin,
127{
128    type Item = Result<F::Item, ()>;
129
130    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
131        if self.timeout.poll_unpin(cx).is_ready() {
132            return Poll::Ready(Some(Err(())));
133        }
134
135        self.inner.poll_next_unpin(cx).map(|a| a.map(Ok))
136    }
137}
138
139struct TaggedStream<K, S> {
140    key: K,
141    inner: S,
142
143    exhausted: bool,
144}
145
146impl<K, S> TaggedStream<K, S> {
147    fn new(key: K, inner: S) -> Self {
148        Self {
149            key,
150            inner,
151            exhausted: false,
152        }
153    }
154}
155
156impl<K, S> Stream for TaggedStream<K, S>
157where
158    K: Clone + Unpin,
159    S: Stream + Unpin,
160{
161    type Item = (K, Option<S::Item>);
162
163    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
164        if self.exhausted {
165            return Poll::Ready(None);
166        }
167
168        match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
169            Some(item) => Poll::Ready(Some((self.key.clone(), Some(item)))),
170            None => {
171                self.exhausted = true;
172
173                Poll::Ready(Some((self.key.clone(), None)))
174            }
175        }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use futures::channel::mpsc;
182    use futures_util::stream::{once, pending};
183    use futures_util::SinkExt;
184    use std::future::{poll_fn, ready, Future};
185    use std::pin::Pin;
186    use std::time::Instant;
187
188    use super::*;
189
190    #[test]
191    fn cannot_push_more_than_capacity_tasks() {
192        let mut streams = StreamMap::new(Duration::from_secs(10), 1);
193
194        assert!(streams.try_push("ID_1", once(ready(()))).is_ok());
195        matches!(
196            streams.try_push("ID_2", once(ready(()))),
197            Err(PushError::BeyondCapacity(_))
198        );
199    }
200
201    #[test]
202    fn cannot_push_the_same_id_few_times() {
203        let mut streams = StreamMap::new(Duration::from_secs(10), 5);
204
205        assert!(streams.try_push("ID", once(ready(()))).is_ok());
206        matches!(
207            streams.try_push("ID", once(ready(()))),
208            Err(PushError::Replaced(_))
209        );
210    }
211
212    #[tokio::test]
213    async fn streams_timeout() {
214        let mut streams = StreamMap::new(Duration::from_millis(100), 1);
215
216        let _ = streams.try_push("ID", pending::<()>());
217        Delay::new(Duration::from_millis(150)).await;
218        let (_, result) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
219
220        assert!(result.unwrap().is_err())
221    }
222
223    #[tokio::test]
224    async fn timed_out_stream_gets_removed() {
225        let mut streams = StreamMap::new(Duration::from_millis(100), 1);
226
227        let _ = streams.try_push("ID", pending::<()>());
228        Delay::new(Duration::from_millis(150)).await;
229        poll_fn(|cx| streams.poll_next_unpin(cx)).await;
230
231        let poll = streams.poll_next_unpin(&mut Context::from_waker(
232            futures_util::task::noop_waker_ref(),
233        ));
234        assert!(poll.is_pending())
235    }
236
237    #[test]
238    fn removing_stream() {
239        let mut streams = StreamMap::new(Duration::from_millis(100), 1);
240
241        let _ = streams.try_push("ID", stream::once(ready(())));
242
243        {
244            let cancelled_stream = streams.remove("ID");
245            assert!(cancelled_stream.is_some());
246        }
247
248        let poll = streams.poll_next_unpin(&mut Context::from_waker(
249            futures_util::task::noop_waker_ref(),
250        ));
251
252        assert!(poll.is_pending());
253        assert_eq!(
254            streams.len(),
255            0,
256            "resources of cancelled streams are cleaned up properly"
257        );
258    }
259
260    #[tokio::test]
261    async fn replaced_stream_is_still_registered() {
262        let mut streams = StreamMap::new(Duration::from_millis(100), 3);
263
264        let (mut tx1, rx1) = mpsc::channel(5);
265        let (mut tx2, rx2) = mpsc::channel(5);
266
267        let _ = streams.try_push("ID1", rx1);
268        let _ = streams.try_push("ID2", rx2);
269
270        let _ = tx2.send(2).await;
271        let _ = tx1.send(1).await;
272        let _ = tx2.send(3).await;
273        let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
274        assert_eq!(id, "ID1");
275        assert_eq!(res.unwrap().unwrap(), 1);
276        let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
277        assert_eq!(id, "ID2");
278        assert_eq!(res.unwrap().unwrap(), 2);
279        let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
280        assert_eq!(id, "ID2");
281        assert_eq!(res.unwrap().unwrap(), 3);
282
283        let (mut new_tx1, new_rx1) = mpsc::channel(5);
284        let replaced = streams.try_push("ID1", new_rx1);
285        assert!(matches!(replaced.unwrap_err(), PushError::Replaced(_)));
286
287        let _ = new_tx1.send(4).await;
288        let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
289
290        assert_eq!(id, "ID1");
291        assert_eq!(res.unwrap().unwrap(), 4);
292    }
293
294    // Each stream emits 1 item with delay, `Task` only has a capacity of 1, meaning they must be processed in sequence.
295    // We stop after NUM_STREAMS tasks, meaning the overall execution must at least take DELAY * NUM_STREAMS.
296    #[tokio::test]
297    async fn backpressure() {
298        const DELAY: Duration = Duration::from_millis(100);
299        const NUM_STREAMS: u32 = 10;
300
301        let start = Instant::now();
302        Task::new(DELAY, NUM_STREAMS, 1).await;
303        let duration = start.elapsed();
304
305        assert!(duration >= DELAY * NUM_STREAMS);
306    }
307
308    struct Task {
309        item_delay: Duration,
310        num_streams: usize,
311        num_processed: usize,
312        inner: StreamMap<u8, ()>,
313    }
314
315    impl Task {
316        fn new(item_delay: Duration, num_streams: u32, capacity: usize) -> Self {
317            Self {
318                item_delay,
319                num_streams: num_streams as usize,
320                num_processed: 0,
321                inner: StreamMap::new(Duration::from_secs(60), capacity),
322            }
323        }
324    }
325
326    impl Future for Task {
327        type Output = ();
328
329        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
330            let this = self.get_mut();
331
332            while this.num_processed < this.num_streams {
333                match this.inner.poll_next_unpin(cx) {
334                    Poll::Ready((_, Some(result))) => {
335                        if result.is_err() {
336                            panic!("Timeout is great than item delay")
337                        }
338
339                        this.num_processed += 1;
340                        continue;
341                    }
342                    Poll::Ready((_, None)) => {
343                        continue;
344                    }
345                    _ => {}
346                }
347
348                if let Poll::Ready(()) = this.inner.poll_ready_unpin(cx) {
349                    // We push the constant ID to prove that user can use the same ID if the stream was finished
350                    let maybe_future = this.inner.try_push(1u8, once(Delay::new(this.item_delay)));
351                    assert!(maybe_future.is_ok(), "we polled for readiness");
352
353                    continue;
354                }
355
356                return Poll::Pending;
357            }
358
359            Poll::Ready(())
360        }
361    }
362}