gloo_net/websocket/
io_util.rs

1use core::cmp;
2use core::pin::Pin;
3use core::task::{Context, Poll};
4use std::io;
5
6use futures_core::{ready, Stream as _};
7use futures_io::{AsyncRead, AsyncWrite};
8use futures_sink::Sink;
9
10use crate::websocket::futures::WebSocket;
11use crate::websocket::{Message as WebSocketMessage, WebSocketError};
12
13impl WebSocket {
14    /// Returns whether there are pending bytes left after calling [`AsyncRead::poll_read`] on this WebSocket.
15    ///
16    /// When calling [`AsyncRead::poll_read`], [`Stream::poll_next`](futures_core::Stream::poll_next) is called
17    /// under the hood, and when the received item is too big to fit into the provided buffer, leftover bytes are
18    /// stored. These leftover bytes are returned by subsequent calls to [`AsyncRead::poll_read`].
19    #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
20    pub fn has_pending_bytes(&self) -> bool {
21        self.read_pending_bytes.is_some()
22    }
23}
24
25macro_rules! try_in_poll_io {
26    ($expr:expr) => {{
27        match $expr {
28            Ok(o) => o,
29            // WebSocket is closed, nothing more to read or write
30            Err(WebSocketError::ConnectionClose(event)) if event.was_clean => {
31                return Poll::Ready(Ok(0));
32            }
33            Err(e) => return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
34        }
35    }};
36}
37
38#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
39impl AsyncRead for WebSocket {
40    fn poll_read(
41        mut self: Pin<&mut Self>,
42        cx: &mut Context<'_>,
43        buf: &mut [u8],
44    ) -> Poll<io::Result<usize>> {
45        let mut data = if let Some(data) = self.as_mut().get_mut().read_pending_bytes.take() {
46            data
47        } else {
48            match ready!(self.as_mut().poll_next(cx)) {
49                Some(item) => match try_in_poll_io!(item) {
50                    WebSocketMessage::Text(s) => s.into_bytes(),
51                    WebSocketMessage::Bytes(data) => data,
52                },
53                None => return Poll::Ready(Ok(0)),
54            }
55        };
56
57        let bytes_to_copy = cmp::min(buf.len(), data.len());
58        buf[..bytes_to_copy].copy_from_slice(&data[..bytes_to_copy]);
59
60        if data.len() > bytes_to_copy {
61            data.drain(..bytes_to_copy);
62            self.get_mut().read_pending_bytes = Some(data);
63        }
64
65        Poll::Ready(Ok(bytes_to_copy))
66    }
67}
68
69#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
70impl AsyncWrite for WebSocket {
71    fn poll_write(
72        mut self: Pin<&mut Self>,
73        cx: &mut Context<'_>,
74        buf: &[u8],
75    ) -> Poll<io::Result<usize>> {
76        // try flushing preemptively
77        let _ = AsyncWrite::poll_flush(self.as_mut(), cx);
78
79        // make sure sink is ready to send
80        try_in_poll_io!(ready!(self.as_mut().poll_ready(cx)));
81
82        // actually submit new item
83        try_in_poll_io!(self.start_send(WebSocketMessage::Bytes(buf.to_vec())));
84        // ^ if no error occurred, message is accepted and queued when calling `start_send`
85        // (i.e.: `to_vec` is called only once)
86
87        Poll::Ready(Ok(buf.len()))
88    }
89
90    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
91        let res = ready!(Sink::poll_flush(self, cx));
92        Poll::Ready(ws_result_to_io_result(res))
93    }
94
95    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
96        let res = ready!(Sink::poll_close(self, cx));
97        Poll::Ready(ws_result_to_io_result(res))
98    }
99}
100
101fn ws_result_to_io_result(res: Result<(), WebSocketError>) -> io::Result<()> {
102    match res {
103        Ok(()) => Ok(()),
104        Err(WebSocketError::ConnectionClose(_)) => Ok(()),
105        Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use futures::{AsyncReadExt, AsyncWriteExt, StreamExt};
113    use wasm_bindgen_test::*;
114
115    wasm_bindgen_test_configure!(run_in_browser);
116
117    #[wasm_bindgen_test]
118    async fn check_read_write() {
119        let ws_echo_server_url =
120            option_env!("WS_ECHO_SERVER_URL").expect("Did you set WS_ECHO_SERVER_URL?");
121
122        let mut ws = WebSocket::open(ws_echo_server_url).unwrap();
123
124        // ignore first message
125        // the echo-server uses it to send it's info in the first message
126        let _ = ws.next().await.unwrap();
127
128        let (mut reader, mut writer) = AsyncReadExt::split(ws);
129
130        writer.write_all(b"test 1").await.unwrap();
131        writer.write_all(b"test 2").await.unwrap();
132
133        let mut buf = [0u8; 6];
134        reader.read_exact(&mut buf).await.unwrap();
135        assert_eq!(&buf, b"test 1");
136        reader.read_exact(&mut buf).await.unwrap();
137        assert_eq!(&buf, b"test 2");
138    }
139
140    #[wasm_bindgen_test]
141    async fn with_pending_bytes() {
142        let ws_echo_server_url =
143            option_env!("WS_ECHO_SERVER_URL").expect("Did you set WS_ECHO_SERVER_URL?");
144
145        let mut ws = WebSocket::open(ws_echo_server_url).unwrap();
146
147        // ignore first message
148        // the echo-server uses it to send it's info in the first message
149        let _ = ws.next().await.unwrap();
150
151        ws.write_all(b"1234567890").await.unwrap();
152
153        let mut buf = [0u8; 5];
154
155        ws.read_exact(&mut buf).await.unwrap();
156        assert_eq!(&buf, b"12345");
157        assert!(ws.has_pending_bytes());
158
159        ws.read_exact(&mut buf).await.unwrap();
160        assert_eq!(&buf, b"67890");
161        assert!(!ws.has_pending_bytes());
162    }
163}