tokio_util/util/
poll_buf.rs

1use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
2
3use bytes::{Buf, BufMut};
4use std::io::{self, IoSlice};
5use std::mem::MaybeUninit;
6use std::pin::Pin;
7use std::task::{ready, Context, Poll};
8
9/// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait.
10///
11/// [`BufMut`]: bytes::Buf
12///
13/// # Example
14///
15/// ```
16/// use bytes::{Bytes, BytesMut};
17/// use tokio_stream as stream;
18/// use tokio::io::Result;
19/// use tokio_util::io::{StreamReader, poll_read_buf};
20/// use std::future::poll_fn;
21/// use std::pin::Pin;
22/// # #[tokio::main]
23/// # async fn main() -> std::io::Result<()> {
24///
25/// // Create a reader from an iterator. This particular reader will always be
26/// // ready.
27/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))]));
28///
29/// let mut buf = BytesMut::new();
30/// let mut reads = 0;
31///
32/// loop {
33///     reads += 1;
34///     let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?;
35///
36///     if n == 0 {
37///         break;
38///     }
39/// }
40///
41/// // one or more reads might be necessary.
42/// assert!(reads >= 1);
43/// assert_eq!(&buf[..], &[0, 1, 2, 3]);
44/// # Ok(())
45/// # }
46/// ```
47#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
48pub fn poll_read_buf<T: AsyncRead + ?Sized, B: BufMut>(
49    io: Pin<&mut T>,
50    cx: &mut Context<'_>,
51    buf: &mut B,
52) -> Poll<io::Result<usize>> {
53    if !buf.has_remaining_mut() {
54        return Poll::Ready(Ok(0));
55    }
56
57    let n = {
58        let dst = buf.chunk_mut();
59
60        // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
61        // transparent wrapper around `[MaybeUninit<u8>]`.
62        let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
63        let mut buf = ReadBuf::uninit(dst);
64        let ptr = buf.filled().as_ptr();
65        ready!(io.poll_read(cx, &mut buf)?);
66
67        // Ensure the pointer does not change from under us
68        assert_eq!(ptr, buf.filled().as_ptr());
69        buf.filled().len()
70    };
71
72    // Safety: This is guaranteed to be the number of initialized (and read)
73    // bytes due to the invariants provided by `ReadBuf::filled`.
74    unsafe {
75        buf.advance_mut(n);
76    }
77
78    Poll::Ready(Ok(n))
79}
80
81/// Try to write data from an implementer of the [`Buf`] trait to an
82/// [`AsyncWrite`], advancing the buffer's internal cursor.
83///
84/// This function will use [vectored writes] when the [`AsyncWrite`] supports
85/// vectored writes.
86///
87/// # Examples
88///
89/// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements
90/// [`Buf`]:
91///
92/// ```no_run
93/// use tokio_util::io::poll_write_buf;
94/// use tokio::io;
95/// use tokio::fs::File;
96///
97/// use bytes::Buf;
98/// use std::future::poll_fn;
99/// use std::io::Cursor;
100/// use std::pin::Pin;
101///
102/// #[tokio::main]
103/// async fn main() -> io::Result<()> {
104///     let mut file = File::create("foo.txt").await?;
105///     let mut buf = Cursor::new(b"data to write");
106///
107///     // Loop until the entire contents of the buffer are written to
108///     // the file.
109///     while buf.has_remaining() {
110///         poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?;
111///     }
112///
113///     Ok(())
114/// }
115/// ```
116///
117/// [`Buf`]: bytes::Buf
118/// [`AsyncWrite`]: tokio::io::AsyncWrite
119/// [`File`]: tokio::fs::File
120/// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored
121#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
122pub fn poll_write_buf<T: AsyncWrite + ?Sized, B: Buf>(
123    io: Pin<&mut T>,
124    cx: &mut Context<'_>,
125    buf: &mut B,
126) -> Poll<io::Result<usize>> {
127    const MAX_BUFS: usize = 64;
128
129    if !buf.has_remaining() {
130        return Poll::Ready(Ok(0));
131    }
132
133    let n = if io.is_write_vectored() {
134        let mut slices = [IoSlice::new(&[]); MAX_BUFS];
135        let cnt = buf.chunks_vectored(&mut slices);
136        ready!(io.poll_write_vectored(cx, &slices[..cnt]))?
137    } else {
138        ready!(io.poll_write(cx, buf.chunk()))?
139    };
140
141    buf.advance(n);
142
143    Poll::Ready(Ok(n))
144}