trust_dns_proto/openssl/
tls_stream.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::io;
9use std::net::SocketAddr;
10use std::pin::Pin;
11use std::{future::Future, marker::PhantomData};
12
13use futures_util::{future, TryFutureExt};
14use openssl::pkcs12::ParsedPkcs12_2;
15use openssl::pkey::{PKey, Private};
16use openssl::ssl::{ConnectConfiguration, SslConnector, SslContextBuilder, SslMethod, SslOptions};
17use openssl::stack::Stack;
18use openssl::x509::store::X509StoreBuilder;
19use openssl::x509::X509;
20use tokio_openssl::{self, SslStream as TokioTlsStream};
21
22use crate::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd};
23use crate::tcp::TcpStream;
24use crate::tcp::{Connect, DnsTcpStream};
25use crate::xfer::BufDnsStreamHandle;
26
27pub(crate) trait TlsIdentityExt {
28    fn identity(&mut self, pkcs12: &ParsedPkcs12_2) -> io::Result<()> {
29        self.identity_parts(
30            pkcs12.cert.as_ref(),
31            pkcs12.pkey.as_ref(),
32            pkcs12.ca.as_ref(),
33        )
34    }
35
36    fn identity_parts(
37        &mut self,
38        cert: Option<&X509>,
39        pkey: Option<&PKey<Private>>,
40        chain: Option<&Stack<X509>>,
41    ) -> io::Result<()>;
42}
43
44impl TlsIdentityExt for SslContextBuilder {
45    fn identity_parts(
46        &mut self,
47        cert: Option<&X509>,
48        pkey: Option<&PKey<Private>>,
49        chain: Option<&Stack<X509>>,
50    ) -> io::Result<()> {
51        if let Some(cert) = cert {
52            self.set_certificate(cert)?;
53        }
54        if let Some(pkey) = pkey {
55            self.set_private_key(pkey)?;
56        }
57        self.check_private_key()?;
58        if let Some(chain) = chain {
59            for cert in chain {
60                self.add_extra_chain_cert(cert.to_owned())?;
61            }
62        }
63        Ok(())
64    }
65}
66
67/// A TlsStream counterpart to the TcpStream which embeds a secure TlsStream
68pub type TlsStream<S> = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<S>>>;
69pub(crate) type CompatTlsStream<S> = TlsStream<AsyncIoStdAsTokio<S>>;
70
71fn new(certs: Vec<X509>, pkcs12: Option<ParsedPkcs12_2>) -> io::Result<SslConnector> {
72    let mut tls = SslConnector::builder(SslMethod::tls())
73        .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}")))?;
74
75    // mutable reference block
76    {
77        let openssl_ctx_builder = &mut tls;
78
79        // only want to support current TLS versions, 1.2 or future
80        openssl_ctx_builder.set_options(
81            SslOptions::NO_SSLV2
82                | SslOptions::NO_SSLV3
83                | SslOptions::NO_TLSV1
84                | SslOptions::NO_TLSV1_1,
85        );
86
87        let mut store = X509StoreBuilder::new().map_err(|e| {
88            io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
89        })?;
90
91        for cert in certs {
92            store.add_cert(cert).map_err(|e| {
93                io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
94            })?;
95        }
96
97        openssl_ctx_builder
98            .set_verify_cert_store(store.build())
99            .map_err(|e| {
100                io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
101            })?;
102
103        // if there was a pkcs12 associated, we'll add it to the identity
104        if let Some(pkcs12) = pkcs12 {
105            openssl_ctx_builder.identity(&pkcs12)?;
106        }
107    }
108    Ok(tls.build())
109}
110
111/// Initializes a TlsStream with an existing tokio_tls::TlsStream.
112///
113/// This is intended for use with a TlsListener and Incoming connections
114pub fn tls_stream_from_existing_tls_stream<S: DnsTcpStream>(
115    stream: AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>,
116    peer_addr: SocketAddr,
117) -> (CompatTlsStream<S>, BufDnsStreamHandle) {
118    let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
119    let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages);
120    (stream, message_sender)
121}
122
123async fn connect_tls<S, F>(
124    future: F,
125    tls_config: ConnectConfiguration,
126    dns_name: String,
127) -> Result<TokioTlsStream<AsyncIoStdAsTokio<S>>, io::Error>
128where
129    S: DnsTcpStream,
130    F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
131{
132    let tcp = future
133        .await
134        .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}")))?;
135    let mut stream = tls_config
136        .into_ssl(&dns_name)
137        .and_then(|ssl| TokioTlsStream::new(ssl, AsyncIoStdAsTokio(tcp)))
138        .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("tls error: {e}")))?;
139    Pin::new(&mut stream)
140        .connect()
141        .await
142        .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}")))?;
143    Ok(stream)
144}
145
146/// A builder for the TlsStream
147#[derive(Default)]
148pub struct TlsStreamBuilder<S> {
149    ca_chain: Vec<X509>,
150    identity: Option<ParsedPkcs12_2>,
151    bind_addr: Option<SocketAddr>,
152    marker: PhantomData<S>,
153}
154
155impl<S: DnsTcpStream> TlsStreamBuilder<S> {
156    /// A builder for associating trust information to the `TlsStream`.
157    pub fn new() -> Self {
158        Self {
159            ca_chain: vec![],
160            identity: None,
161            bind_addr: None,
162            marker: PhantomData,
163        }
164    }
165
166    /// Add a custom trusted peer certificate or certificate authority.
167    ///
168    /// If this is the 'client' then the 'server' must have it associated as it's `identity`, or have had the `identity` signed by this
169    pub fn add_ca(&mut self, ca: X509) {
170        self.ca_chain.push(ca);
171    }
172
173    /// Client side identity for client auth in TLS (aka mutual TLS auth)
174    #[cfg(feature = "mtls")]
175    pub fn identity(&mut self, pkcs12: ParsedPkcs12) {
176        self.identity = Some(pkcs12);
177    }
178
179    /// Sets the address to connect from.
180    pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
181        self.bind_addr = Some(bind_addr);
182    }
183
184    /// Similar to `build`, but with prebuilt tcp stream
185    #[allow(clippy::type_complexity)]
186    pub fn build_with_future<F>(
187        self,
188        future: F,
189        name_server: SocketAddr,
190        dns_name: String,
191    ) -> (
192        Pin<Box<dyn Future<Output = Result<CompatTlsStream<S>, io::Error>> + Send>>,
193        BufDnsStreamHandle,
194    )
195    where
196        F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
197    {
198        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
199        let tls_config = match new(self.ca_chain, self.identity) {
200            Ok(c) => c,
201            Err(e) => {
202                return (
203                    Box::pin(future::err(e).map_err(|e| {
204                        io::Error::new(io::ErrorKind::ConnectionRefused, format!("tls error: {e}"))
205                    })),
206                    message_sender,
207                )
208            }
209        };
210
211        let tls_config = match tls_config.configure() {
212            Ok(c) => c,
213            Err(e) => {
214                return (
215                    Box::pin(future::err(e).map_err(|e| {
216                        io::Error::new(
217                            io::ErrorKind::ConnectionRefused,
218                            format!("tls config error: {e}"),
219                        )
220                    })),
221                    message_sender,
222                )
223            }
224        };
225
226        // This set of futures collapses the next tcp socket into a stream which can be used for
227        //  sending and receiving tcp packets.
228        let stream = Box::pin(connect_tls(future, tls_config, dns_name).map_ok(move |s| {
229            TcpStream::from_stream_with_receiver(
230                AsyncIoTokioAsStd(s),
231                name_server,
232                outbound_messages,
233            )
234        }));
235
236        (stream, message_sender)
237    }
238}
239
240impl<S: Connect> TlsStreamBuilder<S> {
241    /// Creates a new TlsStream to the specified name_server
242    ///
243    /// [RFC 7858](https://tools.ietf.org/html/rfc7858), DNS over TLS, May 2016
244    ///
245    /// ```text
246    /// 3.2.  TLS Handshake and Authentication
247    ///
248    ///   Once the DNS client succeeds in connecting via TCP on the well-known
249    ///   port for DNS over TLS, it proceeds with the TLS handshake [RFC5246],
250    ///   following the best practices specified in [BCP195].
251    ///
252    ///   The client will then authenticate the server, if required.  This
253    ///   document does not propose new ideas for authentication.  Depending on
254    ///   the privacy profile in use (Section 4), the DNS client may choose not
255    ///   to require authentication of the server, or it may make use of a
256    ///   trusted Subject Public Key Info (SPKI) Fingerprint pin set.
257    ///
258    ///   After TLS negotiation completes, the connection will be encrypted and
259    ///   is now protected from eavesdropping.
260    /// ```
261    ///
262    /// # Arguments
263    ///
264    /// * `name_server` - IP and Port for the remote DNS resolver
265    /// * `dns_name` - The DNS name, Subject Public Key Info (SPKI) name, as associated to a certificate
266    #[allow(clippy::type_complexity)]
267    pub fn build(
268        self,
269        name_server: SocketAddr,
270        dns_name: String,
271    ) -> (
272        Pin<Box<dyn Future<Output = Result<CompatTlsStream<S>, io::Error>> + Send>>,
273        BufDnsStreamHandle,
274    ) {
275        let future = S::connect_with_bind(name_server, self.bind_addr);
276        self.build_with_future(future, name_server, dns_name)
277    }
278}