yew_stdweb/services/
websocket.rs

1//! A service to connect to a server through the
2//! [`WebSocket` Protocol](https://tools.ietf.org/html/rfc6455).
3
4use super::Task;
5use crate::callback::Callback;
6use crate::format::{Binary, FormatError, Text};
7use cfg_if::cfg_if;
8use cfg_match::cfg_match;
9use std::fmt;
10cfg_if! {
11    if #[cfg(feature = "std_web")] {
12        use stdweb::traits::IMessageEvent;
13        use stdweb::web::event::{SocketCloseEvent, SocketErrorEvent, SocketMessageEvent, SocketOpenEvent};
14        use stdweb::web::{IEventTarget, SocketBinaryType, SocketReadyState, WebSocket};
15    } else if #[cfg(feature = "web_sys")] {
16        use gloo::events::EventListener;
17        use js_sys::Uint8Array;
18        use wasm_bindgen::JsCast;
19        use web_sys::{BinaryType, Event, MessageEvent, WebSocket};
20    }
21}
22
23/// The status of a WebSocket connection. Used for status notifications.
24#[derive(Clone, Debug, PartialEq)]
25pub enum WebSocketStatus {
26    /// Fired when a WebSocket connection has opened.
27    Opened,
28    /// Fired when a WebSocket connection has closed.
29    Closed,
30    /// Fired when a WebSocket connection has failed.
31    Error,
32}
33
34#[derive(Clone, Debug, PartialEq, thiserror::Error)]
35/// An error encountered by a WebSocket.
36pub enum WebSocketError {
37    #[error("{0}")]
38    /// An error encountered when creating the WebSocket.
39    CreationError(String),
40}
41
42/// A handle to control the WebSocket connection. Implements `Task` and could be canceled.
43#[must_use = "the connection will be closed when the task is dropped"]
44pub struct WebSocketTask {
45    ws: WebSocket,
46    notification: Callback<WebSocketStatus>,
47    #[cfg(feature = "web_sys")]
48    #[allow(dead_code)]
49    listeners: [EventListener; 4],
50}
51
52#[cfg(feature = "web_sys")]
53impl WebSocketTask {
54    fn new(
55        ws: WebSocket,
56        notification: Callback<WebSocketStatus>,
57        listener_0: EventListener,
58        listeners: [EventListener; 3],
59    ) -> WebSocketTask {
60        let [listener_1, listener_2, listener_3] = listeners;
61        WebSocketTask {
62            ws,
63            notification,
64            listeners: [listener_0, listener_1, listener_2, listener_3],
65        }
66    }
67}
68
69impl fmt::Debug for WebSocketTask {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.write_str("WebSocketTask")
72    }
73}
74
75/// A WebSocket service attached to a user context.
76#[derive(Default, Debug)]
77pub struct WebSocketService {}
78
79impl WebSocketService {
80    /// Connects to a server through a WebSocket connection. Needs two callbacks; one is passed
81    /// data, the other is passed updates about the WebSocket's status.
82    pub fn connect<OUT: 'static>(
83        url: &str,
84        callback: Callback<OUT>,
85        notification: Callback<WebSocketStatus>,
86    ) -> Result<WebSocketTask, WebSocketError>
87    where
88        OUT: From<Text> + From<Binary>,
89    {
90        cfg_match! {
91            feature = "std_web" => ({
92                let ws = Self::connect_common(url, &notification)?.0;
93                ws.add_event_listener(move |event: SocketMessageEvent| {
94                    process_both(&event, &callback);
95                });
96                Ok(WebSocketTask { ws, notification })
97            }),
98            feature = "web_sys" => ({
99                let ConnectCommon(ws, listeners) = Self::connect_common(url, &notification)?;
100                let listener = EventListener::new(&ws, "message", move |event: &Event| {
101                    let event = event.dyn_ref::<MessageEvent>().unwrap();
102                    process_both(&event, &callback);
103                });
104                Ok(WebSocketTask::new(ws, notification, listener, listeners))
105            }),
106        }
107    }
108
109    /// Connects to a server through a WebSocket connection, like connect,
110    /// but only processes binary frames. Text frames are silently
111    /// ignored. Needs two functions to generate data and notification
112    /// messages.
113    pub fn connect_binary<OUT: 'static>(
114        url: &str,
115        callback: Callback<OUT>,
116        notification: Callback<WebSocketStatus>,
117    ) -> Result<WebSocketTask, WebSocketError>
118    where
119        OUT: From<Binary>,
120    {
121        cfg_match! {
122            feature = "std_web" => ({
123                let ws = Self::connect_common(url, &notification)?.0;
124                ws.add_event_listener(move |event: SocketMessageEvent| {
125                    process_binary(&event, &callback);
126                });
127                Ok(WebSocketTask { ws, notification })
128            }),
129            feature = "web_sys" => ({
130                let ConnectCommon(ws, listeners) = Self::connect_common(url, &notification)?;
131                let listener = EventListener::new(&ws, "message", move |event: &Event| {
132                    let event = event.dyn_ref::<MessageEvent>().unwrap();
133                    process_binary(&event, &callback);
134                });
135                Ok(WebSocketTask::new(ws, notification, listener, listeners))
136            }),
137        }
138    }
139
140    /// Connects to a server through a WebSocket connection, like connect,
141    /// but only processes text frames. Binary frames are silently
142    /// ignored. Needs two functions to generate data and notification
143    /// messages.
144    pub fn connect_text<OUT: 'static>(
145        url: &str,
146        callback: Callback<OUT>,
147        notification: Callback<WebSocketStatus>,
148    ) -> Result<WebSocketTask, WebSocketError>
149    where
150        OUT: From<Text>,
151    {
152        cfg_match! {
153            feature = "std_web" => ({
154                let ws = Self::connect_common(url, &notification)?.0;
155                ws.add_event_listener(move |event: SocketMessageEvent| {
156                    process_text(&event, &callback);
157                });
158                Ok(WebSocketTask { ws, notification })
159            }),
160            feature = "web_sys" => ({
161                let ConnectCommon(ws, listeners) = Self::connect_common(url, &notification)?;
162                let listener = EventListener::new(&ws, "message", move |event: &Event| {
163                    let event = event.dyn_ref::<MessageEvent>().unwrap();
164                    process_text(&event, &callback);
165                });
166                Ok(WebSocketTask::new(ws, notification, listener, listeners))
167            }),
168        }
169    }
170
171    fn connect_common(
172        url: &str,
173        notification: &Callback<WebSocketStatus>,
174    ) -> Result<ConnectCommon, WebSocketError> {
175        let ws = WebSocket::new(url);
176
177        let ws = ws.map_err(
178            #[cfg(feature = "std_web")]
179            |_| WebSocketError::CreationError("Error opening a WebSocket connection.".to_string()),
180            #[cfg(feature = "web_sys")]
181            |ws_error| {
182                WebSocketError::CreationError(
183                    ws_error
184                        .unchecked_into::<js_sys::Error>()
185                        .to_string()
186                        .as_string()
187                        .unwrap(),
188                )
189            },
190        )?;
191
192        cfg_match! {
193            feature = "std_web" => ws.set_binary_type(SocketBinaryType::ArrayBuffer),
194            feature = "web_sys" => ws.set_binary_type(BinaryType::Arraybuffer),
195        };
196        let notify = notification.clone();
197        let listener_open =
198            move |#[cfg(feature = "std_web")] _: SocketOpenEvent,
199                  #[cfg(feature = "web_sys")] _: &Event| {
200                notify.emit(WebSocketStatus::Opened);
201            };
202        let notify = notification.clone();
203        let listener_close =
204            move |#[cfg(feature = "std_web")] _: SocketCloseEvent,
205                  #[cfg(feature = "web_sys")] _: &Event| {
206                notify.emit(WebSocketStatus::Closed);
207            };
208        let notify = notification.clone();
209        let listener_error =
210            move |#[cfg(feature = "std_web")] _: SocketErrorEvent,
211                  #[cfg(feature = "web_sys")] _: &Event| {
212                notify.emit(WebSocketStatus::Error);
213            };
214        #[cfg_attr(feature = "std_web", allow(clippy::let_unit_value, unused_variables))]
215        {
216            let listeners = cfg_match! {
217                feature = "std_web" => ({
218                    ws.add_event_listener(listener_open);
219                    ws.add_event_listener(listener_close);
220                    ws.add_event_listener(listener_error);
221                }),
222                feature = "web_sys" => [
223                    EventListener::new(&ws, "open", listener_open),
224                    EventListener::new(&ws, "close", listener_close),
225                    EventListener::new(&ws, "error", listener_error),
226                ],
227            };
228            Ok(ConnectCommon(
229                ws,
230                #[cfg(feature = "web_sys")]
231                listeners,
232            ))
233        }
234    }
235}
236
237struct ConnectCommon(WebSocket, #[cfg(feature = "web_sys")] [EventListener; 3]);
238
239fn process_binary<OUT: 'static>(
240    #[cfg(feature = "std_web")] event: &SocketMessageEvent,
241    #[cfg(feature = "web_sys")] event: &MessageEvent,
242    callback: &Callback<OUT>,
243) where
244    OUT: From<Binary>,
245{
246    #[cfg(feature = "std_web")]
247    let bytes = event.data().into_array_buffer();
248
249    #[cfg(feature = "web_sys")]
250    let bytes = if !event.data().is_string() {
251        Some(event.data())
252    } else {
253        None
254    };
255
256    let data = if let Some(bytes) = bytes {
257        let bytes: Vec<u8> = cfg_match! {
258            feature = "std_web" => bytes.into(),
259            feature = "web_sys" => Uint8Array::new(&bytes).to_vec(),
260        };
261        Ok(bytes)
262    } else {
263        Err(FormatError::ReceivedTextForBinary.into())
264    };
265
266    let out = OUT::from(data);
267    callback.emit(out);
268}
269
270fn process_text<OUT: 'static>(
271    #[cfg(feature = "std_web")] event: &SocketMessageEvent,
272    #[cfg(feature = "web_sys")] event: &MessageEvent,
273    callback: &Callback<OUT>,
274) where
275    OUT: From<Text>,
276{
277    let text = cfg_match! {
278        feature = "std_web" => event.data().into_text(),
279        feature = "web_sys" => event.data().as_string(),
280    };
281
282    let data = if let Some(text) = text {
283        Ok(text)
284    } else {
285        Err(FormatError::ReceivedBinaryForText.into())
286    };
287
288    let out = OUT::from(data);
289    callback.emit(out);
290}
291
292fn process_both<OUT: 'static>(
293    #[cfg(feature = "std_web")] event: &SocketMessageEvent,
294    #[cfg(feature = "web_sys")] event: &MessageEvent,
295    callback: &Callback<OUT>,
296) where
297    OUT: From<Text> + From<Binary>,
298{
299    #[cfg(feature = "std_web")]
300    let is_text = event.data().into_text().is_some();
301
302    #[cfg(feature = "web_sys")]
303    let is_text = event.data().is_string();
304
305    if is_text {
306        process_text(event, callback);
307    } else {
308        process_binary(event, callback);
309    }
310}
311
312impl WebSocketTask {
313    /// Sends data to a WebSocket connection.
314    pub fn send<IN>(&mut self, data: IN)
315    where
316        IN: Into<Text>,
317    {
318        if let Ok(body) = data.into() {
319            let result = cfg_match! {
320                feature = "std_web" => self.ws.send_text(&body),
321                feature = "web_sys" => self.ws.send_with_str(&body),
322            };
323
324            if result.is_err() {
325                self.notification.emit(WebSocketStatus::Error);
326            }
327        }
328    }
329
330    /// Sends binary data to a WebSocket connection.
331    pub fn send_binary<IN>(&mut self, data: IN)
332    where
333        IN: Into<Binary>,
334    {
335        if let Ok(body) = data.into() {
336            let result = cfg_match! {
337                feature = "std_web" => self.ws.send_bytes(&body),
338                feature = "web_sys" => self.ws.send_with_u8_array(&body),
339            };
340
341            if result.is_err() {
342                self.notification.emit(WebSocketStatus::Error);
343            }
344        }
345    }
346}
347
348impl Task for WebSocketTask {
349    fn is_active(&self) -> bool {
350        cfg_match! {
351            feature = "std_web" => matches!(self.ws.ready_state(), SocketReadyState::Connecting | SocketReadyState::Open),
352            feature = "web_sys" => matches!(self.ws.ready_state(), WebSocket::CONNECTING | WebSocket::OPEN),
353        }
354    }
355}
356
357impl Drop for WebSocketTask {
358    fn drop(&mut self) {
359        if self.is_active() {
360            cfg_match! {
361                feature = "std_web" => self.ws.close(),
362                feature = "web_sys" => self.ws.close().ok(),
363            };
364        }
365    }
366}
367
368#[cfg(test)]
369#[cfg(all(feature = "wasm_test", feature = "echo_server_test"))]
370mod tests {
371    use super::*;
372    use crate::callback::{test_util::CallbackFuture, Callback};
373    use crate::format::{FormatError, Json};
374    use crate::services::TimeoutService;
375    use serde::{Deserialize, Serialize};
376    use std::time::Duration;
377    use wasm_bindgen_test::{wasm_bindgen_test as test, wasm_bindgen_test_configure};
378
379    wasm_bindgen_test_configure!(run_in_browser);
380
381    const fn echo_server_url() -> &'static str {
382        // we can't do this at runtime because we're running in the browser.
383        env!("ECHO_SERVER_URL")
384    }
385
386    // Ignore the first response from the echo server
387    async fn ignore_first_message<T>(cb_future: &CallbackFuture<T>) {
388        let sleep_future = CallbackFuture::<()>::default();
389        let _sleep_task =
390            TimeoutService::spawn(Duration::from_millis(10), sleep_future.clone().into());
391        sleep_future.await;
392        cb_future.ready();
393    }
394
395    #[derive(Serialize, Deserialize, Debug, PartialEq)]
396    struct Message {
397        test: String,
398    }
399
400    #[test]
401    async fn connect() {
402        let url = echo_server_url();
403        let cb_future = CallbackFuture::<Json<Result<Message, anyhow::Error>>>::default();
404        let callback: Callback<_> = cb_future.clone().into();
405        let status_future = CallbackFuture::<WebSocketStatus>::default();
406        let notification: Callback<_> = status_future.clone().into();
407
408        let mut task = WebSocketService::connect(url, callback, notification).unwrap();
409        assert_eq!(status_future.await, WebSocketStatus::Opened);
410        ignore_first_message(&cb_future).await;
411
412        let msg = Message {
413            test: String::from("hello"),
414        };
415
416        task.send(Json(&msg));
417        match cb_future.clone().await {
418            Json(Ok(received)) => assert_eq!(received, msg),
419            Json(Err(err)) => assert!(false, err),
420        }
421
422        task.send_binary(Json(&msg));
423        match cb_future.await {
424            Json(Ok(received)) => assert_eq!(received, msg),
425            Json(Err(err)) => assert!(false, err),
426        }
427    }
428
429    #[test]
430    #[cfg(feature = "web_sys")]
431    async fn test_invalid_url_error() {
432        let url = "syntactically-invalid";
433        let cb_future = CallbackFuture::<Json<Result<Message, anyhow::Error>>>::default();
434        let callback = cb_future.clone().into();
435        let status_future = CallbackFuture::<WebSocketStatus>::default();
436        let notification: Callback<_> = status_future.clone().into();
437        let task = WebSocketService::connect_text(url, callback, notification);
438        assert!(task.is_err());
439        if let Err(err) = task {
440            #[allow(irrefutable_let_patterns)]
441            if let WebSocketError::CreationError(creation_err) = err {
442                assert!(creation_err.starts_with("SyntaxError:"));
443            } else {
444                assert!(false);
445            }
446        }
447    }
448
449    #[test]
450    async fn connect_text() {
451        let url = echo_server_url();
452        let cb_future = CallbackFuture::<Json<Result<Message, anyhow::Error>>>::default();
453        let callback: Callback<_> = cb_future.clone().into();
454        let status_future = CallbackFuture::<WebSocketStatus>::default();
455        let notification: Callback<_> = status_future.clone().into();
456
457        let mut task = WebSocketService::connect_text(url, callback, notification).unwrap();
458        assert_eq!(status_future.await, WebSocketStatus::Opened);
459        ignore_first_message(&cb_future).await;
460
461        let msg = Message {
462            test: String::from("hello"),
463        };
464
465        task.send(Json(&msg));
466        match cb_future.clone().await {
467            Json(Ok(received)) => assert_eq!(received, msg),
468            Json(Err(err)) => assert!(false, err),
469        }
470
471        task.send_binary(Json(&msg));
472        match cb_future.await {
473            Json(Ok(received)) => assert!(false, received),
474            Json(Err(err)) => assert_eq!(
475                err.to_string(),
476                FormatError::ReceivedBinaryForText.to_string()
477            ),
478        }
479    }
480
481    #[test]
482    async fn connect_binary() {
483        let url = echo_server_url();
484        let cb_future = CallbackFuture::<Json<Result<Message, anyhow::Error>>>::default();
485        let callback: Callback<_> = cb_future.clone().into();
486        let status_future = CallbackFuture::<WebSocketStatus>::default();
487        let notification: Callback<_> = status_future.clone().into();
488
489        let mut task = WebSocketService::connect_binary(url, callback, notification).unwrap();
490        assert_eq!(status_future.await, WebSocketStatus::Opened);
491        ignore_first_message(&cb_future).await;
492
493        let msg = Message {
494            test: String::from("hello"),
495        };
496
497        task.send_binary(Json(&msg));
498        match cb_future.clone().await {
499            Json(Ok(received)) => assert_eq!(received, msg),
500            Json(Err(err)) => assert!(false, err),
501        }
502
503        task.send(Json(&msg));
504        match cb_future.await {
505            Json(Ok(received)) => assert!(false, received),
506            Json(Err(err)) => assert_eq!(
507                err.to_string(),
508                FormatError::ReceivedTextForBinary.to_string()
509            ),
510        }
511    }
512
513    #[test]
514    #[cfg(feature = "web_sys")]
515    async fn is_active_while_connecting() {
516        let url = echo_server_url();
517        let cb_future = CallbackFuture::<Json<Result<Message, anyhow::Error>>>::default();
518        let callback: Callback<_> = cb_future.clone().into();
519        let status_future = CallbackFuture::<WebSocketStatus>::default();
520        let notification: Callback<_> = status_future.clone().into();
521
522        let task = WebSocketService::connect_text(url, callback, notification).unwrap();
523
524        // NOTE: There's a bit of a race here between checking `is_active`
525        // and the WebSocket completing the connection handshake.
526        // The handshake *should* take sufficient time to complete that we
527        // can see it still in the `WebSocket::CONNECTING` state, but it's
528        // not guaranteed.  If someone has a way to guarantee we capture
529        // the WebSocket in the connecting state, please update this test.
530        assert!(task.is_active());
531
532        assert_eq!(status_future.await, WebSocketStatus::Opened);
533    }
534
535    #[test]
536    #[cfg(feature = "web_sys")]
537    async fn drop_while_still_connecting() {
538        let url = echo_server_url();
539        let cb_future = CallbackFuture::<Json<Result<Message, anyhow::Error>>>::default();
540        let callback: Callback<_> = cb_future.clone().into();
541        let status_future = CallbackFuture::<WebSocketStatus>::default();
542        let notification: Callback<_> = status_future.clone().into();
543
544        let task = WebSocketService::connect_text(url, callback, notification).unwrap();
545        let ws = task.ws.clone();
546
547        // NOTE: There's a bit of a race here between dropping the
548        // `WebSocketTask` and the WebSocket completing the connection
549        // handshake.  The handshake *should* take sufficient time to complete
550        // that we can see it still in the `WebSocket::CONNECTING` state, but
551        // it's not guaranteed.  If someone has a way to guarantee we capture
552        // the WebSocket in the connecting state, please update this test.
553        drop(task);
554
555        let ws_ready_state = ws.ready_state();
556        assert!(cfg_match! {
557            feature = "std_web" => matches!(ws_ready_state, SocketReadyState::Closing | SocketReadyState::Closed),
558            feature = "web_sys" => matches!(ws_ready_state, WebSocket::CLOSING | WebSocket::CLOSED),
559        });
560    }
561}