1use std::io;
9use std::net::SocketAddr;
10use std::pin::Pin;
11use std::{future::Future, marker::PhantomData};
12
13use futures_util::{future, TryFutureExt};
14use openssl::pkcs12::ParsedPkcs12_2;
15use openssl::pkey::{PKey, Private};
16use openssl::ssl::{ConnectConfiguration, SslConnector, SslContextBuilder, SslMethod, SslOptions};
17use openssl::stack::Stack;
18use openssl::x509::store::X509StoreBuilder;
19use openssl::x509::X509;
20use tokio_openssl::{self, SslStream as TokioTlsStream};
21
22use crate::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd};
23use crate::tcp::TcpStream;
24use crate::tcp::{Connect, DnsTcpStream};
25use crate::xfer::BufDnsStreamHandle;
26
27pub(crate) trait TlsIdentityExt {
28 fn identity(&mut self, pkcs12: &ParsedPkcs12_2) -> io::Result<()> {
29 self.identity_parts(
30 pkcs12.cert.as_ref(),
31 pkcs12.pkey.as_ref(),
32 pkcs12.ca.as_ref(),
33 )
34 }
35
36 fn identity_parts(
37 &mut self,
38 cert: Option<&X509>,
39 pkey: Option<&PKey<Private>>,
40 chain: Option<&Stack<X509>>,
41 ) -> io::Result<()>;
42}
43
44impl TlsIdentityExt for SslContextBuilder {
45 fn identity_parts(
46 &mut self,
47 cert: Option<&X509>,
48 pkey: Option<&PKey<Private>>,
49 chain: Option<&Stack<X509>>,
50 ) -> io::Result<()> {
51 if let Some(cert) = cert {
52 self.set_certificate(cert)?;
53 }
54 if let Some(pkey) = pkey {
55 self.set_private_key(pkey)?;
56 }
57 self.check_private_key()?;
58 if let Some(chain) = chain {
59 for cert in chain {
60 self.add_extra_chain_cert(cert.to_owned())?;
61 }
62 }
63 Ok(())
64 }
65}
66
67pub type TlsStream<S> = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<S>>>;
69pub(crate) type CompatTlsStream<S> = TlsStream<AsyncIoStdAsTokio<S>>;
70
71fn new(certs: Vec<X509>, pkcs12: Option<ParsedPkcs12_2>) -> io::Result<SslConnector> {
72 let mut tls = SslConnector::builder(SslMethod::tls())
73 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}")))?;
74
75 {
77 let openssl_ctx_builder = &mut tls;
78
79 openssl_ctx_builder.set_options(
81 SslOptions::NO_SSLV2
82 | SslOptions::NO_SSLV3
83 | SslOptions::NO_TLSV1
84 | SslOptions::NO_TLSV1_1,
85 );
86
87 let mut store = X509StoreBuilder::new().map_err(|e| {
88 io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
89 })?;
90
91 for cert in certs {
92 store.add_cert(cert).map_err(|e| {
93 io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
94 })?;
95 }
96
97 openssl_ctx_builder
98 .set_verify_cert_store(store.build())
99 .map_err(|e| {
100 io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
101 })?;
102
103 if let Some(pkcs12) = pkcs12 {
105 openssl_ctx_builder.identity(&pkcs12)?;
106 }
107 }
108 Ok(tls.build())
109}
110
111pub fn tls_stream_from_existing_tls_stream<S: DnsTcpStream>(
115 stream: AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>,
116 peer_addr: SocketAddr,
117) -> (CompatTlsStream<S>, BufDnsStreamHandle) {
118 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
119 let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages);
120 (stream, message_sender)
121}
122
123async fn connect_tls<S, F>(
124 future: F,
125 tls_config: ConnectConfiguration,
126 dns_name: String,
127) -> Result<TokioTlsStream<AsyncIoStdAsTokio<S>>, io::Error>
128where
129 S: DnsTcpStream,
130 F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
131{
132 let tcp = future
133 .await
134 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}")))?;
135 let mut stream = tls_config
136 .into_ssl(&dns_name)
137 .and_then(|ssl| TokioTlsStream::new(ssl, AsyncIoStdAsTokio(tcp)))
138 .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("tls error: {e}")))?;
139 Pin::new(&mut stream)
140 .connect()
141 .await
142 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}")))?;
143 Ok(stream)
144}
145
146#[derive(Default)]
148pub struct TlsStreamBuilder<S> {
149 ca_chain: Vec<X509>,
150 identity: Option<ParsedPkcs12_2>,
151 bind_addr: Option<SocketAddr>,
152 marker: PhantomData<S>,
153}
154
155impl<S: DnsTcpStream> TlsStreamBuilder<S> {
156 pub fn new() -> Self {
158 Self {
159 ca_chain: vec![],
160 identity: None,
161 bind_addr: None,
162 marker: PhantomData,
163 }
164 }
165
166 pub fn add_ca(&mut self, ca: X509) {
170 self.ca_chain.push(ca);
171 }
172
173 pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
175 self.bind_addr = Some(bind_addr);
176 }
177
178 #[allow(clippy::type_complexity)]
180 pub fn build_with_future<F>(
181 self,
182 future: F,
183 name_server: SocketAddr,
184 dns_name: String,
185 ) -> (
186 Pin<Box<dyn Future<Output = Result<CompatTlsStream<S>, io::Error>> + Send>>,
187 BufDnsStreamHandle,
188 )
189 where
190 F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
191 {
192 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
193 let tls_config = match new(self.ca_chain, self.identity) {
194 Ok(c) => c,
195 Err(e) => {
196 return (
197 Box::pin(future::err(e).map_err(|e| {
198 io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
199 })),
200 message_sender,
201 )
202 }
203 };
204
205 let tls_config = match tls_config.configure() {
206 Ok(c) => c,
207 Err(e) => {
208 return (
209 Box::pin(future::err(e).map_err(|e| {
210 io::Error::new(
211 io::ErrorKind::ConnectionRefused,
212 format!("tls config error: {e}"),
213 )
214 })),
215 message_sender,
216 )
217 }
218 };
219
220 let stream = Box::pin(connect_tls(future, tls_config, dns_name).map_ok(move |s| {
223 TcpStream::from_stream_with_receiver(
224 AsyncIoTokioAsStd(s),
225 name_server,
226 outbound_messages,
227 )
228 }));
229
230 (stream, message_sender)
231 }
232}
233
234impl<S: Connect> TlsStreamBuilder<S> {
235 #[allow(clippy::type_complexity)]
261 pub fn build(
262 self,
263 name_server: SocketAddr,
264 dns_name: String,
265 ) -> (
266 Pin<Box<dyn Future<Output = Result<CompatTlsStream<S>, io::Error>> + Send>>,
267 BufDnsStreamHandle,
268 ) {
269 let future = S::connect_with_bind(name_server, self.bind_addr);
270 self.build_with_future(future, name_server, dns_name)
271 }
272}