tokio_util/io/
stream_reader.rs

1use bytes::Buf;
2use futures_core::stream::Stream;
3use futures_sink::Sink;
4use std::io;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
8
9/// Convert a [`Stream`] of byte chunks into an [`AsyncRead`].
10///
11/// This type performs the inverse operation of [`ReaderStream`].
12///
13/// This type also implements the [`AsyncBufRead`] trait, so you can use it
14/// to read a `Stream` of byte chunks line-by-line. See the examples below.
15///
16/// # Example
17///
18/// ```
19/// use bytes::Bytes;
20/// use tokio::io::{AsyncReadExt, Result};
21/// use tokio_util::io::StreamReader;
22/// # #[tokio::main(flavor = "current_thread")]
23/// # async fn main() -> std::io::Result<()> {
24///
25/// // Create a stream from an iterator.
26/// let stream = tokio_stream::iter(vec![
27///     Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
28///     Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
29///     Result::Ok(Bytes::from_static(&[8, 9, 10, 11])),
30/// ]);
31///
32/// // Convert it to an AsyncRead.
33/// let mut read = StreamReader::new(stream);
34///
35/// // Read five bytes from the stream.
36/// let mut buf = [0; 5];
37/// read.read_exact(&mut buf).await?;
38/// assert_eq!(buf, [0, 1, 2, 3, 4]);
39///
40/// // Read the rest of the current chunk.
41/// assert_eq!(read.read(&mut buf).await?, 3);
42/// assert_eq!(&buf[..3], [5, 6, 7]);
43///
44/// // Read the next chunk.
45/// assert_eq!(read.read(&mut buf).await?, 4);
46/// assert_eq!(&buf[..4], [8, 9, 10, 11]);
47///
48/// // We have now reached the end.
49/// assert_eq!(read.read(&mut buf).await?, 0);
50///
51/// # Ok(())
52/// # }
53/// ```
54///
55/// If the stream produces errors which are not [`std::io::Error`],
56/// the errors can be converted using [`StreamExt`] to map each
57/// element.
58///
59/// ```
60/// use bytes::Bytes;
61/// use tokio::io::AsyncReadExt;
62/// use tokio_util::io::StreamReader;
63/// use tokio_stream::StreamExt;
64/// # #[tokio::main(flavor = "current_thread")]
65/// # async fn main() -> std::io::Result<()> {
66///
67/// // Create a stream from an iterator, including an error.
68/// let stream = tokio_stream::iter(vec![
69///     Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
70///     Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
71///     Result::Err("Something bad happened!")
72/// ]);
73///
74/// // Use StreamExt to map the stream and error to a std::io::Error
75/// let stream = stream.map(|result| result.map_err(|err| {
76///     std::io::Error::new(std::io::ErrorKind::Other, err)
77/// }));
78///
79/// // Convert it to an AsyncRead.
80/// let mut read = StreamReader::new(stream);
81///
82/// // Read five bytes from the stream.
83/// let mut buf = [0; 5];
84/// read.read_exact(&mut buf).await?;
85/// assert_eq!(buf, [0, 1, 2, 3, 4]);
86///
87/// // Read the rest of the current chunk.
88/// assert_eq!(read.read(&mut buf).await?, 3);
89/// assert_eq!(&buf[..3], [5, 6, 7]);
90///
91/// // Reading the next chunk will produce an error
92/// let error = read.read(&mut buf).await.unwrap_err();
93/// assert_eq!(error.kind(), std::io::ErrorKind::Other);
94/// assert_eq!(error.into_inner().unwrap().to_string(), "Something bad happened!");
95///
96/// // We have now reached the end.
97/// assert_eq!(read.read(&mut buf).await?, 0);
98///
99/// # Ok(())
100/// # }
101/// ```
102///
103/// Using the [`AsyncBufRead`] impl, you can read a `Stream` of byte chunks
104/// line-by-line. Note that you will usually also need to convert the error
105/// type when doing this. See the second example for an explanation of how
106/// to do this.
107///
108/// ```
109/// use tokio::io::{Result, AsyncBufReadExt};
110/// use tokio_util::io::StreamReader;
111/// # #[tokio::main(flavor = "current_thread")]
112/// # async fn main() -> std::io::Result<()> {
113///
114/// // Create a stream of byte chunks.
115/// let stream = tokio_stream::iter(vec![
116///     Result::Ok(b"The first line.\n".as_slice()),
117///     Result::Ok(b"The second line.".as_slice()),
118///     Result::Ok(b"\nThe third".as_slice()),
119///     Result::Ok(b" line.\nThe fourth line.\nThe fifth line.\n".as_slice()),
120/// ]);
121///
122/// // Convert it to an AsyncRead.
123/// let mut read = StreamReader::new(stream);
124///
125/// // Loop through the lines from the `StreamReader`.
126/// let mut line = String::new();
127/// let mut lines = Vec::new();
128/// loop {
129///     line.clear();
130///     let len = read.read_line(&mut line).await?;
131///     if len == 0 { break; }
132///     lines.push(line.clone());
133/// }
134///
135/// // Verify that we got the lines we expected.
136/// assert_eq!(
137///     lines,
138///     vec![
139///         "The first line.\n",
140///         "The second line.\n",
141///         "The third line.\n",
142///         "The fourth line.\n",
143///         "The fifth line.\n",
144///     ]
145/// );
146/// # Ok(())
147/// # }
148/// ```
149///
150/// [`AsyncRead`]: tokio::io::AsyncRead
151/// [`AsyncBufRead`]: tokio::io::AsyncBufRead
152/// [`Stream`]: futures_core::Stream
153/// [`ReaderStream`]: crate::io::ReaderStream
154/// [`StreamExt`]: https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html
155#[derive(Debug)]
156pub struct StreamReader<S, B> {
157    // This field is pinned.
158    inner: S,
159    // This field is not pinned.
160    chunk: Option<B>,
161}
162
163impl<S, B, E> StreamReader<S, B>
164where
165    S: Stream<Item = Result<B, E>>,
166    B: Buf,
167    E: Into<std::io::Error>,
168{
169    /// Convert a stream of byte chunks into an [`AsyncRead`].
170    ///
171    /// The item should be a [`Result`] with the ok variant being something that
172    /// implements the [`Buf`] trait (e.g. `Vec<u8>` or `Bytes`). The error
173    /// should be convertible into an [io error].
174    ///
175    /// [`Result`]: std::result::Result
176    /// [`Buf`]: bytes::Buf
177    /// [io error]: std::io::Error
178    pub fn new(stream: S) -> Self {
179        Self {
180            inner: stream,
181            chunk: None,
182        }
183    }
184
185    /// Do we have a chunk and is it non-empty?
186    fn has_chunk(&self) -> bool {
187        if let Some(ref chunk) = self.chunk {
188            chunk.remaining() > 0
189        } else {
190            false
191        }
192    }
193
194    /// Consumes this `StreamReader`, returning a Tuple consisting
195    /// of the underlying stream and an Option of the internal buffer,
196    /// which is Some in case the buffer contains elements.
197    pub fn into_inner_with_chunk(self) -> (S, Option<B>) {
198        if self.has_chunk() {
199            (self.inner, self.chunk)
200        } else {
201            (self.inner, None)
202        }
203    }
204}
205
206impl<S, B> StreamReader<S, B> {
207    /// Gets a reference to the underlying stream.
208    ///
209    /// It is inadvisable to directly read from the underlying stream.
210    pub fn get_ref(&self) -> &S {
211        &self.inner
212    }
213
214    /// Gets a mutable reference to the underlying stream.
215    ///
216    /// It is inadvisable to directly read from the underlying stream.
217    pub fn get_mut(&mut self) -> &mut S {
218        &mut self.inner
219    }
220
221    /// Gets a pinned mutable reference to the underlying stream.
222    ///
223    /// It is inadvisable to directly read from the underlying stream.
224    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
225        self.project().inner
226    }
227
228    /// Consumes this `BufWriter`, returning the underlying stream.
229    ///
230    /// Note that any leftover data in the internal buffer is lost.
231    /// If you additionally want access to the internal buffer use
232    /// [`into_inner_with_chunk`].
233    ///
234    /// [`into_inner_with_chunk`]: crate::io::StreamReader::into_inner_with_chunk
235    pub fn into_inner(self) -> S {
236        self.inner
237    }
238}
239
240impl<S, B, E> AsyncRead for StreamReader<S, B>
241where
242    S: Stream<Item = Result<B, E>>,
243    B: Buf,
244    E: Into<std::io::Error>,
245{
246    fn poll_read(
247        mut self: Pin<&mut Self>,
248        cx: &mut Context<'_>,
249        buf: &mut ReadBuf<'_>,
250    ) -> Poll<io::Result<()>> {
251        if buf.remaining() == 0 {
252            return Poll::Ready(Ok(()));
253        }
254
255        let inner_buf = match self.as_mut().poll_fill_buf(cx) {
256            Poll::Ready(Ok(buf)) => buf,
257            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
258            Poll::Pending => return Poll::Pending,
259        };
260        let len = std::cmp::min(inner_buf.len(), buf.remaining());
261        buf.put_slice(&inner_buf[..len]);
262
263        self.consume(len);
264        Poll::Ready(Ok(()))
265    }
266}
267
268impl<S, B, E> AsyncBufRead for StreamReader<S, B>
269where
270    S: Stream<Item = Result<B, E>>,
271    B: Buf,
272    E: Into<std::io::Error>,
273{
274    fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
275        loop {
276            if self.as_mut().has_chunk() {
277                // This unwrap is very sad, but it can't be avoided.
278                let buf = self.project().chunk.as_ref().unwrap().chunk();
279                return Poll::Ready(Ok(buf));
280            } else {
281                match self.as_mut().project().inner.poll_next(cx) {
282                    Poll::Ready(Some(Ok(chunk))) => {
283                        // Go around the loop in case the chunk is empty.
284                        *self.as_mut().project().chunk = Some(chunk);
285                    }
286                    Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())),
287                    Poll::Ready(None) => return Poll::Ready(Ok(&[])),
288                    Poll::Pending => return Poll::Pending,
289                }
290            }
291        }
292    }
293    fn consume(self: Pin<&mut Self>, amt: usize) {
294        if amt > 0 {
295            self.project()
296                .chunk
297                .as_mut()
298                .expect("No chunk present")
299                .advance(amt);
300        }
301    }
302}
303
304// The code below is a manual expansion of the code that pin-project-lite would
305// generate. This is done because pin-project-lite fails by hitting the recursion
306// limit on this struct. (Every line of documentation is handled recursively by
307// the macro.)
308
309impl<S: Unpin, B> Unpin for StreamReader<S, B> {}
310
311struct StreamReaderProject<'a, S, B> {
312    inner: Pin<&'a mut S>,
313    chunk: &'a mut Option<B>,
314}
315
316impl<S, B> StreamReader<S, B> {
317    #[inline]
318    fn project(self: Pin<&mut Self>) -> StreamReaderProject<'_, S, B> {
319        // SAFETY: We define that only `inner` should be pinned when `Self` is
320        // and have an appropriate `impl Unpin` for this.
321        let me = unsafe { Pin::into_inner_unchecked(self) };
322        StreamReaderProject {
323            inner: unsafe { Pin::new_unchecked(&mut me.inner) },
324            chunk: &mut me.chunk,
325        }
326    }
327}
328
329impl<S: Sink<T, Error = E>, B, E, T> Sink<T> for StreamReader<S, B> {
330    type Error = E;
331    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
332        self.project().inner.poll_ready(cx)
333    }
334
335    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
336        self.project().inner.start_send(item)
337    }
338
339    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
340        self.project().inner.poll_flush(cx)
341    }
342
343    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
344        self.project().inner.poll_close(cx)
345    }
346}