tokio_util/util/
poll_buf.rs

1use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
2
3use bytes::{Buf, BufMut};
4use std::io::{self, IoSlice};
5use std::pin::Pin;
6use std::task::{ready, Context, Poll};
7
8/// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait.
9///
10/// [`BufMut`]: bytes::Buf
11///
12/// # Example
13///
14/// ```
15/// use bytes::{Bytes, BytesMut};
16/// use tokio_stream as stream;
17/// use tokio::io::Result;
18/// use tokio_util::io::{StreamReader, poll_read_buf};
19/// use std::future::poll_fn;
20/// use std::pin::Pin;
21/// # #[tokio::main]
22/// # async fn main() -> std::io::Result<()> {
23///
24/// // Create a reader from an iterator. This particular reader will always be
25/// // ready.
26/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))]));
27///
28/// let mut buf = BytesMut::new();
29/// let mut reads = 0;
30///
31/// loop {
32///     reads += 1;
33///     let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?;
34///
35///     if n == 0 {
36///         break;
37///     }
38/// }
39///
40/// // one or more reads might be necessary.
41/// assert!(reads >= 1);
42/// assert_eq!(&buf[..], &[0, 1, 2, 3]);
43/// # Ok(())
44/// # }
45/// ```
46#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
47pub fn poll_read_buf<T: AsyncRead + ?Sized, B: BufMut>(
48    io: Pin<&mut T>,
49    cx: &mut Context<'_>,
50    buf: &mut B,
51) -> Poll<io::Result<usize>> {
52    if !buf.has_remaining_mut() {
53        return Poll::Ready(Ok(0));
54    }
55
56    let n = {
57        let dst = buf.chunk_mut();
58
59        // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
60        // transparent wrapper around `[MaybeUninit<u8>]`.
61        let dst = unsafe { dst.as_uninit_slice_mut() };
62        let mut buf = ReadBuf::uninit(dst);
63        let ptr = buf.filled().as_ptr();
64        ready!(io.poll_read(cx, &mut buf)?);
65
66        // Ensure the pointer does not change from under us
67        assert_eq!(ptr, buf.filled().as_ptr());
68        buf.filled().len()
69    };
70
71    // Safety: This is guaranteed to be the number of initialized (and read)
72    // bytes due to the invariants provided by `ReadBuf::filled`.
73    unsafe {
74        buf.advance_mut(n);
75    }
76
77    Poll::Ready(Ok(n))
78}
79
80/// Try to write data from an implementer of the [`Buf`] trait to an
81/// [`AsyncWrite`], advancing the buffer's internal cursor.
82///
83/// This function will use [vectored writes] when the [`AsyncWrite`] supports
84/// vectored writes.
85///
86/// # Examples
87///
88/// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements
89/// [`Buf`]:
90///
91/// ```no_run
92/// use tokio_util::io::poll_write_buf;
93/// use tokio::io;
94/// use tokio::fs::File;
95///
96/// use bytes::Buf;
97/// use std::future::poll_fn;
98/// use std::io::Cursor;
99/// use std::pin::Pin;
100///
101/// #[tokio::main]
102/// async fn main() -> io::Result<()> {
103///     let mut file = File::create("foo.txt").await?;
104///     let mut buf = Cursor::new(b"data to write");
105///
106///     // Loop until the entire contents of the buffer are written to
107///     // the file.
108///     while buf.has_remaining() {
109///         poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?;
110///     }
111///
112///     Ok(())
113/// }
114/// ```
115///
116/// [`Buf`]: bytes::Buf
117/// [`AsyncWrite`]: tokio::io::AsyncWrite
118/// [`File`]: tokio::fs::File
119/// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored
120#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
121pub fn poll_write_buf<T: AsyncWrite + ?Sized, B: Buf>(
122    io: Pin<&mut T>,
123    cx: &mut Context<'_>,
124    buf: &mut B,
125) -> Poll<io::Result<usize>> {
126    const MAX_BUFS: usize = 64;
127
128    if !buf.has_remaining() {
129        return Poll::Ready(Ok(0));
130    }
131
132    let n = if io.is_write_vectored() {
133        let mut slices = [IoSlice::new(&[]); MAX_BUFS];
134        let cnt = buf.chunks_vectored(&mut slices);
135        ready!(io.poll_write_vectored(cx, &slices[..cnt]))?
136    } else {
137        ready!(io.poll_write(cx, buf.chunk()))?
138    };
139
140    buf.advance(n);
141
142    Poll::Ready(Ok(n))
143}