tokio_tungstenite/
tls.rs

1//! Connection helper.
2use tokio::io::{AsyncRead, AsyncWrite};
3
4use tungstenite::{
5    client::uri_mode, error::Error, handshake::client::Response, protocol::WebSocketConfig,
6};
7
8use crate::{client_async_with_config, IntoClientRequest, WebSocketStream};
9
10pub use crate::stream::MaybeTlsStream;
11
12/// A connector that can be used when establishing connections, allowing to control whether
13/// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the
14/// `Plain` variant.
15#[non_exhaustive]
16#[derive(Clone)]
17pub enum Connector {
18    /// Plain (non-TLS) connector.
19    Plain,
20    /// `native-tls` TLS connector.
21    #[cfg(feature = "native-tls")]
22    NativeTls(native_tls_crate::TlsConnector),
23    /// `rustls` TLS connector.
24    #[cfg(feature = "__rustls-tls")]
25    Rustls(std::sync::Arc<rustls::ClientConfig>),
26}
27
28mod encryption {
29    #[cfg(feature = "native-tls")]
30    pub mod native_tls {
31        use native_tls_crate::TlsConnector;
32        use tokio_native_tls::TlsConnector as TokioTlsConnector;
33
34        use tokio::io::{AsyncRead, AsyncWrite};
35
36        use tungstenite::{error::TlsError, stream::Mode, Error};
37
38        use crate::stream::MaybeTlsStream;
39
40        pub async fn wrap_stream<S>(
41            socket: S,
42            domain: String,
43            mode: Mode,
44            tls_connector: Option<TlsConnector>,
45        ) -> Result<MaybeTlsStream<S>, Error>
46        where
47            S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
48        {
49            match mode {
50                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
51                Mode::Tls => {
52                    let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok);
53                    let connector = try_connector.map_err(TlsError::Native)?;
54                    let stream = TokioTlsConnector::from(connector);
55                    let connected = stream.connect(&domain, socket).await;
56                    match connected {
57                        Err(e) => Err(Error::Tls(e.into())),
58                        Ok(s) => Ok(MaybeTlsStream::NativeTls(s)),
59                    }
60                }
61            }
62        }
63    }
64
65    #[cfg(feature = "__rustls-tls")]
66    pub mod rustls {
67        pub use rustls::ClientConfig;
68        use rustls::RootCertStore;
69        use rustls_pki_types::ServerName;
70        use tokio_rustls::TlsConnector as TokioTlsConnector;
71
72        use std::{convert::TryFrom, sync::Arc};
73        use tokio::io::{AsyncRead, AsyncWrite};
74
75        use tungstenite::{error::TlsError, stream::Mode, Error};
76
77        use crate::stream::MaybeTlsStream;
78
79        pub async fn wrap_stream<S>(
80            socket: S,
81            domain: String,
82            mode: Mode,
83            tls_connector: Option<Arc<ClientConfig>>,
84        ) -> Result<MaybeTlsStream<S>, Error>
85        where
86            S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
87        {
88            match mode {
89                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
90                Mode::Tls => {
91                    let config = match tls_connector {
92                        Some(config) => config,
93                        None => {
94                            #[allow(unused_mut)]
95                            let mut root_store = RootCertStore::empty();
96                            #[cfg(feature = "rustls-tls-native-roots")]
97                            {
98                                let rustls_native_certs::CertificateResult {
99                                    certs, errors, ..
100                                } = rustls_native_certs::load_native_certs();
101
102                                if !errors.is_empty() {
103                                    log::warn!(
104                                        "native root CA certificate loading errors: {errors:?}"
105                                    );
106                                }
107
108                                // Not finding any native root CA certificates is not fatal if the
109                                // "rustls-tls-webpki-roots" feature is enabled.
110                                #[cfg(not(feature = "rustls-tls-webpki-roots"))]
111                                if certs.is_empty() {
112                                    return Err(std::io::Error::new(std::io::ErrorKind::NotFound, format!("no native root CA certificates found (errors: {errors:?})")).into());
113                                }
114
115                                let total_number = certs.len();
116                                let (number_added, number_ignored) =
117                                    root_store.add_parsable_certificates(certs);
118                                log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
119                            }
120                            #[cfg(feature = "rustls-tls-webpki-roots")]
121                            {
122                                root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
123                            }
124
125                            Arc::new(
126                                ClientConfig::builder()
127                                    .with_root_certificates(root_store)
128                                    .with_no_client_auth(),
129                            )
130                        }
131                    };
132                    let domain = ServerName::try_from(domain.as_str())
133                        .map_err(|_| TlsError::InvalidDnsName)?
134                        .to_owned();
135                    let stream = TokioTlsConnector::from(config);
136                    let connected = stream.connect(domain, socket).await;
137
138                    match connected {
139                        Err(e) => Err(Error::Io(e)),
140                        Ok(s) => Ok(MaybeTlsStream::Rustls(s)),
141                    }
142                }
143            }
144        }
145    }
146
147    pub mod plain {
148        use tokio::io::{AsyncRead, AsyncWrite};
149
150        use tungstenite::{
151            error::{Error, UrlError},
152            stream::Mode,
153        };
154
155        use crate::stream::MaybeTlsStream;
156
157        pub async fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>, Error>
158        where
159            S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
160        {
161            match mode {
162                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
163                Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
164            }
165        }
166    }
167}
168
169/// Creates a WebSocket handshake from a request and a stream,
170/// upgrading the stream to TLS if required.
171#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
172pub async fn client_async_tls<R, S>(
173    request: R,
174    stream: S,
175) -> Result<(WebSocketStream<MaybeTlsStream<S>>, Response), Error>
176where
177    R: IntoClientRequest + Unpin,
178    S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
179    MaybeTlsStream<S>: Unpin,
180{
181    client_async_tls_with_config(request, stream, None, None).await
182}
183
184/// The same as `client_async_tls()` but the one can specify a websocket configuration,
185/// and an optional connector. If no connector is specified, a default one will
186/// be created.
187///
188/// Please refer to `client_async_tls()` for more details.
189pub async fn client_async_tls_with_config<R, S>(
190    request: R,
191    stream: S,
192    config: Option<WebSocketConfig>,
193    connector: Option<Connector>,
194) -> Result<(WebSocketStream<MaybeTlsStream<S>>, Response), Error>
195where
196    R: IntoClientRequest + Unpin,
197    S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
198    MaybeTlsStream<S>: Unpin,
199{
200    let request = request.into_client_request()?;
201
202    #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
203    let domain = crate::domain(&request)?;
204
205    // Make sure we check domain and mode first. URL must be valid.
206    let mode = uri_mode(request.uri())?;
207
208    let stream = match connector {
209        Some(conn) => match conn {
210            #[cfg(feature = "native-tls")]
211            Connector::NativeTls(conn) => {
212                self::encryption::native_tls::wrap_stream(stream, domain, mode, Some(conn)).await
213            }
214            #[cfg(feature = "__rustls-tls")]
215            Connector::Rustls(conn) => {
216                self::encryption::rustls::wrap_stream(stream, domain, mode, Some(conn)).await
217            }
218            Connector::Plain => self::encryption::plain::wrap_stream(stream, mode).await,
219        },
220        None => {
221            #[cfg(feature = "native-tls")]
222            {
223                self::encryption::native_tls::wrap_stream(stream, domain, mode, None).await
224            }
225            #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
226            {
227                self::encryption::rustls::wrap_stream(stream, domain, mode, None).await
228            }
229            #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
230            {
231                self::encryption::plain::wrap_stream(stream, mode).await
232            }
233        }
234    }?;
235
236    client_async_with_config(request, stream, config).await
237}