use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use socket2::Domain;
use socket2::Protocol;
use socket2::Type;
static CONNS: std::sync::OnceLock<std::sync::Mutex<Connections>> =
std::sync::OnceLock::new();
#[derive(Default)]
struct Connections {
tcp: HashMap<SocketAddr, Arc<TcpConnection>>,
}
pub struct TcpConnection {
#[cfg(unix)]
sock: std::os::fd::OwnedFd,
#[cfg(not(unix))]
sock: std::os::windows::io::OwnedSocket,
key: SocketAddr,
}
impl TcpConnection {
pub fn start(key: SocketAddr) -> std::io::Result<Self> {
let listener = bind_socket_and_listen(key, false)?;
let sock = listener.into();
Ok(Self { sock, key })
}
fn listener(&self) -> std::io::Result<tokio::net::TcpListener> {
let listener = std::net::TcpListener::from(self.sock.try_clone()?);
let listener = tokio::net::TcpListener::from_std(listener)?;
Ok(listener)
}
}
pub struct TcpListener {
listener: Option<tokio::net::TcpListener>,
conn: Option<Arc<TcpConnection>>,
}
const REUSE_PORT_LOAD_BALANCES: bool =
cfg!(any(target_os = "android", target_os = "linux"));
impl TcpListener {
pub fn bind(
socket_addr: SocketAddr,
reuse_port: bool,
) -> std::io::Result<Self> {
if REUSE_PORT_LOAD_BALANCES && reuse_port {
Self::bind_load_balanced(socket_addr)
} else {
Self::bind_direct(socket_addr, reuse_port)
}
}
pub fn bind_direct(
socket_addr: SocketAddr,
reuse_port: bool,
) -> std::io::Result<Self> {
let listener = bind_socket_and_listen(socket_addr, reuse_port)?;
Ok(Self {
listener: Some(tokio::net::TcpListener::from_std(listener)?),
conn: None,
})
}
pub fn bind_load_balanced(socket_addr: SocketAddr) -> std::io::Result<Self> {
let tcp = &mut CONNS.get_or_init(Default::default).lock().unwrap().tcp;
if let Some(conn) = tcp.get(&socket_addr) {
let listener = Some(conn.listener()?);
return Ok(Self {
listener,
conn: Some(conn.clone()),
});
}
let conn = Arc::new(TcpConnection::start(socket_addr)?);
let listener = Some(conn.listener()?);
tcp.insert(socket_addr, conn.clone());
Ok(Self {
listener,
conn: Some(conn),
})
}
pub async fn accept(
&self,
) -> std::io::Result<(tokio::net::TcpStream, SocketAddr)> {
let (tcp, addr) = self.listener.as_ref().unwrap().accept().await?;
Ok((tcp, addr))
}
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.listener.as_ref().unwrap().local_addr()
}
}
impl Drop for TcpListener {
fn drop(&mut self) {
if let Some(conn) = self.conn.take() {
let mut tcp = CONNS.get().unwrap().lock().unwrap();
if Arc::strong_count(&conn) == 2 {
tcp.tcp.remove(&conn.key);
debug_assert_eq!(Arc::strong_count(&conn), 1);
drop(conn);
}
}
}
}
#[allow(unused_variables)]
fn bind_socket_and_listen(
socket_addr: SocketAddr,
reuse_port: bool,
) -> Result<std::net::TcpListener, std::io::Error> {
let socket = if socket_addr.is_ipv4() {
socket2::Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?
} else {
socket2::Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))?
};
#[cfg(not(windows))]
if REUSE_PORT_LOAD_BALANCES && reuse_port {
socket.set_reuse_port(true)?;
}
#[cfg(not(windows))]
socket.set_reuse_address(true)?;
socket.set_nonblocking(true)?;
socket.bind(&socket_addr.into())?;
socket.listen(128)?;
let listener = socket.into();
Ok(listener)
}