gloo_net/websocket/
io_util.rs1use core::cmp;
2use core::pin::Pin;
3use core::task::{Context, Poll};
4use std::io;
5
6use futures_core::{ready, Stream as _};
7use futures_io::{AsyncRead, AsyncWrite};
8use futures_sink::Sink;
9
10use crate::websocket::futures::WebSocket;
11use crate::websocket::{Message as WebSocketMessage, WebSocketError};
12
13impl WebSocket {
14 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
20 pub fn has_pending_bytes(&self) -> bool {
21 self.read_pending_bytes.is_some()
22 }
23}
24
25macro_rules! try_in_poll_io {
26 ($expr:expr) => {{
27 match $expr {
28 Ok(o) => o,
29 Err(WebSocketError::ConnectionClose(event)) if event.was_clean => {
31 return Poll::Ready(Ok(0));
32 }
33 Err(e) => return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
34 }
35 }};
36}
37
38#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
39impl AsyncRead for WebSocket {
40 fn poll_read(
41 mut self: Pin<&mut Self>,
42 cx: &mut Context<'_>,
43 buf: &mut [u8],
44 ) -> Poll<io::Result<usize>> {
45 let mut data = if let Some(data) = self.as_mut().get_mut().read_pending_bytes.take() {
46 data
47 } else {
48 match ready!(self.as_mut().poll_next(cx)) {
49 Some(item) => match try_in_poll_io!(item) {
50 WebSocketMessage::Text(s) => s.into_bytes(),
51 WebSocketMessage::Bytes(data) => data,
52 },
53 None => return Poll::Ready(Ok(0)),
54 }
55 };
56
57 let bytes_to_copy = cmp::min(buf.len(), data.len());
58 buf[..bytes_to_copy].copy_from_slice(&data[..bytes_to_copy]);
59
60 if data.len() > bytes_to_copy {
61 data.drain(..bytes_to_copy);
62 self.get_mut().read_pending_bytes = Some(data);
63 }
64
65 Poll::Ready(Ok(bytes_to_copy))
66 }
67}
68
69#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
70impl AsyncWrite for WebSocket {
71 fn poll_write(
72 mut self: Pin<&mut Self>,
73 cx: &mut Context<'_>,
74 buf: &[u8],
75 ) -> Poll<io::Result<usize>> {
76 let _ = AsyncWrite::poll_flush(self.as_mut(), cx);
78
79 try_in_poll_io!(ready!(self.as_mut().poll_ready(cx)));
81
82 try_in_poll_io!(self.start_send(WebSocketMessage::Bytes(buf.to_vec())));
84 Poll::Ready(Ok(buf.len()))
88 }
89
90 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
91 let res = ready!(Sink::poll_flush(self, cx));
92 Poll::Ready(ws_result_to_io_result(res))
93 }
94
95 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
96 let res = ready!(Sink::poll_close(self, cx));
97 Poll::Ready(ws_result_to_io_result(res))
98 }
99}
100
101fn ws_result_to_io_result(res: Result<(), WebSocketError>) -> io::Result<()> {
102 match res {
103 Ok(()) => Ok(()),
104 Err(WebSocketError::ConnectionClose(_)) => Ok(()),
105 Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use futures::{AsyncReadExt, AsyncWriteExt, StreamExt};
113 use wasm_bindgen_test::*;
114
115 wasm_bindgen_test_configure!(run_in_browser);
116
117 #[wasm_bindgen_test]
118 async fn check_read_write() {
119 let ws_echo_server_url =
120 option_env!("WS_ECHO_SERVER_URL").expect("Did you set WS_ECHO_SERVER_URL?");
121
122 let mut ws = WebSocket::open(ws_echo_server_url).unwrap();
123
124 let _ = ws.next().await.unwrap();
127
128 let (mut reader, mut writer) = AsyncReadExt::split(ws);
129
130 writer.write_all(b"test 1").await.unwrap();
131 writer.write_all(b"test 2").await.unwrap();
132
133 let mut buf = [0u8; 6];
134 reader.read_exact(&mut buf).await.unwrap();
135 assert_eq!(&buf, b"test 1");
136 reader.read_exact(&mut buf).await.unwrap();
137 assert_eq!(&buf, b"test 2");
138 }
139
140 #[wasm_bindgen_test]
141 async fn with_pending_bytes() {
142 let ws_echo_server_url =
143 option_env!("WS_ECHO_SERVER_URL").expect("Did you set WS_ECHO_SERVER_URL?");
144
145 let mut ws = WebSocket::open(ws_echo_server_url).unwrap();
146
147 let _ = ws.next().await.unwrap();
150
151 ws.write_all(b"1234567890").await.unwrap();
152
153 let mut buf = [0u8; 5];
154
155 ws.read_exact(&mut buf).await.unwrap();
156 assert_eq!(&buf, b"12345");
157 assert!(ws.has_pending_bytes());
158
159 ws.read_exact(&mut buf).await.unwrap();
160 assert_eq!(&buf, b"67890");
161 assert!(!ws.has_pending_bytes());
162 }
163}