fuel_core_services/
yield_stream.rs

1//! Stream that yields each `batch_size` items allowing other tasks to work.
2
3use futures::{
4    ready,
5    stream::Fuse,
6    Stream,
7    StreamExt,
8};
9use std::{
10    pin::Pin,
11    task::{
12        Context,
13        Poll,
14    },
15};
16
17pin_project_lite::pin_project! {
18    /// Stream that yields each `batch_size` items.
19    #[derive(Debug)]
20    #[must_use = "streams do nothing unless polled"]
21    pub struct YieldStream<St: Stream> {
22        #[pin]
23        stream: Fuse<St>,
24        item: Option<St::Item>,
25        counter: usize,
26        batch_size: usize,
27    }
28}
29
30impl<St: Stream> YieldStream<St> {
31    /// Create a new `YieldStream` with the given `batch_size`.
32    pub fn new(stream: St, batch_size: usize) -> Self {
33        assert!(batch_size > 0);
34
35        Self {
36            stream: stream.fuse(),
37            item: None,
38            counter: 0,
39            batch_size,
40        }
41    }
42}
43
44impl<St: Stream> Stream for YieldStream<St> {
45    type Item = St::Item;
46
47    fn poll_next(
48        mut self: Pin<&mut Self>,
49        cx: &mut Context<'_>,
50    ) -> Poll<Option<Self::Item>> {
51        let mut this = self.as_mut().project();
52
53        // If we have a cached item, return it because that means we were woken up.
54        if let Some(item) = this.item.take() {
55            *this.counter = 1;
56            return Poll::Ready(Some(item));
57        }
58
59        match ready!(this.stream.as_mut().poll_next(cx)) {
60            // Return items, unless we reached the batch size.
61            // after that, we want to yield before returning the next item.
62            Some(item) => {
63                if this.counter < this.batch_size {
64                    *this.counter = this.counter.saturating_add(1);
65
66                    Poll::Ready(Some(item))
67                } else {
68                    *this.item = Some(item);
69
70                    cx.waker().wake_by_ref();
71
72                    Poll::Pending
73                }
74            }
75
76            // Underlying stream ran out of values, so finish this stream as well.
77            None => Poll::Ready(None),
78        }
79    }
80
81    fn size_hint(&self) -> (usize, Option<usize>) {
82        let cached_len = usize::from(self.item.is_some());
83        let (lower, upper) = self.stream.size_hint();
84        let lower = lower.saturating_add(cached_len);
85        let upper = match upper {
86            Some(x) => x.checked_add(cached_len),
87            None => None,
88        };
89        (lower, upper)
90    }
91}
92
93/// Extension trait for `Stream`.
94pub trait StreamYieldExt: Stream {
95    /// Yields each `batch_size` items allowing other tasks to work.
96    fn yield_each(self, batch_size: usize) -> YieldStream<Self>
97    where
98        Self: Sized,
99    {
100        YieldStream::new(self, batch_size)
101    }
102}
103
104impl<St> StreamYieldExt for St where St: Stream {}
105
106#[cfg(test)]
107#[allow(non_snake_case)]
108mod tests {
109    use super::*;
110
111    #[tokio::test]
112    async fn yield_stream__works_with_10_elements_loop() {
113        let stream = futures::stream::iter(0..10);
114        let mut yield_stream = YieldStream::new(stream, 3);
115
116        let mut items = Vec::new();
117        while let Some(item) = yield_stream.next().await {
118            items.push(item);
119        }
120
121        assert_eq!(items, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
122    }
123
124    #[tokio::test]
125    async fn yield_stream__works_with_10_elements__collect() {
126        let stream = futures::stream::iter(0..10);
127        let yield_stream = stream.yield_each(3);
128
129        let items = yield_stream.collect::<Vec<_>>().await;
130
131        assert_eq!(items, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
132    }
133
134    #[tokio::test]
135    async fn yield_stream__passed_control_to_another_future() {
136        let stream = futures::stream::iter(0..10);
137        let mut yield_stream = YieldStream::new(stream, 3);
138
139        async fn second_future() -> i32 {
140            -1
141        }
142
143        let mut items = Vec::new();
144        loop {
145            tokio::select! {
146                biased;
147
148                item = yield_stream.next() => {
149                    if let Some(item) = item {
150                        items.push(item);
151                    } else {
152                        break;
153                    }
154                }
155
156                item = second_future() => {
157                    items.push(item);
158                }
159            }
160        }
161
162        assert_eq!(items, vec![0, 1, 2, -1, 3, 4, 5, -1, 6, 7, 8, -1, 9]);
163    }
164}