1use std::io;
9use std::net::SocketAddr;
10use std::pin::Pin;
11use std::{future::Future, marker::PhantomData};
12
13use futures_util::{future, TryFutureExt};
14use openssl::pkcs12::ParsedPkcs12_2;
15use openssl::pkey::{PKey, Private};
16use openssl::ssl::{ConnectConfiguration, SslConnector, SslContextBuilder, SslMethod, SslOptions};
17use openssl::stack::Stack;
18use openssl::x509::store::X509StoreBuilder;
19use openssl::x509::X509;
20use tokio_openssl::{self, SslStream as TokioTlsStream};
21
22use crate::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd};
23use crate::tcp::TcpStream;
24use crate::tcp::{Connect, DnsTcpStream};
25use crate::xfer::BufDnsStreamHandle;
26
27pub(crate) trait TlsIdentityExt {
28 fn identity(&mut self, pkcs12: &ParsedPkcs12_2) -> io::Result<()> {
29 self.identity_parts(
30 pkcs12.cert.as_ref(),
31 pkcs12.pkey.as_ref(),
32 pkcs12.ca.as_ref(),
33 )
34 }
35
36 fn identity_parts(
37 &mut self,
38 cert: Option<&X509>,
39 pkey: Option<&PKey<Private>>,
40 chain: Option<&Stack<X509>>,
41 ) -> io::Result<()>;
42}
43
44impl TlsIdentityExt for SslContextBuilder {
45 fn identity_parts(
46 &mut self,
47 cert: Option<&X509>,
48 pkey: Option<&PKey<Private>>,
49 chain: Option<&Stack<X509>>,
50 ) -> io::Result<()> {
51 if let Some(cert) = cert {
52 self.set_certificate(cert)?;
53 }
54 if let Some(pkey) = pkey {
55 self.set_private_key(pkey)?;
56 }
57 self.check_private_key()?;
58 if let Some(chain) = chain {
59 for cert in chain {
60 self.add_extra_chain_cert(cert.to_owned())?;
61 }
62 }
63 Ok(())
64 }
65}
66
67pub type TlsStream<S> = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<S>>>;
69pub(crate) type CompatTlsStream<S> = TlsStream<AsyncIoStdAsTokio<S>>;
70
71fn new(certs: Vec<X509>, pkcs12: Option<ParsedPkcs12_2>) -> io::Result<SslConnector> {
72 let mut tls = SslConnector::builder(SslMethod::tls())
73 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}")))?;
74
75 {
77 let openssl_ctx_builder = &mut tls;
78
79 openssl_ctx_builder.set_options(
81 SslOptions::NO_SSLV2
82 | SslOptions::NO_SSLV3
83 | SslOptions::NO_TLSV1
84 | SslOptions::NO_TLSV1_1,
85 );
86
87 let mut store = X509StoreBuilder::new().map_err(|e| {
88 io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
89 })?;
90
91 for cert in certs {
92 store.add_cert(cert).map_err(|e| {
93 io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
94 })?;
95 }
96
97 openssl_ctx_builder
98 .set_verify_cert_store(store.build())
99 .map_err(|e| {
100 io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
101 })?;
102
103 if let Some(pkcs12) = pkcs12 {
105 openssl_ctx_builder.identity(&pkcs12)?;
106 }
107 }
108 Ok(tls.build())
109}
110
111pub fn tls_stream_from_existing_tls_stream<S: DnsTcpStream>(
115 stream: AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>,
116 peer_addr: SocketAddr,
117) -> (CompatTlsStream<S>, BufDnsStreamHandle) {
118 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
119 let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages);
120 (stream, message_sender)
121}
122
123async fn connect_tls<S, F>(
124 future: F,
125 tls_config: ConnectConfiguration,
126 dns_name: String,
127) -> Result<TokioTlsStream<AsyncIoStdAsTokio<S>>, io::Error>
128where
129 S: DnsTcpStream,
130 F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
131{
132 let tcp = future
133 .await
134 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}")))?;
135 let mut stream = tls_config
136 .into_ssl(&dns_name)
137 .and_then(|ssl| TokioTlsStream::new(ssl, AsyncIoStdAsTokio(tcp)))
138 .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("tls error: {e}")))?;
139 Pin::new(&mut stream)
140 .connect()
141 .await
142 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}")))?;
143 Ok(stream)
144}
145
146#[derive(Default)]
148pub struct TlsStreamBuilder<S> {
149 ca_chain: Vec<X509>,
150 identity: Option<ParsedPkcs12_2>,
151 bind_addr: Option<SocketAddr>,
152 marker: PhantomData<S>,
153}
154
155impl<S: DnsTcpStream> TlsStreamBuilder<S> {
156 pub fn new() -> Self {
158 Self {
159 ca_chain: vec![],
160 identity: None,
161 bind_addr: None,
162 marker: PhantomData,
163 }
164 }
165
166 pub fn add_ca(&mut self, ca: X509) {
170 self.ca_chain.push(ca);
171 }
172
173 #[cfg(feature = "mtls")]
175 pub fn identity(&mut self, pkcs12: ParsedPkcs12) {
176 self.identity = Some(pkcs12);
177 }
178
179 pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
181 self.bind_addr = Some(bind_addr);
182 }
183
184 #[allow(clippy::type_complexity)]
186 pub fn build_with_future<F>(
187 self,
188 future: F,
189 name_server: SocketAddr,
190 dns_name: String,
191 ) -> (
192 Pin<Box<dyn Future<Output = Result<CompatTlsStream<S>, io::Error>> + Send>>,
193 BufDnsStreamHandle,
194 )
195 where
196 F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
197 {
198 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
199 let tls_config = match new(self.ca_chain, self.identity) {
200 Ok(c) => c,
201 Err(e) => {
202 return (
203 Box::pin(future::err(e).map_err(|e| {
204 io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
205 })),
206 message_sender,
207 )
208 }
209 };
210
211 let tls_config = match tls_config.configure() {
212 Ok(c) => c,
213 Err(e) => {
214 return (
215 Box::pin(future::err(e).map_err(|e| {
216 io::Error::new(
217 io::ErrorKind::ConnectionRefused,
218 format!("tls config error: {e}"),
219 )
220 })),
221 message_sender,
222 )
223 }
224 };
225
226 let stream = Box::pin(connect_tls(future, tls_config, dns_name).map_ok(move |s| {
229 TcpStream::from_stream_with_receiver(
230 AsyncIoTokioAsStd(s),
231 name_server,
232 outbound_messages,
233 )
234 }));
235
236 (stream, message_sender)
237 }
238}
239
240impl<S: Connect> TlsStreamBuilder<S> {
241 #[allow(clippy::type_complexity)]
267 pub fn build(
268 self,
269 name_server: SocketAddr,
270 dns_name: String,
271 ) -> (
272 Pin<Box<dyn Future<Output = Result<CompatTlsStream<S>, io::Error>> + Send>>,
273 BufDnsStreamHandle,
274 ) {
275 let future = S::connect_with_bind(name_server, self.bind_addr);
276 self.build_with_future(future, name_server, dns_name)
277 }
278}