1use tokio::io::{AsyncRead, AsyncWrite};
3
4use tungstenite::{
5 client::uri_mode, error::Error, handshake::client::Response, protocol::WebSocketConfig,
6};
7
8use crate::{client_async_with_config, IntoClientRequest, WebSocketStream};
9
10pub use crate::stream::MaybeTlsStream;
11
12#[non_exhaustive]
16#[derive(Clone)]
17pub enum Connector {
18 Plain,
20 #[cfg(feature = "native-tls")]
22 NativeTls(native_tls_crate::TlsConnector),
23 #[cfg(feature = "__rustls-tls")]
25 Rustls(std::sync::Arc<rustls::ClientConfig>),
26}
27
28mod encryption {
29 #[cfg(feature = "native-tls")]
30 pub mod native_tls {
31 use native_tls_crate::TlsConnector;
32 use tokio_native_tls::TlsConnector as TokioTlsConnector;
33
34 use tokio::io::{AsyncRead, AsyncWrite};
35
36 use tungstenite::{error::TlsError, stream::Mode, Error};
37
38 use crate::stream::MaybeTlsStream;
39
40 pub async fn wrap_stream<S>(
41 socket: S,
42 domain: String,
43 mode: Mode,
44 tls_connector: Option<TlsConnector>,
45 ) -> Result<MaybeTlsStream<S>, Error>
46 where
47 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
48 {
49 match mode {
50 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
51 Mode::Tls => {
52 let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok);
53 let connector = try_connector.map_err(TlsError::Native)?;
54 let stream = TokioTlsConnector::from(connector);
55 let connected = stream.connect(&domain, socket).await;
56 match connected {
57 Err(e) => Err(Error::Tls(e.into())),
58 Ok(s) => Ok(MaybeTlsStream::NativeTls(s)),
59 }
60 }
61 }
62 }
63 }
64
65 #[cfg(feature = "__rustls-tls")]
66 pub mod rustls {
67 pub use rustls::ClientConfig;
68 use rustls::RootCertStore;
69 use rustls_pki_types::ServerName;
70 use tokio_rustls::TlsConnector as TokioTlsConnector;
71
72 use std::{convert::TryFrom, sync::Arc};
73 use tokio::io::{AsyncRead, AsyncWrite};
74
75 use tungstenite::{error::TlsError, stream::Mode, Error};
76
77 use crate::stream::MaybeTlsStream;
78
79 pub async fn wrap_stream<S>(
80 socket: S,
81 domain: String,
82 mode: Mode,
83 tls_connector: Option<Arc<ClientConfig>>,
84 ) -> Result<MaybeTlsStream<S>, Error>
85 where
86 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
87 {
88 match mode {
89 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
90 Mode::Tls => {
91 let config = match tls_connector {
92 Some(config) => config,
93 None => {
94 #[allow(unused_mut)]
95 let mut root_store = RootCertStore::empty();
96 #[cfg(feature = "rustls-tls-native-roots")]
97 {
98 let rustls_native_certs::CertificateResult {
99 certs, errors, ..
100 } = rustls_native_certs::load_native_certs();
101
102 if !errors.is_empty() {
103 log::warn!(
104 "native root CA certificate loading errors: {errors:?}"
105 );
106 }
107
108 #[cfg(not(feature = "rustls-tls-webpki-roots"))]
111 if certs.is_empty() {
112 return Err(std::io::Error::new(std::io::ErrorKind::NotFound, format!("no native root CA certificates found (errors: {errors:?})")).into());
113 }
114
115 let total_number = certs.len();
116 let (number_added, number_ignored) =
117 root_store.add_parsable_certificates(certs);
118 log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
119 }
120 #[cfg(feature = "rustls-tls-webpki-roots")]
121 {
122 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
123 }
124
125 Arc::new(
126 ClientConfig::builder()
127 .with_root_certificates(root_store)
128 .with_no_client_auth(),
129 )
130 }
131 };
132 let domain = ServerName::try_from(domain.as_str())
133 .map_err(|_| TlsError::InvalidDnsName)?
134 .to_owned();
135 let stream = TokioTlsConnector::from(config);
136 let connected = stream.connect(domain, socket).await;
137
138 match connected {
139 Err(e) => Err(Error::Io(e)),
140 Ok(s) => Ok(MaybeTlsStream::Rustls(s)),
141 }
142 }
143 }
144 }
145 }
146
147 pub mod plain {
148 use tokio::io::{AsyncRead, AsyncWrite};
149
150 use tungstenite::{
151 error::{Error, UrlError},
152 stream::Mode,
153 };
154
155 use crate::stream::MaybeTlsStream;
156
157 pub async fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>, Error>
158 where
159 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
160 {
161 match mode {
162 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
163 Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
164 }
165 }
166 }
167}
168
169#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
172pub async fn client_async_tls<R, S>(
173 request: R,
174 stream: S,
175) -> Result<(WebSocketStream<MaybeTlsStream<S>>, Response), Error>
176where
177 R: IntoClientRequest + Unpin,
178 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
179 MaybeTlsStream<S>: Unpin,
180{
181 client_async_tls_with_config(request, stream, None, None).await
182}
183
184pub async fn client_async_tls_with_config<R, S>(
190 request: R,
191 stream: S,
192 config: Option<WebSocketConfig>,
193 connector: Option<Connector>,
194) -> Result<(WebSocketStream<MaybeTlsStream<S>>, Response), Error>
195where
196 R: IntoClientRequest + Unpin,
197 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
198 MaybeTlsStream<S>: Unpin,
199{
200 let request = request.into_client_request()?;
201
202 #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
203 let domain = crate::domain(&request)?;
204
205 let mode = uri_mode(request.uri())?;
207
208 let stream = match connector {
209 Some(conn) => match conn {
210 #[cfg(feature = "native-tls")]
211 Connector::NativeTls(conn) => {
212 self::encryption::native_tls::wrap_stream(stream, domain, mode, Some(conn)).await
213 }
214 #[cfg(feature = "__rustls-tls")]
215 Connector::Rustls(conn) => {
216 self::encryption::rustls::wrap_stream(stream, domain, mode, Some(conn)).await
217 }
218 Connector::Plain => self::encryption::plain::wrap_stream(stream, mode).await,
219 },
220 None => {
221 #[cfg(feature = "native-tls")]
222 {
223 self::encryption::native_tls::wrap_stream(stream, domain, mode, None).await
224 }
225 #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
226 {
227 self::encryption::rustls::wrap_stream(stream, domain, mode, None).await
228 }
229 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
230 {
231 self::encryption::plain::wrap_stream(stream, mode).await
232 }
233 }
234 }?;
235
236 client_async_with_config(request, stream, config).await
237}