madsim_real_tokio/io/util/
buf_writer.rs

1use crate::io::util::DEFAULT_BUF_SIZE;
2use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
3
4use pin_project_lite::pin_project;
5use std::fmt;
6use std::io::{self, IoSlice, SeekFrom, Write};
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10pin_project! {
11    /// Wraps a writer and buffers its output.
12    ///
13    /// It can be excessively inefficient to work directly with something that
14    /// implements [`AsyncWrite`]. A `BufWriter` keeps an in-memory buffer of data and
15    /// writes it to an underlying writer in large, infrequent batches.
16    ///
17    /// `BufWriter` can improve the speed of programs that make *small* and
18    /// *repeated* write calls to the same file or network socket. It does not
19    /// help when writing very large amounts at once, or writing just one or a few
20    /// times. It also provides no advantage when writing to a destination that is
21    /// in memory, like a `Vec<u8>`.
22    ///
23    /// When the `BufWriter` is dropped, the contents of its buffer will be
24    /// discarded. Creating multiple instances of a `BufWriter` on the same
25    /// stream can cause data loss. If you need to write out the contents of its
26    /// buffer, you must manually call flush before the writer is dropped.
27    ///
28    /// [`AsyncWrite`]: AsyncWrite
29    /// [`flush`]: super::AsyncWriteExt::flush
30    ///
31    #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
32    pub struct BufWriter<W> {
33        #[pin]
34        pub(super) inner: W,
35        pub(super) buf: Vec<u8>,
36        pub(super) written: usize,
37        pub(super) seek_state: SeekState,
38    }
39}
40
41impl<W: AsyncWrite> BufWriter<W> {
42    /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB,
43    /// but may change in the future.
44    pub fn new(inner: W) -> Self {
45        Self::with_capacity(DEFAULT_BUF_SIZE, inner)
46    }
47
48    /// Creates a new `BufWriter` with the specified buffer capacity.
49    pub fn with_capacity(cap: usize, inner: W) -> Self {
50        Self {
51            inner,
52            buf: Vec::with_capacity(cap),
53            written: 0,
54            seek_state: SeekState::Init,
55        }
56    }
57
58    fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
59        let mut me = self.project();
60
61        let len = me.buf.len();
62        let mut ret = Ok(());
63        while *me.written < len {
64            match ready!(me.inner.as_mut().poll_write(cx, &me.buf[*me.written..])) {
65                Ok(0) => {
66                    ret = Err(io::Error::new(
67                        io::ErrorKind::WriteZero,
68                        "failed to write the buffered data",
69                    ));
70                    break;
71                }
72                Ok(n) => *me.written += n,
73                Err(e) => {
74                    ret = Err(e);
75                    break;
76                }
77            }
78        }
79        if *me.written > 0 {
80            me.buf.drain(..*me.written);
81        }
82        *me.written = 0;
83        Poll::Ready(ret)
84    }
85
86    /// Gets a reference to the underlying writer.
87    pub fn get_ref(&self) -> &W {
88        &self.inner
89    }
90
91    /// Gets a mutable reference to the underlying writer.
92    ///
93    /// It is inadvisable to directly write to the underlying writer.
94    pub fn get_mut(&mut self) -> &mut W {
95        &mut self.inner
96    }
97
98    /// Gets a pinned mutable reference to the underlying writer.
99    ///
100    /// It is inadvisable to directly write to the underlying writer.
101    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
102        self.project().inner
103    }
104
105    /// Consumes this `BufWriter`, returning the underlying writer.
106    ///
107    /// Note that any leftover data in the internal buffer is lost.
108    pub fn into_inner(self) -> W {
109        self.inner
110    }
111
112    /// Returns a reference to the internally buffered data.
113    pub fn buffer(&self) -> &[u8] {
114        &self.buf
115    }
116}
117
118impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
119    fn poll_write(
120        mut self: Pin<&mut Self>,
121        cx: &mut Context<'_>,
122        buf: &[u8],
123    ) -> Poll<io::Result<usize>> {
124        if self.buf.len() + buf.len() > self.buf.capacity() {
125            ready!(self.as_mut().flush_buf(cx))?;
126        }
127
128        let me = self.project();
129        if buf.len() >= me.buf.capacity() {
130            me.inner.poll_write(cx, buf)
131        } else {
132            Poll::Ready(me.buf.write(buf))
133        }
134    }
135
136    fn poll_write_vectored(
137        mut self: Pin<&mut Self>,
138        cx: &mut Context<'_>,
139        mut bufs: &[IoSlice<'_>],
140    ) -> Poll<io::Result<usize>> {
141        if self.inner.is_write_vectored() {
142            let total_len = bufs
143                .iter()
144                .fold(0usize, |acc, b| acc.saturating_add(b.len()));
145            if total_len > self.buf.capacity() - self.buf.len() {
146                ready!(self.as_mut().flush_buf(cx))?;
147            }
148            let me = self.as_mut().project();
149            if total_len >= me.buf.capacity() {
150                // It's more efficient to pass the slices directly to the
151                // underlying writer than to buffer them.
152                // The case when the total_len calculation saturates at
153                // usize::MAX is also handled here.
154                me.inner.poll_write_vectored(cx, bufs)
155            } else {
156                bufs.iter().for_each(|b| me.buf.extend_from_slice(b));
157                Poll::Ready(Ok(total_len))
158            }
159        } else {
160            // Remove empty buffers at the beginning of bufs.
161            while bufs.first().map(|buf| buf.len()) == Some(0) {
162                bufs = &bufs[1..];
163            }
164            if bufs.is_empty() {
165                return Poll::Ready(Ok(0));
166            }
167            // Flush if the first buffer doesn't fit.
168            let first_len = bufs[0].len();
169            if first_len > self.buf.capacity() - self.buf.len() {
170                ready!(self.as_mut().flush_buf(cx))?;
171                debug_assert!(self.buf.is_empty());
172            }
173            let me = self.as_mut().project();
174            if first_len >= me.buf.capacity() {
175                // The slice is at least as large as the buffering capacity,
176                // so it's better to write it directly, bypassing the buffer.
177                debug_assert!(me.buf.is_empty());
178                return me.inner.poll_write(cx, &bufs[0]);
179            } else {
180                me.buf.extend_from_slice(&bufs[0]);
181                bufs = &bufs[1..];
182            }
183            let mut total_written = first_len;
184            debug_assert!(total_written != 0);
185            // Append the buffers that fit in the internal buffer.
186            for buf in bufs {
187                if buf.len() > me.buf.capacity() - me.buf.len() {
188                    break;
189                } else {
190                    me.buf.extend_from_slice(buf);
191                    total_written += buf.len();
192                }
193            }
194            Poll::Ready(Ok(total_written))
195        }
196    }
197
198    fn is_write_vectored(&self) -> bool {
199        true
200    }
201
202    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
203        ready!(self.as_mut().flush_buf(cx))?;
204        self.get_pin_mut().poll_flush(cx)
205    }
206
207    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208        ready!(self.as_mut().flush_buf(cx))?;
209        self.get_pin_mut().poll_shutdown(cx)
210    }
211}
212
213#[derive(Debug, Clone, Copy)]
214pub(super) enum SeekState {
215    /// `start_seek` has not been called.
216    Init,
217    /// `start_seek` has been called, but `poll_complete` has not yet been called.
218    Start(SeekFrom),
219    /// Waiting for completion of `poll_complete`.
220    Pending,
221}
222
223/// Seek to the offset, in bytes, in the underlying writer.
224///
225/// Seeking always writes out the internal buffer before seeking.
226impl<W: AsyncWrite + AsyncSeek> AsyncSeek for BufWriter<W> {
227    fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> {
228        // We need to flush the internal buffer before seeking.
229        // It receives a `Context` and returns a `Poll`, so it cannot be called
230        // inside `start_seek`.
231        *self.project().seek_state = SeekState::Start(pos);
232        Ok(())
233    }
234
235    fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
236        let pos = match self.seek_state {
237            SeekState::Init => {
238                return self.project().inner.poll_complete(cx);
239            }
240            SeekState::Start(pos) => Some(pos),
241            SeekState::Pending => None,
242        };
243
244        // Flush the internal buffer before seeking.
245        ready!(self.as_mut().flush_buf(cx))?;
246
247        let mut me = self.project();
248        if let Some(pos) = pos {
249            // Ensure previous seeks have finished before starting a new one
250            ready!(me.inner.as_mut().poll_complete(cx))?;
251            if let Err(e) = me.inner.as_mut().start_seek(pos) {
252                *me.seek_state = SeekState::Init;
253                return Poll::Ready(Err(e));
254            }
255        }
256        match me.inner.poll_complete(cx) {
257            Poll::Ready(res) => {
258                *me.seek_state = SeekState::Init;
259                Poll::Ready(res)
260            }
261            Poll::Pending => {
262                *me.seek_state = SeekState::Pending;
263                Poll::Pending
264            }
265        }
266    }
267}
268
269impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> {
270    fn poll_read(
271        self: Pin<&mut Self>,
272        cx: &mut Context<'_>,
273        buf: &mut ReadBuf<'_>,
274    ) -> Poll<io::Result<()>> {
275        self.get_pin_mut().poll_read(cx, buf)
276    }
277}
278
279impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> {
280    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
281        self.get_pin_mut().poll_fill_buf(cx)
282    }
283
284    fn consume(self: Pin<&mut Self>, amt: usize) {
285        self.get_pin_mut().consume(amt);
286    }
287}
288
289impl<W: fmt::Debug> fmt::Debug for BufWriter<W> {
290    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291        f.debug_struct("BufWriter")
292            .field("writer", &self.inner)
293            .field(
294                "buffer",
295                &format_args!("{}/{}", self.buf.len(), self.buf.capacity()),
296            )
297            .field("written", &self.written)
298            .finish()
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn assert_unpin() {
308        crate::is_unpin::<BufWriter<()>>();
309    }
310}