pingora_core/protocols/
raw_connect.rsuse super::http::v1::client::HttpSession;
use super::http::v1::common::*;
use super::Stream;
use bytes::{BufMut, BytesMut};
use http::request::Parts as ReqHeader;
use http::Version;
use pingora_error::{Error, ErrorType::*, OrErr, Result};
use pingora_http::ResponseHeader;
use tokio::io::AsyncWriteExt;
pub async fn connect(stream: Stream, request_header: &ReqHeader) -> Result<(Stream, ProxyDigest)> {
let mut http = HttpSession::new(stream);
let to_wire = http_req_header_to_wire_auth_form(request_header);
http.underlying_stream
.write_all(to_wire.as_ref())
.await
.or_err(WriteError, "while writing request headers")?;
http.underlying_stream
.flush()
.await
.or_err(WriteError, "while flushing request headers")?;
let resp_header = http.read_resp_header_parts().await?;
Ok((
http.underlying_stream,
validate_connect_response(resp_header)?,
))
}
pub fn generate_connect_header<'a, H, S>(
host: &str,
port: u16,
headers: H,
) -> Result<Box<ReqHeader>>
where
S: AsRef<[u8]>,
H: Iterator<Item = (S, &'a Vec<u8>)>,
{
let authority = if host.parse::<std::net::Ipv6Addr>().is_ok() {
format!("[{host}]:{port}")
} else {
format!("{host}:{port}")
};
let req = http::request::Builder::new()
.version(http::Version::HTTP_11)
.method(http::method::Method::CONNECT)
.uri(format!("https://{authority}/")) .header(http::header::HOST, &authority);
let (mut req, _) = match req.body(()) {
Ok(r) => r.into_parts(),
Err(e) => {
return Err(e).or_err(InvalidHTTPHeader, "Invalid CONNECT request");
}
};
for (k, v) in headers {
let header_name = http::header::HeaderName::from_bytes(k.as_ref())
.or_err(InvalidHTTPHeader, "Invalid CONNECT request")?;
let header_value = http::header::HeaderValue::from_bytes(v.as_slice())
.or_err(InvalidHTTPHeader, "Invalid CONNECT request")?;
req.headers.insert(header_name, header_value);
}
Ok(Box::new(req))
}
#[derive(Debug)]
pub struct ProxyDigest {
pub response: Box<ResponseHeader>,
}
impl ProxyDigest {
pub fn new(response: Box<ResponseHeader>) -> Self {
ProxyDigest { response }
}
}
#[derive(Debug)]
pub struct ConnectProxyError {
pub response: Box<ResponseHeader>,
}
impl ConnectProxyError {
pub fn boxed_new(response: Box<ResponseHeader>) -> Box<Self> {
Box::new(ConnectProxyError { response })
}
}
impl std::fmt::Display for ConnectProxyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
const PROXY_STATUS: &str = "proxy-status";
let reason = self
.response
.headers
.get(PROXY_STATUS)
.and_then(|s| s.to_str().ok())
.unwrap_or("missing proxy-status header value");
write!(
f,
"Failed CONNECT Response: status {}, proxy-status {reason}",
&self.response.status
)
}
}
impl std::error::Error for ConnectProxyError {}
#[inline]
fn http_req_header_to_wire_auth_form(req: &ReqHeader) -> BytesMut {
let mut buf = BytesMut::with_capacity(512);
let method = req.method.as_str().as_bytes();
buf.put_slice(method);
buf.put_u8(b' ');
if let Some(path) = req.uri.authority() {
buf.put_slice(path.as_str().as_bytes());
}
buf.put_u8(b' ');
let version = match req.version {
Version::HTTP_09 => "HTTP/0.9",
Version::HTTP_10 => "HTTP/1.0",
Version::HTTP_11 => "HTTP/1.1",
_ => "HTTP/0.9",
};
buf.put_slice(version.as_bytes());
buf.put_slice(CRLF);
let headers = &req.headers;
for (key, value) in headers.iter() {
buf.put_slice(key.as_ref());
buf.put_slice(HEADER_KV_DELIMITER);
buf.put_slice(value.as_ref());
buf.put_slice(CRLF);
}
buf.put_slice(CRLF);
buf
}
#[inline]
fn validate_connect_response(resp: Box<ResponseHeader>) -> Result<ProxyDigest> {
if !resp.status.is_success() {
return Error::e_because(
ConnectProxyFailure,
"None 2xx code",
ConnectProxyError::boxed_new(resp),
);
}
if resp.headers.get(http::header::TRANSFER_ENCODING).is_some() {
return Error::e_because(
ConnectProxyFailure,
"Invalid Transfer-Encoding presents",
ConnectProxyError::boxed_new(resp),
);
}
Ok(ProxyDigest::new(resp))
}
#[cfg(test)]
mod test_sync {
use super::*;
use std::collections::BTreeMap;
use tokio_test::io::Builder;
#[test]
fn test_generate_connect_header() {
let mut headers = BTreeMap::new();
headers.insert(String::from("foo"), b"bar".to_vec());
let req = generate_connect_header("pingora.org", 123, headers.iter()).unwrap();
assert_eq!(req.method, http::method::Method::CONNECT);
assert_eq!(req.uri.authority().unwrap(), "pingora.org:123");
assert_eq!(req.headers.get("Host").unwrap(), "pingora.org:123");
assert_eq!(req.headers.get("foo").unwrap(), "bar");
}
#[test]
fn test_generate_connect_header_ipv6() {
let mut headers = BTreeMap::new();
headers.insert(String::from("foo"), b"bar".to_vec());
let req = generate_connect_header("::1", 123, headers.iter()).unwrap();
assert_eq!(req.method, http::method::Method::CONNECT);
assert_eq!(req.uri.authority().unwrap(), "[::1]:123");
assert_eq!(req.headers.get("Host").unwrap(), "[::1]:123");
assert_eq!(req.headers.get("foo").unwrap(), "bar");
}
#[test]
fn test_request_to_wire_auth_form() {
let new_request = http::Request::builder()
.method("CONNECT")
.uri("https://pingora.org:123/")
.header("Foo", "Bar")
.body(())
.unwrap();
let (new_request, _) = new_request.into_parts();
let wire = http_req_header_to_wire_auth_form(&new_request);
assert_eq!(
&b"CONNECT pingora.org:123 HTTP/1.1\r\nfoo: Bar\r\n\r\n"[..],
&wire
);
}
#[test]
fn test_validate_connect_response() {
let resp = ResponseHeader::build(200, None).unwrap();
validate_connect_response(Box::new(resp)).unwrap();
let resp = ResponseHeader::build(404, None).unwrap();
assert!(validate_connect_response(Box::new(resp)).is_err());
let mut resp = ResponseHeader::build(200, None).unwrap();
resp.append_header("content-length", 0).unwrap();
assert!(validate_connect_response(Box::new(resp)).is_ok());
let mut resp = ResponseHeader::build(200, None).unwrap();
resp.append_header("transfer-encoding", 0).unwrap();
assert!(validate_connect_response(Box::new(resp)).is_err());
}
#[tokio::test]
async fn test_connect_write_request() {
let wire = b"CONNECT pingora.org:123 HTTP/1.1\r\nhost: pingora.org:123\r\n\r\n";
let mock_io = Box::new(Builder::new().write(wire).build());
let headers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
let req = generate_connect_header("pingora.org", 123, headers.iter()).unwrap();
assert!(connect(mock_io, &req).await.is_err());
let to_wire = b"CONNECT pingora.org:123 HTTP/1.1\r\nhost: pingora.org:123\r\n\r\n";
let from_wire = b"HTTP/1.1 200 OK\r\n\r\n";
let mock_io = Box::new(Builder::new().write(to_wire).read(from_wire).build());
let req = generate_connect_header("pingora.org", 123, headers.iter()).unwrap();
let result = connect(mock_io, &req).await;
assert!(result.is_ok());
}
}