broker_tokio/io/util/
buf_writer.rs1use crate::io::util::DEFAULT_BUF_SIZE;
2use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite};
3
4use pin_project_lite::pin_project;
5use std::fmt;
6use std::io::{self, Write};
7use std::mem::MaybeUninit;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11pin_project! {
12 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
33 pub struct BufWriter<W> {
34 #[pin]
35 pub(super) inner: W,
36 pub(super) buf: Vec<u8>,
37 pub(super) written: usize,
38 }
39}
40
41impl<W: AsyncWrite> BufWriter<W> {
42 pub fn new(inner: W) -> Self {
45 Self::with_capacity(DEFAULT_BUF_SIZE, inner)
46 }
47
48 pub fn with_capacity(cap: usize, inner: W) -> Self {
50 Self {
51 inner,
52 buf: Vec::with_capacity(cap),
53 written: 0,
54 }
55 }
56
57 fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
58 let mut me = self.project();
59
60 let len = me.buf.len();
61 let mut ret = Ok(());
62 while *me.written < len {
63 match ready!(me.inner.as_mut().poll_write(cx, &me.buf[*me.written..])) {
64 Ok(0) => {
65 ret = Err(io::Error::new(
66 io::ErrorKind::WriteZero,
67 "failed to write the buffered data",
68 ));
69 break;
70 }
71 Ok(n) => *me.written += n,
72 Err(e) => {
73 ret = Err(e);
74 break;
75 }
76 }
77 }
78 if *me.written > 0 {
79 me.buf.drain(..*me.written);
80 }
81 *me.written = 0;
82 Poll::Ready(ret)
83 }
84
85 pub fn get_ref(&self) -> &W {
87 &self.inner
88 }
89
90 pub fn get_mut(&mut self) -> &mut W {
94 &mut self.inner
95 }
96
97 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
101 self.project().inner
102 }
103
104 pub fn into_inner(self) -> W {
108 self.inner
109 }
110
111 pub fn buffer(&self) -> &[u8] {
113 &self.buf
114 }
115}
116
117impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
118 fn poll_write(
119 mut self: Pin<&mut Self>,
120 cx: &mut Context<'_>,
121 buf: &[u8],
122 ) -> Poll<io::Result<usize>> {
123 if self.buf.len() + buf.len() > self.buf.capacity() {
124 ready!(self.as_mut().flush_buf(cx))?;
125 }
126
127 let me = self.project();
128 if buf.len() >= me.buf.capacity() {
129 me.inner.poll_write(cx, buf)
130 } else {
131 Poll::Ready(me.buf.write(buf))
132 }
133 }
134
135 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
136 ready!(self.as_mut().flush_buf(cx))?;
137 self.get_pin_mut().poll_flush(cx)
138 }
139
140 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
141 ready!(self.as_mut().flush_buf(cx))?;
142 self.get_pin_mut().poll_shutdown(cx)
143 }
144}
145
146impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> {
147 fn poll_read(
148 self: Pin<&mut Self>,
149 cx: &mut Context<'_>,
150 buf: &mut [u8],
151 ) -> Poll<io::Result<usize>> {
152 self.get_pin_mut().poll_read(cx, buf)
153 }
154
155 unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
157 self.get_ref().prepare_uninitialized_buffer(buf)
158 }
159}
160
161impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> {
162 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
163 self.get_pin_mut().poll_fill_buf(cx)
164 }
165
166 fn consume(self: Pin<&mut Self>, amt: usize) {
167 self.get_pin_mut().consume(amt)
168 }
169}
170
171impl<W: fmt::Debug> fmt::Debug for BufWriter<W> {
172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173 f.debug_struct("BufWriter")
174 .field("writer", &self.inner)
175 .field(
176 "buffer",
177 &format_args!("{}/{}", self.buf.len(), self.buf.capacity()),
178 )
179 .field("written", &self.written)
180 .finish()
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn assert_unpin() {
190 crate::is_unpin::<BufWriter<()>>();
191 }
192}