async_tungstenite/
lib.rs

1//! Async WebSockets.
2//!
3//! This crate is based on [tungstenite](https://crates.io/crates/tungstenite)
4//! Rust WebSocket library and provides async bindings and wrappers for it, so you
5//! can use it with non-blocking/asynchronous `TcpStream`s from and couple it
6//! together with other crates from the async stack. In addition, optional
7//! integration with various other crates can be enabled via feature flags
8//!
9//!  * `async-tls`: Enables the `async_tls` module, which provides integration
10//!    with the [async-tls](https://crates.io/crates/async-tls) TLS stack and can
11//!    be used independent of any async runtime.
12//!  * `async-std-runtime`: Enables the `async_std` module, which provides
13//!    integration with the [async-std](https://async.rs) runtime.
14//!  * `async-native-tls`: Enables the additional functions in the `async_std`
15//!    module to implement TLS via
16//!    [async-native-tls](https://crates.io/crates/async-native-tls).
17//!  * `tokio-runtime`: Enables the `tokio` module, which provides integration
18//!    with the [tokio](https://tokio.rs) runtime.
19//!  * `tokio-native-tls`: Enables the additional functions in the `tokio` module to
20//!    implement TLS via [tokio-native-tls](https://crates.io/crates/tokio-native-tls).
21//!  * `tokio-rustls-native-certs`: Enables the additional functions in the `tokio`
22//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
23//!    and uses native system certificates found with
24//!    [rustls-native-certs](https://github.com/rustls/rustls-native-certs).
25//!  * `tokio-rustls-webpki-roots`: Enables the additional functions in the `tokio`
26//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
27//!    and uses the certificates [webpki-roots](https://github.com/rustls/webpki-roots)
28//!    provides.
29//!  * `tokio-openssl`: Enables the additional functions in the `tokio` module to
30//!    implement TLS via [tokio-openssl](https://crates.io/crates/tokio-openssl).
31//!  * `gio-runtime`: Enables the `gio` module, which provides integration with
32//!    the [gio](https://www.gtk-rs.org) runtime.
33//!
34//! Each WebSocket stream implements the required `Stream` and `Sink` traits,
35//! making the socket a stream of WebSocket messages coming in and going out.
36
37#![deny(
38    missing_docs,
39    unused_must_use,
40    unused_mut,
41    unused_imports,
42    unused_import_braces
43)]
44
45pub use tungstenite;
46
47mod compat;
48mod handshake;
49
50#[cfg(any(
51    feature = "async-tls",
52    feature = "async-native-tls",
53    feature = "tokio-native-tls",
54    feature = "tokio-rustls-manual-roots",
55    feature = "tokio-rustls-native-certs",
56    feature = "tokio-rustls-webpki-roots",
57    feature = "tokio-openssl",
58))]
59pub mod stream;
60
61use std::{
62    io::{Read, Write},
63    pin::Pin,
64    task::{ready, Context, Poll},
65};
66
67use compat::{cvt, AllowStd, ContextWaker};
68use futures_core::stream::{FusedStream, Stream};
69use futures_io::{AsyncRead, AsyncWrite};
70use log::*;
71
72#[cfg(feature = "handshake")]
73use tungstenite::{
74    client::IntoClientRequest,
75    handshake::{
76        client::{ClientHandshake, Response},
77        server::{Callback, NoCallback},
78        HandshakeError,
79    },
80};
81use tungstenite::{
82    error::Error as WsError,
83    protocol::{Message, Role, WebSocket, WebSocketConfig},
84};
85
86#[cfg(feature = "async-std-runtime")]
87pub mod async_std;
88#[cfg(feature = "async-tls")]
89pub mod async_tls;
90#[cfg(feature = "gio-runtime")]
91pub mod gio;
92#[cfg(feature = "tokio-runtime")]
93pub mod tokio;
94
95pub mod bytes;
96pub use bytes::ByteReader;
97#[cfg(feature = "futures-03-sink")]
98pub use bytes::ByteWriter;
99
100use tungstenite::protocol::CloseFrame;
101
102/// Creates a WebSocket handshake from a request and a stream.
103/// For convenience, the user may call this with a url string, a URL,
104/// or a `Request`. Calling with `Request` allows the user to add
105/// a WebSocket protocol or other custom headers.
106///
107/// Internally, this custom creates a handshake representation and returns
108/// a future representing the resolution of the WebSocket handshake. The
109/// returned future will resolve to either `WebSocketStream<S>` or `Error`
110/// depending on whether the handshake is successful.
111///
112/// This is typically used for clients who have already established, for
113/// example, a TCP connection to the remote server.
114#[cfg(feature = "handshake")]
115pub async fn client_async<'a, R, S>(
116    request: R,
117    stream: S,
118) -> Result<(WebSocketStream<S>, Response), WsError>
119where
120    R: IntoClientRequest + Unpin,
121    S: AsyncRead + AsyncWrite + Unpin,
122{
123    client_async_with_config(request, stream, None).await
124}
125
126/// The same as `client_async()` but the one can specify a websocket configuration.
127/// Please refer to `client_async()` for more details.
128#[cfg(feature = "handshake")]
129pub async fn client_async_with_config<'a, R, S>(
130    request: R,
131    stream: S,
132    config: Option<WebSocketConfig>,
133) -> Result<(WebSocketStream<S>, Response), WsError>
134where
135    R: IntoClientRequest + Unpin,
136    S: AsyncRead + AsyncWrite + Unpin,
137{
138    let f = handshake::client_handshake(stream, move |allow_std| {
139        let request = request.into_client_request()?;
140        let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
141        cli_handshake.handshake()
142    });
143    f.await.map_err(|e| match e {
144        HandshakeError::Failure(e) => e,
145        e => WsError::Io(std::io::Error::new(
146            std::io::ErrorKind::Other,
147            e.to_string(),
148        )),
149    })
150}
151
152/// Accepts a new WebSocket connection with the provided stream.
153///
154/// This function will internally call `server::accept` to create a
155/// handshake representation and returns a future representing the
156/// resolution of the WebSocket handshake. The returned future will resolve
157/// to either `WebSocketStream<S>` or `Error` depending if it's successful
158/// or not.
159///
160/// This is typically used after a socket has been accepted from a
161/// `TcpListener`. That socket is then passed to this function to perform
162/// the server half of the accepting a client's websocket connection.
163#[cfg(feature = "handshake")]
164pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
165where
166    S: AsyncRead + AsyncWrite + Unpin,
167{
168    accept_hdr_async(stream, NoCallback).await
169}
170
171/// The same as `accept_async()` but the one can specify a websocket configuration.
172/// Please refer to `accept_async()` for more details.
173#[cfg(feature = "handshake")]
174pub async fn accept_async_with_config<S>(
175    stream: S,
176    config: Option<WebSocketConfig>,
177) -> Result<WebSocketStream<S>, WsError>
178where
179    S: AsyncRead + AsyncWrite + Unpin,
180{
181    accept_hdr_async_with_config(stream, NoCallback, config).await
182}
183
184/// Accepts a new WebSocket connection with the provided stream.
185///
186/// This function does the same as `accept_async()` but accepts an extra callback
187/// for header processing. The callback receives headers of the incoming
188/// requests and is able to add extra headers to the reply.
189#[cfg(feature = "handshake")]
190pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
191where
192    S: AsyncRead + AsyncWrite + Unpin,
193    C: Callback + Unpin,
194{
195    accept_hdr_async_with_config(stream, callback, None).await
196}
197
198/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
199/// Please refer to `accept_hdr_async()` for more details.
200#[cfg(feature = "handshake")]
201pub async fn accept_hdr_async_with_config<S, C>(
202    stream: S,
203    callback: C,
204    config: Option<WebSocketConfig>,
205) -> Result<WebSocketStream<S>, WsError>
206where
207    S: AsyncRead + AsyncWrite + Unpin,
208    C: Callback + Unpin,
209{
210    let f = handshake::server_handshake(stream, move |allow_std| {
211        tungstenite::accept_hdr_with_config(allow_std, callback, config)
212    });
213    f.await.map_err(|e| match e {
214        HandshakeError::Failure(e) => e,
215        e => WsError::Io(std::io::Error::new(
216            std::io::ErrorKind::Other,
217            e.to_string(),
218        )),
219    })
220}
221
222/// A wrapper around an underlying raw stream which implements the WebSocket
223/// protocol.
224///
225/// A `WebSocketStream<S>` represents a handshake that has been completed
226/// successfully and both the server and the client are ready for receiving
227/// and sending data. Message from a `WebSocketStream<S>` are accessible
228/// through the respective `Stream` and `Sink`. Check more information about
229/// them in `futures-rs` crate documentation or have a look on the examples
230/// and unit tests for this crate.
231#[derive(Debug)]
232pub struct WebSocketStream<S> {
233    inner: WebSocket<AllowStd<S>>,
234    #[cfg(feature = "futures-03-sink")]
235    closing: bool,
236    ended: bool,
237    /// Tungstenite is probably ready to receive more data.
238    ///
239    /// `false` once start_send hits `WouldBlock` errors.
240    /// `true` initially and after `flush`ing.
241    ready: bool,
242}
243
244impl<S> WebSocketStream<S> {
245    /// Convert a raw socket into a WebSocketStream without performing a
246    /// handshake.
247    pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
248    where
249        S: AsyncRead + AsyncWrite + Unpin,
250    {
251        handshake::without_handshake(stream, move |allow_std| {
252            WebSocket::from_raw_socket(allow_std, role, config)
253        })
254        .await
255    }
256
257    /// Convert a raw socket into a WebSocketStream without performing a
258    /// handshake.
259    pub async fn from_partially_read(
260        stream: S,
261        part: Vec<u8>,
262        role: Role,
263        config: Option<WebSocketConfig>,
264    ) -> Self
265    where
266        S: AsyncRead + AsyncWrite + Unpin,
267    {
268        handshake::without_handshake(stream, move |allow_std| {
269            WebSocket::from_partially_read(allow_std, part, role, config)
270        })
271        .await
272    }
273
274    pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
275        Self {
276            inner: ws,
277            #[cfg(feature = "futures-03-sink")]
278            closing: false,
279            ended: false,
280            ready: true,
281        }
282    }
283
284    fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
285    where
286        S: Unpin,
287        F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
288        AllowStd<S>: Read + Write,
289    {
290        #[cfg(feature = "verbose-logging")]
291        trace!("{}:{} WebSocketStream.with_context", file!(), line!());
292        if let Some((kind, ctx)) = ctx {
293            self.inner.get_mut().set_waker(kind, ctx.waker());
294        }
295        f(&mut self.inner)
296    }
297
298    /// Returns a shared reference to the inner stream.
299    pub fn get_ref(&self) -> &S
300    where
301        S: AsyncRead + AsyncWrite + Unpin,
302    {
303        self.inner.get_ref().get_ref()
304    }
305
306    /// Returns a mutable reference to the inner stream.
307    pub fn get_mut(&mut self) -> &mut S
308    where
309        S: AsyncRead + AsyncWrite + Unpin,
310    {
311        self.inner.get_mut().get_mut()
312    }
313
314    /// Returns a reference to the configuration of the tungstenite stream.
315    pub fn get_config(&self) -> &WebSocketConfig {
316        self.inner.get_config()
317    }
318
319    /// Close the underlying web socket
320    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
321    where
322        S: AsyncRead + AsyncWrite + Unpin,
323    {
324        self.send(Message::Close(msg)).await
325    }
326}
327
328impl<T> Stream for WebSocketStream<T>
329where
330    T: AsyncRead + AsyncWrite + Unpin,
331{
332    type Item = Result<Message, WsError>;
333
334    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
335        #[cfg(feature = "verbose-logging")]
336        trace!("{}:{} Stream.poll_next", file!(), line!());
337
338        // The connection has been closed or a critical error has occurred.
339        // We have already returned the error to the user, the `Stream` is unusable,
340        // so we assume that the stream has been "fused".
341        if self.ended {
342            return Poll::Ready(None);
343        }
344
345        match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
346            #[cfg(feature = "verbose-logging")]
347            trace!(
348                "{}:{} Stream.with_context poll_next -> read()",
349                file!(),
350                line!()
351            );
352            cvt(s.read())
353        })) {
354            Ok(v) => Poll::Ready(Some(Ok(v))),
355            Err(e) => {
356                self.ended = true;
357                if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
358                    Poll::Ready(None)
359                } else {
360                    Poll::Ready(Some(Err(e)))
361                }
362            }
363        }
364    }
365}
366
367impl<T> FusedStream for WebSocketStream<T>
368where
369    T: AsyncRead + AsyncWrite + Unpin,
370{
371    fn is_terminated(&self) -> bool {
372        self.ended
373    }
374}
375
376#[cfg(feature = "futures-03-sink")]
377impl<T> futures_util::Sink<Message> for WebSocketStream<T>
378where
379    T: AsyncRead + AsyncWrite + Unpin,
380{
381    type Error = WsError;
382
383    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
384        if self.ready {
385            Poll::Ready(Ok(()))
386        } else {
387            // Currently blocked so try to flush the blockage away
388            (*self)
389                .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
390                .map(|r| {
391                    self.ready = true;
392                    r
393                })
394        }
395    }
396
397    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
398        match (*self).with_context(None, |s| s.write(item)) {
399            Ok(()) => {
400                self.ready = true;
401                Ok(())
402            }
403            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
404                // the message was accepted and queued so not an error
405                // but `poll_ready` will now start trying to flush the block
406                self.ready = false;
407                Ok(())
408            }
409            Err(e) => {
410                self.ready = true;
411                debug!("websocket start_send error: {}", e);
412                Err(e)
413            }
414        }
415    }
416
417    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
418        (*self)
419            .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
420            .map(|r| {
421                self.ready = true;
422                match r {
423                    // WebSocket connection has just been closed. Flushing completed, not an error.
424                    Err(WsError::ConnectionClosed) => Ok(()),
425                    other => other,
426                }
427            })
428    }
429
430    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
431        self.ready = true;
432        let res = if self.closing {
433            // After queueing it, we call `flush` to drive the close handshake to completion.
434            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
435        } else {
436            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
437        };
438
439        match res {
440            Ok(()) => Poll::Ready(Ok(())),
441            Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
442            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
443                trace!("WouldBlock");
444                self.closing = true;
445                Poll::Pending
446            }
447            Err(err) => {
448                debug!("websocket close error: {}", err);
449                Poll::Ready(Err(err))
450            }
451        }
452    }
453}
454
455impl<S> WebSocketStream<S> {
456    /// Simple send method to replace `futures_sink::Sink` (till v0.3).
457    pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
458    where
459        S: AsyncRead + AsyncWrite + Unpin,
460    {
461        Send::new(self, msg).await
462    }
463}
464
465struct Send<'a, S> {
466    ws: &'a mut WebSocketStream<S>,
467    msg: Option<Message>,
468}
469
470impl<'a, S> Send<'a, S>
471where
472    S: AsyncRead + AsyncWrite + Unpin,
473{
474    fn new(ws: &'a mut WebSocketStream<S>, msg: Message) -> Self {
475        Self { ws, msg: Some(msg) }
476    }
477}
478
479impl<S> std::future::Future for Send<'_, S>
480where
481    S: AsyncRead + AsyncWrite + Unpin,
482{
483    type Output = Result<(), WsError>;
484
485    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
486        if self.msg.is_some() {
487            if !self.ws.ready {
488                // Currently blocked so try to flush the blockage away
489                let polled = self
490                    .ws
491                    .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
492                    .map(|r| {
493                        self.ws.ready = true;
494                        r
495                    });
496                ready!(polled)?
497            }
498
499            let msg = self.msg.take().expect("unreachable");
500            match self.ws.with_context(None, |s| s.write(msg)) {
501                Ok(_) => Ok(()),
502                Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
503                    // the message was accepted and queued so not an error
504                    //
505                    // set to false here for cancellation safety of *this* Future
506                    self.ws.ready = false;
507                    Ok(())
508                }
509                Err(e) => {
510                    debug!("websocket start_send error: {}", e);
511                    Err(e)
512                }
513            }?;
514        }
515
516        let polled = self
517            .ws
518            .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
519            .map(|r| {
520                self.ws.ready = true;
521                match r {
522                    // WebSocket connection has just been closed. Flushing completed, not an error.
523                    Err(WsError::ConnectionClosed) => Ok(()),
524                    other => other,
525                }
526            });
527        ready!(polled)?;
528
529        Poll::Ready(Ok(()))
530    }
531}
532
533#[cfg(any(
534    feature = "async-tls",
535    feature = "async-std-runtime",
536    feature = "tokio-runtime",
537    feature = "gio-runtime"
538))]
539/// Get a domain from an URL.
540#[inline]
541pub(crate) fn domain(
542    request: &tungstenite::handshake::client::Request,
543) -> Result<String, tungstenite::Error> {
544    request
545        .uri()
546        .host()
547        .map(|host| {
548            // If host is an IPv6 address, it might be surrounded by brackets. These brackets are
549            // *not* part of a valid IP, so they must be stripped out.
550            //
551            // The URI from the request is guaranteed to be valid, so we don't need a separate
552            // check for the closing bracket.
553            let host = if host.starts_with('[') {
554                &host[1..host.len() - 1]
555            } else {
556                host
557            };
558
559            host.to_owned()
560        })
561        .ok_or(tungstenite::Error::Url(
562            tungstenite::error::UrlError::NoHostName,
563        ))
564}
565
566#[cfg(any(
567    feature = "async-std-runtime",
568    feature = "tokio-runtime",
569    feature = "gio-runtime"
570))]
571/// Get the port from an URL.
572#[inline]
573pub(crate) fn port(
574    request: &tungstenite::handshake::client::Request,
575) -> Result<u16, tungstenite::Error> {
576    request
577        .uri()
578        .port_u16()
579        .or_else(|| match request.uri().scheme_str() {
580            Some("wss") => Some(443),
581            Some("ws") => Some(80),
582            _ => None,
583        })
584        .ok_or(tungstenite::Error::Url(
585            tungstenite::error::UrlError::UnsupportedUrlScheme,
586        ))
587}
588
589#[cfg(test)]
590mod tests {
591    #[cfg(any(
592        feature = "async-tls",
593        feature = "async-std-runtime",
594        feature = "tokio-runtime",
595        feature = "gio-runtime"
596    ))]
597    #[test]
598    fn domain_strips_ipv6_brackets() {
599        use tungstenite::client::IntoClientRequest;
600
601        let request = "ws://[::1]:80".into_client_request().unwrap();
602        assert_eq!(crate::domain(&request).unwrap(), "::1");
603    }
604
605    #[cfg(feature = "handshake")]
606    #[test]
607    fn requests_cannot_contain_invalid_uris() {
608        use tungstenite::client::IntoClientRequest;
609
610        assert!("ws://[".into_client_request().is_err());
611        assert!("ws://[blabla/bla".into_client_request().is_err());
612        assert!("ws://[::1/bla".into_client_request().is_err());
613    }
614}