pingora_core/protocols/l4/
socket.rsuse crate::{Error, OrErr};
use log::warn;
#[cfg(unix)]
use nix::sys::socket::{getpeername, getsockname, SockaddrStorage};
use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::net::SocketAddr as StdSockAddr;
#[cfg(unix)]
use std::os::unix::net::SocketAddr as StdUnixSockAddr;
#[cfg(unix)]
use tokio::net::unix::SocketAddr as TokioUnixSockAddr;
#[derive(Debug, Clone)]
pub enum SocketAddr {
Inet(StdSockAddr),
#[cfg(unix)]
Unix(StdUnixSockAddr),
}
impl SocketAddr {
pub fn as_inet(&self) -> Option<&StdSockAddr> {
if let SocketAddr::Inet(addr) = self {
Some(addr)
} else {
None
}
}
#[cfg(unix)]
pub fn as_unix(&self) -> Option<&StdUnixSockAddr> {
if let SocketAddr::Unix(addr) = self {
Some(addr)
} else {
None
}
}
pub fn set_port(&mut self, port: u16) {
if let SocketAddr::Inet(addr) = self {
addr.set_port(port)
}
}
#[cfg(unix)]
fn from_sockaddr_storage(sock: &SockaddrStorage) -> Option<SocketAddr> {
if let Some(v4) = sock.as_sockaddr_in() {
return Some(SocketAddr::Inet(StdSockAddr::V4(
std::net::SocketAddrV4::new(v4.ip().into(), v4.port()),
)));
} else if let Some(v6) = sock.as_sockaddr_in6() {
return Some(SocketAddr::Inet(StdSockAddr::V6(
std::net::SocketAddrV6::new(v6.ip(), v6.port(), v6.flowinfo(), v6.scope_id()),
)));
}
Some(SocketAddr::Unix(
sock.as_unix_addr()
.map(|addr| addr.path().map(StdUnixSockAddr::from_pathname))??
.ok()?,
))
}
#[cfg(unix)]
pub fn from_raw_fd(fd: std::os::unix::io::RawFd, peer_addr: bool) -> Option<SocketAddr> {
let sockaddr_storage = if peer_addr {
getpeername(fd)
} else {
getsockname(fd)
};
match sockaddr_storage {
Ok(sockaddr) => Self::from_sockaddr_storage(&sockaddr),
Err(_e) => None,
}
}
#[cfg(windows)]
pub fn from_raw_socket(
sock: std::os::windows::io::RawSocket,
is_peer_addr: bool,
) -> Option<SocketAddr> {
use crate::protocols::windows::{local_addr, peer_addr};
if is_peer_addr {
peer_addr(sock)
} else {
local_addr(sock)
}
.map(|s| s.into())
.ok()
}
}
impl std::fmt::Display for SocketAddr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
SocketAddr::Inet(addr) => write!(f, "{addr}"),
#[cfg(unix)]
SocketAddr::Unix(addr) => {
if let Some(path) = addr.as_pathname() {
write!(f, "{}", path.display())
} else {
write!(f, "{addr:?}")
}
}
}
}
}
impl Hash for SocketAddr {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
Self::Inet(sockaddr) => sockaddr.hash(state),
#[cfg(unix)]
Self::Unix(sockaddr) => {
if let Some(path) = sockaddr.as_pathname() {
path.hash(state);
} else {
panic!("Unnamed and abstract UDS types not yet supported for hashing")
}
}
}
}
}
impl PartialEq for SocketAddr {
fn eq(&self, other: &Self) -> bool {
match self {
Self::Inet(addr) => Some(addr) == other.as_inet(),
#[cfg(unix)]
Self::Unix(addr) => {
let path = addr.as_pathname();
path.is_some() && path == other.as_unix().and_then(|addr| addr.as_pathname())
}
}
}
}
impl PartialOrd for SocketAddr {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SocketAddr {
fn cmp(&self, other: &Self) -> Ordering {
match self {
Self::Inet(addr) => {
if let Some(o) = other.as_inet() {
addr.cmp(o)
} else {
Ordering::Less
}
}
#[cfg(unix)]
Self::Unix(addr) => {
if let Some(o) = other.as_unix() {
addr.as_pathname().cmp(&o.as_pathname())
} else {
Ordering::Greater
}
}
}
}
}
impl Eq for SocketAddr {}
impl std::str::FromStr for SocketAddr {
type Err = Box<Error>;
#[cfg(unix)]
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.starts_with("unix:") {
let path = s.trim_start_matches("unix:");
let uds_socket = StdUnixSockAddr::from_pathname(path)
.or_err(crate::BindError, "invalid UDS path")?;
Ok(SocketAddr::Unix(uds_socket))
} else {
match StdSockAddr::from_str(s) {
Ok(addr) => Ok(SocketAddr::Inet(addr)),
Err(_) => {
let uds_socket = StdUnixSockAddr::from_pathname(s)
.or_err(crate::BindError, "invalid UDS path")?;
warn!("Raw Unix domain socket path support will be deprecated, add 'unix:' prefix instead");
Ok(SocketAddr::Unix(uds_socket))
}
}
}
}
#[cfg(windows)]
fn from_str(s: &str) -> Result<Self, Self::Err> {
let addr = StdSockAddr::from_str(s).or_err(crate::BindError, "invalid socket addr")?;
Ok(SocketAddr::Inet(addr))
}
}
impl std::net::ToSocketAddrs for SocketAddr {
type Iter = std::iter::Once<StdSockAddr>;
fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
if let Some(inet) = self.as_inet() {
Ok(std::iter::once(*inet))
} else {
Err(std::io::Error::new(
std::io::ErrorKind::Other,
"UDS socket cannot be used as inet socket",
))
}
}
}
impl From<StdSockAddr> for SocketAddr {
fn from(sockaddr: StdSockAddr) -> Self {
SocketAddr::Inet(sockaddr)
}
}
#[cfg(unix)]
impl From<StdUnixSockAddr> for SocketAddr {
fn from(sockaddr: StdUnixSockAddr) -> Self {
SocketAddr::Unix(sockaddr)
}
}
#[cfg(unix)]
impl TryFrom<TokioUnixSockAddr> for SocketAddr {
type Error = String;
fn try_from(value: TokioUnixSockAddr) -> Result<Self, Self::Error> {
if let Some(Ok(addr)) = value.as_pathname().map(StdUnixSockAddr::from_pathname) {
Ok(addr.into())
} else {
Err(format!("could not convert {value:?} to SocketAddr"))
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn parse_ip() {
let ip: SocketAddr = "127.0.0.1:80".parse().unwrap();
assert!(ip.as_inet().is_some());
}
#[cfg(unix)]
#[test]
fn parse_uds() {
let uds: SocketAddr = "/tmp/my.sock".parse().unwrap();
assert!(uds.as_unix().is_some());
}
#[cfg(unix)]
#[test]
fn parse_uds_with_prefix() {
let uds: SocketAddr = "unix:/tmp/my.sock".parse().unwrap();
assert!(uds.as_unix().is_some());
}
}