ordered_stream/
join.rs

1use crate::*;
2use core::mem;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5
6pin_project_lite::pin_project! {
7    /// A stream for the [`join`](fn.join.html) function.
8    #[derive(Debug)]
9    pub struct Join<A, B>
10    where
11        A: OrderedStream,
12        B: OrderedStream<Data = A::Data, Ordering=A::Ordering>,
13    {
14        #[pin]
15        stream_a: A,
16        #[pin]
17        stream_b: B,
18        state: JoinState<A::Data, B::Data, A::Ordering>,
19    }
20}
21
22/// Join two streams while preserving the overall ordering of elements.
23///
24/// You can think of this as implementing the "merge" step of a merge sort on the two streams,
25/// producing a single stream that is sorted given two sorted streams.  If the streams return
26/// [`PollResult::NoneBefore`] as intended, then the joined stream will be able to produce items
27/// when only one of the sources has unblocked.
28pub fn join<A, B>(stream_a: A, stream_b: B) -> Join<A, B>
29where
30    A: OrderedStream,
31    B: OrderedStream<Data = A::Data, Ordering = A::Ordering>,
32{
33    Join {
34        stream_a,
35        stream_b,
36        state: JoinState::None,
37    }
38}
39
40#[derive(Debug)]
41enum JoinState<A, B, T> {
42    None,
43    A(A, T),
44    B(B, T),
45    OnlyPollA,
46    OnlyPollB,
47    Terminated,
48}
49
50impl<A, B, T> JoinState<A, B, T> {
51    fn take_split(&mut self) -> (PollState<A, T>, PollState<B, T>) {
52        match mem::replace(self, JoinState::None) {
53            JoinState::None => (PollState::Pending, PollState::Pending),
54            JoinState::A(a, t) => (PollState::Item(a, t), PollState::Pending),
55            JoinState::B(b, t) => (PollState::Pending, PollState::Item(b, t)),
56            JoinState::OnlyPollA => (PollState::Pending, PollState::Terminated),
57            JoinState::OnlyPollB => (PollState::Terminated, PollState::Pending),
58            JoinState::Terminated => (PollState::Terminated, PollState::Terminated),
59        }
60    }
61}
62
63/// A helper equivalent to Poll<PollResult<T, I>> but easier to match
64pub(crate) enum PollState<I, T> {
65    Item(I, T),
66    Pending,
67    NoneBefore,
68    Terminated,
69}
70
71impl<I, T: Ord> PollState<I, T> {
72    fn ordering(&self) -> Option<&T> {
73        match self {
74            Self::Item(_, t) => Some(t),
75            _ => None,
76        }
77    }
78
79    fn update(
80        &mut self,
81        before: Option<&T>,
82        other_token: Option<&T>,
83        retry: bool,
84        run: impl FnOnce(Option<&T>) -> Poll<PollResult<T, I>>,
85    ) -> bool {
86        match self {
87            // Do not re-poll if we have an item already or if we are terminated
88            Self::Item { .. } | Self::Terminated => return false,
89
90            // No need to re-poll if we already declared no items <= before
91            Self::NoneBefore if retry => return false,
92
93            _ => {}
94        }
95
96        // Run the poll with the earlier of the two tokens to avoid transitioning to Pending (which
97        // will stall the Join) when we could have transitioned to NoneBefore.
98        let ordering = match (before, other_token) {
99            (Some(u), Some(o)) => {
100                if *u > *o {
101                    // The other ordering is earlier - so a retry might let us upgrade a Pending to a
102                    // NoneBefore
103                    Some(o)
104                } else if retry {
105                    // A retry will not improve matters, so don't bother
106                    return false;
107                } else {
108                    Some(u)
109                }
110            }
111            (Some(t), None) | (None, Some(t)) => Some(t),
112            (None, None) => None,
113        };
114
115        *self = run(ordering).into();
116        matches!(self, Self::Item { .. })
117    }
118}
119
120impl<I, T> From<PollState<I, T>> for Poll<PollResult<T, I>> {
121    fn from(poll: PollState<I, T>) -> Self {
122        match poll {
123            PollState::Item(data, ordering) => Poll::Ready(PollResult::Item { data, ordering }),
124            PollState::Pending => Poll::Pending,
125            PollState::NoneBefore => Poll::Ready(PollResult::NoneBefore),
126            PollState::Terminated => Poll::Ready(PollResult::Terminated),
127        }
128    }
129}
130
131impl<I, T> From<Poll<PollResult<T, I>>> for PollState<I, T> {
132    fn from(poll: Poll<PollResult<T, I>>) -> Self {
133        match poll {
134            Poll::Ready(PollResult::Item { data, ordering }) => Self::Item(data, ordering),
135            Poll::Ready(PollResult::NoneBefore) => Self::NoneBefore,
136            Poll::Ready(PollResult::Terminated) => Self::Terminated,
137            Poll::Pending => Self::Pending,
138        }
139    }
140}
141
142impl<A, B> Join<A, B>
143where
144    A: OrderedStream,
145    B: OrderedStream<Data = A::Data, Ordering = A::Ordering>,
146{
147    /// Split into the source streams.
148    ///
149    /// This method returns the source streams along with any buffered item and its
150    /// ordering.
151    pub fn into_inner(self) -> (A, B, Option<(A::Data, A::Ordering)>) {
152        let item = match self.state {
153            JoinState::A(a, o) => Some((a, o)),
154            JoinState::B(b, o) => Some((b, o)),
155            _ => None,
156        };
157
158        (self.stream_a, self.stream_b, item)
159    }
160
161    /// Provide direct access to the underlying stream.
162    ///
163    /// This may be useful if the stream provides APIs beyond [OrderedStream].  Note that the join
164    /// itself may be buffering an item from this stream, so you should consult
165    /// [Self::peek_buffered] and, if needed, [Self::take_buffered] before polling it directly.
166    pub fn stream_a(self: Pin<&mut Self>) -> Pin<&mut A> {
167        self.project().stream_a
168    }
169
170    /// Provide direct access to the underlying stream.
171    ///
172    /// This may be useful if the stream provides APIs beyond [OrderedStream].  Note that the join
173    /// itself may be buffering an item from this stream, so you should consult
174    /// [Self::peek_buffered] and, if needed, [Self::take_buffered] before polling it directly.
175    pub fn stream_b(self: Pin<&mut Self>) -> Pin<&mut B> {
176        self.project().stream_b
177    }
178
179    /// Allow access to the buffered item, if any.
180    ///
181    /// At most one of the two sides will be `Some`.  The returned item is a candidate for being
182    /// the next item returned by the joined stream, but it could not be returned by the most
183    /// recent [`OrderedStream::poll_next_before`] call.
184    pub fn peek_buffered(
185        self: Pin<&mut Self>,
186    ) -> (
187        Option<(&mut A::Data, &A::Ordering)>,
188        Option<(&mut B::Data, &B::Ordering)>,
189    ) {
190        match self.project().state {
191            JoinState::A(a, o) => (Some((a, o)), None),
192            JoinState::B(b, o) => (None, Some((b, o))),
193            _ => (None, None),
194        }
195    }
196
197    /// Remove the buffered item, if one is present.
198    ///
199    /// This does not poll either underlying stream.  See [Self::peek_buffered] for details on why
200    /// buffering exists.
201    pub fn take_buffered(self: Pin<&mut Self>) -> Option<(A::Data, A::Ordering)> {
202        let state = self.project().state;
203        match mem::replace(state, JoinState::None) {
204            JoinState::A(a, o) => Some((a, o)),
205            JoinState::B(b, o) => Some((b, o)),
206            other => {
207                *state = other;
208                None
209            }
210        }
211    }
212}
213
214impl<A, B> OrderedStream for Join<A, B>
215where
216    A: OrderedStream,
217    B: OrderedStream<Data = A::Data, Ordering = A::Ordering>,
218{
219    type Data = A::Data;
220    type Ordering = A::Ordering;
221
222    fn poll_next_before(
223        self: Pin<&mut Self>,
224        cx: &mut Context<'_>,
225        before: Option<&Self::Ordering>,
226    ) -> Poll<PollResult<Self::Ordering, Self::Data>> {
227        let mut this = self.project();
228        let (mut poll_a, mut poll_b) = this.state.take_split();
229
230        poll_a.update(before, poll_b.ordering(), false, |ordering| {
231            this.stream_a.as_mut().poll_next_before(cx, ordering)
232        });
233        if poll_b.update(before, poll_a.ordering(), false, |ordering| {
234            this.stream_b.as_mut().poll_next_before(cx, ordering)
235        }) {
236            // If B just got an item, it's possible that A already knows that it won't have any
237            // items before that item; we couldn't ask that question before.  Ask it now.
238            poll_a.update(before, poll_b.ordering(), true, |ordering| {
239                this.stream_a.as_mut().poll_next_before(cx, ordering)
240            });
241        }
242
243        match (poll_a, poll_b) {
244            // Both are ready - we can judge ordering directly (simplest case).  The first one is
245            // returned while the other one is buffered for the next poll.
246            (PollState::Item(a, ta), PollState::Item(b, tb)) => {
247                if ta <= tb {
248                    *this.state = JoinState::B(b, tb);
249                    Poll::Ready(PollResult::Item {
250                        data: a,
251                        ordering: ta,
252                    })
253                } else {
254                    *this.state = JoinState::A(a, ta);
255                    Poll::Ready(PollResult::Item {
256                        data: b,
257                        ordering: tb,
258                    })
259                }
260            }
261
262            // If both sides are terminated, so are we.
263            (PollState::Terminated, PollState::Terminated) => {
264                *this.state = JoinState::Terminated;
265                Poll::Ready(PollResult::Terminated)
266            }
267
268            // If one side is terminated, we can produce items directly from the other side.
269            (a, PollState::Terminated) => {
270                *this.state = JoinState::OnlyPollA;
271                a.into()
272            }
273            (PollState::Terminated, b) => {
274                *this.state = JoinState::OnlyPollB;
275                b.into()
276            }
277
278            // If one side is pending, we can't return Ready until that gets resolved.  Because we
279            // have already requested that our child streams wake us when it is possible to make
280            // any kind of progress, we meet the requirements to return Poll::Pending.
281            (PollState::Item(a, t), PollState::Pending) => {
282                *this.state = JoinState::A(a, t);
283                Poll::Pending
284            }
285            (PollState::Pending, PollState::Item(b, t)) => {
286                *this.state = JoinState::B(b, t);
287                Poll::Pending
288            }
289            (PollState::Pending, PollState::Pending) => Poll::Pending,
290            (PollState::Pending, PollState::NoneBefore) => Poll::Pending,
291            (PollState::NoneBefore, PollState::Pending) => Poll::Pending,
292
293            // If both sides report NoneBefore, so can we.
294            (PollState::NoneBefore, PollState::NoneBefore) => Poll::Ready(PollResult::NoneBefore),
295
296            (PollState::Item(data, ordering), PollState::NoneBefore) => {
297                // B was polled using either the Some value of (before) or using A's ordering.
298                //
299                // If before is set and is earlier than A's ordering, then B might later produce a
300                // value with (bt >= before && bt < at), so we can't return A's item yet and must
301                // buffer it.  However, we can return None because neither stream will produce
302                // items before the ordering passed in before.
303                //
304                // If before is either None or after A's ordering, B's NoneBefore return represents a
305                // promise to not produce an item before A's, so we can return A's item now.
306                match before {
307                    Some(before) if ordering > *before => {
308                        *this.state = JoinState::A(data, ordering);
309                        Poll::Ready(PollResult::NoneBefore)
310                    }
311                    _ => Poll::Ready(PollResult::Item { data, ordering }),
312                }
313            }
314
315            (PollState::NoneBefore, PollState::Item(data, ordering)) => {
316                // A was polled using either the Some value of (before) or using B's ordering.
317                //
318                // By a mirror of the above argument, this NoneBefore result gives us permission to
319                // produce either B's item or NoneBefore.
320                match before {
321                    Some(before) if ordering > *before => {
322                        *this.state = JoinState::B(data, ordering);
323                        Poll::Ready(PollResult::NoneBefore)
324                    }
325                    _ => Poll::Ready(PollResult::Item { data, ordering }),
326                }
327            }
328        }
329    }
330
331    fn position_hint(&self) -> Option<MaybeBorrowed<'_, Self::Ordering>> {
332        let (a, b) = match &self.state {
333            JoinState::None => (self.stream_a.position_hint(), self.stream_b.position_hint()),
334            JoinState::A(_, t) => (
335                Some(MaybeBorrowed::Borrowed(t)),
336                self.stream_b.position_hint(),
337            ),
338            JoinState::B(_, t) => (
339                self.stream_b.position_hint(),
340                Some(MaybeBorrowed::Borrowed(t)),
341            ),
342            JoinState::OnlyPollA => return self.stream_a.position_hint(),
343            JoinState::OnlyPollB => return self.stream_b.position_hint(),
344            JoinState::Terminated => return None,
345        };
346        // We can only provide a hint if we have a valid hint for both sides
347        match (a, b) {
348            (Some(a), Some(b)) if *a <= *b => Some(a),
349            (Some(_), Some(b)) => Some(b),
350            _ => None,
351        }
352    }
353
354    fn size_hint(&self) -> (usize, Option<usize>) {
355        let extra = match &self.state {
356            JoinState::None => 0,
357            JoinState::A { .. } => 1,
358            JoinState::B { .. } => 1,
359            JoinState::OnlyPollA => return self.stream_a.size_hint(),
360            JoinState::OnlyPollB => return self.stream_b.size_hint(),
361            JoinState::Terminated => return (0, Some(0)),
362        };
363        let (al, ah) = self.stream_a.size_hint();
364        let (bl, bh) = self.stream_b.size_hint();
365        let min = al.saturating_add(bl).saturating_add(extra);
366        let max = ah
367            .and_then(|a| bh.and_then(|b| a.checked_add(b)))
368            .and_then(|h| h.checked_add(extra));
369        (min, max)
370    }
371}
372
373impl<A, B> FusedOrderedStream for Join<A, B>
374where
375    A: OrderedStream,
376    B: OrderedStream<Data = A::Data, Ordering = A::Ordering>,
377{
378    fn is_terminated(&self) -> bool {
379        matches!(self.state, JoinState::Terminated)
380    }
381}
382
383#[cfg(test)]
384mod test {
385    extern crate alloc;
386    use crate::join;
387    use crate::FromStream;
388    use crate::OrderedStream;
389    use crate::OrderedStreamExt;
390    use crate::PollResult;
391    use alloc::rc::Rc;
392    use core::cell::Cell;
393    use core::pin::Pin;
394    use core::task::{Context, Poll};
395    use futures_executor::block_on;
396    use futures_util::pin_mut;
397    use futures_util::stream::iter;
398
399    #[derive(Debug, PartialEq)]
400    pub struct Message {
401        serial: u32,
402    }
403
404    #[test]
405    fn join_two() {
406        block_on(async {
407            let stream1 = iter([
408                Message { serial: 1 },
409                Message { serial: 4 },
410                Message { serial: 5 },
411            ]);
412
413            let stream2 = iter([
414                Message { serial: 2 },
415                Message { serial: 3 },
416                Message { serial: 6 },
417            ]);
418            let mut joined = join(
419                FromStream::with_ordering(stream1, |m| m.serial),
420                FromStream::with_ordering(stream2, |m| m.serial),
421            );
422            for i in 0..6 {
423                let msg = joined.next().await.unwrap();
424                assert_eq!(msg.serial, i as u32 + 1);
425            }
426        });
427    }
428
429    #[test]
430    fn join_one_slow() {
431        futures_executor::block_on(async {
432            pub struct DelayStream(Rc<Cell<u8>>);
433
434            impl OrderedStream for DelayStream {
435                type Ordering = u32;
436                type Data = Message;
437                fn poll_next_before(
438                    self: Pin<&mut Self>,
439                    _: &mut Context<'_>,
440                    before: Option<&Self::Ordering>,
441                ) -> Poll<PollResult<Self::Ordering, Self::Data>> {
442                    match self.0.get() {
443                        0 => Poll::Pending,
444                        1 if matches!(before, Some(&1)) => Poll::Ready(PollResult::NoneBefore),
445                        1 => Poll::Pending,
446
447                        2 => {
448                            self.0.set(3);
449                            Poll::Ready(PollResult::Item {
450                                data: Message { serial: 4 },
451                                ordering: 4,
452                            })
453                        }
454                        _ => Poll::Ready(PollResult::Terminated),
455                    }
456                }
457            }
458
459            let stream1 = iter([
460                Message { serial: 1 },
461                Message { serial: 3 },
462                Message { serial: 5 },
463            ]);
464
465            let stream1 = FromStream::with_ordering(stream1, |m| m.serial);
466            let go = Rc::new(Cell::new(0));
467            let stream2 = DelayStream(go.clone());
468
469            let join = join(stream1, stream2);
470            let waker = futures_util::task::noop_waker();
471            let mut ctx = core::task::Context::from_waker(&waker);
472
473            pin_mut!(join);
474
475            // When the DelayStream has no information about what it contains, join returns Pending
476            // (since there could be a serial-0 message output of DelayStream)
477            assert_eq!(
478                join.as_mut().poll_next_before(&mut ctx, None),
479                Poll::Pending
480            );
481
482            go.set(1);
483            // Now the DelayStream will return NoneBefore on serial 1
484            assert_eq!(
485                join.as_mut().poll_next_before(&mut ctx, None),
486                Poll::Ready(PollResult::Item {
487                    data: Message { serial: 1 },
488                    ordering: 1,
489                })
490            );
491            // however, it does not (yet) do so for serial 3
492            assert_eq!(
493                join.as_mut().poll_next_before(&mut ctx, None),
494                Poll::Pending
495            );
496
497            go.set(2);
498            assert_eq!(
499                join.as_mut().poll_next_before(&mut ctx, None),
500                Poll::Ready(PollResult::Item {
501                    data: Message { serial: 3 },
502                    ordering: 3,
503                })
504            );
505            assert_eq!(
506                join.as_mut().poll_next_before(&mut ctx, None),
507                Poll::Ready(PollResult::Item {
508                    data: Message { serial: 4 },
509                    ordering: 4,
510                })
511            );
512            assert_eq!(
513                join.as_mut().poll_next_before(&mut ctx, None),
514                Poll::Ready(PollResult::Item {
515                    data: Message { serial: 5 },
516                    ordering: 5,
517                })
518            );
519
520            assert_eq!(
521                join.as_mut().poll_next_before(&mut ctx, None),
522                Poll::Ready(PollResult::Terminated)
523            );
524        });
525    }
526}