sqlx_core/net/socket/
buffered.rs

1use crate::error::Error;
2use crate::net::Socket;
3use bytes::BytesMut;
4use std::ops::ControlFlow;
5use std::{cmp, io};
6
7use crate::io::{AsyncRead, AsyncReadExt, ProtocolDecode, ProtocolEncode};
8
9// Tokio, async-std, and std all use this as the default capacity for their buffered I/O.
10const DEFAULT_BUF_SIZE: usize = 8192;
11
12pub struct BufferedSocket<S> {
13    socket: S,
14    write_buf: WriteBuffer,
15    read_buf: ReadBuffer,
16}
17
18pub struct WriteBuffer {
19    buf: Vec<u8>,
20    bytes_written: usize,
21    bytes_flushed: usize,
22}
23
24pub struct ReadBuffer {
25    read: BytesMut,
26    available: BytesMut,
27}
28
29impl<S: Socket> BufferedSocket<S> {
30    pub fn new(socket: S) -> Self
31    where
32        S: Sized,
33    {
34        BufferedSocket {
35            socket,
36            write_buf: WriteBuffer {
37                buf: Vec::with_capacity(DEFAULT_BUF_SIZE),
38                bytes_written: 0,
39                bytes_flushed: 0,
40            },
41            read_buf: ReadBuffer {
42                read: BytesMut::new(),
43                available: BytesMut::with_capacity(DEFAULT_BUF_SIZE),
44            },
45        }
46    }
47
48    pub async fn read_buffered(&mut self, len: usize) -> Result<BytesMut, Error> {
49        self.try_read(|buf| {
50            Ok(if buf.len() < len {
51                ControlFlow::Continue(len)
52            } else {
53                ControlFlow::Break(buf.split_to(len))
54            })
55        })
56        .await
57    }
58
59    /// Retryable read operation.
60    ///
61    /// The callback should check the contents of the buffer passed to it and either:
62    ///
63    /// * Remove a full message from the buffer and return [`ControlFlow::Break`], or:
64    /// * Return [`ControlFlow::Continue`] with the expected _total_ length of the buffer,
65    ///   _without_ modifying it.
66    ///
67    /// Cancel-safe as long as the callback does not modify the passed `BytesMut`
68    /// before returning [`ControlFlow::Continue`].
69    pub async fn try_read<F, R>(&mut self, mut try_read: F) -> Result<R, Error>
70    where
71        F: FnMut(&mut BytesMut) -> Result<ControlFlow<R, usize>, Error>,
72    {
73        loop {
74            let read_len = match try_read(&mut self.read_buf.read)? {
75                ControlFlow::Continue(read_len) => read_len,
76                ControlFlow::Break(ret) => return Ok(ret),
77            };
78
79            self.read_buf.read(read_len, &mut self.socket).await?;
80        }
81    }
82
83    pub fn write_buffer(&self) -> &WriteBuffer {
84        &self.write_buf
85    }
86
87    pub fn write_buffer_mut(&mut self) -> &mut WriteBuffer {
88        &mut self.write_buf
89    }
90
91    pub async fn read<'de, T>(&mut self, byte_len: usize) -> Result<T, Error>
92    where
93        T: ProtocolDecode<'de, ()>,
94    {
95        self.read_with(byte_len, ()).await
96    }
97
98    pub async fn read_with<'de, T, C>(&mut self, byte_len: usize, context: C) -> Result<T, Error>
99    where
100        T: ProtocolDecode<'de, C>,
101    {
102        T::decode_with(self.read_buffered(byte_len).await?.freeze(), context)
103    }
104
105    #[inline(always)]
106    pub fn write<'en, T>(&mut self, value: T) -> Result<(), Error>
107    where
108        T: ProtocolEncode<'en, ()>,
109    {
110        self.write_with(value, ())
111    }
112
113    #[inline(always)]
114    pub fn write_with<'en, T, C>(&mut self, value: T, context: C) -> Result<(), Error>
115    where
116        T: ProtocolEncode<'en, C>,
117    {
118        value.encode_with(self.write_buf.buf_mut(), context)?;
119        self.write_buf.bytes_written = self.write_buf.buf.len();
120        self.write_buf.sanity_check();
121
122        Ok(())
123    }
124
125    pub async fn flush(&mut self) -> io::Result<()> {
126        while !self.write_buf.is_empty() {
127            let written = self.socket.write(self.write_buf.get()).await?;
128            self.write_buf.consume(written);
129            self.write_buf.sanity_check();
130        }
131
132        self.socket.flush().await?;
133
134        Ok(())
135    }
136
137    pub async fn shutdown(&mut self) -> io::Result<()> {
138        self.flush().await?;
139        self.socket.shutdown().await
140    }
141
142    pub fn shrink_buffers(&mut self) {
143        // Won't drop data still in the buffer.
144        self.write_buf.shrink();
145        self.read_buf.shrink();
146    }
147
148    pub fn into_inner(self) -> S {
149        self.socket
150    }
151
152    pub fn boxed(self) -> BufferedSocket<Box<dyn Socket>> {
153        BufferedSocket {
154            socket: Box::new(self.socket),
155            write_buf: self.write_buf,
156            read_buf: self.read_buf,
157        }
158    }
159}
160
161impl WriteBuffer {
162    fn sanity_check(&self) {
163        assert_ne!(self.buf.capacity(), 0);
164        assert!(self.bytes_written <= self.buf.len());
165        assert!(self.bytes_flushed <= self.bytes_written);
166    }
167
168    pub fn buf_mut(&mut self) -> &mut Vec<u8> {
169        self.buf.truncate(self.bytes_written);
170        self.sanity_check();
171        &mut self.buf
172    }
173
174    pub fn init_remaining_mut(&mut self) -> &mut [u8] {
175        self.buf.resize(self.buf.capacity(), 0);
176        self.sanity_check();
177        &mut self.buf[self.bytes_written..]
178    }
179
180    pub fn put_slice(&mut self, slice: &[u8]) {
181        // If we already have an initialized area that can fit the slice,
182        // don't change `self.buf.len()`
183        if let Some(dest) = self.buf[self.bytes_written..].get_mut(..slice.len()) {
184            dest.copy_from_slice(slice);
185        } else {
186            self.buf.truncate(self.bytes_written);
187            self.buf.extend_from_slice(slice);
188        }
189        self.advance(slice.len());
190        self.sanity_check();
191    }
192
193    pub fn advance(&mut self, amt: usize) {
194        let new_bytes_written = self
195            .bytes_written
196            .checked_add(amt)
197            .expect("self.bytes_written + amt overflowed");
198
199        assert!(new_bytes_written <= self.buf.len());
200
201        self.bytes_written = new_bytes_written;
202
203        self.sanity_check();
204    }
205
206    /// Read into the buffer from `source`, returning the number of bytes read.
207    ///
208    /// The buffer is automatically advanced by the number of bytes read.
209    pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> io::Result<usize> {
210        let read = match () {
211            // Tokio lets us read into the buffer without zeroing first
212            #[cfg(feature = "_rt-tokio")]
213            _ => source.read_buf(self.buf_mut()).await?,
214            #[cfg(not(feature = "_rt-tokio"))]
215            _ => source.read(self.init_remaining_mut()).await?,
216        };
217
218        if read > 0 {
219            self.advance(read);
220        }
221
222        Ok(read)
223    }
224
225    pub fn is_empty(&self) -> bool {
226        self.bytes_flushed >= self.bytes_written
227    }
228
229    pub fn is_full(&self) -> bool {
230        self.bytes_written == self.buf.len()
231    }
232
233    pub fn get(&self) -> &[u8] {
234        &self.buf[self.bytes_flushed..self.bytes_written]
235    }
236
237    pub fn get_mut(&mut self) -> &mut [u8] {
238        &mut self.buf[self.bytes_flushed..self.bytes_written]
239    }
240
241    pub fn shrink(&mut self) {
242        if self.bytes_flushed > 0 {
243            // Move any data that remains to be flushed to the beginning of the buffer,
244            // if necessary.
245            self.buf
246                .copy_within(self.bytes_flushed..self.bytes_written, 0);
247            self.bytes_written -= self.bytes_flushed;
248            self.bytes_flushed = 0
249        }
250
251        // Drop excess capacity.
252        self.buf
253            .truncate(cmp::max(self.bytes_written, DEFAULT_BUF_SIZE));
254        self.buf.shrink_to_fit();
255    }
256
257    fn consume(&mut self, amt: usize) {
258        let new_bytes_flushed = self
259            .bytes_flushed
260            .checked_add(amt)
261            .expect("self.bytes_flushed + amt overflowed");
262
263        assert!(new_bytes_flushed <= self.bytes_written);
264
265        self.bytes_flushed = new_bytes_flushed;
266
267        if self.bytes_flushed == self.bytes_written {
268            // Reset cursors to zero if we've consumed the whole buffer
269            self.bytes_flushed = 0;
270            self.bytes_written = 0;
271        }
272
273        self.sanity_check();
274    }
275}
276
277impl ReadBuffer {
278    async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result<()> {
279        // Because of how `BytesMut` works, we should only be shifting capacity back and forth
280        // between `read` and `available` unless we have to read an oversize message.
281        while self.read.len() < len {
282            self.reserve(len - self.read.len());
283
284            let read = socket.read(&mut self.available).await?;
285
286            if read == 0 {
287                return Err(io::Error::new(
288                    io::ErrorKind::UnexpectedEof,
289                    format!(
290                        "expected to read {} bytes, got {} bytes at EOF",
291                        len,
292                        self.read.len()
293                    ),
294                ));
295            }
296
297            self.advance(read);
298        }
299
300        Ok(())
301    }
302
303    fn reserve(&mut self, amt: usize) {
304        if let Some(additional) = amt.checked_sub(self.available.capacity()) {
305            self.available.reserve(additional);
306        }
307    }
308
309    fn advance(&mut self, amt: usize) {
310        self.read.unsplit(self.available.split_to(amt));
311    }
312
313    fn shrink(&mut self) {
314        if self.available.capacity() > DEFAULT_BUF_SIZE {
315            // `BytesMut` doesn't have a way to shrink its capacity,
316            // but we only use `available` for spare capacity anyway so we can just replace it.
317            //
318            // If `self.read` still contains data on the next call to `advance` then this might
319            // force a memcpy as they'll no longer be pointing to the same allocation,
320            // but that's kind of unavoidable.
321            //
322            // The `async-std` impl of `Socket` will also need to re-zero the buffer,
323            // but that's also kind of unavoidable.
324            //
325            // We should be warning the user not to call this often.
326            self.available = BytesMut::with_capacity(DEFAULT_BUF_SIZE);
327        }
328    }
329}