sqlx_core/net/socket/
buffered.rs1use crate::error::Error;
2use crate::net::Socket;
3use bytes::BytesMut;
4use std::ops::ControlFlow;
5use std::{cmp, io};
6
7use crate::io::{AsyncRead, AsyncReadExt, ProtocolDecode, ProtocolEncode};
8
9const DEFAULT_BUF_SIZE: usize = 8192;
11
12pub struct BufferedSocket<S> {
13 socket: S,
14 write_buf: WriteBuffer,
15 read_buf: ReadBuffer,
16}
17
18pub struct WriteBuffer {
19 buf: Vec<u8>,
20 bytes_written: usize,
21 bytes_flushed: usize,
22}
23
24pub struct ReadBuffer {
25 read: BytesMut,
26 available: BytesMut,
27}
28
29impl<S: Socket> BufferedSocket<S> {
30 pub fn new(socket: S) -> Self
31 where
32 S: Sized,
33 {
34 BufferedSocket {
35 socket,
36 write_buf: WriteBuffer {
37 buf: Vec::with_capacity(DEFAULT_BUF_SIZE),
38 bytes_written: 0,
39 bytes_flushed: 0,
40 },
41 read_buf: ReadBuffer {
42 read: BytesMut::new(),
43 available: BytesMut::with_capacity(DEFAULT_BUF_SIZE),
44 },
45 }
46 }
47
48 pub async fn read_buffered(&mut self, len: usize) -> Result<BytesMut, Error> {
49 self.try_read(|buf| {
50 Ok(if buf.len() < len {
51 ControlFlow::Continue(len)
52 } else {
53 ControlFlow::Break(buf.split_to(len))
54 })
55 })
56 .await
57 }
58
59 pub async fn try_read<F, R>(&mut self, mut try_read: F) -> Result<R, Error>
70 where
71 F: FnMut(&mut BytesMut) -> Result<ControlFlow<R, usize>, Error>,
72 {
73 loop {
74 let read_len = match try_read(&mut self.read_buf.read)? {
75 ControlFlow::Continue(read_len) => read_len,
76 ControlFlow::Break(ret) => return Ok(ret),
77 };
78
79 self.read_buf.read(read_len, &mut self.socket).await?;
80 }
81 }
82
83 pub fn write_buffer(&self) -> &WriteBuffer {
84 &self.write_buf
85 }
86
87 pub fn write_buffer_mut(&mut self) -> &mut WriteBuffer {
88 &mut self.write_buf
89 }
90
91 pub async fn read<'de, T>(&mut self, byte_len: usize) -> Result<T, Error>
92 where
93 T: ProtocolDecode<'de, ()>,
94 {
95 self.read_with(byte_len, ()).await
96 }
97
98 pub async fn read_with<'de, T, C>(&mut self, byte_len: usize, context: C) -> Result<T, Error>
99 where
100 T: ProtocolDecode<'de, C>,
101 {
102 T::decode_with(self.read_buffered(byte_len).await?.freeze(), context)
103 }
104
105 #[inline(always)]
106 pub fn write<'en, T>(&mut self, value: T) -> Result<(), Error>
107 where
108 T: ProtocolEncode<'en, ()>,
109 {
110 self.write_with(value, ())
111 }
112
113 #[inline(always)]
114 pub fn write_with<'en, T, C>(&mut self, value: T, context: C) -> Result<(), Error>
115 where
116 T: ProtocolEncode<'en, C>,
117 {
118 value.encode_with(self.write_buf.buf_mut(), context)?;
119 self.write_buf.bytes_written = self.write_buf.buf.len();
120 self.write_buf.sanity_check();
121
122 Ok(())
123 }
124
125 pub async fn flush(&mut self) -> io::Result<()> {
126 while !self.write_buf.is_empty() {
127 let written = self.socket.write(self.write_buf.get()).await?;
128 self.write_buf.consume(written);
129 self.write_buf.sanity_check();
130 }
131
132 self.socket.flush().await?;
133
134 Ok(())
135 }
136
137 pub async fn shutdown(&mut self) -> io::Result<()> {
138 self.flush().await?;
139 self.socket.shutdown().await
140 }
141
142 pub fn shrink_buffers(&mut self) {
143 self.write_buf.shrink();
145 self.read_buf.shrink();
146 }
147
148 pub fn into_inner(self) -> S {
149 self.socket
150 }
151
152 pub fn boxed(self) -> BufferedSocket<Box<dyn Socket>> {
153 BufferedSocket {
154 socket: Box::new(self.socket),
155 write_buf: self.write_buf,
156 read_buf: self.read_buf,
157 }
158 }
159}
160
161impl WriteBuffer {
162 fn sanity_check(&self) {
163 assert_ne!(self.buf.capacity(), 0);
164 assert!(self.bytes_written <= self.buf.len());
165 assert!(self.bytes_flushed <= self.bytes_written);
166 }
167
168 pub fn buf_mut(&mut self) -> &mut Vec<u8> {
169 self.buf.truncate(self.bytes_written);
170 self.sanity_check();
171 &mut self.buf
172 }
173
174 pub fn init_remaining_mut(&mut self) -> &mut [u8] {
175 self.buf.resize(self.buf.capacity(), 0);
176 self.sanity_check();
177 &mut self.buf[self.bytes_written..]
178 }
179
180 pub fn put_slice(&mut self, slice: &[u8]) {
181 if let Some(dest) = self.buf[self.bytes_written..].get_mut(..slice.len()) {
184 dest.copy_from_slice(slice);
185 } else {
186 self.buf.truncate(self.bytes_written);
187 self.buf.extend_from_slice(slice);
188 }
189 self.advance(slice.len());
190 self.sanity_check();
191 }
192
193 pub fn advance(&mut self, amt: usize) {
194 let new_bytes_written = self
195 .bytes_written
196 .checked_add(amt)
197 .expect("self.bytes_written + amt overflowed");
198
199 assert!(new_bytes_written <= self.buf.len());
200
201 self.bytes_written = new_bytes_written;
202
203 self.sanity_check();
204 }
205
206 pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> io::Result<usize> {
210 let read = match () {
211 #[cfg(feature = "_rt-tokio")]
213 _ => source.read_buf(self.buf_mut()).await?,
214 #[cfg(not(feature = "_rt-tokio"))]
215 _ => source.read(self.init_remaining_mut()).await?,
216 };
217
218 if read > 0 {
219 self.advance(read);
220 }
221
222 Ok(read)
223 }
224
225 pub fn is_empty(&self) -> bool {
226 self.bytes_flushed >= self.bytes_written
227 }
228
229 pub fn is_full(&self) -> bool {
230 self.bytes_written == self.buf.len()
231 }
232
233 pub fn get(&self) -> &[u8] {
234 &self.buf[self.bytes_flushed..self.bytes_written]
235 }
236
237 pub fn get_mut(&mut self) -> &mut [u8] {
238 &mut self.buf[self.bytes_flushed..self.bytes_written]
239 }
240
241 pub fn shrink(&mut self) {
242 if self.bytes_flushed > 0 {
243 self.buf
246 .copy_within(self.bytes_flushed..self.bytes_written, 0);
247 self.bytes_written -= self.bytes_flushed;
248 self.bytes_flushed = 0
249 }
250
251 self.buf
253 .truncate(cmp::max(self.bytes_written, DEFAULT_BUF_SIZE));
254 self.buf.shrink_to_fit();
255 }
256
257 fn consume(&mut self, amt: usize) {
258 let new_bytes_flushed = self
259 .bytes_flushed
260 .checked_add(amt)
261 .expect("self.bytes_flushed + amt overflowed");
262
263 assert!(new_bytes_flushed <= self.bytes_written);
264
265 self.bytes_flushed = new_bytes_flushed;
266
267 if self.bytes_flushed == self.bytes_written {
268 self.bytes_flushed = 0;
270 self.bytes_written = 0;
271 }
272
273 self.sanity_check();
274 }
275}
276
277impl ReadBuffer {
278 async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result<()> {
279 while self.read.len() < len {
282 self.reserve(len - self.read.len());
283
284 let read = socket.read(&mut self.available).await?;
285
286 if read == 0 {
287 return Err(io::Error::new(
288 io::ErrorKind::UnexpectedEof,
289 format!(
290 "expected to read {} bytes, got {} bytes at EOF",
291 len,
292 self.read.len()
293 ),
294 ));
295 }
296
297 self.advance(read);
298 }
299
300 Ok(())
301 }
302
303 fn reserve(&mut self, amt: usize) {
304 if let Some(additional) = amt.checked_sub(self.available.capacity()) {
305 self.available.reserve(additional);
306 }
307 }
308
309 fn advance(&mut self, amt: usize) {
310 self.read.unsplit(self.available.split_to(amt));
311 }
312
313 fn shrink(&mut self) {
314 if self.available.capacity() > DEFAULT_BUF_SIZE {
315 self.available = BytesMut::with_capacity(DEFAULT_BUF_SIZE);
327 }
328 }
329}