ordered_stream/
multi.rs

1use crate::*;
2use core::ops::DerefMut;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5
6fn poll_multiple_step<I, P, S>(
7    streams: I,
8    cx: &mut Context<'_>,
9    before: Option<&S::Ordering>,
10    mut retry: Option<&mut Option<S::Ordering>>,
11) -> Poll<PollResult<S::Ordering, S::Data>>
12where
13    I: IntoIterator<Item = Pin<P>>,
14    P: DerefMut<Target = Peekable<S>>,
15    S: OrderedStream,
16    S::Ordering: Clone,
17{
18    // The stream with the earliest item that is actually before the given point
19    let mut best: Option<Pin<P>> = None;
20    // true if we have a stream that has not terminated
21    let mut has_data = false;
22    let mut has_pending = false;
23    let mut skip_retry = false;
24    for mut stream in streams {
25        let best_before = best.as_ref().and_then(|p| p.item().map(|i| &i.0));
26        let current_bound = match (before, best_before) {
27            (Some(given), Some(best)) if given <= best => Some(given),
28            (_, Some(best)) => Some(best),
29            (given, None) => given,
30        };
31        // improved is true if have improved the `before` bound from the initial value
32
33        match stream.as_mut().poll_peek_before(cx, current_bound) {
34            Poll::Pending => {
35                has_pending = true;
36                skip_retry = true;
37            }
38            Poll::Ready(PollResult::Terminated) => continue,
39            Poll::Ready(PollResult::NoneBefore) => {
40                has_data = true;
41            }
42            Poll::Ready(PollResult::Item { ordering, .. }) => {
43                has_data = true;
44                match current_bound {
45                    Some(max) if max < ordering => continue,
46                    _ => {}
47                }
48                match (&mut retry, before, has_pending) {
49                    (Some(retry), Some(initial_bound), true) if ordering < initial_bound => {
50                        // We have just improved the initial bound, so the streams that
51                        // previously returned Pending might be able to return NoneBefore in a
52                        // retry.  This is only useful if there are no later Pending returns, so
53                        // those will set skip_retry.
54                        **retry = Some(ordering.clone());
55                        skip_retry = false;
56                    }
57                    (Some(retry), None, true) => {
58                        **retry = Some(ordering.clone());
59                        skip_retry = false;
60                    }
61                    _ => {}
62                }
63                best = Some(stream);
64            }
65        }
66    }
67    if skip_retry {
68        retry.map(|r| *r = None);
69    }
70    match best {
71        _ if has_pending => Poll::Pending,
72        None if has_data => Poll::Ready(PollResult::NoneBefore),
73        None => Poll::Ready(PollResult::Terminated),
74        // This is guaranteed to return PollResult::Item
75        Some(mut stream) => stream.as_mut().poll_next_before(cx, before),
76    }
77}
78
79/// Join a collection of [`OrderedStream`]s.
80///
81/// This is similar to repeatedly using [`join()`] on all the streams in the contained collection.
82/// It is not optimized to avoid polling streams that are not ready, so it works best if the number
83/// of streams is relatively small.
84//
85// Unlike `FutureUnordered` or `SelectAll`, the ordering properties that this struct provides can
86// easily require that all items in the collection be consulted before returning any item.  An
87// example of such a situation is a series of streams that all generate timestamps (locally) for
88// their items and only return `NoneBefore` for past timestamps.  If only one stream produces an
89// item for each call to `JoinMultiple::poll_next_before`, that timestamp must be checked against
90// every other stream, and no amount of preparatory work or hints will help this.
91//
92// On the other hand, if all streams provide a position hint that matches their next item, it is
93// possible to build a priority queue to sort the streams and reduce the cost of a single poll from
94// `n` to `log(n)`.  This does require maintaining a snapshot of the hints (so S::Ordering: Clone),
95// and will significantly increase the worst-case workload, so it should be a distinct type.
96#[derive(Debug, Default, Clone)]
97pub struct JoinMultiple<C>(pub C);
98impl<C> Unpin for JoinMultiple<C> {}
99
100impl<C, S> OrderedStream for JoinMultiple<C>
101where
102    for<'a> &'a mut C: IntoIterator<Item = &'a mut Peekable<S>>,
103    S: OrderedStream + Unpin,
104    S::Ordering: Clone,
105{
106    type Ordering = S::Ordering;
107    type Data = S::Data;
108    fn poll_next_before(
109        mut self: Pin<&mut Self>,
110        cx: &mut Context<'_>,
111        before: Option<&S::Ordering>,
112    ) -> Poll<PollResult<S::Ordering, S::Data>> {
113        let mut retry = None;
114        let rv = poll_multiple_step(
115            self.as_mut().get_mut().0.into_iter().map(Pin::new),
116            cx,
117            before,
118            Some(&mut retry),
119        );
120        if rv.is_pending() && retry.is_some() {
121            poll_multiple_step(
122                self.get_mut().0.into_iter().map(Pin::new),
123                cx,
124                retry.as_ref(),
125                None,
126            )
127        } else {
128            rv
129        }
130    }
131}
132
133impl<C, S> FusedOrderedStream for JoinMultiple<C>
134where
135    for<'a> &'a mut C: IntoIterator<Item = &'a mut Peekable<S>>,
136    for<'a> &'a C: IntoIterator<Item = &'a Peekable<S>>,
137    S: OrderedStream + Unpin,
138    S::Ordering: Clone,
139{
140    fn is_terminated(&self) -> bool {
141        self.0.into_iter().all(|peekable| peekable.is_terminated())
142    }
143}
144
145pin_project_lite::pin_project! {
146    /// Join a collection of pinned [`OrderedStream`]s.
147    ///
148    /// This is identical to [`JoinMultiple`], but accepts [`OrderedStream`]s that are not [`Unpin`] by
149    /// requiring that the collection provide a pinned [`IntoIterator`] implementation.
150    ///
151    /// This is not a feature available in most `std` collections.  If you wish to use them, you
152    /// should use `Box::pin` to make the stream [`Unpin`] before inserting it in the collection,
153    /// and then use [`JoinMultiple`] on the resulting collection.
154    #[derive(Debug,Default,Clone)]
155    pub struct JoinMultiplePin<C> {
156        #[pin]
157        pub streams: C,
158    }
159}
160
161impl<C> JoinMultiplePin<C> {
162    pub fn as_pin_mut(self: Pin<&mut Self>) -> Pin<&mut C> {
163        self.project().streams
164    }
165}
166
167impl<C, S> OrderedStream for JoinMultiplePin<C>
168where
169    for<'a> Pin<&'a mut C>: IntoIterator<Item = Pin<&'a mut Peekable<S>>>,
170    S: OrderedStream,
171    S::Ordering: Clone,
172{
173    type Ordering = S::Ordering;
174    type Data = S::Data;
175    fn poll_next_before(
176        mut self: Pin<&mut Self>,
177        cx: &mut Context<'_>,
178        before: Option<&S::Ordering>,
179    ) -> Poll<PollResult<S::Ordering, S::Data>> {
180        let mut retry = None;
181        let rv = poll_multiple_step(self.as_mut().as_pin_mut(), cx, before, Some(&mut retry));
182        if rv.is_pending() && retry.is_some() {
183            poll_multiple_step(self.as_pin_mut(), cx, retry.as_ref(), None)
184        } else {
185            rv
186        }
187    }
188}
189
190#[cfg(test)]
191mod test {
192    extern crate alloc;
193
194    use crate::{FromStream, JoinMultiple, OrderedStream, OrderedStreamExt, PollResult};
195    use alloc::{boxed::Box, rc::Rc, vec, vec::Vec};
196    use core::{cell::Cell, pin::Pin, task::Context, task::Poll};
197    use futures_core::Stream;
198    use futures_util::{pin_mut, stream::iter};
199
200    #[derive(Debug, PartialEq)]
201    pub struct Message {
202        serial: u32,
203    }
204
205    #[test]
206    fn join_mutiple() {
207        futures_executor::block_on(async {
208            pub struct RemoteLogSource {
209                stream: Pin<Box<dyn Stream<Item = Message>>>,
210            }
211
212            let mut logs = [
213                RemoteLogSource {
214                    stream: Box::pin(iter([
215                        Message { serial: 1 },
216                        Message { serial: 4 },
217                        Message { serial: 5 },
218                    ])),
219                },
220                RemoteLogSource {
221                    stream: Box::pin(iter([
222                        Message { serial: 2 },
223                        Message { serial: 3 },
224                        Message { serial: 6 },
225                    ])),
226                },
227            ];
228            let streams: Vec<_> = logs
229                .iter_mut()
230                .map(|s| FromStream::with_ordering(&mut s.stream, |m| m.serial).peekable())
231                .collect();
232            let mut joined = JoinMultiple(streams);
233            for i in 0..6 {
234                let msg = joined.next().await.unwrap();
235                assert_eq!(msg.serial, i as u32 + 1);
236            }
237        });
238    }
239
240    #[test]
241    fn join_one_slow() {
242        futures_executor::block_on(async {
243            pub struct DelayStream(Rc<Cell<u8>>);
244
245            impl OrderedStream for DelayStream {
246                type Ordering = u32;
247                type Data = Message;
248                fn poll_next_before(
249                    self: Pin<&mut Self>,
250                    _: &mut Context<'_>,
251                    before: Option<&Self::Ordering>,
252                ) -> Poll<PollResult<Self::Ordering, Self::Data>> {
253                    match self.0.get() {
254                        0 => Poll::Pending,
255                        1 if matches!(before, Some(&1)) => Poll::Ready(PollResult::NoneBefore),
256                        1 => Poll::Pending,
257
258                        2 => {
259                            self.0.set(3);
260                            Poll::Ready(PollResult::Item {
261                                data: Message { serial: 4 },
262                                ordering: 4,
263                            })
264                        }
265                        _ => Poll::Ready(PollResult::Terminated),
266                    }
267                }
268            }
269
270            let stream1 = iter([
271                Message { serial: 1 },
272                Message { serial: 3 },
273                Message { serial: 5 },
274            ]);
275
276            let stream1 = FromStream::with_ordering(stream1, |m| m.serial);
277            let go = Rc::new(Cell::new(0));
278            let stream2 = DelayStream(go.clone());
279
280            let stream1: Pin<Box<dyn OrderedStream<Ordering = u32, Data = Message>>> =
281                Box::pin(stream1);
282            let stream2: Pin<Box<dyn OrderedStream<Ordering = u32, Data = Message>>> =
283                Box::pin(stream2);
284            let streams = vec![stream1.peekable(), stream2.peekable()];
285            let join = JoinMultiple(streams);
286            let waker = futures_util::task::noop_waker();
287            let mut ctx = core::task::Context::from_waker(&waker);
288
289            pin_mut!(join);
290
291            // When the DelayStream has no information about what it contains, join returns Pending
292            // (since there could be a serial-0 message output of DelayStream)
293            assert_eq!(
294                join.as_mut().poll_next_before(&mut ctx, None),
295                Poll::Pending
296            );
297
298            go.set(1);
299            // Now the DelayStream will return NoneBefore on serial 1
300            assert_eq!(
301                join.as_mut().poll_next_before(&mut ctx, None),
302                Poll::Ready(PollResult::Item {
303                    data: Message { serial: 1 },
304                    ordering: 1,
305                })
306            );
307            // however, it does not (yet) do so for serial 3
308            assert_eq!(
309                join.as_mut().poll_next_before(&mut ctx, None),
310                Poll::Pending
311            );
312
313            go.set(2);
314            assert_eq!(
315                join.as_mut().poll_next_before(&mut ctx, None),
316                Poll::Ready(PollResult::Item {
317                    data: Message { serial: 3 },
318                    ordering: 3,
319                })
320            );
321            assert_eq!(
322                join.as_mut().poll_next_before(&mut ctx, None),
323                Poll::Ready(PollResult::Item {
324                    data: Message { serial: 4 },
325                    ordering: 4,
326                })
327            );
328            assert_eq!(
329                join.as_mut().poll_next_before(&mut ctx, None),
330                Poll::Ready(PollResult::Item {
331                    data: Message { serial: 5 },
332                    ordering: 5,
333                })
334            );
335
336            assert_eq!(
337                join.as_mut().poll_next_before(&mut ctx, None),
338                Poll::Ready(PollResult::Terminated)
339            );
340        });
341    }
342}