use std::{
fmt::{self, Debug},
io,
net::SocketAddr,
sync::{Arc, Mutex},
task,
};
use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use quinn::{AsyncUdpSocket, EndpointConfig};
use quinn_udp::RecvMeta;
struct FlumeSocketInner {
local: SocketAddr,
receiver: flume::r#async::RecvStream<'static, Packet>,
sender: flume::r#async::SendSink<'static, Packet>,
}
impl fmt::Debug for FlumeSocketInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FlumeSocketInner")
.field("local", &self.local)
.finish_non_exhaustive()
}
}
#[derive(Debug)]
pub struct Packet {
pub from: SocketAddr,
pub to: SocketAddr,
pub contents: Bytes,
pub segment_size: Option<usize>,
}
#[derive(Debug)]
pub struct FlumeSocket(Arc<Mutex<FlumeSocketInner>>);
impl FlumeSocket {
pub fn new(local: SocketAddr, tx: flume::Sender<Packet>, rx: flume::Receiver<Packet>) -> Self {
let inner = FlumeSocketInner {
receiver: rx.into_stream(),
sender: tx.into_sink(),
local,
};
Self(Arc::new(Mutex::new(inner)))
}
}
impl AsyncUdpSocket for FlumeSocket {
fn poll_send(
&self,
state: &quinn_udp::UdpState,
cx: &mut task::Context,
transmits: &[quinn_udp::Transmit],
) -> task::Poll<Result<usize, io::Error>> {
self.0.lock().unwrap().poll_send(state, cx, transmits)
}
fn poll_recv(
&self,
cx: &mut task::Context,
bufs: &mut [io::IoSliceMut<'_>],
meta: &mut [RecvMeta],
) -> task::Poll<io::Result<usize>> {
self.0.lock().unwrap().poll_recv(cx, bufs, meta)
}
fn local_addr(&self) -> io::Result<SocketAddr> {
self.0.lock().unwrap().local_addr()
}
}
impl FlumeSocketInner {
fn poll_send(
&mut self,
_state: &quinn_udp::UdpState,
cx: &mut task::Context,
transmits: &[quinn_udp::Transmit],
) -> task::Poll<Result<usize, std::io::Error>> {
if transmits.is_empty() {
return task::Poll::Ready(Ok(0));
}
let mut offset = 0;
let mut pending = false;
tracing::trace!("S {} transmits", transmits.len());
for transmit in transmits {
let item = Packet {
from: self.local,
to: transmit.destination,
contents: transmit.contents.clone(),
segment_size: transmit.segment_size,
};
tracing::trace!(
"S sending {} {:?}",
transmit.contents.len(),
transmit.segment_size
);
let res = self.sender.poll_ready_unpin(cx);
match res {
task::Poll::Ready(Ok(())) => {
if self.sender.start_send_unpin(item).is_err() {
break;
}
}
task::Poll::Ready(Err(_)) => {
break;
}
task::Poll::Pending => {
pending = true;
break;
}
}
offset += 1;
}
if offset > 0 {
if let task::Poll::Ready(Err(_)) = self.sender.poll_flush_unpin(cx) {
return task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"all receivers dropped",
)));
}
task::Poll::Ready(Ok(offset))
} else if pending {
task::Poll::Pending
} else {
task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"all receivers dropped",
)))
}
}
fn poll_recv(
&mut self,
cx: &mut std::task::Context,
bufs: &mut [io::IoSliceMut<'_>],
meta: &mut [quinn_udp::RecvMeta],
) -> task::Poll<io::Result<usize>> {
let n = bufs.len().min(meta.len());
if n == 0 {
return task::Poll::Ready(Ok(0));
}
let mut offset = 0;
let mut pending = false;
while offset < n {
let packet = match self.receiver.poll_next_unpin(cx) {
task::Poll::Ready(Some(recv)) => recv,
task::Poll::Ready(None) => break,
task::Poll::Pending => {
pending = true;
break;
}
};
if packet.to == self.local {
let len = packet.contents.len();
let m = quinn_udp::RecvMeta {
addr: packet.from,
len,
stride: packet.segment_size.unwrap_or(len),
ecn: None,
dst_ip: Some(self.local.ip()),
};
tracing::trace!("R bufs {} bytes, {} slots", bufs[offset].len(), n);
bufs[offset][..len].copy_from_slice(&packet.contents);
meta[offset] = m;
offset += 1;
} else {
continue;
}
}
if offset > 0 {
task::Poll::Ready(Ok(offset))
} else if pending {
task::Poll::Pending
} else {
task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"all senders dropped",
)))
}
}
fn local_addr(&self) -> std::io::Result<SocketAddr> {
Ok(self.local)
}
}
pub(crate) fn make_endpoint(
socket: FlumeSocket,
config: EndpointConfig,
server_config: Option<quinn::ServerConfig>,
) -> io::Result<quinn::Endpoint> {
quinn::Endpoint::new_with_abstract_socket(
config,
server_config,
socket,
Arc::new(quinn::TokioRuntime),
)
}
pub fn endpoint_pair(
server_addr: SocketAddr,
client_addr: SocketAddr,
server_config: quinn::ServerConfig,
) -> io::Result<(quinn::Endpoint, quinn::Endpoint)> {
let (tx1, rx1) = flume::bounded(16);
let (tx2, rx2) = flume::bounded(16);
let server = FlumeSocket::new(server_addr, tx1, rx2);
let client = FlumeSocket::new(client_addr, tx2, rx1);
let ac = EndpointConfig::default();
let bc = EndpointConfig::default();
Ok((
make_endpoint(server, ac, Some(server_config))?,
make_endpoint(client, bc, None)?,
))
}