futures_time/stream/
throttle.rs

1use pin_project_lite::pin_project;
2
3use futures_core::stream::Stream;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7pin_project! {
8    /// Filter out all items after the first for a specified time.
9    ///
10    /// This `struct` is created by the [`throttle`] method on [`StreamExt`]. See its
11    /// documentation for more.
12    ///
13    /// [`throttle`]: crate::stream::StreamExt::throttle
14    /// [`StreamExt`]: crate::stream::StreamExt
15    #[derive(Debug)]
16    #[must_use = "streams do nothing unless polled or .awaited"]
17    pub struct Throttle<S: Stream, I> {
18        #[pin]
19        stream: S,
20        #[pin]
21        interval: I,
22        state: State,
23        budget: usize,
24    }
25}
26
27impl<S: Stream, I> Throttle<S, I> {
28    pub(crate) fn new(stream: S, interval: I) -> Self {
29        Self {
30            state: State::Streaming(0),
31            stream,
32            interval,
33            budget: 1,
34        }
35    }
36}
37
38#[derive(Debug)]
39enum State {
40    /// The underlying stream is yielding items.
41    Streaming(usize),
42    /// All timers have completed and all data has been yielded.
43    StreamDone,
44    /// The closing `Ready(None)` has been yielded.
45    AllDone,
46}
47
48impl<S: Stream, I: Stream> Stream for Throttle<S, I> {
49    type Item = S::Item;
50
51    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
52        let mut this = self.project();
53
54        let mut slot = None;
55
56        match this.state {
57            // The underlying stream is yielding items.
58            State::Streaming(count) => {
59                // Poll the underlying stream until we get to `Poll::Pending`.
60                loop {
61                    match this.stream.as_mut().poll_next(cx) {
62                        Poll::Ready(Some(value)) => {
63                            if count < this.budget {
64                                slot = Some(value);
65                                *count += 1;
66                            }
67                        }
68                        Poll::Ready(None) => {
69                            *this.state = State::StreamDone;
70                            break;
71                        }
72                        Poll::Pending => break,
73                    }
74                }
75
76                // After the stream, always poll the interval timer.
77                let _ = this
78                    .interval
79                    .as_mut()
80                    .poll_next(cx)
81                    .map(move |_| match this.state {
82                        State::Streaming(count) => *count = 0, // reset the counter
83                        State::StreamDone => cx.waker().wake_by_ref(),
84                        State::AllDone => {}
85                    });
86                match slot {
87                    Some(item) => Poll::Ready(Some(item)),
88                    None => Poll::Pending,
89                }
90            }
91
92            // All streams have completed and all data has been yielded.
93            State::StreamDone => {
94                *this.state = State::AllDone;
95                Poll::Ready(None)
96            }
97
98            // The closing `Ready(None)` has been yielded.
99            State::AllDone => panic!("stream polled after completion"),
100        }
101    }
102}
103
104#[cfg(test)]
105mod test {
106    use crate::prelude::*;
107    use crate::time::Duration;
108    use futures_lite::prelude::*;
109
110    #[test]
111    fn smoke() {
112        async_io::block_on(async {
113            let interval = Duration::from_millis(100);
114            let throttle = Duration::from_millis(300);
115
116            let take = 4;
117            let expected = 2;
118
119            let mut counter = 0;
120            crate::stream::interval(interval)
121                .take(take)
122                .throttle(throttle)
123                .for_each(|_| counter += 1)
124                .await;
125
126            assert_eq!(counter, expected);
127        })
128    }
129}