1#[cfg(all(feature = "tls", feature = "rustls"))]
33compile_error!(
34 "`tls` and `rustls` features are mutually exclusive. You should enable only one of them"
35);
36
37use async_socks5::AddrKind;
38use http::uri::Scheme;
39use hyper::{
40 rt::{Read, Write},
41 Uri,
42};
43#[cfg(feature = "rustls")]
44use hyper_rustls::HttpsConnector;
45#[cfg(feature = "tls")]
46use hyper_tls::HttpsConnector;
47use hyper_util::rt::TokioIo;
48use std::{
49 future::Future,
50 io,
51 pin::Pin,
52 task::{ready, Context, Poll},
53};
54use tokio::io::BufStream;
55use tower_service::Service;
56
57pub use async_socks5::Auth;
58
59#[cfg(feature = "tls")]
60pub use hyper_tls::native_tls::Error as TlsError;
61
62#[derive(Debug, thiserror::Error)]
63pub enum Error {
64 #[error("{0}")]
65 Socks(
66 #[from]
67 #[source]
68 async_socks5::Error,
69 ),
70 #[error("{0}")]
71 Io(
72 #[from]
73 #[source]
74 io::Error,
75 ),
76 #[error("{0}")]
77 Connector(
78 #[from]
79 #[source]
80 BoxedError,
81 ),
82 #[error("Missing host")]
83 MissingHost,
84}
85
86pub type SocksFuture<R> = Pin<Box<dyn Future<Output = Result<R, Error>> + Send>>;
90
91pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
92
93#[derive(Debug, Clone, PartialEq, Eq, Hash)]
95pub struct SocksConnector<C> {
96 pub proxy_addr: Uri,
97 pub auth: Option<Auth>,
98 pub connector: C,
99}
100
101impl<C> SocksConnector<C> {
102 #[cfg(feature = "tls")]
104 pub fn with_tls(self) -> Result<HttpsConnector<Self>, TlsError> {
105 let args = (self, hyper_tls::native_tls::TlsConnector::new()?.into());
106 Ok(HttpsConnector::from(args))
107 }
108
109 #[cfg(feature = "rustls")]
111 pub fn with_tls(self) -> Result<HttpsConnector<Self>, io::Error> {
112 let mut root_store = rusttls::RootCertStore::empty();
113 for cert in rustls_native_certs::load_native_certs()? {
114 root_store
115 .add(cert)
116 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
117 }
118 Ok(self.with_rustls_root_cert_store(root_store))
119 }
120
121 #[cfg(feature = "rustls")]
123 pub fn with_rustls_root_cert_store(
124 self,
125 root_store: rusttls::RootCertStore,
126 ) -> HttpsConnector<Self> {
127 use rusttls::ClientConfig;
128 use std::sync::Arc;
129
130 let config = ClientConfig::builder()
131 .with_root_certificates(root_store)
132 .with_no_client_auth();
133
134 let config = Arc::new(config);
135
136 let args = (self, config);
137 HttpsConnector::from(args)
138 }
139}
140
141impl<C> SocksConnector<C>
142where
143 C: Service<Uri>,
144 C::Response: Read + Write + Send + Unpin,
145 C::Error: Into<BoxedError>,
146{
147 async fn call_async(mut self, target_addr: Uri) -> Result<C::Response, Error> {
148 let host = target_addr
149 .host()
150 .map(str::to_string)
151 .ok_or(Error::MissingHost)?;
152 let port =
153 target_addr
154 .port_u16()
155 .unwrap_or(if target_addr.scheme() == Some(&Scheme::HTTPS) {
156 443
157 } else {
158 80
159 });
160 let target_addr = AddrKind::Domain(host, port);
161
162 let stream = self
163 .connector
164 .call(self.proxy_addr)
165 .await
166 .map_err(Into::<BoxedError>::into)?;
167 let mut buf_stream = BufStream::new(TokioIo::new(stream)); let _ = async_socks5::connect(&mut buf_stream, target_addr, self.auth).await?;
169 Ok(buf_stream.into_inner().into_inner())
170 }
171}
172
173impl<C> Service<Uri> for SocksConnector<C>
174where
175 C: Service<Uri> + Clone + Send + 'static,
176 C::Response: Read + Write + Send + Unpin,
177 C::Error: Into<BoxedError>,
178 C::Future: Send,
179{
180 type Response = C::Response;
181 type Error = Error;
182 type Future = SocksFuture<C::Response>;
183
184 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
185 ready!(self.connector.poll_ready(cx)).map_err(Into::<BoxedError>::into)?;
186 Poll::Ready(Ok(()))
187 }
188
189 fn call(&mut self, req: Uri) -> Self::Future {
190 let this = self.clone();
191 Box::pin(async move { this.call_async(req).await })
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use bytes::Bytes;
199 use http_body_util::Empty;
200 use hyper_util::{
201 client::legacy::{connect::HttpConnector, Client},
202 rt::TokioExecutor,
203 };
204
205 const PROXY_ADDR: &str = "socks5://127.0.0.1:1080";
206 const PROXY_USERNAME: &str = "hyper";
207 const PROXY_PASSWORD: &str = "proxy";
208 const HTTP_ADDR: &str = "http://google.com";
209 const HTTPS_ADDR: &str = "https://google.com";
210
211 struct Tester {
212 uri: Uri,
213 auth: Option<Auth>,
214 swap_connector: bool,
215 }
216
217 impl Tester {
218 fn uri(uri: Uri) -> Tester {
219 Self {
220 uri,
221 auth: None,
222 swap_connector: false,
223 }
224 }
225
226 fn http() -> Self {
227 Self::uri(Uri::from_static(HTTP_ADDR))
228 }
229
230 fn https() -> Self {
231 Self::uri(Uri::from_static(HTTPS_ADDR))
232 }
233
234 fn with_auth(mut self) -> Self {
235 self.auth = Some(Auth {
236 username: PROXY_USERNAME.to_string(),
237 password: PROXY_PASSWORD.to_string(),
238 });
239 self
240 }
241
242 fn swap_connector(mut self) -> Self {
243 self.swap_connector = true;
244 self
245 }
246
247 async fn test(self) {
248 let mut connector = HttpConnector::new();
249 connector.enforce_http(false);
250 let socks = SocksConnector {
251 proxy_addr: Uri::from_static(PROXY_ADDR),
252 auth: self.auth,
253 connector,
254 };
255
256 let fut = if (self.uri.scheme() == Some(&Scheme::HTTP)) ^ self.swap_connector {
257 Client::builder(TokioExecutor::new())
258 .build::<_, Empty<Bytes>>(socks)
259 .get(self.uri)
260 } else {
261 Client::builder(TokioExecutor::new())
262 .build::<_, Empty<Bytes>>(socks.with_tls().unwrap())
263 .get(self.uri)
264 };
265 let _ = fut.await.unwrap();
266 }
267 }
268
269 #[tokio::test]
270 async fn http_no_auth() {
271 Tester::http().test().await
272 }
273
274 #[tokio::test]
275 async fn https_no_auth() {
276 Tester::https().test().await
277 }
278
279 #[tokio::test]
280 async fn http_auth() {
281 Tester::http().with_auth().test().await
282 }
283
284 #[tokio::test]
285 async fn https_auth() {
286 Tester::https().with_auth().test().await
287 }
288
289 #[tokio::test]
290 async fn http_no_auth_swap() {
291 Tester::http().swap_connector().test().await
292 }
293
294 #[should_panic = "IncompleteMessage"]
295 #[tokio::test]
296 async fn https_no_auth_swap() {
297 Tester::https().swap_connector().test().await
298 }
299
300 #[tokio::test]
301 async fn http_auth_swap() {
302 Tester::http().with_auth().swap_connector().test().await
303 }
304
305 #[should_panic = "IncompleteMessage"]
306 #[tokio::test]
307 async fn https_auth_swap() {
308 Tester::https().with_auth().swap_connector().test().await
309 }
310}