1use tungstenite::client::IntoClientRequest;
3use tungstenite::handshake::client::{Request, Response};
4use tungstenite::handshake::server::{Callback, NoCallback};
5use tungstenite::protocol::WebSocketConfig;
6use tungstenite::Error;
7
8use tokio::net::TcpStream;
9
10use super::{domain, port, WebSocketStream};
11
12use futures_io::{AsyncRead, AsyncWrite};
13
14#[cfg(feature = "tokio-native-tls")]
15#[path = "tokio/native_tls.rs"]
16mod tls;
17
18#[cfg(all(
19 any(
20 feature = "tokio-rustls-manual-roots",
21 feature = "tokio-rustls-native-certs",
22 feature = "tokio-rustls-webpki-roots"
23 ),
24 not(feature = "tokio-native-tls")
25))]
26#[path = "tokio/rustls.rs"]
27mod tls;
28
29#[cfg(all(
30 feature = "tokio-openssl",
31 not(any(
32 feature = "tokio-native-tls",
33 feature = "tokio-rustls-manual-roots",
34 feature = "tokio-rustls-native-certs",
35 feature = "tokio-rustls-webpki-roots"
36 ))
37))]
38#[path = "tokio/openssl.rs"]
39mod tls;
40
41#[cfg(all(
42 feature = "async-tls",
43 not(any(
44 feature = "tokio-native-tls",
45 feature = "tokio-rustls-manual-roots",
46 feature = "tokio-rustls-native-certs",
47 feature = "tokio-rustls-webpki-roots",
48 feature = "tokio-openssl"
49 ))
50))]
51#[path = "tokio/async_tls.rs"]
52mod tls;
53
54#[cfg(not(any(
55 feature = "tokio-native-tls",
56 feature = "tokio-rustls-manual-roots",
57 feature = "tokio-rustls-native-certs",
58 feature = "tokio-rustls-webpki-roots",
59 feature = "tokio-openssl",
60 feature = "async-tls"
61)))]
62#[path = "tokio/dummy_tls.rs"]
63mod tls;
64
65#[cfg(any(
66 feature = "tokio-native-tls",
67 feature = "tokio-rustls-manual-roots",
68 feature = "tokio-rustls-native-certs",
69 feature = "tokio-rustls-webpki-roots",
70 feature = "tokio-openssl",
71 feature = "async-tls",
72))]
73pub use self::tls::client_async_tls_with_connector_and_config;
74#[cfg(any(
75 feature = "tokio-native-tls",
76 feature = "tokio-rustls-manual-roots",
77 feature = "tokio-rustls-native-certs",
78 feature = "tokio-rustls-webpki-roots",
79 feature = "tokio-openssl",
80 feature = "async-tls"
81))]
82use self::tls::{AutoStream, Connector};
83
84#[cfg(not(any(
85 feature = "tokio-native-tls",
86 feature = "tokio-rustls-manual-roots",
87 feature = "tokio-rustls-native-certs",
88 feature = "tokio-rustls-webpki-roots",
89 feature = "tokio-openssl",
90 feature = "async-tls"
91)))]
92pub use self::tls::client_async_tls_with_connector_and_config;
93#[cfg(not(any(
94 feature = "tokio-native-tls",
95 feature = "tokio-rustls-manual-roots",
96 feature = "tokio-rustls-native-certs",
97 feature = "tokio-rustls-webpki-roots",
98 feature = "tokio-openssl",
99 feature = "async-tls"
100)))]
101use self::tls::AutoStream;
102
103pub async fn client_async<'a, R, S>(
116 request: R,
117 stream: S,
118) -> Result<(WebSocketStream<TokioAdapter<S>>, Response), Error>
119where
120 R: IntoClientRequest + Unpin,
121 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
122{
123 client_async_with_config(request, stream, None).await
124}
125
126pub async fn client_async_with_config<'a, R, S>(
129 request: R,
130 stream: S,
131 config: Option<WebSocketConfig>,
132) -> Result<(WebSocketStream<TokioAdapter<S>>, Response), Error>
133where
134 R: IntoClientRequest + Unpin,
135 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
136{
137 crate::client_async_with_config(request, TokioAdapter::new(stream), config).await
138}
139
140pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
152where
153 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
154{
155 accept_hdr_async(stream, NoCallback).await
156}
157
158pub async fn accept_async_with_config<S>(
161 stream: S,
162 config: Option<WebSocketConfig>,
163) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
164where
165 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
166{
167 accept_hdr_async_with_config(stream, NoCallback, config).await
168}
169
170pub async fn accept_hdr_async<S, C>(
176 stream: S,
177 callback: C,
178) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
179where
180 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
181 C: Callback + Unpin,
182{
183 accept_hdr_async_with_config(stream, callback, None).await
184}
185
186pub async fn accept_hdr_async_with_config<S, C>(
189 stream: S,
190 callback: C,
191 config: Option<WebSocketConfig>,
192) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
193where
194 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
195 C: Callback + Unpin,
196{
197 crate::accept_hdr_async_with_config(TokioAdapter::new(stream), callback, config).await
198}
199
200pub type ClientStream<S> = AutoStream<S>;
202
203#[cfg(any(
204 feature = "tokio-native-tls",
205 feature = "tokio-rustls-native-certs",
206 feature = "tokio-rustls-webpki-roots",
207 all(feature = "__rustls-tls", not(feature = "tokio-rustls-manual-roots")), all(feature = "async-tls", not(feature = "tokio-openssl"))
209))]
210pub async fn client_async_tls<R, S>(
213 request: R,
214 stream: S,
215) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
216where
217 R: IntoClientRequest + Unpin,
218 S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
219 AutoStream<S>: Unpin,
220{
221 client_async_tls_with_connector_and_config(request, stream, None, None).await
222}
223
224#[cfg(any(
225 feature = "tokio-native-tls",
226 feature = "tokio-rustls-native-certs",
227 feature = "tokio-rustls-webpki-roots",
228 all(feature = "__rustls-tls", not(feature = "tokio-rustls-manual-roots")), all(feature = "async-tls", not(feature = "tokio-openssl"))
230))]
231pub async fn client_async_tls_with_config<R, S>(
235 request: R,
236 stream: S,
237 config: Option<WebSocketConfig>,
238) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
239where
240 R: IntoClientRequest + Unpin,
241 S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
242 AutoStream<S>: Unpin,
243{
244 client_async_tls_with_connector_and_config(request, stream, None, config).await
245}
246
247#[cfg(any(
248 feature = "tokio-native-tls",
249 feature = "tokio-rustls-manual-roots",
250 feature = "tokio-rustls-native-certs",
251 feature = "tokio-rustls-webpki-roots",
252 all(feature = "async-tls", not(feature = "tokio-openssl"))
253))]
254pub async fn client_async_tls_with_connector<R, S>(
258 request: R,
259 stream: S,
260 connector: Option<Connector>,
261) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
262where
263 R: IntoClientRequest + Unpin,
264 S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
265 AutoStream<S>: Unpin,
266{
267 client_async_tls_with_connector_and_config(request, stream, connector, None).await
268}
269
270#[cfg(all(
271 feature = "tokio-openssl",
272 not(any(
273 feature = "tokio-native-tls",
274 feature = "tokio-rustls-manual-roots",
275 feature = "tokio-rustls-native-certs",
276 feature = "tokio-rustls-webpki-roots"
277 ))
278))]
279pub async fn client_async_tls<R, S>(
282 request: R,
283 stream: S,
284) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
285where
286 R: IntoClientRequest + Unpin,
287 S: 'static
288 + tokio::io::AsyncRead
289 + tokio::io::AsyncWrite
290 + Unpin
291 + std::fmt::Debug
292 + Send
293 + Sync,
294 AutoStream<S>: Unpin,
295{
296 client_async_tls_with_connector_and_config(request, stream, None, None).await
297}
298
299#[cfg(all(
300 feature = "tokio-openssl",
301 not(any(
302 feature = "tokio-native-tls",
303 feature = "tokio-rustls-manual-roots",
304 feature = "tokio-rustls-native-certs",
305 feature = "tokio-rustls-webpki-roots"
306 ))
307))]
308pub async fn client_async_tls_with_config<R, S>(
312 request: R,
313 stream: S,
314 config: Option<WebSocketConfig>,
315) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
316where
317 R: IntoClientRequest + Unpin,
318 S: 'static
319 + tokio::io::AsyncRead
320 + tokio::io::AsyncWrite
321 + Unpin
322 + std::fmt::Debug
323 + Send
324 + Sync,
325 AutoStream<S>: Unpin,
326{
327 client_async_tls_with_connector_and_config(request, stream, None, config).await
328}
329
330#[cfg(all(
331 feature = "tokio-openssl",
332 not(any(
333 feature = "tokio-native-tls",
334 feature = "tokio-rustls-manual-roots",
335 feature = "tokio-rustls-native-certs",
336 feature = "tokio-rustls-webpki-roots"
337 ))
338))]
339pub async fn client_async_tls_with_connector<R, S>(
343 request: R,
344 stream: S,
345 connector: Option<Connector>,
346) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
347where
348 R: IntoClientRequest + Unpin,
349 S: 'static
350 + tokio::io::AsyncRead
351 + tokio::io::AsyncWrite
352 + Unpin
353 + std::fmt::Debug
354 + Send
355 + Sync,
356 AutoStream<S>: Unpin,
357{
358 client_async_tls_with_connector_and_config(request, stream, connector, None).await
359}
360
361pub type ConnectStream = ClientStream<TcpStream>;
363
364pub async fn connect_async<R>(
384 request: R,
385) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
386where
387 R: IntoClientRequest + Unpin,
388{
389 connect_async_with_config(request, None).await
390}
391
392pub async fn connect_async_with_config<R>(
394 request: R,
395 config: Option<WebSocketConfig>,
396) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
397where
398 R: IntoClientRequest + Unpin,
399{
400 let request: Request = request.into_client_request()?;
401
402 let domain = domain(&request)?;
403 let port = port(&request)?;
404
405 let try_socket = TcpStream::connect((domain.as_str(), port)).await;
406 let socket = try_socket.map_err(Error::Io)?;
407 client_async_tls_with_connector_and_config(request, socket, None, config).await
408}
409
410#[cfg(any(
411 feature = "async-tls",
412 feature = "tokio-native-tls",
413 feature = "tokio-rustls-manual-roots",
414 feature = "tokio-rustls-native-certs",
415 feature = "tokio-rustls-webpki-roots",
416 feature = "tokio-openssl"
417))]
418pub async fn connect_async_with_tls_connector<R>(
420 request: R,
421 connector: Option<Connector>,
422) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
423where
424 R: IntoClientRequest + Unpin,
425{
426 connect_async_with_tls_connector_and_config(request, connector, None).await
427}
428
429#[cfg(any(
430 feature = "async-tls",
431 feature = "tokio-native-tls",
432 feature = "tokio-rustls-manual-roots",
433 feature = "tokio-rustls-native-certs",
434 feature = "tokio-rustls-webpki-roots",
435 feature = "tokio-openssl"
436))]
437pub async fn connect_async_with_tls_connector_and_config<R>(
439 request: R,
440 connector: Option<Connector>,
441 config: Option<WebSocketConfig>,
442) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
443where
444 R: IntoClientRequest + Unpin,
445{
446 let request: Request = request.into_client_request()?;
447
448 let domain = domain(&request)?;
449 let port = port(&request)?;
450
451 let try_socket = TcpStream::connect((domain.as_str(), port)).await;
452 let socket = try_socket.map_err(Error::Io)?;
453 client_async_tls_with_connector_and_config(request, socket, connector, config).await
454}
455
456use std::pin::Pin;
457use std::task::{Context, Poll};
458
459pin_project_lite::pin_project! {
460 #[derive(Debug, Clone)]
463 pub struct TokioAdapter<T> {
464 #[pin]
465 inner: T,
466 }
467}
468
469impl<T> TokioAdapter<T> {
470 pub fn new(inner: T) -> Self {
472 Self { inner }
473 }
474
475 pub fn into_inner(self) -> T {
477 self.inner
478 }
479
480 pub fn get_ref(&self) -> &T {
482 &self.inner
483 }
484
485 pub fn get_mut(&mut self) -> &mut T {
487 &mut self.inner
488 }
489}
490
491impl<T: tokio::io::AsyncRead> AsyncRead for TokioAdapter<T> {
492 fn poll_read(
493 self: Pin<&mut Self>,
494 cx: &mut Context<'_>,
495 buf: &mut [u8],
496 ) -> Poll<std::io::Result<usize>> {
497 let mut buf = tokio::io::ReadBuf::new(buf);
498 match self.project().inner.poll_read(cx, &mut buf)? {
499 Poll::Pending => Poll::Pending,
500 Poll::Ready(_) => Poll::Ready(Ok(buf.filled().len())),
501 }
502 }
503}
504
505impl<T: tokio::io::AsyncWrite> AsyncWrite for TokioAdapter<T> {
506 fn poll_write(
507 self: Pin<&mut Self>,
508 cx: &mut Context<'_>,
509 buf: &[u8],
510 ) -> Poll<Result<usize, std::io::Error>> {
511 self.project().inner.poll_write(cx, buf)
512 }
513
514 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
515 self.project().inner.poll_flush(cx)
516 }
517
518 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
519 self.project().inner.poll_shutdown(cx)
520 }
521}
522
523impl<T: AsyncRead> tokio::io::AsyncRead for TokioAdapter<T> {
524 fn poll_read(
525 self: Pin<&mut Self>,
526 cx: &mut Context<'_>,
527 buf: &mut tokio::io::ReadBuf<'_>,
528 ) -> Poll<std::io::Result<()>> {
529 let slice = buf.initialize_unfilled();
530 let n = match self.project().inner.poll_read(cx, slice)? {
531 Poll::Pending => return Poll::Pending,
532 Poll::Ready(n) => n,
533 };
534 buf.advance(n);
535 Poll::Ready(Ok(()))
536 }
537}
538
539impl<T: AsyncWrite> tokio::io::AsyncWrite for TokioAdapter<T> {
540 fn poll_write(
541 self: Pin<&mut Self>,
542 cx: &mut Context<'_>,
543 buf: &[u8],
544 ) -> Poll<Result<usize, std::io::Error>> {
545 self.project().inner.poll_write(cx, buf)
546 }
547
548 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
549 self.project().inner.poll_flush(cx)
550 }
551
552 fn poll_shutdown(
553 self: Pin<&mut Self>,
554 cx: &mut Context<'_>,
555 ) -> Poll<Result<(), std::io::Error>> {
556 self.project().inner.poll_close(cx)
557 }
558}