use async_trait::async_trait;
use log::debug;
use pingora_error::{Context, Error, ErrorType::*, OrErr, Result};
use rand::seq::SliceRandom;
use std::net::SocketAddr as InetSocketAddr;
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
#[cfg(windows)]
use std::os::windows::io::AsRawSocket;
#[cfg(unix)]
use crate::protocols::l4::ext::connect_uds;
use crate::protocols::l4::ext::{
connect_with as tcp_connect, set_dscp, set_recv_buf, set_tcp_fastopen_connect,
};
use crate::protocols::l4::socket::SocketAddr;
use crate::protocols::l4::stream::Stream;
use crate::protocols::{GetSocketDigest, SocketDigest};
use crate::upstreams::peer::Peer;
#[async_trait]
pub trait Connect: std::fmt::Debug {
async fn connect(&self, addr: &SocketAddr) -> Result<Stream>;
}
#[derive(Clone, Debug, Default)]
pub struct BindTo {
pub addr: Option<InetSocketAddr>,
port_range: Option<(u16, u16)>,
fallback: bool,
}
impl BindTo {
pub fn set_port_range(&mut self, range: Option<(u16, u16)>) -> Result<()> {
if range.is_none() && self.port_range.is_none() {
return Ok(());
}
match range {
None | Some((0, 0)) => self.port_range = Some((0, 0)),
Some((low, high)) if low > 0 && low < high => {
self.port_range = Some((low, high));
}
_ => return Error::e_explain(SocketError, "invalid port range: {range}"),
}
Ok(())
}
pub fn set_fallback(&mut self, fallback: bool) {
self.fallback = fallback
}
pub fn port_range(&self) -> Option<(u16, u16)> {
self.port_range
}
pub fn will_fallback(&self) -> bool {
self.fallback && self.port_range.is_some()
}
}
pub(crate) async fn connect<P>(peer: &P, bind_to: Option<BindTo>) -> Result<Stream>
where
P: Peer + Send + Sync,
{
if peer.get_proxy().is_some() {
return proxy_connect(peer)
.await
.err_context(|| format!("Fail to establish CONNECT proxy: {}", peer));
}
let peer_addr = peer.address();
let mut stream: Stream =
if let Some(custom_l4) = peer.get_peer_options().and_then(|o| o.custom_l4.as_ref()) {
custom_l4.connect(peer_addr).await?
} else {
match peer_addr {
SocketAddr::Inet(addr) => {
let connect_future = tcp_connect(addr, bind_to.as_ref(), |socket| {
#[cfg(unix)]
let raw = socket.as_raw_fd();
#[cfg(windows)]
let raw = socket.as_raw_socket();
if peer.tcp_fast_open() {
set_tcp_fastopen_connect(raw)?;
}
if let Some(recv_buf) = peer.tcp_recv_buf() {
debug!("Setting recv buf size");
set_recv_buf(raw, recv_buf)?;
}
if let Some(dscp) = peer.dscp() {
debug!("Setting dscp");
set_dscp(raw, dscp)?;
}
Ok(())
});
let conn_res = match peer.connection_timeout() {
Some(t) => pingora_timeout::timeout(t, connect_future)
.await
.explain_err(ConnectTimedout, |_| {
format!("timeout {t:?} connecting to server {peer}")
})?,
None => connect_future.await,
};
match conn_res {
Ok(socket) => {
debug!("connected to new server: {}", peer.address());
Ok(socket.into())
}
Err(e) => {
let c = format!("Fail to connect to {peer}");
match e.etype() {
SocketError | BindError => Error::e_because(InternalError, c, e),
_ => Err(e.more_context(c)),
}
}
}
}
#[cfg(unix)]
SocketAddr::Unix(addr) => {
let connect_future = connect_uds(
addr.as_pathname()
.expect("non-pathname unix sockets not supported as peer"),
);
let conn_res = match peer.connection_timeout() {
Some(t) => pingora_timeout::timeout(t, connect_future)
.await
.explain_err(ConnectTimedout, |_| {
format!("timeout {t:?} connecting to server {peer}")
})?,
None => connect_future.await,
};
match conn_res {
Ok(socket) => {
debug!("connected to new server: {}", peer.address());
Ok(socket.into())
}
Err(e) => {
let c = format!("Fail to connect to {peer}");
match e.etype() {
SocketError | BindError => Error::e_because(InternalError, c, e),
_ => Err(e.more_context(c)),
}
}
}
}
}?
};
let tracer = peer.get_tracer();
if let Some(t) = tracer {
t.0.on_connected();
stream.tracer = Some(t);
}
if let Some(ka) = peer.tcp_keepalive() {
stream.set_keepalive(ka)?;
}
stream.set_nodelay()?;
#[cfg(unix)]
let digest = SocketDigest::from_raw_fd(stream.as_raw_fd());
#[cfg(windows)]
let digest = SocketDigest::from_raw_socket(stream.as_raw_socket());
digest
.peer_addr
.set(Some(peer_addr.clone()))
.expect("newly created OnceCell must be empty");
stream.set_socket_digest(digest);
Ok(stream)
}
pub(crate) fn bind_to_random<P: Peer>(
peer: &P,
v4_list: &[InetSocketAddr],
v6_list: &[InetSocketAddr],
) -> Option<BindTo> {
fn bind_to_ips(ips: &[InetSocketAddr]) -> Option<InetSocketAddr> {
match ips.len() {
0 => None,
1 => Some(ips[0]),
_ => {
ips.choose(&mut rand::thread_rng()).copied()
}
}
}
let mut bind_to = peer.get_peer_options().and_then(|o| o.bind_to.clone());
if bind_to.as_ref().map(|b| b.addr).is_some() {
return bind_to;
}
let addr = match peer.address() {
SocketAddr::Inet(sockaddr) => match sockaddr {
InetSocketAddr::V4(_) => bind_to_ips(v4_list),
InetSocketAddr::V6(_) => bind_to_ips(v6_list),
},
#[cfg(unix)]
SocketAddr::Unix(_) => None,
};
if addr.is_some() {
if let Some(bind_to) = bind_to.as_mut() {
bind_to.addr = addr;
} else {
bind_to = Some(BindTo {
addr,
..Default::default()
});
}
}
bind_to
}
use crate::protocols::raw_connect;
#[cfg(unix)]
async fn proxy_connect<P: Peer>(peer: &P) -> Result<Stream> {
let proxy = peer.get_proxy().unwrap();
let options = peer.get_peer_options().unwrap();
let mut headers = proxy
.headers
.iter()
.chain(options.extra_proxy_headers.iter());
let stream: Box<Stream> = Box::new(
connect_uds(&proxy.next_hop)
.await
.or_err_with(ConnectError, || {
format!("CONNECT proxy connect() error to {:?}", &proxy.next_hop)
})?
.into(),
);
let req_header = raw_connect::generate_connect_header(&proxy.host, proxy.port, &mut headers)?;
let fut = raw_connect::connect(stream, &req_header);
let (mut stream, digest) = match peer.connection_timeout() {
Some(t) => pingora_timeout::timeout(t, fut)
.await
.explain_err(ConnectTimedout, |_| "establishing CONNECT proxy")?,
None => fut.await,
}
.map_err(|mut e| {
e.retry.decide_reuse(false);
e
})?;
debug!("CONNECT proxy established: {:?}", proxy);
stream.set_proxy_digest(digest);
let stream = stream.into_any().downcast::<Stream>().unwrap(); Ok(*stream)
}
#[cfg(windows)]
async fn proxy_connect<P: Peer>(peer: &P) -> Result<Stream> {
panic!("peer proxy not supported on windows")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::upstreams::peer::{BasicPeer, HttpPeer, Proxy};
use std::collections::BTreeMap;
use std::path::PathBuf;
use tokio::io::AsyncWriteExt;
#[cfg(unix)]
use tokio::net::UnixListener;
#[tokio::test]
async fn test_conn_error_refused() {
let peer = BasicPeer::new("127.0.0.1:79"); let new_session = connect(&peer, None).await;
assert_eq!(new_session.unwrap_err().etype(), &ConnectRefused)
}
#[ignore]
#[tokio::test]
async fn test_conn_error_no_route() {
let peer = BasicPeer::new("[::3]:79"); let new_session = connect(&peer, None).await;
assert_eq!(new_session.unwrap_err().etype(), &ConnectNoRoute)
}
#[tokio::test]
async fn test_conn_error_addr_not_avail() {
let peer = HttpPeer::new("127.0.0.1:121".to_string(), false, "".to_string());
let addr = "192.0.2.2:0".parse().ok();
let bind_to = BindTo {
addr,
..Default::default()
};
let new_session = connect(&peer, Some(bind_to)).await;
assert_eq!(new_session.unwrap_err().etype(), &InternalError)
}
#[tokio::test]
async fn test_conn_error_other() {
let peer = HttpPeer::new("240.0.0.1:80".to_string(), false, "".to_string()); let addr = "127.0.0.1:0".parse().ok();
let bind_to = BindTo {
addr,
..Default::default()
};
let new_session = connect(&peer, Some(bind_to)).await;
let error = new_session.unwrap_err();
assert!(error.etype() == &ConnectError || error.etype() == &ConnectTimedout)
}
#[tokio::test]
async fn test_conn_timeout() {
let mut peer = BasicPeer::new("192.0.2.1:79");
peer.options.connection_timeout = Some(std::time::Duration::from_millis(1)); let new_session = connect(&peer, None).await;
assert_eq!(new_session.unwrap_err().etype(), &ConnectTimedout)
}
#[tokio::test]
async fn test_custom_connect() {
#[derive(Debug)]
struct MyL4;
#[async_trait]
impl Connect for MyL4 {
async fn connect(&self, _addr: &SocketAddr) -> Result<Stream> {
tokio::net::TcpStream::connect("1.1.1.1:80")
.await
.map(|s| s.into())
.or_fail()
}
}
let mut peer = BasicPeer::new("1.1.1.1:79");
peer.options.custom_l4 = Some(std::sync::Arc::new(MyL4 {}));
let new_session = connect(&peer, None).await;
assert!(new_session.is_ok());
}
#[cfg(unix)]
#[tokio::test]
async fn test_connect_proxy_fail() {
let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string());
let mut path = PathBuf::new();
path.push("/tmp/123");
peer.proxy = Some(Proxy {
next_hop: path.into(),
host: "1.1.1.1".into(),
port: 80,
headers: BTreeMap::new(),
});
let new_session = connect(&peer, None).await;
let e = new_session.unwrap_err();
assert_eq!(e.etype(), &ConnectError);
assert!(!e.retry());
}
#[cfg(unix)]
const MOCK_UDS_PATH: &str = "/tmp/test_unix_connect_proxy.sock";
#[cfg(unix)]
async fn mock_connect_server() {
let _ = std::fs::remove_file(MOCK_UDS_PATH);
let listener = UnixListener::bind(MOCK_UDS_PATH).unwrap();
if let Ok((mut stream, _addr)) = listener.accept().await {
stream.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
let _ = std::fs::remove_file(MOCK_UDS_PATH);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_connect_proxy_work() {
tokio::spawn(async {
mock_connect_server().await;
});
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string());
let mut path = PathBuf::new();
path.push(MOCK_UDS_PATH);
peer.proxy = Some(Proxy {
next_hop: path.into(),
host: "1.1.1.1".into(),
port: 80,
headers: BTreeMap::new(),
});
let new_session = connect(&peer, None).await;
assert!(new_session.is_ok());
}
#[cfg(unix)]
const MOCK_BAD_UDS_PATH: &str = "/tmp/test_unix_bad_connect_proxy.sock";
#[cfg(unix)]
async fn mock_connect_bad_server() {
let _ = std::fs::remove_file(MOCK_BAD_UDS_PATH);
let listener = UnixListener::bind(MOCK_BAD_UDS_PATH).unwrap();
if let Ok((mut stream, _addr)) = listener.accept().await {
stream.shutdown().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
let _ = std::fs::remove_file(MOCK_BAD_UDS_PATH);
}
#[cfg(unix)]
#[tokio::test(flavor = "multi_thread")]
async fn test_connect_proxy_conn_closed() {
tokio::spawn(async {
mock_connect_bad_server().await;
});
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string());
let mut path = PathBuf::new();
path.push(MOCK_BAD_UDS_PATH);
peer.proxy = Some(Proxy {
next_hop: path.into(),
host: "1.1.1.1".into(),
port: 80,
headers: BTreeMap::new(),
});
let new_session = connect(&peer, None).await;
let err = new_session.unwrap_err();
assert_eq!(err.etype(), &ConnectionClosed);
assert!(!err.retry());
}
#[cfg(target_os = "linux")]
#[tokio::test(flavor = "multi_thread")]
async fn test_bind_to_port_range_on_connect() {
fn get_ip_local_port_range() -> (u16, u16) {
let path = "/proc/sys/net/ipv4/ip_local_port_range";
let file = std::fs::read_to_string(path).unwrap();
let mut parts = file.split_whitespace();
(
parts.next().unwrap().parse().unwrap(),
parts.next().unwrap().parse().unwrap(),
)
}
async fn mock_inet_connect_server() {
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:10020").await.unwrap();
if let Ok((mut stream, _addr)) = listener.accept().await {
stream.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
fn in_port_range(session: Stream, lower: u16, upper: u16) -> bool {
let digest = session.get_socket_digest();
let local_addr = digest
.as_ref()
.and_then(|s| s.local_addr())
.unwrap()
.as_inet()
.unwrap();
local_addr.port() >= lower && local_addr.port() <= upper
}
tokio::spawn(async {
mock_inet_connect_server().await;
});
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let (low, _) = get_ip_local_port_range();
let high = low + 1;
let peer = HttpPeer::new("127.0.0.1:10020".to_string(), false, "".to_string());
let mut bind_to = BindTo {
addr: "127.0.0.1:0".parse().ok(),
..Default::default()
};
bind_to.set_port_range(Some((low, high))).unwrap();
let session1 = connect(&peer, Some(bind_to.clone())).await.unwrap();
assert!(in_port_range(session1, low, high));
let session2 = connect(&peer, Some(bind_to.clone())).await.unwrap();
assert!(in_port_range(session2, low, high));
let session3 = connect(&peer, Some(bind_to.clone())).await.unwrap();
assert!(in_port_range(session3, low, high));
let err = connect(&peer, Some(bind_to.clone())).await.unwrap_err();
assert_eq!(err.etype(), &InternalError);
bind_to.set_fallback(true);
let session4 = connect(&peer, Some(bind_to.clone())).await.unwrap();
assert!(!in_port_range(session4, low, high));
let low = low + 2;
let high = low + 1;
let mut bind_to = BindTo::default();
bind_to.set_port_range(Some((low, high))).unwrap();
let session5 = connect(&peer, Some(bind_to.clone())).await.unwrap();
assert!(in_port_range(session5, low, high));
}
#[test]
fn test_bind_to_port_ranges() {
let addr = "127.0.0.1:0".parse().ok();
let mut bind_to = BindTo {
addr,
..Default::default()
};
bind_to.set_port_range(None).unwrap();
assert!(bind_to.port_range.is_none());
bind_to.set_port_range(Some((0, 0))).unwrap();
assert_eq!(bind_to.port_range, Some((0, 0)));
bind_to.set_port_range(None).unwrap();
assert_eq!(bind_to.port_range, Some((0, 0)));
assert!(bind_to.set_port_range(Some((2000, 1000))).is_err());
bind_to.set_port_range(Some((1000, 2000))).unwrap();
assert_eq!(bind_to.port_range, Some((1000, 2000)));
}
}