#![allow(non_camel_case_types)]
#[cfg(unix)]
use libc::socklen_t;
#[cfg(target_os = "linux")]
use libc::{c_int, c_ulonglong, c_void};
use pingora_error::{Error, ErrorType::*, OrErr, Result};
use std::io::{self, ErrorKind};
use std::mem;
use std::net::SocketAddr;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::time::Duration;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::net::{TcpSocket, TcpStream};
use crate::connectors::l4::BindTo;
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct TCP_INFO {
pub tcpi_state: u8,
pub tcpi_ca_state: u8,
pub tcpi_retransmits: u8,
pub tcpi_probes: u8,
pub tcpi_backoff: u8,
pub tcpi_options: u8,
pub tcpi_snd_wscale_4_rcv_wscale_4: u8,
pub tcpi_delivery_rate_app_limited: u8,
pub tcpi_rto: u32,
pub tcpi_ato: u32,
pub tcpi_snd_mss: u32,
pub tcpi_rcv_mss: u32,
pub tcpi_unacked: u32,
pub tcpi_sacked: u32,
pub tcpi_lost: u32,
pub tcpi_retrans: u32,
pub tcpi_fackets: u32,
pub tcpi_last_data_sent: u32,
pub tcpi_last_ack_sent: u32,
pub tcpi_last_data_recv: u32,
pub tcpi_last_ack_recv: u32,
pub tcpi_pmtu: u32,
pub tcpi_rcv_ssthresh: u32,
pub tcpi_rtt: u32,
pub tcpi_rttvar: u32,
pub tcpi_snd_ssthresh: u32,
pub tcpi_snd_cwnd: u32,
pub tcpi_advmss: u32,
pub tcpi_reordering: u32,
pub tcpi_rcv_rtt: u32,
pub tcpi_rcv_space: u32,
pub tcpi_total_retrans: u32,
pub tcpi_pacing_rate: u64,
pub tcpi_max_pacing_rate: u64,
pub tcpi_bytes_acked: u64,
pub tcpi_bytes_received: u64,
pub tcpi_segs_out: u32,
pub tcpi_segs_in: u32,
pub tcpi_notsent_bytes: u32,
pub tcpi_min_rtt: u32,
pub tcpi_data_segs_in: u32,
pub tcpi_data_segs_out: u32,
pub tcpi_delivery_rate: u64,
pub tcpi_busy_time: u64,
pub tcpi_rwnd_limited: u64,
pub tcpi_sndbuf_limited: u64,
pub tcpi_delivered: u32,
pub tcpi_delivered_ce: u32,
pub tcpi_bytes_sent: u64,
pub tcpi_bytes_retrans: u64,
pub tcpi_dsack_dups: u32,
pub tcpi_reord_seen: u32,
pub tcpi_rcv_ooopack: u32,
pub tcpi_snd_wnd: u32,
pub tcpi_rcv_wnd: u32,
}
impl TCP_INFO {
pub unsafe fn new() -> Self {
mem::zeroed()
}
#[cfg(unix)]
pub fn len() -> socklen_t {
mem::size_of::<Self>() as socklen_t
}
#[cfg(windows)]
pub fn len() -> usize {
mem::size_of::<Self>()
}
}
#[cfg(target_os = "linux")]
fn set_opt<T: Copy>(sock: c_int, opt: c_int, val: c_int, payload: T) -> io::Result<()> {
unsafe {
let payload = &payload as *const T as *const c_void;
cvt_linux_error(libc::setsockopt(
sock,
opt,
val,
payload as *const _,
mem::size_of::<T>() as socklen_t,
))?;
Ok(())
}
}
#[cfg(target_os = "linux")]
fn get_opt<T>(
sock: c_int,
opt: c_int,
val: c_int,
payload: &mut T,
size: &mut socklen_t,
) -> io::Result<()> {
unsafe {
let payload = payload as *mut T as *mut c_void;
cvt_linux_error(libc::getsockopt(sock, opt, val, payload as *mut _, size))?;
Ok(())
}
}
#[cfg(target_os = "linux")]
fn get_opt_sized<T>(sock: c_int, opt: c_int, val: c_int) -> io::Result<T> {
let mut payload = mem::MaybeUninit::zeroed();
let expected_size = mem::size_of::<T>() as socklen_t;
let mut size = expected_size;
get_opt(sock, opt, val, &mut payload, &mut size)?;
if size != expected_size {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"get_opt size mismatch",
));
}
let payload = unsafe { payload.assume_init() };
Ok(payload)
}
#[cfg(target_os = "linux")]
fn cvt_linux_error(t: i32) -> io::Result<i32> {
if t == -1 {
Err(io::Error::last_os_error())
} else {
Ok(t)
}
}
#[cfg(target_os = "linux")]
fn ip_bind_addr_no_port(fd: RawFd, val: bool) -> io::Result<()> {
set_opt(
fd,
libc::IPPROTO_IP,
libc::IP_BIND_ADDRESS_NO_PORT,
val as c_int,
)
}
#[cfg(all(unix, not(target_os = "linux")))]
fn ip_bind_addr_no_port(_fd: RawFd, _val: bool) -> io::Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
fn ip_local_port_range(fd: RawFd, low: u16, high: u16) -> io::Result<()> {
const IP_LOCAL_PORT_RANGE: i32 = 51;
let range: u32 = (low as u32) | ((high as u32) << 16);
let result = set_opt(fd, libc::IPPROTO_IP, IP_LOCAL_PORT_RANGE, range as c_int);
match result {
Err(e) if e.raw_os_error() != Some(libc::ENOPROTOOPT) => Err(e),
_ => Ok(()), }
}
#[cfg(all(unix, not(target_os = "linux")))]
fn ip_local_port_range(_fd: RawFd, _low: u16, _high: u16) -> io::Result<()> {
Ok(())
}
#[cfg(windows)]
fn ip_local_port_range(_fd: RawSocket, _low: u16, _high: u16) -> io::Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
fn set_so_keepalive(fd: RawFd, val: bool) -> io::Result<()> {
set_opt(fd, libc::SOL_SOCKET, libc::SO_KEEPALIVE, val as c_int)
}
#[cfg(target_os = "linux")]
fn set_so_keepalive_idle(fd: RawFd, val: Duration) -> io::Result<()> {
set_opt(
fd,
libc::IPPROTO_TCP,
libc::TCP_KEEPIDLE,
val.as_secs() as c_int, )
}
#[cfg(target_os = "linux")]
fn set_so_keepalive_interval(fd: RawFd, val: Duration) -> io::Result<()> {
set_opt(
fd,
libc::IPPROTO_TCP,
libc::TCP_KEEPINTVL,
val.as_secs() as c_int, )
}
#[cfg(target_os = "linux")]
fn set_so_keepalive_count(fd: RawFd, val: usize) -> io::Result<()> {
set_opt(fd, libc::IPPROTO_TCP, libc::TCP_KEEPCNT, val as c_int)
}
#[cfg(target_os = "linux")]
fn set_keepalive(fd: RawFd, ka: &TcpKeepalive) -> io::Result<()> {
set_so_keepalive(fd, true)?;
set_so_keepalive_idle(fd, ka.idle)?;
set_so_keepalive_interval(fd, ka.interval)?;
set_so_keepalive_count(fd, ka.count)
}
#[cfg(all(unix, not(target_os = "linux")))]
fn set_keepalive(_fd: RawFd, _ka: &TcpKeepalive) -> io::Result<()> {
Ok(())
}
#[cfg(windows)]
fn set_keepalive(_sock: RawSocket, _ka: &TcpKeepalive) -> io::Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
pub fn get_tcp_info(fd: RawFd) -> io::Result<TCP_INFO> {
get_opt_sized(fd, libc::IPPROTO_TCP, libc::TCP_INFO)
}
#[cfg(all(unix, not(target_os = "linux")))]
pub fn get_tcp_info(_fd: RawFd) -> io::Result<TCP_INFO> {
Ok(unsafe { TCP_INFO::new() })
}
#[cfg(windows)]
pub fn get_tcp_info(_fd: RawSocket) -> io::Result<TCP_INFO> {
Ok(unsafe { TCP_INFO::new() })
}
#[cfg(target_os = "linux")]
pub fn set_recv_buf(fd: RawFd, val: usize) -> Result<()> {
set_opt(fd, libc::SOL_SOCKET, libc::SO_RCVBUF, val as c_int)
.or_err(ConnectError, "failed to set SO_RCVBUF")
}
#[cfg(all(unix, not(target_os = "linux")))]
pub fn set_recv_buf(_fd: RawFd, _: usize) -> Result<()> {
Ok(())
}
#[cfg(windows)]
pub fn set_recv_buf(_sock: RawSocket, _: usize) -> Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
pub fn get_recv_buf(fd: RawFd) -> io::Result<usize> {
get_opt_sized::<c_int>(fd, libc::SOL_SOCKET, libc::SO_RCVBUF).map(|v| v as usize)
}
#[cfg(all(unix, not(target_os = "linux")))]
pub fn get_recv_buf(_fd: RawFd) -> io::Result<usize> {
Ok(0)
}
#[cfg(windows)]
pub fn get_recv_buf(_sock: RawSocket) -> io::Result<usize> {
Ok(0)
}
#[cfg(target_os = "linux")]
pub fn set_tcp_fastopen_connect(fd: RawFd) -> Result<()> {
set_opt(
fd,
libc::IPPROTO_TCP,
libc::TCP_FASTOPEN_CONNECT,
1 as c_int,
)
.or_err(ConnectError, "failed to set TCP_FASTOPEN_CONNECT")
}
#[cfg(all(unix, not(target_os = "linux")))]
pub fn set_tcp_fastopen_connect(_fd: RawFd) -> Result<()> {
Ok(())
}
#[cfg(windows)]
pub fn set_tcp_fastopen_connect(_sock: RawSocket) -> Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
pub fn set_tcp_fastopen_backlog(fd: RawFd, backlog: usize) -> Result<()> {
set_opt(fd, libc::IPPROTO_TCP, libc::TCP_FASTOPEN, backlog as c_int)
.or_err(ConnectError, "failed to set TCP_FASTOPEN")
}
#[cfg(all(unix, not(target_os = "linux")))]
pub fn set_tcp_fastopen_backlog(_fd: RawFd, _backlog: usize) -> Result<()> {
Ok(())
}
#[cfg(windows)]
pub fn set_tcp_fastopen_backlog(_sock: RawSocket, _backlog: usize) -> Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
pub fn set_dscp(fd: RawFd, value: u8) -> Result<()> {
use super::socket::SocketAddr;
use pingora_error::OkOrErr;
let sock = SocketAddr::from_raw_fd(fd, false);
let addr = sock
.as_ref()
.and_then(|s| s.as_inet())
.or_err(SocketError, "failed to set dscp, invalid IP socket")?;
if addr.is_ipv6() {
set_opt(fd, libc::IPPROTO_IPV6, libc::IPV6_TCLASS, value as c_int)
.or_err(SocketError, "failed to set dscp (IPV6_TCLASS)")
} else {
set_opt(fd, libc::IPPROTO_IP, libc::IP_TOS, value as c_int)
.or_err(SocketError, "failed to set dscp (IP_TOS)")
}
}
#[cfg(all(unix, not(target_os = "linux")))]
pub fn set_dscp(_fd: RawFd, _value: u8) -> Result<()> {
Ok(())
}
#[cfg(windows)]
pub fn set_dscp(_sock: RawSocket, _value: u8) -> Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
pub fn get_socket_cookie(fd: RawFd) -> io::Result<u64> {
get_opt_sized::<c_ulonglong>(fd, libc::SOL_SOCKET, libc::SO_COOKIE)
}
#[cfg(all(unix, not(target_os = "linux")))]
pub fn get_socket_cookie(_fd: RawFd) -> io::Result<u64> {
Ok(0) }
#[cfg(target_os = "linux")]
pub fn get_original_dest(fd: RawFd) -> Result<Option<SocketAddr>> {
use super::socket;
use pingora_error::OkOrErr;
use std::net::{SocketAddrV4, SocketAddrV6};
let sock = socket::SocketAddr::from_raw_fd(fd, false);
let addr = sock
.as_ref()
.and_then(|s| s.as_inet())
.or_err(SocketError, "failed get original dest, invalid IP socket")?;
let dest = if addr.is_ipv4() {
get_opt_sized::<libc::sockaddr_in>(fd, libc::SOL_IP, libc::SO_ORIGINAL_DST).map(|addr| {
SocketAddr::V4(SocketAddrV4::new(
u32::from_be(addr.sin_addr.s_addr).into(),
u16::from_be(addr.sin_port),
))
})
} else {
get_opt_sized::<libc::sockaddr_in6>(fd, libc::SOL_IPV6, libc::IP6T_SO_ORIGINAL_DST).map(
|addr| {
SocketAddr::V6(SocketAddrV6::new(
addr.sin6_addr.s6_addr.into(),
u16::from_be(addr.sin6_port),
addr.sin6_flowinfo,
addr.sin6_scope_id,
))
},
)
};
dest.or_err(SocketError, "failed to get original dest")
.map(Some)
}
#[cfg(all(unix, not(target_os = "linux")))]
pub fn get_original_dest(_fd: RawFd) -> Result<Option<SocketAddr>> {
Ok(None)
}
#[cfg(windows)]
pub fn get_original_dest(_sock: RawSocket) -> Result<Option<SocketAddr>> {
Ok(None)
}
pub(crate) async fn connect_with<F: FnOnce(&TcpSocket) -> Result<()> + Clone>(
addr: &SocketAddr,
bind_to: Option<&BindTo>,
set_socket: F,
) -> Result<TcpStream> {
if bind_to.as_ref().map_or(false, |b| b.will_fallback()) {
let connect_result = inner_connect_with(addr, bind_to, set_socket.clone()).await;
if let Err(e) = connect_result.as_ref() {
if matches!(e.etype(), BindError) {
let mut new_bind_to = BindTo::default();
new_bind_to.addr = bind_to.as_ref().and_then(|b| b.addr);
new_bind_to.set_port_range(None).unwrap();
return inner_connect_with(addr, Some(&new_bind_to), set_socket).await;
}
}
connect_result
} else {
inner_connect_with(addr, bind_to, set_socket).await
}
}
async fn inner_connect_with<F: FnOnce(&TcpSocket) -> Result<()>>(
addr: &SocketAddr,
bind_to: Option<&BindTo>,
set_socket: F,
) -> Result<TcpStream> {
let socket = if addr.is_ipv4() {
TcpSocket::new_v4()
} else {
TcpSocket::new_v6()
}
.or_err(SocketError, "failed to create socket")?;
#[cfg(unix)]
{
ip_bind_addr_no_port(socket.as_raw_fd(), true).or_err(
SocketError,
"failed to set socket opts IP_BIND_ADDRESS_NO_PORT",
)?;
if let Some(bind_to) = bind_to {
if let Some((low, high)) = bind_to.port_range() {
ip_local_port_range(socket.as_raw_fd(), low, high)
.or_err(SocketError, "failed to set socket opts IP_LOCAL_PORT_RANGE")?;
}
if let Some(baddr) = bind_to.addr {
socket
.bind(baddr)
.or_err_with(BindError, || format!("failed to bind to socket {}", baddr))?;
}
}
}
#[cfg(windows)]
if let Some(bind_to) = bind_to {
if let Some(baddr) = bind_to.addr {
socket
.bind(baddr)
.or_err_with(BindError, || format!("failed to bind to socket {}", baddr))?;
};
};
set_socket(&socket)?;
socket
.connect(*addr)
.await
.map_err(|e| wrap_os_connect_error(e, format!("Fail to connect to {}", *addr)))
}
pub async fn connect(addr: &SocketAddr, bind_to: Option<&BindTo>) -> Result<TcpStream> {
connect_with(addr, bind_to, |_| Ok(())).await
}
#[cfg(unix)]
pub async fn connect_uds(path: &std::path::Path) -> Result<UnixStream> {
UnixStream::connect(path)
.await
.map_err(|e| wrap_os_connect_error(e, format!("Fail to connect to {}", path.display())))
}
fn wrap_os_connect_error(e: std::io::Error, context: String) -> Box<Error> {
match e.kind() {
ErrorKind::ConnectionRefused => Error::because(ConnectRefused, context, e),
ErrorKind::TimedOut => Error::because(ConnectTimedout, context, e),
ErrorKind::AddrNotAvailable => Error::because(BindError, context, e),
ErrorKind::PermissionDenied | ErrorKind::AddrInUse => {
Error::because(InternalError, context, e)
}
_ => match e.raw_os_error() {
Some(libc::ENETUNREACH | libc::EHOSTUNREACH) => {
Error::because(ConnectNoRoute, context, e)
}
_ => Error::because(ConnectError, context, e),
},
}
}
#[derive(Clone, Debug)]
pub struct TcpKeepalive {
pub idle: Duration,
pub interval: Duration,
pub count: usize,
}
impl std::fmt::Display for TcpKeepalive {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}/{:?}/{}", self.idle, self.interval, self.count)
}
}
pub fn set_tcp_keepalive(stream: &TcpStream, ka: &TcpKeepalive) -> Result<()> {
#[cfg(unix)]
let raw = stream.as_raw_fd();
#[cfg(windows)]
let raw = stream.as_raw_socket();
set_keepalive(raw, ka).or_err(ConnectError, "failed to set keepalive")
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_set_recv_buf() {
use tokio::net::TcpSocket;
let socket = TcpSocket::new_v4().unwrap();
#[cfg(unix)]
set_recv_buf(socket.as_raw_fd(), 102400).unwrap();
#[cfg(windows)]
set_recv_buf(socket.as_raw_socket(), 102400).unwrap();
#[cfg(target_os = "linux")]
{
assert_eq!(get_recv_buf(socket.as_raw_fd()).unwrap(), 102400 * 2);
}
}
#[cfg(target_os = "linux")]
#[ignore] #[tokio::test]
async fn test_set_fast_open() {
use std::time::Instant;
connect_with(&"1.1.1.1:80".parse().unwrap(), None, |socket| {
set_tcp_fastopen_connect(socket.as_raw_fd())
})
.await
.unwrap();
let start = Instant::now();
connect_with(&"1.1.1.1:80".parse().unwrap(), None, |socket| {
set_tcp_fastopen_connect(socket.as_raw_fd())
})
.await
.unwrap();
let connection_time = start.elapsed();
assert!(connection_time.as_millis() < 4);
}
}