hyper_tls/
client.rs

1use hyper::{
2    rt::{Read, Write},
3    Uri,
4};
5use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioIo};
6use std::fmt;
7use std::future::Future;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use tokio_native_tls::TlsConnector;
11use tower_service::Service;
12
13use crate::stream::MaybeHttpsStream;
14
15type BoxError = Box<dyn std::error::Error + Send + Sync>;
16
17/// A Connector for the `https` scheme.
18#[derive(Clone)]
19pub struct HttpsConnector<T> {
20    force_https: bool,
21    http: T,
22    tls: TlsConnector,
23}
24
25impl HttpsConnector<HttpConnector> {
26    /// Construct a new `HttpsConnector`.
27    ///
28    /// This uses hyper's default `HttpConnector`, and default `TlsConnector`.
29    /// If you wish to use something besides the defaults, use `From::from`.
30    ///
31    /// # Note
32    ///
33    /// By default this connector will use plain HTTP if the URL provided uses
34    /// the HTTP scheme (eg: <http://example.com/>).
35    ///
36    /// If you would like to force the use of HTTPS then call `https_only(true)`
37    /// on the returned connector.
38    ///
39    /// # Panics
40    ///
41    /// This will panic if the underlying TLS context could not be created.
42    ///
43    /// To handle that error yourself, you can use the `HttpsConnector::from`
44    /// constructor after trying to make a `TlsConnector`.
45    #[must_use]
46    pub fn new() -> Self {
47        native_tls::TlsConnector::new().map_or_else(
48            |e| panic!("HttpsConnector::new() failure: {}", e),
49            |tls| HttpsConnector::new_(tls.into()),
50        )
51    }
52
53    fn new_(tls: TlsConnector) -> Self {
54        let mut http = HttpConnector::new();
55        http.enforce_http(false);
56        HttpsConnector::from((http, tls))
57    }
58}
59
60impl<T: Default> Default for HttpsConnector<T> {
61    fn default() -> Self {
62        Self::new_with_connector(Default::default())
63    }
64}
65
66impl<T> HttpsConnector<T> {
67    /// Force the use of HTTPS when connecting.
68    ///
69    /// If a URL is not `https` when connecting, an error is returned.
70    pub fn https_only(&mut self, enable: bool) {
71        self.force_https = enable;
72    }
73
74    /// With connector constructor
75    ///
76    /// # Panics
77    ///
78    /// This will panic if the underlying TLS context could not be created.
79    ///
80    /// To handle that error yourself, you can use the `HttpsConnector::from`
81    /// constructor after trying to make a `TlsConnector`.
82    pub fn new_with_connector(http: T) -> Self {
83        native_tls::TlsConnector::new().map_or_else(
84            |e| {
85                panic!(
86                    "HttpsConnector::new_with_connector(<connector>) failure: {}",
87                    e
88                )
89            },
90            |tls| HttpsConnector::from((http, tls.into())),
91        )
92    }
93}
94
95impl<T> From<(T, TlsConnector)> for HttpsConnector<T> {
96    fn from(args: (T, TlsConnector)) -> HttpsConnector<T> {
97        HttpsConnector {
98            force_https: false,
99            http: args.0,
100            tls: args.1,
101        }
102    }
103}
104
105impl<T: fmt::Debug> fmt::Debug for HttpsConnector<T> {
106    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
107        f.debug_struct("HttpsConnector")
108            .field("force_https", &self.force_https)
109            .field("http", &self.http)
110            .finish_non_exhaustive()
111    }
112}
113
114impl<T> Service<Uri> for HttpsConnector<T>
115where
116    T: Service<Uri>,
117    T::Response: Read + Write + Send + Unpin,
118    T::Future: Send + 'static,
119    T::Error: Into<BoxError>,
120{
121    type Response = MaybeHttpsStream<T::Response>;
122    type Error = BoxError;
123    type Future = HttpsConnecting<T::Response>;
124
125    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
126        match self.http.poll_ready(cx) {
127            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
128            Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
129            Poll::Pending => Poll::Pending,
130        }
131    }
132
133    fn call(&mut self, dst: Uri) -> Self::Future {
134        let is_https = dst.scheme_str() == Some("https");
135        // Early abort if HTTPS is forced but can't be used
136        if !is_https && self.force_https {
137            return err(ForceHttpsButUriNotHttps.into());
138        }
139
140        let host = dst
141            .host()
142            .unwrap_or("")
143            .trim_matches(|c| c == '[' || c == ']')
144            .to_owned();
145        let connecting = self.http.call(dst);
146
147        let tls_connector = self.tls.clone();
148
149        let fut = async move {
150            let tcp = connecting.await.map_err(Into::into)?;
151
152            let maybe = if is_https {
153                let stream = TokioIo::new(tcp);
154
155                let tls = TokioIo::new(tls_connector.connect(&host, stream).await?);
156                MaybeHttpsStream::Https(tls)
157            } else {
158                MaybeHttpsStream::Http(tcp)
159            };
160            Ok(maybe)
161        };
162        HttpsConnecting(Box::pin(fut))
163    }
164}
165
166fn err<T>(e: BoxError) -> HttpsConnecting<T> {
167    HttpsConnecting(Box::pin(async { Err(e) }))
168}
169
170type BoxedFut<T> = Pin<Box<dyn Future<Output = Result<MaybeHttpsStream<T>, BoxError>> + Send>>;
171
172/// A Future representing work to connect to a URL, and a TLS handshake.
173pub struct HttpsConnecting<T>(BoxedFut<T>);
174
175impl<T: Read + Write + Unpin> Future for HttpsConnecting<T> {
176    type Output = Result<MaybeHttpsStream<T>, BoxError>;
177
178    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179        Pin::new(&mut self.0).poll(cx)
180    }
181}
182
183impl<T> fmt::Debug for HttpsConnecting<T> {
184    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
185        f.pad("HttpsConnecting")
186    }
187}
188
189// ===== Custom Errors =====
190
191#[derive(Debug)]
192struct ForceHttpsButUriNotHttps;
193
194impl fmt::Display for ForceHttpsButUriNotHttps {
195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196        f.write_str("https required but URI was not https")
197    }
198}
199
200impl std::error::Error for ForceHttpsButUriNotHttps {}