broker_tokio/io/util/
buf_writer.rs

1use crate::io::util::DEFAULT_BUF_SIZE;
2use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite};
3
4use pin_project_lite::pin_project;
5use std::fmt;
6use std::io::{self, Write};
7use std::mem::MaybeUninit;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11pin_project! {
12    /// Wraps a writer and buffers its output.
13    ///
14    /// It can be excessively inefficient to work directly with something that
15    /// implements [`AsyncWrite`]. A `BufWriter` keeps an in-memory buffer of data and
16    /// writes it to an underlying writer in large, infrequent batches.
17    ///
18    /// `BufWriter` can improve the speed of programs that make *small* and
19    /// *repeated* write calls to the same file or network socket. It does not
20    /// help when writing very large amounts at once, or writing just one or a few
21    /// times. It also provides no advantage when writing to a destination that is
22    /// in memory, like a `Vec<u8>`.
23    ///
24    /// When the `BufWriter` is dropped, the contents of its buffer will be
25    /// discarded. Creating multiple instances of a `BufWriter` on the same
26    /// stream can cause data loss. If you need to write out the contents of its
27    /// buffer, you must manually call flush before the writer is dropped.
28    ///
29    /// [`AsyncWrite`]: AsyncWrite
30    /// [`flush`]: super::AsyncWriteExt::flush
31    ///
32    #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
33    pub struct BufWriter<W> {
34        #[pin]
35        pub(super) inner: W,
36        pub(super) buf: Vec<u8>,
37        pub(super) written: usize,
38    }
39}
40
41impl<W: AsyncWrite> BufWriter<W> {
42    /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB,
43    /// but may change in the future.
44    pub fn new(inner: W) -> Self {
45        Self::with_capacity(DEFAULT_BUF_SIZE, inner)
46    }
47
48    /// Creates a new `BufWriter` with the specified buffer capacity.
49    pub fn with_capacity(cap: usize, inner: W) -> Self {
50        Self {
51            inner,
52            buf: Vec::with_capacity(cap),
53            written: 0,
54        }
55    }
56
57    fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
58        let mut me = self.project();
59
60        let len = me.buf.len();
61        let mut ret = Ok(());
62        while *me.written < len {
63            match ready!(me.inner.as_mut().poll_write(cx, &me.buf[*me.written..])) {
64                Ok(0) => {
65                    ret = Err(io::Error::new(
66                        io::ErrorKind::WriteZero,
67                        "failed to write the buffered data",
68                    ));
69                    break;
70                }
71                Ok(n) => *me.written += n,
72                Err(e) => {
73                    ret = Err(e);
74                    break;
75                }
76            }
77        }
78        if *me.written > 0 {
79            me.buf.drain(..*me.written);
80        }
81        *me.written = 0;
82        Poll::Ready(ret)
83    }
84
85    /// Gets a reference to the underlying writer.
86    pub fn get_ref(&self) -> &W {
87        &self.inner
88    }
89
90    /// Gets a mutable reference to the underlying writer.
91    ///
92    /// It is inadvisable to directly write to the underlying writer.
93    pub fn get_mut(&mut self) -> &mut W {
94        &mut self.inner
95    }
96
97    /// Gets a pinned mutable reference to the underlying writer.
98    ///
99    /// It is inadvisable to directly write to the underlying writer.
100    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
101        self.project().inner
102    }
103
104    /// Consumes this `BufWriter`, returning the underlying writer.
105    ///
106    /// Note that any leftover data in the internal buffer is lost.
107    pub fn into_inner(self) -> W {
108        self.inner
109    }
110
111    /// Returns a reference to the internally buffered data.
112    pub fn buffer(&self) -> &[u8] {
113        &self.buf
114    }
115}
116
117impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
118    fn poll_write(
119        mut self: Pin<&mut Self>,
120        cx: &mut Context<'_>,
121        buf: &[u8],
122    ) -> Poll<io::Result<usize>> {
123        if self.buf.len() + buf.len() > self.buf.capacity() {
124            ready!(self.as_mut().flush_buf(cx))?;
125        }
126
127        let me = self.project();
128        if buf.len() >= me.buf.capacity() {
129            me.inner.poll_write(cx, buf)
130        } else {
131            Poll::Ready(me.buf.write(buf))
132        }
133    }
134
135    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
136        ready!(self.as_mut().flush_buf(cx))?;
137        self.get_pin_mut().poll_flush(cx)
138    }
139
140    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
141        ready!(self.as_mut().flush_buf(cx))?;
142        self.get_pin_mut().poll_shutdown(cx)
143    }
144}
145
146impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> {
147    fn poll_read(
148        self: Pin<&mut Self>,
149        cx: &mut Context<'_>,
150        buf: &mut [u8],
151    ) -> Poll<io::Result<usize>> {
152        self.get_pin_mut().poll_read(cx, buf)
153    }
154
155    // we can't skip unconditionally because of the large buffer case in read.
156    unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
157        self.get_ref().prepare_uninitialized_buffer(buf)
158    }
159}
160
161impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> {
162    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
163        self.get_pin_mut().poll_fill_buf(cx)
164    }
165
166    fn consume(self: Pin<&mut Self>, amt: usize) {
167        self.get_pin_mut().consume(amt)
168    }
169}
170
171impl<W: fmt::Debug> fmt::Debug for BufWriter<W> {
172    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173        f.debug_struct("BufWriter")
174            .field("writer", &self.inner)
175            .field(
176                "buffer",
177                &format_args!("{}/{}", self.buf.len(), self.buf.capacity()),
178            )
179            .field("written", &self.written)
180            .finish()
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn assert_unpin() {
190        crate::is_unpin::<BufWriter<()>>();
191    }
192}