hickory_proto/rustls/
tls_stream.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! DNS over TLS I/O stream implementation for Rustls
9
10use alloc::boxed::Box;
11use alloc::string::String;
12use alloc::sync::Arc;
13use core::future::Future;
14use core::pin::Pin;
15use std::io;
16use std::net::SocketAddr;
17
18use rustls::ClientConfig;
19use rustls::pki_types::ServerName;
20use tokio::net::TcpStream as TokioTcpStream;
21use tokio::{self, time::timeout};
22use tokio_rustls::TlsConnector;
23
24use crate::runtime::RuntimeProvider;
25use crate::runtime::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd};
26use crate::tcp::{DnsTcpStream, TcpStream};
27use crate::xfer::{BufDnsStreamHandle, CONNECT_TIMEOUT, StreamReceiver};
28
29/// Predefined type for abstracting the TlsClientStream with TokioTls
30pub type TokioTlsClientStream<S> = tokio_rustls::client::TlsStream<AsyncIoStdAsTokio<S>>;
31
32/// Predefined type for abstracting the TlsServerStream with TokioTls
33pub type TokioTlsServerStream = tokio_rustls::server::TlsStream<TokioTcpStream>;
34
35/// Predefined type for abstracting the base I/O TlsStream with TokioTls
36pub type TlsStream<S> = TcpStream<S>;
37
38/// Initializes a TlsStream with an existing tokio_tls::TlsStream.
39///
40/// This is intended for use with a TlsListener and Incoming connections
41pub fn tls_from_stream<S: DnsTcpStream>(
42    stream: S,
43    peer_addr: SocketAddr,
44) -> (TlsStream<S>, BufDnsStreamHandle) {
45    let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
46    let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages);
47    (stream, message_sender)
48}
49
50/// Creates a new TlsStream to the specified name_server
51///
52/// [RFC 7858](https://tools.ietf.org/html/rfc7858), DNS over TLS, May 2016
53///
54/// ```text
55/// 3.2.  TLS Handshake and Authentication
56///
57///   Once the DNS client succeeds in connecting via TCP on the well-known
58///   port for DNS over TLS, it proceeds with the TLS handshake [RFC5246],
59///   following the best practices specified in [BCP195].
60///
61///   The client will then authenticate the server, if required.  This
62///   document does not propose new ideas for authentication.  Depending on
63///   the privacy profile in use (Section 4), the DNS client may choose not
64///   to require authentication of the server, or it may make use of a
65///   trusted Subject Public Key Info (SPKI) Fingerprint pin set.
66///
67///   After TLS negotiation completes, the connection will be encrypted and
68///   is now protected from eavesdropping.
69/// ```
70///
71/// # Arguments
72///
73/// * `name_server` - IP and Port for the remote DNS resolver
74/// * `bind_addr` - IP and port to connect from
75/// * `dns_name` - The DNS name associated with a certificate
76#[allow(clippy::type_complexity)]
77pub fn tls_connect<P: RuntimeProvider>(
78    name_server: SocketAddr,
79    dns_name: String,
80    client_config: Arc<ClientConfig>,
81    provider: P,
82) -> (
83    Pin<
84        Box<
85            dyn Future<
86                    Output = Result<
87                        TlsStream<AsyncIoTokioAsStd<TokioTlsClientStream<P::Tcp>>>,
88                        io::Error,
89                    >,
90                > + Send,
91        >,
92    >,
93    BufDnsStreamHandle,
94) {
95    tls_connect_with_bind_addr(name_server, None, dns_name, client_config, provider)
96}
97
98/// Creates a new TlsStream to the specified name_server connecting from a specific address.
99///
100/// # Arguments
101///
102/// * `name_server` - IP and Port for the remote DNS resolver
103/// * `bind_addr` - IP and port to connect from
104/// * `dns_name` - The DNS name associated with a certificate
105#[allow(clippy::type_complexity)]
106pub fn tls_connect_with_bind_addr<P: RuntimeProvider>(
107    name_server: SocketAddr,
108    bind_addr: Option<SocketAddr>,
109    dns_name: String,
110    client_config: Arc<ClientConfig>,
111    provider: P,
112) -> (
113    Pin<
114        Box<
115            dyn Future<
116                    Output = Result<
117                        TlsStream<AsyncIoTokioAsStd<TokioTlsClientStream<P::Tcp>>>,
118                        io::Error,
119                    >,
120                > + Send,
121        >,
122    >,
123    BufDnsStreamHandle,
124) {
125    let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
126    let early_data_enabled = client_config.enable_early_data;
127    let tls_connector = TlsConnector::from(client_config).early_data(early_data_enabled);
128
129    // This set of futures collapses the next tcp socket into a stream which can be used for
130    //  sending and receiving tcp packets.
131    let stream = Box::pin(connect_tls(
132        tls_connector,
133        name_server,
134        bind_addr,
135        dns_name,
136        outbound_messages,
137        provider,
138    ));
139
140    (stream, message_sender)
141}
142
143/// Creates a new TlsStream to the specified name_server connecting from a specific address.
144///
145/// # Arguments
146///
147/// * `name_server` - IP and Port for the remote DNS resolver
148/// * `bind_addr` - IP and port to connect from
149/// * `dns_name` - The DNS name associated with a certificate
150#[allow(clippy::type_complexity)]
151pub fn tls_connect_with_future<S, F>(
152    future: F,
153    name_server: SocketAddr,
154    dns_name: String,
155    client_config: Arc<ClientConfig>,
156) -> (
157    Pin<
158        Box<
159            dyn Future<
160                    Output = Result<
161                        TlsStream<AsyncIoTokioAsStd<TokioTlsClientStream<S>>>,
162                        io::Error,
163                    >,
164                > + Send,
165        >,
166    >,
167    BufDnsStreamHandle,
168)
169where
170    S: DnsTcpStream,
171    F: Future<Output = io::Result<S>> + Send + Unpin + 'static,
172{
173    let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
174    let early_data_enabled = client_config.enable_early_data;
175    let tls_connector = TlsConnector::from(client_config).early_data(early_data_enabled);
176
177    // This set of futures collapses the next tcp socket into a stream which can be used for
178    //  sending and receiving tcp packets.
179    let stream = Box::pin(connect_tls_with_future(
180        tls_connector,
181        future,
182        name_server,
183        dns_name,
184        outbound_messages,
185    ));
186
187    (stream, message_sender)
188}
189
190async fn connect_tls<P: RuntimeProvider>(
191    tls_connector: TlsConnector,
192    name_server: SocketAddr,
193    bind_addr: Option<SocketAddr>,
194    dns_name: String,
195    outbound_messages: StreamReceiver,
196    provider: P,
197) -> io::Result<TcpStream<AsyncIoTokioAsStd<TokioTlsClientStream<P::Tcp>>>> {
198    let tcp = provider.connect_tcp(name_server, bind_addr, None);
199    connect_tls_with_future(tls_connector, tcp, name_server, dns_name, outbound_messages).await
200}
201
202async fn connect_tls_with_future<S, F>(
203    tls_connector: TlsConnector,
204    future: F,
205    name_server: SocketAddr,
206    server_name: String,
207    outbound_messages: StreamReceiver,
208) -> io::Result<TcpStream<AsyncIoTokioAsStd<TokioTlsClientStream<S>>>>
209where
210    S: DnsTcpStream,
211    F: Future<Output = io::Result<S>> + Send + Unpin,
212{
213    let dns_name = match ServerName::try_from(server_name) {
214        Ok(name) => name,
215        Err(_) => return Err(io::Error::new(io::ErrorKind::InvalidInput, "bad dns_name")),
216    };
217
218    let stream = AsyncIoStdAsTokio(future.await?);
219    let s = match timeout(CONNECT_TIMEOUT, tls_connector.connect(dns_name, stream)).await {
220        Ok(Ok(s)) => s,
221        Ok(Err(e)) => {
222            return Err(io::Error::new(
223                io::ErrorKind::ConnectionRefused,
224                format!("tls error: {e}"),
225            ));
226        }
227        Err(_) => {
228            return Err(io::Error::new(
229                io::ErrorKind::TimedOut,
230                format!("TLS handshake timed out after {CONNECT_TIMEOUT:?}"),
231            ));
232        }
233    };
234
235    Ok(TcpStream::from_stream_with_receiver(
236        AsyncIoTokioAsStd(s),
237        name_server,
238        outbound_messages,
239    ))
240}