#[cfg(all(feature = "tls", feature = "rustls"))]
compile_error!(
"`tls` and `rustls` features are mutually exclusive. You should enable only one of them"
);
use async_socks5::AddrKind;
use http::uri::Scheme;
use hyper::{
rt::{Read, Write},
Uri,
};
#[cfg(feature = "rustls")]
use hyper_rustls::HttpsConnector;
#[cfg(feature = "tls")]
use hyper_tls::HttpsConnector;
use hyper_util::rt::TokioIo;
use std::{
future::Future,
io,
pin::Pin,
task::{ready, Context, Poll},
};
use tokio::io::BufStream;
use tower_service::Service;
pub use async_socks5::Auth;
#[cfg(feature = "tls")]
pub use hyper_tls::native_tls::Error as TlsError;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("{0}")]
Socks(
#[from]
#[source]
async_socks5::Error,
),
#[error("{0}")]
Io(
#[from]
#[source]
io::Error,
),
#[error("{0}")]
Connector(
#[from]
#[source]
BoxedError,
),
#[error("Missing host")]
MissingHost,
}
pub type SocksFuture<R> = Pin<Box<dyn Future<Output = Result<R, Error>> + Send>>;
pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SocksConnector<C> {
pub proxy_addr: Uri,
pub auth: Option<Auth>,
pub connector: C,
}
impl<C> SocksConnector<C> {
#[cfg(feature = "tls")]
pub fn with_tls(self) -> Result<HttpsConnector<Self>, TlsError> {
let args = (self, hyper_tls::native_tls::TlsConnector::new()?.into());
Ok(HttpsConnector::from(args))
}
#[cfg(feature = "rustls")]
pub fn with_tls(self) -> Result<HttpsConnector<Self>, io::Error> {
let mut root_store = rusttls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs()? {
root_store
.add(cert)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
}
Ok(self.with_rustls_root_cert_store(root_store))
}
#[cfg(feature = "rustls")]
pub fn with_rustls_root_cert_store(
self,
root_store: rusttls::RootCertStore,
) -> HttpsConnector<Self> {
use rusttls::ClientConfig;
use std::sync::Arc;
let config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let config = Arc::new(config);
let args = (self, config);
HttpsConnector::from(args)
}
}
impl<C> SocksConnector<C>
where
C: Service<Uri>,
C::Response: Read + Write + Send + Unpin,
C::Error: Into<BoxedError>,
{
async fn call_async(mut self, target_addr: Uri) -> Result<C::Response, Error> {
let host = target_addr
.host()
.map(str::to_string)
.ok_or(Error::MissingHost)?;
let port =
target_addr
.port_u16()
.unwrap_or(if target_addr.scheme() == Some(&Scheme::HTTPS) {
443
} else {
80
});
let target_addr = AddrKind::Domain(host, port);
let stream = self
.connector
.call(self.proxy_addr)
.await
.map_err(Into::<BoxedError>::into)?;
let mut buf_stream = BufStream::new(TokioIo::new(stream)); let _ = async_socks5::connect(&mut buf_stream, target_addr, self.auth).await?;
Ok(buf_stream.into_inner().into_inner())
}
}
impl<C> Service<Uri> for SocksConnector<C>
where
C: Service<Uri> + Clone + Send + 'static,
C::Response: Read + Write + Send + Unpin,
C::Error: Into<BoxedError>,
C::Future: Send,
{
type Response = C::Response;
type Error = Error;
type Future = SocksFuture<C::Response>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.connector.poll_ready(cx)).map_err(Into::<BoxedError>::into)?;
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Uri) -> Self::Future {
let this = self.clone();
Box::pin(async move { this.call_async(req).await })
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use http_body_util::Empty;
use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
rt::TokioExecutor,
};
const PROXY_ADDR: &str = "socks5://127.0.0.1:1080";
const PROXY_USERNAME: &str = "hyper";
const PROXY_PASSWORD: &str = "proxy";
const HTTP_ADDR: &str = "http://google.com";
const HTTPS_ADDR: &str = "https://google.com";
struct Tester {
uri: Uri,
auth: Option<Auth>,
swap_connector: bool,
}
impl Tester {
fn uri(uri: Uri) -> Tester {
Self {
uri,
auth: None,
swap_connector: false,
}
}
fn http() -> Self {
Self::uri(Uri::from_static(HTTP_ADDR))
}
fn https() -> Self {
Self::uri(Uri::from_static(HTTPS_ADDR))
}
fn with_auth(mut self) -> Self {
self.auth = Some(Auth {
username: PROXY_USERNAME.to_string(),
password: PROXY_PASSWORD.to_string(),
});
self
}
fn swap_connector(mut self) -> Self {
self.swap_connector = true;
self
}
async fn test(self) {
let mut connector = HttpConnector::new();
connector.enforce_http(false);
let socks = SocksConnector {
proxy_addr: Uri::from_static(PROXY_ADDR),
auth: self.auth,
connector,
};
let fut = if (self.uri.scheme() == Some(&Scheme::HTTP)) ^ self.swap_connector {
Client::builder(TokioExecutor::new())
.build::<_, Empty<Bytes>>(socks)
.get(self.uri)
} else {
Client::builder(TokioExecutor::new())
.build::<_, Empty<Bytes>>(socks.with_tls().unwrap())
.get(self.uri)
};
let _ = fut.await.unwrap();
}
}
#[tokio::test]
async fn http_no_auth() {
Tester::http().test().await
}
#[tokio::test]
async fn https_no_auth() {
Tester::https().test().await
}
#[tokio::test]
async fn http_auth() {
Tester::http().with_auth().test().await
}
#[tokio::test]
async fn https_auth() {
Tester::https().with_auth().test().await
}
#[tokio::test]
async fn http_no_auth_swap() {
Tester::http().swap_connector().test().await
}
#[should_panic = "IncompleteMessage"]
#[tokio::test]
async fn https_no_auth_swap() {
Tester::https().swap_connector().test().await
}
#[tokio::test]
async fn http_auth_swap() {
Tester::http().with_auth().swap_connector().test().await
}
#[should_panic = "IncompleteMessage"]
#[tokio::test]
async fn https_auth_swap() {
Tester::https().with_auth().swap_connector().test().await
}
}