1use std::future::Future;
11use std::io;
12use std::net::SocketAddr;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use rustls::ClientConfig;
17use tokio;
18use tokio::net::TcpStream as TokioTcpStream;
19use tokio_rustls::TlsConnector;
20
21use crate::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd};
22use crate::tcp::Connect;
23use crate::tcp::{DnsTcpStream, TcpStream};
24use crate::xfer::{BufDnsStreamHandle, StreamReceiver};
25
26pub type TokioTlsClientStream<S> = tokio_rustls::client::TlsStream<AsyncIoStdAsTokio<S>>;
28
29pub type TokioTlsServerStream = tokio_rustls::server::TlsStream<TokioTcpStream>;
31
32pub type TlsStream<S> = TcpStream<S>;
34
35pub fn tls_from_stream<S: DnsTcpStream>(
39 stream: S,
40 peer_addr: SocketAddr,
41) -> (TlsStream<S>, BufDnsStreamHandle) {
42 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
43 let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages);
44 (stream, message_sender)
45}
46
47#[allow(clippy::type_complexity)]
74pub fn tls_connect<S: Connect>(
75 name_server: SocketAddr,
76 dns_name: String,
77 client_config: Arc<ClientConfig>,
78) -> (
79 Pin<
80 Box<
81 dyn Future<
82 Output = Result<
83 TlsStream<AsyncIoTokioAsStd<TokioTlsClientStream<S>>>,
84 io::Error,
85 >,
86 > + Send,
87 >,
88 >,
89 BufDnsStreamHandle,
90) {
91 tls_connect_with_bind_addr(name_server, None, dns_name, client_config)
92}
93
94#[allow(clippy::type_complexity)]
102pub fn tls_connect_with_bind_addr<S: Connect>(
103 name_server: SocketAddr,
104 bind_addr: Option<SocketAddr>,
105 dns_name: String,
106 client_config: Arc<ClientConfig>,
107) -> (
108 Pin<
109 Box<
110 dyn Future<
111 Output = Result<
112 TlsStream<AsyncIoTokioAsStd<TokioTlsClientStream<S>>>,
113 io::Error,
114 >,
115 > + Send,
116 >,
117 >,
118 BufDnsStreamHandle,
119) {
120 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
121 let early_data_enabled = client_config.enable_early_data;
122 let tls_connector = TlsConnector::from(client_config).early_data(early_data_enabled);
123
124 let stream = Box::pin(connect_tls(
127 tls_connector,
128 name_server,
129 bind_addr,
130 dns_name,
131 outbound_messages,
132 ));
133
134 (stream, message_sender)
135}
136
137#[allow(clippy::type_complexity)]
145pub fn tls_connect_with_future<S, F>(
146 future: F,
147 name_server: SocketAddr,
148 dns_name: String,
149 client_config: Arc<ClientConfig>,
150) -> (
151 Pin<
152 Box<
153 dyn Future<
154 Output = Result<
155 TlsStream<AsyncIoTokioAsStd<TokioTlsClientStream<S>>>,
156 io::Error,
157 >,
158 > + Send,
159 >,
160 >,
161 BufDnsStreamHandle,
162)
163where
164 S: DnsTcpStream,
165 F: Future<Output = io::Result<S>> + Send + Unpin + 'static,
166{
167 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
168 let early_data_enabled = client_config.enable_early_data;
169 let tls_connector = TlsConnector::from(client_config).early_data(early_data_enabled);
170
171 let stream = Box::pin(connect_tls_with_future(
174 tls_connector,
175 future,
176 name_server,
177 dns_name,
178 outbound_messages,
179 ));
180
181 (stream, message_sender)
182}
183
184async fn connect_tls<S: Connect>(
185 tls_connector: TlsConnector,
186 name_server: SocketAddr,
187 bind_addr: Option<SocketAddr>,
188 dns_name: String,
189 outbound_messages: StreamReceiver,
190) -> io::Result<TcpStream<AsyncIoTokioAsStd<TokioTlsClientStream<S>>>> {
191 let tcp = S::connect_with_bind(name_server, bind_addr);
192 connect_tls_with_future(tls_connector, tcp, name_server, dns_name, outbound_messages).await
193}
194
195async fn connect_tls_with_future<S, F>(
196 tls_connector: TlsConnector,
197 future: F,
198 name_server: SocketAddr,
199 dns_name: String,
200 outbound_messages: StreamReceiver,
201) -> io::Result<TcpStream<AsyncIoTokioAsStd<TokioTlsClientStream<S>>>>
202where
203 S: DnsTcpStream,
204 F: Future<Output = io::Result<S>> + Send + Unpin,
205{
206 let dns_name = match dns_name.as_str().try_into() {
207 Ok(name) => name,
208 Err(_) => return Err(io::Error::new(io::ErrorKind::InvalidInput, "bad dns_name")),
209 };
210
211 let stream = future.await?;
212 let s = tls_connector
213 .connect(dns_name, AsyncIoStdAsTokio(stream))
214 .await
215 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}")))?;
216
217 Ok(TcpStream::from_stream_with_receiver(
218 AsyncIoTokioAsStd(s),
219 name_server,
220 outbound_messages,
221 ))
222}