tokio_tungstenite/
lib.rs

1//! Async WebSocket usage.
2//!
3//! This library is an implementation of WebSocket handshakes and streams. It
4//! is based on the crate which implements all required WebSocket protocol
5//! logic. So this crate basically just brings tokio support / tokio integration
6//! to it.
7//!
8//! Each WebSocket stream implements the required `Stream` and `Sink` traits,
9//! so the socket is just a stream of messages coming in and going out.
10
11#![deny(missing_docs, unused_must_use, unused_mut, unused_imports, unused_import_braces)]
12
13pub use tungstenite;
14
15mod compat;
16#[cfg(feature = "connect")]
17mod connect;
18mod handshake;
19#[cfg(feature = "stream")]
20mod stream;
21#[cfg(any(feature = "native-tls", feature = "__rustls-tls", feature = "connect"))]
22mod tls;
23
24use std::io::{Read, Write};
25
26use compat::{cvt, AllowStd, ContextWaker};
27use futures_util::{
28    sink::{Sink, SinkExt},
29    stream::{FusedStream, Stream},
30};
31use log::*;
32use std::{
33    pin::Pin,
34    task::{Context, Poll},
35};
36use tokio::io::{AsyncRead, AsyncWrite};
37
38#[cfg(feature = "handshake")]
39use tungstenite::{
40    client::IntoClientRequest,
41    handshake::{
42        client::{ClientHandshake, Response},
43        server::{Callback, NoCallback},
44        HandshakeError,
45    },
46};
47use tungstenite::{
48    error::Error as WsError,
49    protocol::{Message, Role, WebSocket, WebSocketConfig},
50};
51
52#[cfg(any(feature = "native-tls", feature = "__rustls-tls", feature = "connect"))]
53pub use tls::Connector;
54#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
55pub use tls::{client_async_tls, client_async_tls_with_config};
56
57#[cfg(feature = "connect")]
58pub use connect::{connect_async, connect_async_with_config};
59
60#[cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "connect"))]
61pub use connect::connect_async_tls_with_config;
62
63#[cfg(feature = "stream")]
64pub use stream::MaybeTlsStream;
65
66use tungstenite::protocol::CloseFrame;
67
68/// Creates a WebSocket handshake from a request and a stream.
69/// For convenience, the user may call this with a url string, a URL,
70/// or a `Request`. Calling with `Request` allows the user to add
71/// a WebSocket protocol or other custom headers.
72///
73/// Internally, this custom creates a handshake representation and returns
74/// a future representing the resolution of the WebSocket handshake. The
75/// returned future will resolve to either `WebSocketStream<S>` or `Error`
76/// depending on whether the handshake is successful.
77///
78/// This is typically used for clients who have already established, for
79/// example, a TCP connection to the remote server.
80#[cfg(feature = "handshake")]
81pub async fn client_async<'a, R, S>(
82    request: R,
83    stream: S,
84) -> Result<(WebSocketStream<S>, Response), WsError>
85where
86    R: IntoClientRequest + Unpin,
87    S: AsyncRead + AsyncWrite + Unpin,
88{
89    client_async_with_config(request, stream, None).await
90}
91
92/// The same as `client_async()` but the one can specify a websocket configuration.
93/// Please refer to `client_async()` for more details.
94#[cfg(feature = "handshake")]
95pub async fn client_async_with_config<'a, R, S>(
96    request: R,
97    stream: S,
98    config: Option<WebSocketConfig>,
99) -> Result<(WebSocketStream<S>, Response), WsError>
100where
101    R: IntoClientRequest + Unpin,
102    S: AsyncRead + AsyncWrite + Unpin,
103{
104    let f = handshake::client_handshake(stream, move |allow_std| {
105        let request = request.into_client_request()?;
106        let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
107        cli_handshake.handshake()
108    });
109    f.await.map_err(|e| match e {
110        HandshakeError::Failure(e) => e,
111        e => WsError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
112    })
113}
114
115/// Accepts a new WebSocket connection with the provided stream.
116///
117/// This function will internally call `server::accept` to create a
118/// handshake representation and returns a future representing the
119/// resolution of the WebSocket handshake. The returned future will resolve
120/// to either `WebSocketStream<S>` or `Error` depending if it's successful
121/// or not.
122///
123/// This is typically used after a socket has been accepted from a
124/// `TcpListener`. That socket is then passed to this function to perform
125/// the server half of the accepting a client's websocket connection.
126#[cfg(feature = "handshake")]
127pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
128where
129    S: AsyncRead + AsyncWrite + Unpin,
130{
131    accept_hdr_async(stream, NoCallback).await
132}
133
134/// The same as `accept_async()` but the one can specify a websocket configuration.
135/// Please refer to `accept_async()` for more details.
136#[cfg(feature = "handshake")]
137pub async fn accept_async_with_config<S>(
138    stream: S,
139    config: Option<WebSocketConfig>,
140) -> Result<WebSocketStream<S>, WsError>
141where
142    S: AsyncRead + AsyncWrite + Unpin,
143{
144    accept_hdr_async_with_config(stream, NoCallback, config).await
145}
146
147/// Accepts a new WebSocket connection with the provided stream.
148///
149/// This function does the same as `accept_async()` but accepts an extra callback
150/// for header processing. The callback receives headers of the incoming
151/// requests and is able to add extra headers to the reply.
152#[cfg(feature = "handshake")]
153pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
154where
155    S: AsyncRead + AsyncWrite + Unpin,
156    C: Callback + Unpin,
157{
158    accept_hdr_async_with_config(stream, callback, None).await
159}
160
161/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
162/// Please refer to `accept_hdr_async()` for more details.
163#[cfg(feature = "handshake")]
164pub async fn accept_hdr_async_with_config<S, C>(
165    stream: S,
166    callback: C,
167    config: Option<WebSocketConfig>,
168) -> Result<WebSocketStream<S>, WsError>
169where
170    S: AsyncRead + AsyncWrite + Unpin,
171    C: Callback + Unpin,
172{
173    let f = handshake::server_handshake(stream, move |allow_std| {
174        tungstenite::accept_hdr_with_config(allow_std, callback, config)
175    });
176    f.await.map_err(|e| match e {
177        HandshakeError::Failure(e) => e,
178        e => WsError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
179    })
180}
181
182/// A wrapper around an underlying raw stream which implements the WebSocket
183/// protocol.
184///
185/// A `WebSocketStream<S>` represents a handshake that has been completed
186/// successfully and both the server and the client are ready for receiving
187/// and sending data. Message from a `WebSocketStream<S>` are accessible
188/// through the respective `Stream` and `Sink`. Check more information about
189/// them in `futures-rs` crate documentation or have a look on the examples
190/// and unit tests for this crate.
191#[derive(Debug)]
192pub struct WebSocketStream<S> {
193    inner: WebSocket<AllowStd<S>>,
194    closing: bool,
195    ended: bool,
196    /// Tungstenite is probably ready to receive more data.
197    ///
198    /// `false` once start_send hits `WouldBlock` errors.
199    /// `true` initially and after `flush`ing.
200    ready: bool,
201}
202
203impl<S> WebSocketStream<S> {
204    /// Convert a raw socket into a WebSocketStream without performing a
205    /// handshake.
206    pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
207    where
208        S: AsyncRead + AsyncWrite + Unpin,
209    {
210        handshake::without_handshake(stream, move |allow_std| {
211            WebSocket::from_raw_socket(allow_std, role, config)
212        })
213        .await
214    }
215
216    /// Convert a raw socket into a WebSocketStream without performing a
217    /// handshake.
218    pub async fn from_partially_read(
219        stream: S,
220        part: Vec<u8>,
221        role: Role,
222        config: Option<WebSocketConfig>,
223    ) -> Self
224    where
225        S: AsyncRead + AsyncWrite + Unpin,
226    {
227        handshake::without_handshake(stream, move |allow_std| {
228            WebSocket::from_partially_read(allow_std, part, role, config)
229        })
230        .await
231    }
232
233    pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
234        Self { inner: ws, closing: false, ended: false, ready: true }
235    }
236
237    fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
238    where
239        S: Unpin,
240        F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
241        AllowStd<S>: Read + Write,
242    {
243        trace!("{}:{} WebSocketStream.with_context", file!(), line!());
244        if let Some((kind, ctx)) = ctx {
245            self.inner.get_mut().set_waker(kind, ctx.waker());
246        }
247        f(&mut self.inner)
248    }
249
250    /// Returns a shared reference to the inner stream.
251    pub fn get_ref(&self) -> &S
252    where
253        S: AsyncRead + AsyncWrite + Unpin,
254    {
255        self.inner.get_ref().get_ref()
256    }
257
258    /// Returns a mutable reference to the inner stream.
259    pub fn get_mut(&mut self) -> &mut S
260    where
261        S: AsyncRead + AsyncWrite + Unpin,
262    {
263        self.inner.get_mut().get_mut()
264    }
265
266    /// Returns a reference to the configuration of the tungstenite stream.
267    pub fn get_config(&self) -> &WebSocketConfig {
268        self.inner.get_config()
269    }
270
271    /// Close the underlying web socket
272    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
273    where
274        S: AsyncRead + AsyncWrite + Unpin,
275    {
276        self.send(Message::Close(msg)).await
277    }
278}
279
280impl<T> Stream for WebSocketStream<T>
281where
282    T: AsyncRead + AsyncWrite + Unpin,
283{
284    type Item = Result<Message, WsError>;
285
286    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
287        trace!("{}:{} Stream.poll_next", file!(), line!());
288
289        // The connection has been closed or a critical error has occurred.
290        // We have already returned the error to the user, the `Stream` is unusable,
291        // so we assume that the stream has been "fused".
292        if self.ended {
293            return Poll::Ready(None);
294        }
295
296        match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
297            trace!("{}:{} Stream.with_context poll_next -> read()", file!(), line!());
298            cvt(s.read())
299        })) {
300            Ok(v) => Poll::Ready(Some(Ok(v))),
301            Err(e) => {
302                self.ended = true;
303                if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
304                    Poll::Ready(None)
305                } else {
306                    Poll::Ready(Some(Err(e)))
307                }
308            }
309        }
310    }
311}
312
313impl<T> FusedStream for WebSocketStream<T>
314where
315    T: AsyncRead + AsyncWrite + Unpin,
316{
317    fn is_terminated(&self) -> bool {
318        self.ended
319    }
320}
321
322impl<T> Sink<Message> for WebSocketStream<T>
323where
324    T: AsyncRead + AsyncWrite + Unpin,
325{
326    type Error = WsError;
327
328    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
329        if self.ready {
330            Poll::Ready(Ok(()))
331        } else {
332            // Currently blocked so try to flush the blockage away
333            (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| {
334                self.ready = true;
335                r
336            })
337        }
338    }
339
340    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
341        match (*self).with_context(None, |s| s.write(item)) {
342            Ok(()) => {
343                self.ready = true;
344                Ok(())
345            }
346            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
347                // the message was accepted and queued so not an error
348                // but `poll_ready` will now start trying to flush the block
349                self.ready = false;
350                Ok(())
351            }
352            Err(e) => {
353                self.ready = true;
354                debug!("websocket start_send error: {}", e);
355                Err(e)
356            }
357        }
358    }
359
360    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
361        (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| {
362            self.ready = true;
363            match r {
364                // WebSocket connection has just been closed. Flushing completed, not an error.
365                Err(WsError::ConnectionClosed) => Ok(()),
366                other => other,
367            }
368        })
369    }
370
371    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
372        self.ready = true;
373        let res = if self.closing {
374            // After queueing it, we call `flush` to drive the close handshake to completion.
375            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
376        } else {
377            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
378        };
379
380        match res {
381            Ok(()) => Poll::Ready(Ok(())),
382            Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
383            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
384                trace!("WouldBlock");
385                self.closing = true;
386                Poll::Pending
387            }
388            Err(err) => {
389                debug!("websocket close error: {}", err);
390                Poll::Ready(Err(err))
391            }
392        }
393    }
394}
395
396/// Get a domain from an URL.
397#[cfg(any(feature = "connect", feature = "native-tls", feature = "__rustls-tls"))]
398#[inline]
399fn domain(request: &tungstenite::handshake::client::Request) -> Result<String, WsError> {
400    match request.uri().host() {
401        // rustls expects IPv6 addresses without the surrounding [] brackets
402        #[cfg(feature = "__rustls-tls")]
403        Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()),
404        Some(d) => Ok(d.to_string()),
405        None => Err(WsError::Url(tungstenite::error::UrlError::NoHostName)),
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    #[cfg(feature = "connect")]
412    use crate::stream::MaybeTlsStream;
413    use crate::{compat::AllowStd, WebSocketStream};
414    use std::io::{Read, Write};
415    #[cfg(feature = "connect")]
416    use tokio::io::{AsyncReadExt, AsyncWriteExt};
417
418    fn is_read<T: Read>() {}
419    fn is_write<T: Write>() {}
420    #[cfg(feature = "connect")]
421    fn is_async_read<T: AsyncReadExt>() {}
422    #[cfg(feature = "connect")]
423    fn is_async_write<T: AsyncWriteExt>() {}
424    fn is_unpin<T: Unpin>() {}
425
426    #[test]
427    fn web_socket_stream_has_traits() {
428        is_read::<AllowStd<tokio::net::TcpStream>>();
429        is_write::<AllowStd<tokio::net::TcpStream>>();
430
431        #[cfg(feature = "connect")]
432        is_async_read::<MaybeTlsStream<tokio::net::TcpStream>>();
433        #[cfg(feature = "connect")]
434        is_async_write::<MaybeTlsStream<tokio::net::TcpStream>>();
435
436        is_unpin::<WebSocketStream<tokio::net::TcpStream>>();
437        #[cfg(feature = "connect")]
438        is_unpin::<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>>();
439    }
440}