fuel_core_services/
yield_stream.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
//! Stream that yields each `batch_size` items allowing other tasks to work.

use futures::{
    ready,
    stream::Fuse,
    Stream,
    StreamExt,
};
use std::{
    pin::Pin,
    task::{
        Context,
        Poll,
    },
};

pin_project_lite::pin_project! {
    /// Stream that yields each `batch_size` items.
    #[derive(Debug)]
    #[must_use = "streams do nothing unless polled"]
    pub struct YieldStream<St: Stream> {
        #[pin]
        stream: Fuse<St>,
        item: Option<St::Item>,
        counter: usize,
        batch_size: usize,
    }
}

impl<St: Stream> YieldStream<St> {
    /// Create a new `YieldStream` with the given `batch_size`.
    pub fn new(stream: St, batch_size: usize) -> Self {
        assert!(batch_size > 0);

        Self {
            stream: stream.fuse(),
            item: None,
            counter: 0,
            batch_size,
        }
    }
}

impl<St: Stream> Stream for YieldStream<St> {
    type Item = St::Item;

    fn poll_next(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Self::Item>> {
        let mut this = self.as_mut().project();

        // If we have a cached item, return it because that means we were woken up.
        if let Some(item) = this.item.take() {
            *this.counter = 1;
            return Poll::Ready(Some(item));
        }

        match ready!(this.stream.as_mut().poll_next(cx)) {
            // Return items, unless we reached the batch size.
            // after that, we want to yield before returning the next item.
            Some(item) => {
                if this.counter < this.batch_size {
                    *this.counter = this.counter.saturating_add(1);

                    Poll::Ready(Some(item))
                } else {
                    *this.item = Some(item);

                    cx.waker().wake_by_ref();

                    Poll::Pending
                }
            }

            // Underlying stream ran out of values, so finish this stream as well.
            None => Poll::Ready(None),
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        let cached_len = usize::from(self.item.is_some());
        let (lower, upper) = self.stream.size_hint();
        let lower = lower.saturating_add(cached_len);
        let upper = match upper {
            Some(x) => x.checked_add(cached_len),
            None => None,
        };
        (lower, upper)
    }
}

/// Extension trait for `Stream`.
pub trait StreamYieldExt: Stream {
    /// Yields each `batch_size` items allowing other tasks to work.
    fn yield_each(self, batch_size: usize) -> YieldStream<Self>
    where
        Self: Sized,
    {
        YieldStream::new(self, batch_size)
    }
}

impl<St> StreamYieldExt for St where St: Stream {}

#[cfg(test)]
#[allow(non_snake_case)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn yield_stream__works_with_10_elements_loop() {
        let stream = futures::stream::iter(0..10);
        let mut yield_stream = YieldStream::new(stream, 3);

        let mut items = Vec::new();
        while let Some(item) = yield_stream.next().await {
            items.push(item);
        }

        assert_eq!(items, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
    }

    #[tokio::test]
    async fn yield_stream__works_with_10_elements__collect() {
        let stream = futures::stream::iter(0..10);
        let yield_stream = stream.yield_each(3);

        let items = yield_stream.collect::<Vec<_>>().await;

        assert_eq!(items, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
    }

    #[tokio::test]
    async fn yield_stream__passed_control_to_another_future() {
        let stream = futures::stream::iter(0..10);
        let mut yield_stream = YieldStream::new(stream, 3);

        async fn second_future() -> i32 {
            -1
        }

        let mut items = Vec::new();
        loop {
            tokio::select! {
                biased;

                item = yield_stream.next() => {
                    if let Some(item) = item {
                        items.push(item);
                    } else {
                        break;
                    }
                }

                item = second_future() => {
                    items.push(item);
                }
            }
        }

        assert_eq!(items, vec![0, 1, 2, -1, 3, 4, 5, -1, 6, 7, 8, -1, 9]);
    }
}