1use hyper::{
2 rt::{Read, Write},
3 Uri,
4};
5use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioIo};
6use std::fmt;
7use std::future::Future;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use tokio_native_tls::TlsConnector;
11use tower_service::Service;
12
13use crate::stream::MaybeHttpsStream;
14
15type BoxError = Box<dyn std::error::Error + Send + Sync>;
16
17#[derive(Clone)]
19pub struct HttpsConnector<T> {
20 force_https: bool,
21 http: T,
22 tls: TlsConnector,
23}
24
25impl HttpsConnector<HttpConnector> {
26 #[must_use]
46 pub fn new() -> Self {
47 native_tls::TlsConnector::new().map_or_else(
48 |e| panic!("HttpsConnector::new() failure: {}", e),
49 |tls| HttpsConnector::new_(tls.into()),
50 )
51 }
52
53 fn new_(tls: TlsConnector) -> Self {
54 let mut http = HttpConnector::new();
55 http.enforce_http(false);
56 HttpsConnector::from((http, tls))
57 }
58}
59
60impl<T: Default> Default for HttpsConnector<T> {
61 fn default() -> Self {
62 Self::new_with_connector(Default::default())
63 }
64}
65
66impl<T> HttpsConnector<T> {
67 pub fn https_only(&mut self, enable: bool) {
71 self.force_https = enable;
72 }
73
74 pub fn new_with_connector(http: T) -> Self {
83 native_tls::TlsConnector::new().map_or_else(
84 |e| {
85 panic!(
86 "HttpsConnector::new_with_connector(<connector>) failure: {}",
87 e
88 )
89 },
90 |tls| HttpsConnector::from((http, tls.into())),
91 )
92 }
93}
94
95impl<T> From<(T, TlsConnector)> for HttpsConnector<T> {
96 fn from(args: (T, TlsConnector)) -> HttpsConnector<T> {
97 HttpsConnector {
98 force_https: false,
99 http: args.0,
100 tls: args.1,
101 }
102 }
103}
104
105impl<T: fmt::Debug> fmt::Debug for HttpsConnector<T> {
106 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
107 f.debug_struct("HttpsConnector")
108 .field("force_https", &self.force_https)
109 .field("http", &self.http)
110 .finish_non_exhaustive()
111 }
112}
113
114impl<T> Service<Uri> for HttpsConnector<T>
115where
116 T: Service<Uri>,
117 T::Response: Read + Write + Send + Unpin,
118 T::Future: Send + 'static,
119 T::Error: Into<BoxError>,
120{
121 type Response = MaybeHttpsStream<T::Response>;
122 type Error = BoxError;
123 type Future = HttpsConnecting<T::Response>;
124
125 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
126 match self.http.poll_ready(cx) {
127 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
128 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
129 Poll::Pending => Poll::Pending,
130 }
131 }
132
133 fn call(&mut self, dst: Uri) -> Self::Future {
134 let is_https = dst.scheme_str() == Some("https");
135 if !is_https && self.force_https {
137 return err(ForceHttpsButUriNotHttps.into());
138 }
139
140 let host = dst
141 .host()
142 .unwrap_or("")
143 .trim_matches(|c| c == '[' || c == ']')
144 .to_owned();
145 let connecting = self.http.call(dst);
146
147 let tls_connector = self.tls.clone();
148
149 let fut = async move {
150 let tcp = connecting.await.map_err(Into::into)?;
151
152 let maybe = if is_https {
153 let stream = TokioIo::new(tcp);
154
155 let tls = TokioIo::new(tls_connector.connect(&host, stream).await?);
156 MaybeHttpsStream::Https(tls)
157 } else {
158 MaybeHttpsStream::Http(tcp)
159 };
160 Ok(maybe)
161 };
162 HttpsConnecting(Box::pin(fut))
163 }
164}
165
166fn err<T>(e: BoxError) -> HttpsConnecting<T> {
167 HttpsConnecting(Box::pin(async { Err(e) }))
168}
169
170type BoxedFut<T> = Pin<Box<dyn Future<Output = Result<MaybeHttpsStream<T>, BoxError>> + Send>>;
171
172pub struct HttpsConnecting<T>(BoxedFut<T>);
174
175impl<T: Read + Write + Unpin> Future for HttpsConnecting<T> {
176 type Output = Result<MaybeHttpsStream<T>, BoxError>;
177
178 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179 Pin::new(&mut self.0).poll(cx)
180 }
181}
182
183impl<T> fmt::Debug for HttpsConnecting<T> {
184 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
185 f.pad("HttpsConnecting")
186 }
187}
188
189#[derive(Debug)]
192struct ForceHttpsButUriNotHttps;
193
194impl fmt::Display for ForceHttpsButUriNotHttps {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 f.write_str("https required but URI was not https")
197 }
198}
199
200impl std::error::Error for ForceHttpsButUriNotHttps {}