async_tungstenite/
tokio.rs

1//! `tokio` integration.
2use tungstenite::client::IntoClientRequest;
3use tungstenite::handshake::client::{Request, Response};
4use tungstenite::handshake::server::{Callback, NoCallback};
5use tungstenite::protocol::WebSocketConfig;
6use tungstenite::Error;
7
8use tokio::net::TcpStream;
9
10use super::{domain, port, WebSocketStream};
11
12use futures_io::{AsyncRead, AsyncWrite};
13
14#[cfg(feature = "tokio-native-tls")]
15#[path = "tokio/native_tls.rs"]
16mod tls;
17
18#[cfg(all(
19    any(
20        feature = "tokio-rustls-manual-roots",
21        feature = "tokio-rustls-native-certs",
22        feature = "tokio-rustls-webpki-roots"
23    ),
24    not(feature = "tokio-native-tls")
25))]
26#[path = "tokio/rustls.rs"]
27mod tls;
28
29#[cfg(all(
30    feature = "tokio-openssl",
31    not(any(
32        feature = "tokio-native-tls",
33        feature = "tokio-rustls-manual-roots",
34        feature = "tokio-rustls-native-certs",
35        feature = "tokio-rustls-webpki-roots"
36    ))
37))]
38#[path = "tokio/openssl.rs"]
39mod tls;
40
41#[cfg(all(
42    feature = "async-tls",
43    not(any(
44        feature = "tokio-native-tls",
45        feature = "tokio-rustls-manual-roots",
46        feature = "tokio-rustls-native-certs",
47        feature = "tokio-rustls-webpki-roots",
48        feature = "tokio-openssl"
49    ))
50))]
51#[path = "tokio/async_tls.rs"]
52mod tls;
53
54#[cfg(not(any(
55    feature = "tokio-native-tls",
56    feature = "tokio-rustls-manual-roots",
57    feature = "tokio-rustls-native-certs",
58    feature = "tokio-rustls-webpki-roots",
59    feature = "tokio-openssl",
60    feature = "async-tls"
61)))]
62#[path = "tokio/dummy_tls.rs"]
63mod tls;
64
65#[cfg(any(
66    feature = "tokio-native-tls",
67    feature = "tokio-rustls-manual-roots",
68    feature = "tokio-rustls-native-certs",
69    feature = "tokio-rustls-webpki-roots",
70    feature = "tokio-openssl",
71    feature = "async-tls",
72))]
73pub use self::tls::client_async_tls_with_connector_and_config;
74#[cfg(any(
75    feature = "tokio-native-tls",
76    feature = "tokio-rustls-manual-roots",
77    feature = "tokio-rustls-native-certs",
78    feature = "tokio-rustls-webpki-roots",
79    feature = "tokio-openssl",
80    feature = "async-tls"
81))]
82use self::tls::{AutoStream, Connector};
83
84#[cfg(not(any(
85    feature = "tokio-native-tls",
86    feature = "tokio-rustls-manual-roots",
87    feature = "tokio-rustls-native-certs",
88    feature = "tokio-rustls-webpki-roots",
89    feature = "tokio-openssl",
90    feature = "async-tls"
91)))]
92pub use self::tls::client_async_tls_with_connector_and_config;
93#[cfg(not(any(
94    feature = "tokio-native-tls",
95    feature = "tokio-rustls-manual-roots",
96    feature = "tokio-rustls-native-certs",
97    feature = "tokio-rustls-webpki-roots",
98    feature = "tokio-openssl",
99    feature = "async-tls"
100)))]
101use self::tls::AutoStream;
102
103/// Creates a WebSocket handshake from a request and a stream.
104/// For convenience, the user may call this with a url string, a URL,
105/// or a `Request`. Calling with `Request` allows the user to add
106/// a WebSocket protocol or other custom headers.
107///
108/// Internally, this custom creates a handshake representation and returns
109/// a future representing the resolution of the WebSocket handshake. The
110/// returned future will resolve to either `WebSocketStream<S>` or `Error`
111/// depending on whether the handshake is successful.
112///
113/// This is typically used for clients who have already established, for
114/// example, a TCP connection to the remote server.
115pub async fn client_async<'a, R, S>(
116    request: R,
117    stream: S,
118) -> Result<(WebSocketStream<TokioAdapter<S>>, Response), Error>
119where
120    R: IntoClientRequest + Unpin,
121    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
122{
123    client_async_with_config(request, stream, None).await
124}
125
126/// The same as `client_async()` but the one can specify a websocket configuration.
127/// Please refer to `client_async()` for more details.
128pub async fn client_async_with_config<'a, R, S>(
129    request: R,
130    stream: S,
131    config: Option<WebSocketConfig>,
132) -> Result<(WebSocketStream<TokioAdapter<S>>, Response), Error>
133where
134    R: IntoClientRequest + Unpin,
135    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
136{
137    crate::client_async_with_config(request, TokioAdapter::new(stream), config).await
138}
139
140/// Accepts a new WebSocket connection with the provided stream.
141///
142/// This function will internally call `server::accept` to create a
143/// handshake representation and returns a future representing the
144/// resolution of the WebSocket handshake. The returned future will resolve
145/// to either `WebSocketStream<S>` or `Error` depending if it's successful
146/// or not.
147///
148/// This is typically used after a socket has been accepted from a
149/// `TcpListener`. That socket is then passed to this function to perform
150/// the server half of the accepting a client's websocket connection.
151pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
152where
153    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
154{
155    accept_hdr_async(stream, NoCallback).await
156}
157
158/// The same as `accept_async()` but the one can specify a websocket configuration.
159/// Please refer to `accept_async()` for more details.
160pub async fn accept_async_with_config<S>(
161    stream: S,
162    config: Option<WebSocketConfig>,
163) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
164where
165    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
166{
167    accept_hdr_async_with_config(stream, NoCallback, config).await
168}
169
170/// Accepts a new WebSocket connection with the provided stream.
171///
172/// This function does the same as `accept_async()` but accepts an extra callback
173/// for header processing. The callback receives headers of the incoming
174/// requests and is able to add extra headers to the reply.
175pub async fn accept_hdr_async<S, C>(
176    stream: S,
177    callback: C,
178) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
179where
180    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
181    C: Callback + Unpin,
182{
183    accept_hdr_async_with_config(stream, callback, None).await
184}
185
186/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
187/// Please refer to `accept_hdr_async()` for more details.
188pub async fn accept_hdr_async_with_config<S, C>(
189    stream: S,
190    callback: C,
191    config: Option<WebSocketConfig>,
192) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
193where
194    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
195    C: Callback + Unpin,
196{
197    crate::accept_hdr_async_with_config(TokioAdapter::new(stream), callback, config).await
198}
199
200/// Type alias for the stream type of the `client_async()` functions.
201pub type ClientStream<S> = AutoStream<S>;
202
203#[cfg(any(
204    feature = "tokio-native-tls",
205    feature = "tokio-rustls-native-certs",
206    feature = "tokio-rustls-webpki-roots",
207    all(feature = "__rustls-tls", not(feature = "tokio-rustls-manual-roots")), // No roots will be available
208    all(feature = "async-tls", not(feature = "tokio-openssl"))
209))]
210/// Creates a WebSocket handshake from a request and a stream,
211/// upgrading the stream to TLS if required.
212pub async fn client_async_tls<R, S>(
213    request: R,
214    stream: S,
215) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
216where
217    R: IntoClientRequest + Unpin,
218    S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
219    AutoStream<S>: Unpin,
220{
221    client_async_tls_with_connector_and_config(request, stream, None, None).await
222}
223
224#[cfg(any(
225    feature = "tokio-native-tls",
226    feature = "tokio-rustls-native-certs",
227    feature = "tokio-rustls-webpki-roots",
228    all(feature = "__rustls-tls", not(feature = "tokio-rustls-manual-roots")), // No roots will be available
229    all(feature = "async-tls", not(feature = "tokio-openssl"))
230))]
231/// Creates a WebSocket handshake from a request and a stream,
232/// upgrading the stream to TLS if required and using the given
233/// WebSocket configuration.
234pub async fn client_async_tls_with_config<R, S>(
235    request: R,
236    stream: S,
237    config: Option<WebSocketConfig>,
238) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
239where
240    R: IntoClientRequest + Unpin,
241    S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
242    AutoStream<S>: Unpin,
243{
244    client_async_tls_with_connector_and_config(request, stream, None, config).await
245}
246
247#[cfg(any(
248    feature = "tokio-native-tls",
249    feature = "tokio-rustls-manual-roots",
250    feature = "tokio-rustls-native-certs",
251    feature = "tokio-rustls-webpki-roots",
252    all(feature = "async-tls", not(feature = "tokio-openssl"))
253))]
254/// Creates a WebSocket handshake from a request and a stream,
255/// upgrading the stream to TLS if required and using the given
256/// connector.
257pub async fn client_async_tls_with_connector<R, S>(
258    request: R,
259    stream: S,
260    connector: Option<Connector>,
261) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
262where
263    R: IntoClientRequest + Unpin,
264    S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
265    AutoStream<S>: Unpin,
266{
267    client_async_tls_with_connector_and_config(request, stream, connector, None).await
268}
269
270#[cfg(all(
271    feature = "tokio-openssl",
272    not(any(
273        feature = "tokio-native-tls",
274        feature = "tokio-rustls-manual-roots",
275        feature = "tokio-rustls-native-certs",
276        feature = "tokio-rustls-webpki-roots"
277    ))
278))]
279/// Creates a WebSocket handshake from a request and a stream,
280/// upgrading the stream to TLS if required.
281pub async fn client_async_tls<R, S>(
282    request: R,
283    stream: S,
284) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
285where
286    R: IntoClientRequest + Unpin,
287    S: 'static
288        + tokio::io::AsyncRead
289        + tokio::io::AsyncWrite
290        + Unpin
291        + std::fmt::Debug
292        + Send
293        + Sync,
294    AutoStream<S>: Unpin,
295{
296    client_async_tls_with_connector_and_config(request, stream, None, None).await
297}
298
299#[cfg(all(
300    feature = "tokio-openssl",
301    not(any(
302        feature = "tokio-native-tls",
303        feature = "tokio-rustls-manual-roots",
304        feature = "tokio-rustls-native-certs",
305        feature = "tokio-rustls-webpki-roots"
306    ))
307))]
308/// Creates a WebSocket handshake from a request and a stream,
309/// upgrading the stream to TLS if required and using the given
310/// WebSocket configuration.
311pub async fn client_async_tls_with_config<R, S>(
312    request: R,
313    stream: S,
314    config: Option<WebSocketConfig>,
315) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
316where
317    R: IntoClientRequest + Unpin,
318    S: 'static
319        + tokio::io::AsyncRead
320        + tokio::io::AsyncWrite
321        + Unpin
322        + std::fmt::Debug
323        + Send
324        + Sync,
325    AutoStream<S>: Unpin,
326{
327    client_async_tls_with_connector_and_config(request, stream, None, config).await
328}
329
330#[cfg(all(
331    feature = "tokio-openssl",
332    not(any(
333        feature = "tokio-native-tls",
334        feature = "tokio-rustls-manual-roots",
335        feature = "tokio-rustls-native-certs",
336        feature = "tokio-rustls-webpki-roots"
337    ))
338))]
339/// Creates a WebSocket handshake from a request and a stream,
340/// upgrading the stream to TLS if required and using the given
341/// connector.
342pub async fn client_async_tls_with_connector<R, S>(
343    request: R,
344    stream: S,
345    connector: Option<Connector>,
346) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
347where
348    R: IntoClientRequest + Unpin,
349    S: 'static
350        + tokio::io::AsyncRead
351        + tokio::io::AsyncWrite
352        + Unpin
353        + std::fmt::Debug
354        + Send
355        + Sync,
356    AutoStream<S>: Unpin,
357{
358    client_async_tls_with_connector_and_config(request, stream, connector, None).await
359}
360
361/// Type alias for the stream type of the `connect_async()` functions.
362pub type ConnectStream = ClientStream<TcpStream>;
363
364/// Connect to a given URL.
365///
366/// Accepts any request that implements [`IntoClientRequest`], which is often just `&str`, but can
367/// be a variety of types such as `httparse::Request` or [`tungstenite::http::Request`] for more
368/// complex uses.
369///
370/// ```no_run
371/// # use tungstenite::client::IntoClientRequest;
372///
373/// # async fn test() {
374/// use tungstenite::http::{Method, Request};
375/// use async_tungstenite::tokio::connect_async;
376///
377/// let mut request = "wss://api.example.com".into_client_request().unwrap();
378/// request.headers_mut().insert("api-key", "42".parse().unwrap());
379///
380/// let (stream, response) = connect_async(request).await.unwrap();
381/// # }
382/// ```
383pub async fn connect_async<R>(
384    request: R,
385) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
386where
387    R: IntoClientRequest + Unpin,
388{
389    connect_async_with_config(request, None).await
390}
391
392/// Connect to a given URL with a given WebSocket configuration.
393pub async fn connect_async_with_config<R>(
394    request: R,
395    config: Option<WebSocketConfig>,
396) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
397where
398    R: IntoClientRequest + Unpin,
399{
400    let request: Request = request.into_client_request()?;
401
402    let domain = domain(&request)?;
403    let port = port(&request)?;
404
405    let try_socket = TcpStream::connect((domain.as_str(), port)).await;
406    let socket = try_socket.map_err(Error::Io)?;
407    client_async_tls_with_connector_and_config(request, socket, None, config).await
408}
409
410#[cfg(any(
411    feature = "async-tls",
412    feature = "tokio-native-tls",
413    feature = "tokio-rustls-manual-roots",
414    feature = "tokio-rustls-native-certs",
415    feature = "tokio-rustls-webpki-roots",
416    feature = "tokio-openssl"
417))]
418/// Connect to a given URL using the provided TLS connector.
419pub async fn connect_async_with_tls_connector<R>(
420    request: R,
421    connector: Option<Connector>,
422) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
423where
424    R: IntoClientRequest + Unpin,
425{
426    connect_async_with_tls_connector_and_config(request, connector, None).await
427}
428
429#[cfg(any(
430    feature = "async-tls",
431    feature = "tokio-native-tls",
432    feature = "tokio-rustls-manual-roots",
433    feature = "tokio-rustls-native-certs",
434    feature = "tokio-rustls-webpki-roots",
435    feature = "tokio-openssl"
436))]
437/// Connect to a given URL using the provided TLS connector.
438pub async fn connect_async_with_tls_connector_and_config<R>(
439    request: R,
440    connector: Option<Connector>,
441    config: Option<WebSocketConfig>,
442) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
443where
444    R: IntoClientRequest + Unpin,
445{
446    let request: Request = request.into_client_request()?;
447
448    let domain = domain(&request)?;
449    let port = port(&request)?;
450
451    let try_socket = TcpStream::connect((domain.as_str(), port)).await;
452    let socket = try_socket.map_err(Error::Io)?;
453    client_async_tls_with_connector_and_config(request, socket, connector, config).await
454}
455
456use std::pin::Pin;
457use std::task::{Context, Poll};
458
459pin_project_lite::pin_project! {
460    /// Adapter for `tokio::io::AsyncRead` and `tokio::io::AsyncWrite` to provide
461    /// the variants from the `futures` crate and the other way around.
462    #[derive(Debug, Clone)]
463    pub struct TokioAdapter<T> {
464        #[pin]
465        inner: T,
466    }
467}
468
469impl<T> TokioAdapter<T> {
470    /// Creates a new `TokioAdapter` wrapping the provided value.
471    pub fn new(inner: T) -> Self {
472        Self { inner }
473    }
474
475    /// Consumes this `TokioAdapter`, returning the underlying value.
476    pub fn into_inner(self) -> T {
477        self.inner
478    }
479
480    /// Get a reference to the underlying value.
481    pub fn get_ref(&self) -> &T {
482        &self.inner
483    }
484
485    /// Get a mutable reference to the underlying value.
486    pub fn get_mut(&mut self) -> &mut T {
487        &mut self.inner
488    }
489}
490
491impl<T: tokio::io::AsyncRead> AsyncRead for TokioAdapter<T> {
492    fn poll_read(
493        self: Pin<&mut Self>,
494        cx: &mut Context<'_>,
495        buf: &mut [u8],
496    ) -> Poll<std::io::Result<usize>> {
497        let mut buf = tokio::io::ReadBuf::new(buf);
498        match self.project().inner.poll_read(cx, &mut buf)? {
499            Poll::Pending => Poll::Pending,
500            Poll::Ready(_) => Poll::Ready(Ok(buf.filled().len())),
501        }
502    }
503}
504
505impl<T: tokio::io::AsyncWrite> AsyncWrite for TokioAdapter<T> {
506    fn poll_write(
507        self: Pin<&mut Self>,
508        cx: &mut Context<'_>,
509        buf: &[u8],
510    ) -> Poll<Result<usize, std::io::Error>> {
511        self.project().inner.poll_write(cx, buf)
512    }
513
514    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
515        self.project().inner.poll_flush(cx)
516    }
517
518    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
519        self.project().inner.poll_shutdown(cx)
520    }
521}
522
523impl<T: AsyncRead> tokio::io::AsyncRead for TokioAdapter<T> {
524    fn poll_read(
525        self: Pin<&mut Self>,
526        cx: &mut Context<'_>,
527        buf: &mut tokio::io::ReadBuf<'_>,
528    ) -> Poll<std::io::Result<()>> {
529        let slice = buf.initialize_unfilled();
530        let n = match self.project().inner.poll_read(cx, slice)? {
531            Poll::Pending => return Poll::Pending,
532            Poll::Ready(n) => n,
533        };
534        buf.advance(n);
535        Poll::Ready(Ok(()))
536    }
537}
538
539impl<T: AsyncWrite> tokio::io::AsyncWrite for TokioAdapter<T> {
540    fn poll_write(
541        self: Pin<&mut Self>,
542        cx: &mut Context<'_>,
543        buf: &[u8],
544    ) -> Poll<Result<usize, std::io::Error>> {
545        self.project().inner.poll_write(cx, buf)
546    }
547
548    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
549        self.project().inner.poll_flush(cx)
550    }
551
552    fn poll_shutdown(
553        self: Pin<&mut Self>,
554        cx: &mut Context<'_>,
555    ) -> Poll<Result<(), std::io::Error>> {
556        self.project().inner.poll_close(cx)
557    }
558}