1use crate::*;
2use core::ops::DerefMut;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5
6fn poll_multiple_step<I, P, S>(
7 streams: I,
8 cx: &mut Context<'_>,
9 before: Option<&S::Ordering>,
10 mut retry: Option<&mut Option<S::Ordering>>,
11) -> Poll<PollResult<S::Ordering, S::Data>>
12where
13 I: IntoIterator<Item = Pin<P>>,
14 P: DerefMut<Target = Peekable<S>>,
15 S: OrderedStream,
16 S::Ordering: Clone,
17{
18 let mut best: Option<Pin<P>> = None;
20 let mut has_data = false;
22 let mut has_pending = false;
23 let mut skip_retry = false;
24 for mut stream in streams {
25 let best_before = best.as_ref().and_then(|p| p.item().map(|i| &i.0));
26 let current_bound = match (before, best_before) {
27 (Some(given), Some(best)) if given <= best => Some(given),
28 (_, Some(best)) => Some(best),
29 (given, None) => given,
30 };
31 match stream.as_mut().poll_peek_before(cx, current_bound) {
34 Poll::Pending => {
35 has_pending = true;
36 skip_retry = true;
37 }
38 Poll::Ready(PollResult::Terminated) => continue,
39 Poll::Ready(PollResult::NoneBefore) => {
40 has_data = true;
41 }
42 Poll::Ready(PollResult::Item { ordering, .. }) => {
43 has_data = true;
44 match current_bound {
45 Some(max) if max < ordering => continue,
46 _ => {}
47 }
48 match (&mut retry, before, has_pending) {
49 (Some(retry), Some(initial_bound), true) if ordering < initial_bound => {
50 **retry = Some(ordering.clone());
55 skip_retry = false;
56 }
57 (Some(retry), None, true) => {
58 **retry = Some(ordering.clone());
59 skip_retry = false;
60 }
61 _ => {}
62 }
63 best = Some(stream);
64 }
65 }
66 }
67 if skip_retry {
68 retry.map(|r| *r = None);
69 }
70 match best {
71 _ if has_pending => Poll::Pending,
72 None if has_data => Poll::Ready(PollResult::NoneBefore),
73 None => Poll::Ready(PollResult::Terminated),
74 Some(mut stream) => stream.as_mut().poll_next_before(cx, before),
76 }
77}
78
79#[derive(Debug, Default, Clone)]
97pub struct JoinMultiple<C>(pub C);
98impl<C> Unpin for JoinMultiple<C> {}
99
100impl<C, S> OrderedStream for JoinMultiple<C>
101where
102 for<'a> &'a mut C: IntoIterator<Item = &'a mut Peekable<S>>,
103 S: OrderedStream + Unpin,
104 S::Ordering: Clone,
105{
106 type Ordering = S::Ordering;
107 type Data = S::Data;
108 fn poll_next_before(
109 mut self: Pin<&mut Self>,
110 cx: &mut Context<'_>,
111 before: Option<&S::Ordering>,
112 ) -> Poll<PollResult<S::Ordering, S::Data>> {
113 let mut retry = None;
114 let rv = poll_multiple_step(
115 self.as_mut().get_mut().0.into_iter().map(Pin::new),
116 cx,
117 before,
118 Some(&mut retry),
119 );
120 if rv.is_pending() && retry.is_some() {
121 poll_multiple_step(
122 self.get_mut().0.into_iter().map(Pin::new),
123 cx,
124 retry.as_ref(),
125 None,
126 )
127 } else {
128 rv
129 }
130 }
131}
132
133impl<C, S> FusedOrderedStream for JoinMultiple<C>
134where
135 for<'a> &'a mut C: IntoIterator<Item = &'a mut Peekable<S>>,
136 for<'a> &'a C: IntoIterator<Item = &'a Peekable<S>>,
137 S: OrderedStream + Unpin,
138 S::Ordering: Clone,
139{
140 fn is_terminated(&self) -> bool {
141 self.0.into_iter().all(|peekable| peekable.is_terminated())
142 }
143}
144
145pin_project_lite::pin_project! {
146 #[derive(Debug,Default,Clone)]
155 pub struct JoinMultiplePin<C> {
156 #[pin]
157 pub streams: C,
158 }
159}
160
161impl<C> JoinMultiplePin<C> {
162 pub fn as_pin_mut(self: Pin<&mut Self>) -> Pin<&mut C> {
163 self.project().streams
164 }
165}
166
167impl<C, S> OrderedStream for JoinMultiplePin<C>
168where
169 for<'a> Pin<&'a mut C>: IntoIterator<Item = Pin<&'a mut Peekable<S>>>,
170 S: OrderedStream,
171 S::Ordering: Clone,
172{
173 type Ordering = S::Ordering;
174 type Data = S::Data;
175 fn poll_next_before(
176 mut self: Pin<&mut Self>,
177 cx: &mut Context<'_>,
178 before: Option<&S::Ordering>,
179 ) -> Poll<PollResult<S::Ordering, S::Data>> {
180 let mut retry = None;
181 let rv = poll_multiple_step(self.as_mut().as_pin_mut(), cx, before, Some(&mut retry));
182 if rv.is_pending() && retry.is_some() {
183 poll_multiple_step(self.as_pin_mut(), cx, retry.as_ref(), None)
184 } else {
185 rv
186 }
187 }
188}
189
190#[cfg(test)]
191mod test {
192 extern crate alloc;
193
194 use crate::{FromStream, JoinMultiple, OrderedStream, OrderedStreamExt, PollResult};
195 use alloc::{boxed::Box, rc::Rc, vec, vec::Vec};
196 use core::{cell::Cell, pin::Pin, task::Context, task::Poll};
197 use futures_core::Stream;
198 use futures_util::{pin_mut, stream::iter};
199
200 #[derive(Debug, PartialEq)]
201 pub struct Message {
202 serial: u32,
203 }
204
205 #[test]
206 fn join_mutiple() {
207 futures_executor::block_on(async {
208 pub struct RemoteLogSource {
209 stream: Pin<Box<dyn Stream<Item = Message>>>,
210 }
211
212 let mut logs = [
213 RemoteLogSource {
214 stream: Box::pin(iter([
215 Message { serial: 1 },
216 Message { serial: 4 },
217 Message { serial: 5 },
218 ])),
219 },
220 RemoteLogSource {
221 stream: Box::pin(iter([
222 Message { serial: 2 },
223 Message { serial: 3 },
224 Message { serial: 6 },
225 ])),
226 },
227 ];
228 let streams: Vec<_> = logs
229 .iter_mut()
230 .map(|s| FromStream::with_ordering(&mut s.stream, |m| m.serial).peekable())
231 .collect();
232 let mut joined = JoinMultiple(streams);
233 for i in 0..6 {
234 let msg = joined.next().await.unwrap();
235 assert_eq!(msg.serial, i as u32 + 1);
236 }
237 });
238 }
239
240 #[test]
241 fn join_one_slow() {
242 futures_executor::block_on(async {
243 pub struct DelayStream(Rc<Cell<u8>>);
244
245 impl OrderedStream for DelayStream {
246 type Ordering = u32;
247 type Data = Message;
248 fn poll_next_before(
249 self: Pin<&mut Self>,
250 _: &mut Context<'_>,
251 before: Option<&Self::Ordering>,
252 ) -> Poll<PollResult<Self::Ordering, Self::Data>> {
253 match self.0.get() {
254 0 => Poll::Pending,
255 1 if matches!(before, Some(&1)) => Poll::Ready(PollResult::NoneBefore),
256 1 => Poll::Pending,
257
258 2 => {
259 self.0.set(3);
260 Poll::Ready(PollResult::Item {
261 data: Message { serial: 4 },
262 ordering: 4,
263 })
264 }
265 _ => Poll::Ready(PollResult::Terminated),
266 }
267 }
268 }
269
270 let stream1 = iter([
271 Message { serial: 1 },
272 Message { serial: 3 },
273 Message { serial: 5 },
274 ]);
275
276 let stream1 = FromStream::with_ordering(stream1, |m| m.serial);
277 let go = Rc::new(Cell::new(0));
278 let stream2 = DelayStream(go.clone());
279
280 let stream1: Pin<Box<dyn OrderedStream<Ordering = u32, Data = Message>>> =
281 Box::pin(stream1);
282 let stream2: Pin<Box<dyn OrderedStream<Ordering = u32, Data = Message>>> =
283 Box::pin(stream2);
284 let streams = vec![stream1.peekable(), stream2.peekable()];
285 let join = JoinMultiple(streams);
286 let waker = futures_util::task::noop_waker();
287 let mut ctx = core::task::Context::from_waker(&waker);
288
289 pin_mut!(join);
290
291 assert_eq!(
294 join.as_mut().poll_next_before(&mut ctx, None),
295 Poll::Pending
296 );
297
298 go.set(1);
299 assert_eq!(
301 join.as_mut().poll_next_before(&mut ctx, None),
302 Poll::Ready(PollResult::Item {
303 data: Message { serial: 1 },
304 ordering: 1,
305 })
306 );
307 assert_eq!(
309 join.as_mut().poll_next_before(&mut ctx, None),
310 Poll::Pending
311 );
312
313 go.set(2);
314 assert_eq!(
315 join.as_mut().poll_next_before(&mut ctx, None),
316 Poll::Ready(PollResult::Item {
317 data: Message { serial: 3 },
318 ordering: 3,
319 })
320 );
321 assert_eq!(
322 join.as_mut().poll_next_before(&mut ctx, None),
323 Poll::Ready(PollResult::Item {
324 data: Message { serial: 4 },
325 ordering: 4,
326 })
327 );
328 assert_eq!(
329 join.as_mut().poll_next_before(&mut ctx, None),
330 Poll::Ready(PollResult::Item {
331 data: Message { serial: 5 },
332 ordering: 5,
333 })
334 );
335
336 assert_eq!(
337 join.as_mut().poll_next_before(&mut ctx, None),
338 Poll::Ready(PollResult::Terminated)
339 );
340 });
341 }
342}