fuel_core_services/
yield_stream.rs1use futures::{
4 ready,
5 stream::Fuse,
6 Stream,
7 StreamExt,
8};
9use std::{
10 pin::Pin,
11 task::{
12 Context,
13 Poll,
14 },
15};
16
17pin_project_lite::pin_project! {
18 #[derive(Debug)]
20 #[must_use = "streams do nothing unless polled"]
21 pub struct YieldStream<St: Stream> {
22 #[pin]
23 stream: Fuse<St>,
24 item: Option<St::Item>,
25 counter: usize,
26 batch_size: usize,
27 }
28}
29
30impl<St: Stream> YieldStream<St> {
31 pub fn new(stream: St, batch_size: usize) -> Self {
33 assert!(batch_size > 0);
34
35 Self {
36 stream: stream.fuse(),
37 item: None,
38 counter: 0,
39 batch_size,
40 }
41 }
42}
43
44impl<St: Stream> Stream for YieldStream<St> {
45 type Item = St::Item;
46
47 fn poll_next(
48 mut self: Pin<&mut Self>,
49 cx: &mut Context<'_>,
50 ) -> Poll<Option<Self::Item>> {
51 let mut this = self.as_mut().project();
52
53 if let Some(item) = this.item.take() {
55 *this.counter = 1;
56 return Poll::Ready(Some(item));
57 }
58
59 match ready!(this.stream.as_mut().poll_next(cx)) {
60 Some(item) => {
63 if this.counter < this.batch_size {
64 *this.counter = this.counter.saturating_add(1);
65
66 Poll::Ready(Some(item))
67 } else {
68 *this.item = Some(item);
69
70 cx.waker().wake_by_ref();
71
72 Poll::Pending
73 }
74 }
75
76 None => Poll::Ready(None),
78 }
79 }
80
81 fn size_hint(&self) -> (usize, Option<usize>) {
82 let cached_len = usize::from(self.item.is_some());
83 let (lower, upper) = self.stream.size_hint();
84 let lower = lower.saturating_add(cached_len);
85 let upper = match upper {
86 Some(x) => x.checked_add(cached_len),
87 None => None,
88 };
89 (lower, upper)
90 }
91}
92
93pub trait StreamYieldExt: Stream {
95 fn yield_each(self, batch_size: usize) -> YieldStream<Self>
97 where
98 Self: Sized,
99 {
100 YieldStream::new(self, batch_size)
101 }
102}
103
104impl<St> StreamYieldExt for St where St: Stream {}
105
106#[cfg(test)]
107#[allow(non_snake_case)]
108mod tests {
109 use super::*;
110
111 #[tokio::test]
112 async fn yield_stream__works_with_10_elements_loop() {
113 let stream = futures::stream::iter(0..10);
114 let mut yield_stream = YieldStream::new(stream, 3);
115
116 let mut items = Vec::new();
117 while let Some(item) = yield_stream.next().await {
118 items.push(item);
119 }
120
121 assert_eq!(items, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
122 }
123
124 #[tokio::test]
125 async fn yield_stream__works_with_10_elements__collect() {
126 let stream = futures::stream::iter(0..10);
127 let yield_stream = stream.yield_each(3);
128
129 let items = yield_stream.collect::<Vec<_>>().await;
130
131 assert_eq!(items, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
132 }
133
134 #[tokio::test]
135 async fn yield_stream__passed_control_to_another_future() {
136 let stream = futures::stream::iter(0..10);
137 let mut yield_stream = YieldStream::new(stream, 3);
138
139 async fn second_future() -> i32 {
140 -1
141 }
142
143 let mut items = Vec::new();
144 loop {
145 tokio::select! {
146 biased;
147
148 item = yield_stream.next() => {
149 if let Some(item) = item {
150 items.push(item);
151 } else {
152 break;
153 }
154 }
155
156 item = second_future() => {
157 items.push(item);
158 }
159 }
160 }
161
162 assert_eq!(items, vec![0, 1, 2, -1, 3, 4, 5, -1, 6, 7, 8, -1, 9]);
163 }
164}