1use std::mem;
2use std::pin::Pin;
3use std::task::{Context, Poll, Waker};
4use std::time::Duration;
5
6use futures_timer::Delay;
7use futures_util::stream::{BoxStream, SelectAll};
8use futures_util::{stream, FutureExt, Stream, StreamExt};
9
10use crate::{PushError, Timeout};
11
12pub struct StreamMap<ID, O> {
16 timeout: Duration,
17 capacity: usize,
18 inner: SelectAll<TaggedStream<ID, TimeoutStream<BoxStream<'static, O>>>>,
19 empty_waker: Option<Waker>,
20 full_waker: Option<Waker>,
21}
22
23impl<ID, O> StreamMap<ID, O>
24where
25 ID: Clone + Unpin,
26{
27 pub fn new(timeout: Duration, capacity: usize) -> Self {
28 Self {
29 timeout,
30 capacity,
31 inner: Default::default(),
32 empty_waker: None,
33 full_waker: None,
34 }
35 }
36}
37
38impl<ID, O> StreamMap<ID, O>
39where
40 ID: Clone + PartialEq + Send + Unpin + 'static,
41 O: Send + 'static,
42{
43 pub fn try_push<F>(&mut self, id: ID, stream: F) -> Result<(), PushError<BoxStream<O>>>
45 where
46 F: Stream<Item = O> + Send + 'static,
47 {
48 if self.inner.len() >= self.capacity {
49 return Err(PushError::BeyondCapacity(stream.boxed()));
50 }
51
52 if let Some(waker) = self.empty_waker.take() {
53 waker.wake();
54 }
55
56 let old = self.remove(id.clone());
57 self.inner.push(TaggedStream::new(
58 id,
59 TimeoutStream {
60 inner: stream.boxed(),
61 timeout: Delay::new(self.timeout),
62 },
63 ));
64
65 match old {
66 None => Ok(()),
67 Some(old) => Err(PushError::Replaced(old)),
68 }
69 }
70
71 pub fn remove(&mut self, id: ID) -> Option<BoxStream<'static, O>> {
72 let tagged = self.inner.iter_mut().find(|s| s.key == id)?;
73
74 let inner = mem::replace(&mut tagged.inner.inner, stream::pending().boxed());
75 tagged.exhausted = true; Some(inner)
78 }
79
80 pub fn len(&self) -> usize {
81 self.inner.len()
82 }
83
84 pub fn is_empty(&self) -> bool {
85 self.inner.is_empty()
86 }
87
88 #[allow(unknown_lints, clippy::needless_pass_by_ref_mut)] pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> {
90 if self.inner.len() < self.capacity {
91 return Poll::Ready(());
92 }
93
94 self.full_waker = Some(cx.waker().clone());
95
96 Poll::Pending
97 }
98
99 pub fn poll_next_unpin(
100 &mut self,
101 cx: &mut Context<'_>,
102 ) -> Poll<(ID, Option<Result<O, Timeout>>)> {
103 match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
104 None => {
105 self.empty_waker = Some(cx.waker().clone());
106 Poll::Pending
107 }
108 Some((id, Some(Ok(output)))) => Poll::Ready((id, Some(Ok(output)))),
109 Some((id, Some(Err(())))) => {
110 self.remove(id.clone()); Poll::Ready((id, Some(Err(Timeout::new(self.timeout)))))
113 }
114 Some((id, None)) => Poll::Ready((id, None)),
115 }
116 }
117}
118
119struct TimeoutStream<S> {
120 inner: S,
121 timeout: Delay,
122}
123
124impl<F> Stream for TimeoutStream<F>
125where
126 F: Stream + Unpin,
127{
128 type Item = Result<F::Item, ()>;
129
130 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
131 if self.timeout.poll_unpin(cx).is_ready() {
132 return Poll::Ready(Some(Err(())));
133 }
134
135 self.inner.poll_next_unpin(cx).map(|a| a.map(Ok))
136 }
137}
138
139struct TaggedStream<K, S> {
140 key: K,
141 inner: S,
142
143 exhausted: bool,
144}
145
146impl<K, S> TaggedStream<K, S> {
147 fn new(key: K, inner: S) -> Self {
148 Self {
149 key,
150 inner,
151 exhausted: false,
152 }
153 }
154}
155
156impl<K, S> Stream for TaggedStream<K, S>
157where
158 K: Clone + Unpin,
159 S: Stream + Unpin,
160{
161 type Item = (K, Option<S::Item>);
162
163 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
164 if self.exhausted {
165 return Poll::Ready(None);
166 }
167
168 match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
169 Some(item) => Poll::Ready(Some((self.key.clone(), Some(item)))),
170 None => {
171 self.exhausted = true;
172
173 Poll::Ready(Some((self.key.clone(), None)))
174 }
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use futures::channel::mpsc;
182 use futures_util::stream::{once, pending};
183 use futures_util::SinkExt;
184 use std::future::{poll_fn, ready, Future};
185 use std::pin::Pin;
186 use std::time::Instant;
187
188 use super::*;
189
190 #[test]
191 fn cannot_push_more_than_capacity_tasks() {
192 let mut streams = StreamMap::new(Duration::from_secs(10), 1);
193
194 assert!(streams.try_push("ID_1", once(ready(()))).is_ok());
195 matches!(
196 streams.try_push("ID_2", once(ready(()))),
197 Err(PushError::BeyondCapacity(_))
198 );
199 }
200
201 #[test]
202 fn cannot_push_the_same_id_few_times() {
203 let mut streams = StreamMap::new(Duration::from_secs(10), 5);
204
205 assert!(streams.try_push("ID", once(ready(()))).is_ok());
206 matches!(
207 streams.try_push("ID", once(ready(()))),
208 Err(PushError::Replaced(_))
209 );
210 }
211
212 #[tokio::test]
213 async fn streams_timeout() {
214 let mut streams = StreamMap::new(Duration::from_millis(100), 1);
215
216 let _ = streams.try_push("ID", pending::<()>());
217 Delay::new(Duration::from_millis(150)).await;
218 let (_, result) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
219
220 assert!(result.unwrap().is_err())
221 }
222
223 #[tokio::test]
224 async fn timed_out_stream_gets_removed() {
225 let mut streams = StreamMap::new(Duration::from_millis(100), 1);
226
227 let _ = streams.try_push("ID", pending::<()>());
228 Delay::new(Duration::from_millis(150)).await;
229 poll_fn(|cx| streams.poll_next_unpin(cx)).await;
230
231 let poll = streams.poll_next_unpin(&mut Context::from_waker(
232 futures_util::task::noop_waker_ref(),
233 ));
234 assert!(poll.is_pending())
235 }
236
237 #[test]
238 fn removing_stream() {
239 let mut streams = StreamMap::new(Duration::from_millis(100), 1);
240
241 let _ = streams.try_push("ID", stream::once(ready(())));
242
243 {
244 let cancelled_stream = streams.remove("ID");
245 assert!(cancelled_stream.is_some());
246 }
247
248 let poll = streams.poll_next_unpin(&mut Context::from_waker(
249 futures_util::task::noop_waker_ref(),
250 ));
251
252 assert!(poll.is_pending());
253 assert_eq!(
254 streams.len(),
255 0,
256 "resources of cancelled streams are cleaned up properly"
257 );
258 }
259
260 #[tokio::test]
261 async fn replaced_stream_is_still_registered() {
262 let mut streams = StreamMap::new(Duration::from_millis(100), 3);
263
264 let (mut tx1, rx1) = mpsc::channel(5);
265 let (mut tx2, rx2) = mpsc::channel(5);
266
267 let _ = streams.try_push("ID1", rx1);
268 let _ = streams.try_push("ID2", rx2);
269
270 let _ = tx2.send(2).await;
271 let _ = tx1.send(1).await;
272 let _ = tx2.send(3).await;
273 let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
274 assert_eq!(id, "ID1");
275 assert_eq!(res.unwrap().unwrap(), 1);
276 let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
277 assert_eq!(id, "ID2");
278 assert_eq!(res.unwrap().unwrap(), 2);
279 let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
280 assert_eq!(id, "ID2");
281 assert_eq!(res.unwrap().unwrap(), 3);
282
283 let (mut new_tx1, new_rx1) = mpsc::channel(5);
284 let replaced = streams.try_push("ID1", new_rx1);
285 assert!(matches!(replaced.unwrap_err(), PushError::Replaced(_)));
286
287 let _ = new_tx1.send(4).await;
288 let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
289
290 assert_eq!(id, "ID1");
291 assert_eq!(res.unwrap().unwrap(), 4);
292 }
293
294 #[tokio::test]
297 async fn backpressure() {
298 const DELAY: Duration = Duration::from_millis(100);
299 const NUM_STREAMS: u32 = 10;
300
301 let start = Instant::now();
302 Task::new(DELAY, NUM_STREAMS, 1).await;
303 let duration = start.elapsed();
304
305 assert!(duration >= DELAY * NUM_STREAMS);
306 }
307
308 struct Task {
309 item_delay: Duration,
310 num_streams: usize,
311 num_processed: usize,
312 inner: StreamMap<u8, ()>,
313 }
314
315 impl Task {
316 fn new(item_delay: Duration, num_streams: u32, capacity: usize) -> Self {
317 Self {
318 item_delay,
319 num_streams: num_streams as usize,
320 num_processed: 0,
321 inner: StreamMap::new(Duration::from_secs(60), capacity),
322 }
323 }
324 }
325
326 impl Future for Task {
327 type Output = ();
328
329 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
330 let this = self.get_mut();
331
332 while this.num_processed < this.num_streams {
333 match this.inner.poll_next_unpin(cx) {
334 Poll::Ready((_, Some(result))) => {
335 if result.is_err() {
336 panic!("Timeout is great than item delay")
337 }
338
339 this.num_processed += 1;
340 continue;
341 }
342 Poll::Ready((_, None)) => {
343 continue;
344 }
345 _ => {}
346 }
347
348 if let Poll::Ready(()) = this.inner.poll_ready_unpin(cx) {
349 let maybe_future = this.inner.try_push(1u8, once(Delay::new(this.item_delay)));
351 assert!(maybe_future.is_ok(), "we polled for readiness");
352
353 continue;
354 }
355
356 return Poll::Pending;
357 }
358
359 Poll::Ready(())
360 }
361 }
362}