1use alloc::boxed::Box;
11use alloc::string::String;
12use alloc::sync::Arc;
13use core::future::Future;
14use core::pin::Pin;
15use std::io;
16use std::net::SocketAddr;
17
18use rustls::ClientConfig;
19use rustls::pki_types::ServerName;
20use tokio::net::TcpStream as TokioTcpStream;
21use tokio::{self, time::timeout};
22use tokio_rustls::TlsConnector;
23
24use crate::runtime::RuntimeProvider;
25use crate::runtime::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd};
26use crate::tcp::{DnsTcpStream, TcpStream};
27use crate::xfer::{BufDnsStreamHandle, CONNECT_TIMEOUT, StreamReceiver};
28
29pub type TokioTlsClientStream<S> = tokio_rustls::client::TlsStream<AsyncIoStdAsTokio<S>>;
31
32pub type TokioTlsServerStream = tokio_rustls::server::TlsStream<TokioTcpStream>;
34
35pub type TlsStream<S> = TcpStream<S>;
37
38pub fn tls_from_stream<S: DnsTcpStream>(
42 stream: S,
43 peer_addr: SocketAddr,
44) -> (TlsStream<S>, BufDnsStreamHandle) {
45 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
46 let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages);
47 (stream, message_sender)
48}
49
50#[allow(clippy::type_complexity)]
77pub fn tls_connect<P: RuntimeProvider>(
78 name_server: SocketAddr,
79 dns_name: String,
80 client_config: Arc<ClientConfig>,
81 provider: P,
82) -> (
83 Pin<
84 Box<
85 dyn Future<
86 Output = Result<
87 TlsStream<AsyncIoTokioAsStd<TokioTlsClientStream<P::Tcp>>>,
88 io::Error,
89 >,
90 > + Send,
91 >,
92 >,
93 BufDnsStreamHandle,
94) {
95 tls_connect_with_bind_addr(name_server, None, dns_name, client_config, provider)
96}
97
98#[allow(clippy::type_complexity)]
106pub fn tls_connect_with_bind_addr<P: RuntimeProvider>(
107 name_server: SocketAddr,
108 bind_addr: Option<SocketAddr>,
109 dns_name: String,
110 client_config: Arc<ClientConfig>,
111 provider: P,
112) -> (
113 Pin<
114 Box<
115 dyn Future<
116 Output = Result<
117 TlsStream<AsyncIoTokioAsStd<TokioTlsClientStream<P::Tcp>>>,
118 io::Error,
119 >,
120 > + Send,
121 >,
122 >,
123 BufDnsStreamHandle,
124) {
125 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
126 let early_data_enabled = client_config.enable_early_data;
127 let tls_connector = TlsConnector::from(client_config).early_data(early_data_enabled);
128
129 let stream = Box::pin(connect_tls(
132 tls_connector,
133 name_server,
134 bind_addr,
135 dns_name,
136 outbound_messages,
137 provider,
138 ));
139
140 (stream, message_sender)
141}
142
143#[allow(clippy::type_complexity)]
151pub fn tls_connect_with_future<S, F>(
152 future: F,
153 name_server: SocketAddr,
154 dns_name: String,
155 client_config: Arc<ClientConfig>,
156) -> (
157 Pin<
158 Box<
159 dyn Future<
160 Output = Result<
161 TlsStream<AsyncIoTokioAsStd<TokioTlsClientStream<S>>>,
162 io::Error,
163 >,
164 > + Send,
165 >,
166 >,
167 BufDnsStreamHandle,
168)
169where
170 S: DnsTcpStream,
171 F: Future<Output = io::Result<S>> + Send + Unpin + 'static,
172{
173 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
174 let early_data_enabled = client_config.enable_early_data;
175 let tls_connector = TlsConnector::from(client_config).early_data(early_data_enabled);
176
177 let stream = Box::pin(connect_tls_with_future(
180 tls_connector,
181 future,
182 name_server,
183 dns_name,
184 outbound_messages,
185 ));
186
187 (stream, message_sender)
188}
189
190async fn connect_tls<P: RuntimeProvider>(
191 tls_connector: TlsConnector,
192 name_server: SocketAddr,
193 bind_addr: Option<SocketAddr>,
194 dns_name: String,
195 outbound_messages: StreamReceiver,
196 provider: P,
197) -> io::Result<TcpStream<AsyncIoTokioAsStd<TokioTlsClientStream<P::Tcp>>>> {
198 let tcp = provider.connect_tcp(name_server, bind_addr, None);
199 connect_tls_with_future(tls_connector, tcp, name_server, dns_name, outbound_messages).await
200}
201
202async fn connect_tls_with_future<S, F>(
203 tls_connector: TlsConnector,
204 future: F,
205 name_server: SocketAddr,
206 server_name: String,
207 outbound_messages: StreamReceiver,
208) -> io::Result<TcpStream<AsyncIoTokioAsStd<TokioTlsClientStream<S>>>>
209where
210 S: DnsTcpStream,
211 F: Future<Output = io::Result<S>> + Send + Unpin,
212{
213 let dns_name = match ServerName::try_from(server_name) {
214 Ok(name) => name,
215 Err(_) => return Err(io::Error::new(io::ErrorKind::InvalidInput, "bad dns_name")),
216 };
217
218 let stream = AsyncIoStdAsTokio(future.await?);
219 let s = match timeout(CONNECT_TIMEOUT, tls_connector.connect(dns_name, stream)).await {
220 Ok(Ok(s)) => s,
221 Ok(Err(e)) => {
222 return Err(io::Error::new(
223 io::ErrorKind::ConnectionRefused,
224 format!("tls error: {e}"),
225 ));
226 }
227 Err(_) => {
228 return Err(io::Error::new(
229 io::ErrorKind::TimedOut,
230 format!("TLS handshake timed out after {CONNECT_TIMEOUT:?}"),
231 ));
232 }
233 };
234
235 Ok(TcpStream::from_stream_with_receiver(
236 AsyncIoTokioAsStd(s),
237 name_server,
238 outbound_messages,
239 ))
240}