1use tungstenite::client::IntoClientRequest;
3use tungstenite::handshake::client::{Request, Response};
4use tungstenite::protocol::WebSocketConfig;
5use tungstenite::Error;
6
7use async_std::net::TcpStream;
8
9use super::{domain, port, WebSocketStream};
10
11#[cfg(feature = "async-native-tls")]
12use futures_io::{AsyncRead, AsyncWrite};
13
14#[cfg(feature = "async-native-tls")]
15pub(crate) mod async_native_tls {
16 use async_native_tls::TlsConnector as AsyncTlsConnector;
17 use async_native_tls::TlsStream;
18 use real_async_native_tls as async_native_tls;
19
20 use tungstenite::client::uri_mode;
21 use tungstenite::handshake::client::Request;
22 use tungstenite::stream::Mode;
23 use tungstenite::Error;
24
25 use futures_io::{AsyncRead, AsyncWrite};
26
27 use crate::stream::Stream as StreamSwitcher;
28 use crate::{
29 client_async_with_config, domain, IntoClientRequest, Response, WebSocketConfig,
30 WebSocketStream,
31 };
32
33 pub type MaybeTlsStream<S> = StreamSwitcher<S, TlsStream<S>>;
35
36 pub type AutoStream<S> = MaybeTlsStream<S>;
37
38 pub type Connector = AsyncTlsConnector;
39
40 async fn wrap_stream<S>(
41 socket: S,
42 domain: String,
43 connector: Option<Connector>,
44 mode: Mode,
45 ) -> Result<AutoStream<S>, Error>
46 where
47 S: 'static + AsyncRead + AsyncWrite + Unpin,
48 {
49 match mode {
50 Mode::Plain => Ok(StreamSwitcher::Plain(socket)),
51 Mode::Tls => {
52 let stream = {
53 let connector = if let Some(connector) = connector {
54 connector
55 } else {
56 AsyncTlsConnector::new()
57 };
58 connector
59 .connect(&domain, socket)
60 .await
61 .map_err(|err| Error::Tls(err.into()))?
62 };
63 Ok(StreamSwitcher::Tls(stream))
64 }
65 }
66 }
67
68 pub async fn client_async_tls_with_connector_and_config<R, S>(
72 request: R,
73 stream: S,
74 connector: Option<AsyncTlsConnector>,
75 config: Option<WebSocketConfig>,
76 ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
77 where
78 R: IntoClientRequest + Unpin,
79 S: 'static + AsyncRead + AsyncWrite + Unpin,
80 AutoStream<S>: Unpin,
81 {
82 let request: Request = request.into_client_request()?;
83
84 let domain = domain(&request)?;
85
86 let mode = uri_mode(request.uri())?;
88
89 let stream = wrap_stream(stream, domain, connector, mode).await?;
90 client_async_with_config(request, stream, config).await
91 }
92}
93
94#[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))]
95pub(crate) mod dummy_tls {
96 use futures_io::{AsyncRead, AsyncWrite};
97
98 use tungstenite::client::{uri_mode, IntoClientRequest};
99 use tungstenite::handshake::client::Request;
100 use tungstenite::stream::Mode;
101 use tungstenite::Error;
102
103 use crate::{client_async_with_config, domain, Response, WebSocketConfig, WebSocketStream};
104
105 pub type AutoStream<S> = S;
106 type Connector = ();
107
108 async fn wrap_stream<S>(
109 socket: S,
110 _domain: String,
111 _connector: Option<()>,
112 mode: Mode,
113 ) -> Result<AutoStream<S>, Error>
114 where
115 S: 'static + AsyncRead + AsyncWrite + Unpin,
116 {
117 match mode {
118 Mode::Plain => Ok(socket),
119 Mode::Tls => Err(Error::Url(
120 tungstenite::error::UrlError::TlsFeatureNotEnabled,
121 )),
122 }
123 }
124
125 pub async fn client_async_tls_with_connector_and_config<R, S>(
129 request: R,
130 stream: S,
131 connector: Option<Connector>,
132 config: Option<WebSocketConfig>,
133 ) -> Result<(WebSocketStream<AutoStream<S>>, Response), Error>
134 where
135 R: IntoClientRequest + Unpin,
136 S: 'static + AsyncRead + AsyncWrite + Unpin,
137 AutoStream<S>: Unpin,
138 {
139 let request: Request = request.into_client_request()?;
140
141 let domain = domain(&request)?;
142
143 let mode = uri_mode(request.uri())?;
145
146 let stream = wrap_stream(stream, domain, connector, mode).await?;
147 client_async_with_config(request, stream, config).await
148 }
149}
150
151#[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))]
152pub use self::dummy_tls::client_async_tls_with_connector_and_config;
153#[cfg(not(any(feature = "async-tls", feature = "async-native-tls")))]
154use self::dummy_tls::AutoStream;
155
156#[cfg(all(feature = "async-tls", not(feature = "async-native-tls")))]
157pub use crate::async_tls::client_async_tls_with_connector_and_config;
158#[cfg(all(feature = "async-tls", not(feature = "async-native-tls")))]
159use crate::async_tls::AutoStream;
160#[cfg(all(feature = "async-tls", not(feature = "async-native-tls")))]
161type Connector = real_async_tls::TlsConnector;
162
163#[cfg(feature = "async-native-tls")]
164pub use self::async_native_tls::client_async_tls_with_connector_and_config;
165#[cfg(feature = "async-native-tls")]
166use self::async_native_tls::{AutoStream, Connector};
167
168pub type ClientStream<S> = AutoStream<S>;
170
171#[cfg(feature = "async-native-tls")]
172pub async fn client_async_tls<R, S>(
175 request: R,
176 stream: S,
177) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
178where
179 R: IntoClientRequest + Unpin,
180 S: 'static + AsyncRead + AsyncWrite + Unpin,
181 AutoStream<S>: Unpin,
182{
183 client_async_tls_with_connector_and_config(request, stream, None, None).await
184}
185
186#[cfg(feature = "async-native-tls")]
187pub async fn client_async_tls_with_config<R, S>(
191 request: R,
192 stream: S,
193 config: Option<WebSocketConfig>,
194) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
195where
196 R: IntoClientRequest + Unpin,
197 S: 'static + AsyncRead + AsyncWrite + Unpin,
198 AutoStream<S>: Unpin,
199{
200 client_async_tls_with_connector_and_config(request, stream, None, config).await
201}
202
203#[cfg(feature = "async-native-tls")]
204pub async fn client_async_tls_with_connector<R, S>(
208 request: R,
209 stream: S,
210 connector: Option<Connector>,
211) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
212where
213 R: IntoClientRequest + Unpin,
214 S: 'static + AsyncRead + AsyncWrite + Unpin,
215 AutoStream<S>: Unpin,
216{
217 client_async_tls_with_connector_and_config(request, stream, connector, None).await
218}
219
220pub type ConnectStream = ClientStream<TcpStream>;
222
223pub async fn connect_async<R>(
243 request: R,
244) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
245where
246 R: IntoClientRequest + Unpin,
247{
248 connect_async_with_config(request, None).await
249}
250
251pub async fn connect_async_with_config<R>(
253 request: R,
254 config: Option<WebSocketConfig>,
255) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
256where
257 R: IntoClientRequest + Unpin,
258{
259 let request: Request = request.into_client_request()?;
260
261 let domain = domain(&request)?;
262 let port = port(&request)?;
263
264 let try_socket = TcpStream::connect((domain.as_str(), port)).await;
265 let socket = try_socket.map_err(Error::Io)?;
266 client_async_tls_with_connector_and_config(request, socket, None, config).await
267}
268
269#[cfg(any(feature = "async-tls", feature = "async-native-tls"))]
270pub async fn connect_async_with_tls_connector<R>(
272 request: R,
273 connector: Option<Connector>,
274) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
275where
276 R: IntoClientRequest + Unpin,
277{
278 connect_async_with_tls_connector_and_config(request, connector, None).await
279}
280
281#[cfg(any(feature = "async-tls", feature = "async-native-tls"))]
282pub async fn connect_async_with_tls_connector_and_config<R>(
284 request: R,
285 connector: Option<Connector>,
286 config: Option<WebSocketConfig>,
287) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
288where
289 R: IntoClientRequest + Unpin,
290{
291 let request: Request = request.into_client_request()?;
292
293 let domain = domain(&request)?;
294 let port = port(&request)?;
295
296 let try_socket = TcpStream::connect((domain.as_str(), port)).await;
297 let socket = try_socket.map_err(Error::Io)?;
298 client_async_tls_with_connector_and_config(request, socket, connector, config).await
299}