async_tungstenite/tokio/
native_tls.rs

1use real_tokio_native_tls::TlsConnector as AsyncTlsConnector;
2use real_tokio_native_tls::TlsStream;
3
4use tungstenite::client::{uri_mode, IntoClientRequest};
5use tungstenite::handshake::client::Request;
6use tungstenite::stream::Mode;
7use tungstenite::Error;
8
9use crate::stream::Stream as StreamSwitcher;
10use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream};
11
12use super::TokioAdapter;
13
14/// A stream that might be protected with TLS.
15pub type MaybeTlsStream<S> = StreamSwitcher<TokioAdapter<S>, TokioAdapter<TlsStream<S>>>;
16
17pub type AutoStream<S> = MaybeTlsStream<S>;
18
19pub type Connector = AsyncTlsConnector;
20
21async fn wrap_stream<S>(
22    socket: S,
23    domain: String,
24    connector: Option<Connector>,
25    mode: Mode,
26) -> Result<AutoStream<S>, Error>
27where
28    S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
29{
30    match mode {
31        Mode::Plain => Ok(StreamSwitcher::Plain(TokioAdapter::new(socket))),
32        Mode::Tls => {
33            let stream = {
34                let connector = if let Some(connector) = connector {
35                    connector
36                } else {
37                    let connector = real_native_tls::TlsConnector::builder()
38                        .build()
39                        .map_err(|err| Error::Tls(err.into()))?;
40                    AsyncTlsConnector::from(connector)
41                };
42                connector
43                    .connect(&domain, socket)
44                    .await
45                    .map_err(|err| Error::Tls(err.into()))?
46            };
47            Ok(StreamSwitcher::Tls(TokioAdapter::new(stream)))
48        }
49    }
50}
51
52/// Creates a WebSocket handshake from a request and a stream,
53/// upgrading the stream to TLS if required and using the given
54/// connector and WebSocket configuration.
55pub async fn client_async_tls_with_connector_and_config<R, S>(
56    request: R,
57    stream: S,
58    connector: Option<AsyncTlsConnector>,
59    config: Option<WebSocketConfig>,
60) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
61where
62    R: IntoClientRequest + Unpin,
63    S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
64    AutoStream<S>: Unpin,
65{
66    let request: Request = request.into_client_request()?;
67
68    let domain = domain(&request)?;
69
70    // Make sure we check domain and mode first. URL must be valid.
71    let mode = uri_mode(request.uri())?;
72
73    let stream = wrap_stream(stream, domain, connector, mode).await?;
74    client_async_with_config(request, stream, config).await
75}