use crate::net::Socket;
use bytes::BytesMut;
use std::{cmp, io};
use crate::error::Error;
use crate::io::{Decode, Encode};
const DEFAULT_BUF_SIZE: usize = 8192;
pub struct BufferedSocket<S> {
socket: S,
write_buf: WriteBuffer,
read_buf: ReadBuffer,
}
pub struct WriteBuffer {
buf: Vec<u8>,
bytes_written: usize,
bytes_flushed: usize,
}
pub struct ReadBuffer {
read: BytesMut,
available: BytesMut,
}
impl<S: Socket> BufferedSocket<S> {
pub fn new(socket: S) -> Self
where
S: Sized,
{
BufferedSocket {
socket,
write_buf: WriteBuffer {
buf: Vec::with_capacity(DEFAULT_BUF_SIZE),
bytes_written: 0,
bytes_flushed: 0,
},
read_buf: ReadBuffer {
read: BytesMut::new(),
available: BytesMut::with_capacity(DEFAULT_BUF_SIZE),
},
}
}
pub async fn read_buffered(&mut self, len: usize) -> io::Result<BytesMut> {
self.read_buf.read(len, &mut self.socket).await
}
pub fn write_buffer(&self) -> &WriteBuffer {
&self.write_buf
}
pub fn write_buffer_mut(&mut self) -> &mut WriteBuffer {
&mut self.write_buf
}
pub async fn read<'de, T>(&mut self, byte_len: usize) -> Result<T, Error>
where
T: Decode<'de, ()>,
{
self.read_with(byte_len, ()).await
}
pub async fn read_with<'de, T, C>(&mut self, byte_len: usize, context: C) -> Result<T, Error>
where
T: Decode<'de, C>,
{
T::decode_with(self.read_buffered(byte_len).await?.freeze(), context)
}
pub fn write<'en, T>(&mut self, value: T)
where
T: Encode<'en, ()>,
{
self.write_with(value, ())
}
pub fn write_with<'en, T, C>(&mut self, value: T, context: C)
where
T: Encode<'en, C>,
{
value.encode_with(self.write_buf.buf_mut(), context);
self.write_buf.bytes_written = self.write_buf.buf.len();
self.write_buf.sanity_check();
}
pub async fn flush(&mut self) -> io::Result<()> {
while !self.write_buf.is_empty() {
let written = self.socket.write(self.write_buf.get()).await?;
self.write_buf.consume(written);
self.write_buf.sanity_check();
}
self.socket.flush().await?;
Ok(())
}
pub async fn shutdown(&mut self) -> io::Result<()> {
self.flush().await?;
self.socket.shutdown().await
}
pub fn shrink_buffers(&mut self) {
self.write_buf.shrink();
self.read_buf.shrink();
}
pub fn into_inner(self) -> S {
self.socket
}
pub fn boxed(self) -> BufferedSocket<Box<dyn Socket>> {
BufferedSocket {
socket: Box::new(self.socket),
write_buf: self.write_buf,
read_buf: self.read_buf,
}
}
}
impl WriteBuffer {
fn sanity_check(&self) {
assert_ne!(self.buf.capacity(), 0);
assert!(self.bytes_written <= self.buf.len());
assert!(self.bytes_flushed <= self.bytes_written);
}
pub fn buf_mut(&mut self) -> &mut Vec<u8> {
self.buf.truncate(self.bytes_written);
self.sanity_check();
&mut self.buf
}
pub fn init_remaining_mut(&mut self) -> &mut [u8] {
self.buf.resize(self.buf.capacity(), 0);
self.sanity_check();
&mut self.buf[self.bytes_written..]
}
pub fn put_slice(&mut self, slice: &[u8]) {
if let Some(dest) = self.buf[self.bytes_written..].get_mut(..slice.len()) {
dest.copy_from_slice(slice);
} else {
self.buf.truncate(self.bytes_written);
self.buf.extend_from_slice(slice);
}
self.advance(slice.len());
self.sanity_check();
}
pub fn advance(&mut self, amt: usize) {
let new_bytes_written = self
.bytes_written
.checked_add(amt)
.expect("self.bytes_written + amt overflowed");
assert!(new_bytes_written <= self.buf.len());
self.bytes_written = new_bytes_written;
self.sanity_check();
}
pub fn is_empty(&self) -> bool {
self.bytes_flushed >= self.bytes_written
}
pub fn is_full(&self) -> bool {
self.bytes_written == self.buf.len()
}
pub fn get(&self) -> &[u8] {
&self.buf[self.bytes_flushed..self.bytes_written]
}
pub fn get_mut(&mut self) -> &mut [u8] {
&mut self.buf[self.bytes_flushed..self.bytes_written]
}
pub fn shrink(&mut self) {
if self.bytes_flushed > 0 {
self.buf
.copy_within(self.bytes_flushed..self.bytes_written, 0);
self.bytes_written -= self.bytes_flushed;
self.bytes_flushed = 0
}
self.buf
.truncate(cmp::max(self.bytes_written, DEFAULT_BUF_SIZE));
self.buf.shrink_to_fit();
}
fn consume(&mut self, amt: usize) {
let new_bytes_flushed = self
.bytes_flushed
.checked_add(amt)
.expect("self.bytes_flushed + amt overflowed");
assert!(new_bytes_flushed <= self.bytes_written);
self.bytes_flushed = new_bytes_flushed;
if self.bytes_flushed == self.bytes_written {
self.bytes_flushed = 0;
self.bytes_written = 0;
}
self.sanity_check();
}
}
impl ReadBuffer {
async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result<BytesMut> {
while self.read.len() < len {
self.reserve(len - self.read.len());
let read = socket.read(&mut self.available).await?;
if read == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"expected to read {} bytes, got {} bytes at EOF",
len,
self.read.len()
),
));
}
self.advance(read);
}
Ok(self.drain(len))
}
fn reserve(&mut self, amt: usize) {
if let Some(additional) = amt.checked_sub(self.available.capacity()) {
self.available.reserve(additional);
}
}
fn advance(&mut self, amt: usize) {
self.read.unsplit(self.available.split_to(amt));
}
fn drain(&mut self, amt: usize) -> BytesMut {
self.read.split_to(amt)
}
fn shrink(&mut self) {
if self.available.capacity() > DEFAULT_BUF_SIZE {
self.available = BytesMut::with_capacity(DEFAULT_BUF_SIZE);
}
}
}