gloo_net/websocket/
futures.rs

1//! The wrapper around `WebSocket` API using the Futures API to be used in async rust
2//!
3//! # Example
4//!
5//! ```rust
6//! use gloo_net::websocket::{Message, futures::WebSocket};
7//! use wasm_bindgen_futures::spawn_local;
8//! use futures::{SinkExt, StreamExt};
9//!
10//! # macro_rules! console_log {
11//! #    ($($expr:expr),*) => {{}};
12//! # }
13//! # fn no_run() {
14//! let mut ws = WebSocket::open("wss://echo.websocket.org").unwrap();
15//! let (mut write, mut read) = ws.split();
16//!
17//! spawn_local(async move {
18//!     write.send(Message::Text(String::from("test"))).await.unwrap();
19//!     write.send(Message::Text(String::from("test 2"))).await.unwrap();
20//! });
21//!
22//! spawn_local(async move {
23//!     while let Some(msg) = read.next().await {
24//!         console_log!(format!("1. {:?}", msg))
25//!     }
26//!     console_log!("WebSocket Closed")
27//! })
28//! # }
29//! ```
30use crate::js_to_js_error;
31use crate::websocket::{events::CloseEvent, Message, State, WebSocketError};
32use futures_channel::mpsc;
33use futures_core::{ready, Stream};
34use futures_sink::Sink;
35use gloo_utils::errors::JsError;
36use pin_project::{pin_project, pinned_drop};
37use std::cell::RefCell;
38use std::pin::Pin;
39use std::rc::Rc;
40use std::task::{Context, Poll, Waker};
41use wasm_bindgen::prelude::*;
42use wasm_bindgen::JsCast;
43use web_sys::{BinaryType, MessageEvent};
44
45/// Wrapper around browser's WebSocket API.
46#[allow(missing_debug_implementations)]
47#[pin_project(PinnedDrop)]
48pub struct WebSocket {
49    ws: web_sys::WebSocket,
50    sink_waker: Rc<RefCell<Option<Waker>>>,
51    #[pin]
52    message_receiver: mpsc::UnboundedReceiver<StreamMessage>,
53    #[allow(clippy::type_complexity)]
54    closures: (
55        Closure<dyn FnMut()>,
56        Closure<dyn FnMut(MessageEvent)>,
57        Closure<dyn FnMut(web_sys::Event)>,
58        Closure<dyn FnMut(web_sys::CloseEvent)>,
59    ),
60    /// Leftover bytes when using `AsyncRead`.
61    ///
62    /// These bytes are drained and returned in subsequent calls to `poll_read`.
63    #[cfg(feature = "io-util")]
64    pub(super) read_pending_bytes: Option<Vec<u8>>, // Same size as `Vec<u8>` alone thanks to niche optimization
65}
66
67impl WebSocket {
68    /// Establish a WebSocket connection.
69    ///
70    /// This function may error in the following cases:
71    /// - The port to which the connection is being attempted is being blocked.
72    /// - The URL is invalid.
73    ///
74    /// The error returned is [`JsError`]. See the
75    /// [MDN Documentation](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/WebSocket#exceptions_thrown)
76    /// to learn more.
77    pub fn open(url: &str) -> Result<Self, JsError> {
78        Self::setup(web_sys::WebSocket::new(url))
79    }
80
81    /// Establish a WebSocket connection.
82    ///
83    /// This function may error in the following cases:
84    /// - The port to which the connection is being attempted is being blocked.
85    /// - The URL is invalid.
86    /// - The specified protocol is not supported
87    ///
88    /// The error returned is [`JsError`]. See the
89    /// [MDN Documentation](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/WebSocket#exceptions_thrown)
90    /// to learn more.
91    pub fn open_with_protocol(url: &str, protocol: &str) -> Result<Self, JsError> {
92        Self::setup(web_sys::WebSocket::new_with_str(url, protocol))
93    }
94
95    /// Establish a WebSocket connection.
96    ///
97    /// This function may error in the following cases:
98    /// - The port to which the connection is being attempted is being blocked.
99    /// - The URL is invalid.
100    /// - The specified protocols are not supported
101    /// - The protocols cannot be converted to a JSON string list
102    ///
103    /// The error returned is [`JsError`]. See the
104    /// [MDN Documentation](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/WebSocket#exceptions_thrown)
105    /// to learn more.
106    ///
107    /// This function requires `json` features because protocols are parsed by `serde` into `JsValue`.
108    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
109    #[cfg(feature = "json")]
110    pub fn open_with_protocols<S: AsRef<str> + serde::Serialize>(
111        url: &str,
112        protocols: &[S],
113    ) -> Result<Self, JsError> {
114        let json = <JsValue as gloo_utils::format::JsValueSerdeExt>::from_serde(protocols)
115            .map_err(|err| {
116                js_sys::Error::new(&format!(
117                    "Failed to convert protocols to Javascript value: {err}"
118                ))
119            })?;
120        Self::setup(web_sys::WebSocket::new_with_str_sequence(url, &json))
121    }
122
123    fn setup(ws: Result<web_sys::WebSocket, JsValue>) -> Result<Self, JsError> {
124        let waker: Rc<RefCell<Option<Waker>>> = Rc::new(RefCell::new(None));
125        let ws = ws.map_err(js_to_js_error)?;
126
127        // We rely on this because the other type Blob can be converted to Vec<u8> only through a
128        // promise which makes it awkward to use in our event callbacks where we want to guarantee
129        // the order of the events stays the same.
130        ws.set_binary_type(BinaryType::Arraybuffer);
131
132        let (sender, receiver) = mpsc::unbounded();
133
134        let open_callback: Closure<dyn FnMut()> = {
135            let waker = Rc::clone(&waker);
136            Closure::wrap(Box::new(move || {
137                if let Some(waker) = waker.borrow_mut().take() {
138                    waker.wake();
139                }
140            }) as Box<dyn FnMut()>)
141        };
142
143        ws.add_event_listener_with_callback_and_add_event_listener_options(
144            "open",
145            open_callback.as_ref().unchecked_ref(),
146            web_sys::AddEventListenerOptions::new().once(true),
147        )
148        .map_err(js_to_js_error)?;
149
150        let message_callback: Closure<dyn FnMut(MessageEvent)> = {
151            let sender = sender.clone();
152            Closure::wrap(Box::new(move |e: MessageEvent| {
153                let msg = parse_message(e);
154                let _ = sender.unbounded_send(StreamMessage::Message(msg));
155            }) as Box<dyn FnMut(MessageEvent)>)
156        };
157
158        ws.add_event_listener_with_callback("message", message_callback.as_ref().unchecked_ref())
159            .map_err(js_to_js_error)?;
160
161        let error_callback: Closure<dyn FnMut(web_sys::Event)> = {
162            let sender = sender.clone();
163            let waker = Rc::clone(&waker);
164            Closure::wrap(Box::new(move |_e: web_sys::Event| {
165                if let Some(waker) = waker.borrow_mut().take() {
166                    waker.wake();
167                }
168                let _ = sender.unbounded_send(StreamMessage::ErrorEvent);
169            }) as Box<dyn FnMut(web_sys::Event)>)
170        };
171
172        ws.add_event_listener_with_callback("error", error_callback.as_ref().unchecked_ref())
173            .map_err(js_to_js_error)?;
174
175        let close_callback: Closure<dyn FnMut(web_sys::CloseEvent)> = {
176            Closure::wrap(Box::new(move |e: web_sys::CloseEvent| {
177                let close_event = CloseEvent {
178                    code: e.code(),
179                    reason: e.reason(),
180                    was_clean: e.was_clean(),
181                };
182                let _ = sender.unbounded_send(StreamMessage::CloseEvent(close_event));
183                let _ = sender.unbounded_send(StreamMessage::ConnectionClose);
184            }) as Box<dyn FnMut(web_sys::CloseEvent)>)
185        };
186
187        ws.add_event_listener_with_callback_and_add_event_listener_options(
188            "close",
189            close_callback.as_ref().unchecked_ref(),
190            web_sys::AddEventListenerOptions::new().once(true),
191        )
192        .map_err(js_to_js_error)?;
193
194        Ok(Self {
195            ws,
196            sink_waker: waker,
197            message_receiver: receiver,
198            closures: (
199                open_callback,
200                message_callback,
201                error_callback,
202                close_callback,
203            ),
204            #[cfg(feature = "io-util")]
205            read_pending_bytes: None,
206        })
207    }
208
209    /// Closes the websocket.
210    ///
211    /// See the [MDN Documentation](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/close#parameters)
212    /// to learn about parameters passed to this function and when it can return an `Err(_)`
213    pub fn close(self, code: Option<u16>, reason: Option<&str>) -> Result<(), JsError> {
214        let result = match (code, reason) {
215            (None, None) => self.ws.close(),
216            (Some(code), None) => self.ws.close_with_code(code),
217            (Some(code), Some(reason)) => self.ws.close_with_code_and_reason(code, reason),
218            // default code is 1005 so we use it,
219            // see: https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/close#parameters
220            (None, Some(reason)) => self.ws.close_with_code_and_reason(1005, reason),
221        };
222        result.map_err(js_to_js_error)
223    }
224
225    /// The current state of the websocket.
226    pub fn state(&self) -> State {
227        let ready_state = self.ws.ready_state();
228        match ready_state {
229            0 => State::Connecting,
230            1 => State::Open,
231            2 => State::Closing,
232            3 => State::Closed,
233            _ => unreachable!(),
234        }
235    }
236
237    /// The extensions in use.
238    pub fn extensions(&self) -> String {
239        self.ws.extensions()
240    }
241
242    /// The sub-protocol in use.
243    pub fn protocol(&self) -> String {
244        self.ws.protocol()
245    }
246}
247
248impl TryFrom<web_sys::WebSocket> for WebSocket {
249    type Error = JsError;
250
251    fn try_from(ws: web_sys::WebSocket) -> Result<Self, Self::Error> {
252        Self::setup(Ok(ws))
253    }
254}
255
256#[derive(Clone)]
257enum StreamMessage {
258    ErrorEvent,
259    CloseEvent(CloseEvent),
260    Message(Message),
261    ConnectionClose,
262}
263
264fn parse_message(event: MessageEvent) -> Message {
265    if let Ok(array_buffer) = event.data().dyn_into::<js_sys::ArrayBuffer>() {
266        let array = js_sys::Uint8Array::new(&array_buffer);
267        Message::Bytes(array.to_vec())
268    } else if let Ok(txt) = event.data().dyn_into::<js_sys::JsString>() {
269        Message::Text(String::from(&txt))
270    } else {
271        unreachable!("message event, received Unknown: {:?}", event.data());
272    }
273}
274
275impl Sink<Message> for WebSocket {
276    type Error = WebSocketError;
277
278    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
279        let ready_state = self.ws.ready_state();
280        if ready_state == 0 {
281            *self.sink_waker.borrow_mut() = Some(cx.waker().clone());
282            Poll::Pending
283        } else {
284            Poll::Ready(Ok(()))
285        }
286    }
287
288    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
289        let result = match item {
290            Message::Bytes(bytes) => self.ws.send_with_u8_array(&bytes),
291            Message::Text(message) => self.ws.send_with_str(&message),
292        };
293        match result {
294            Ok(_) => Ok(()),
295            Err(e) => Err(WebSocketError::MessageSendError(js_to_js_error(e))),
296        }
297    }
298
299    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
300        Poll::Ready(Ok(()))
301    }
302
303    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
304        Poll::Ready(Ok(()))
305    }
306}
307
308impl Stream for WebSocket {
309    type Item = Result<Message, WebSocketError>;
310
311    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
312        let msg = ready!(self.project().message_receiver.poll_next(cx));
313        match msg {
314            Some(StreamMessage::Message(msg)) => Poll::Ready(Some(Ok(msg))),
315            Some(StreamMessage::ErrorEvent) => {
316                Poll::Ready(Some(Err(WebSocketError::ConnectionError)))
317            }
318            Some(StreamMessage::CloseEvent(e)) => {
319                Poll::Ready(Some(Err(WebSocketError::ConnectionClose(e))))
320            }
321            Some(StreamMessage::ConnectionClose) => Poll::Ready(None),
322            None => Poll::Ready(None),
323        }
324    }
325}
326
327#[pinned_drop]
328impl PinnedDrop for WebSocket {
329    fn drop(self: Pin<&mut Self>) {
330        self.ws.close().unwrap();
331
332        for (ty, cb) in [
333            ("open", self.closures.0.as_ref()),
334            ("message", self.closures.1.as_ref()),
335            ("error", self.closures.2.as_ref()),
336        ] {
337            let _ = self
338                .ws
339                .remove_event_listener_with_callback(ty, cb.unchecked_ref());
340        }
341
342        if let Ok(close_event) = web_sys::CloseEvent::new_with_event_init_dict(
343            "close",
344            web_sys::CloseEventInit::new()
345                .code(1000)
346                .reason("client dropped"),
347        ) {
348            let _ = self.ws.dispatch_event(&close_event);
349        }
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use futures::{SinkExt, StreamExt};
357    use wasm_bindgen_test::*;
358
359    wasm_bindgen_test_configure!(run_in_browser);
360
361    #[wasm_bindgen_test]
362    async fn websocket_works() {
363        let ws_echo_server_url =
364            option_env!("WS_ECHO_SERVER_URL").expect("Did you set WS_ECHO_SERVER_URL?");
365
366        let ws = WebSocket::open(ws_echo_server_url).unwrap();
367        let (mut sender, mut receiver) = ws.split();
368
369        sender
370            .send(Message::Text(String::from("test 1")))
371            .await
372            .unwrap();
373        sender
374            .send(Message::Text(String::from("test 2")))
375            .await
376            .unwrap();
377
378        // ignore first message
379        // the echo-server uses it to send it's info in the first message
380        let _ = receiver.next().await;
381
382        assert_eq!(
383            receiver.next().await.unwrap().unwrap(),
384            Message::Text("test 1".to_string())
385        );
386        assert_eq!(
387            receiver.next().await.unwrap().unwrap(),
388            Message::Text("test 2".to_string())
389        );
390    }
391}