1use super::Task;
5use crate::callback::Callback;
6use crate::format::{Binary, FormatError, Text};
7use cfg_if::cfg_if;
8use cfg_match::cfg_match;
9use std::fmt;
10cfg_if! {
11 if #[cfg(feature = "std_web")] {
12 use stdweb::traits::IMessageEvent;
13 use stdweb::web::event::{SocketCloseEvent, SocketErrorEvent, SocketMessageEvent, SocketOpenEvent};
14 use stdweb::web::{IEventTarget, SocketBinaryType, SocketReadyState, WebSocket};
15 } else if #[cfg(feature = "web_sys")] {
16 use gloo::events::EventListener;
17 use js_sys::Uint8Array;
18 use wasm_bindgen::JsCast;
19 use web_sys::{BinaryType, Event, MessageEvent, WebSocket};
20 }
21}
22
23#[derive(Clone, Debug, PartialEq)]
25pub enum WebSocketStatus {
26 Opened,
28 Closed,
30 Error,
32}
33
34#[derive(Clone, Debug, PartialEq, thiserror::Error)]
35pub enum WebSocketError {
37 #[error("{0}")]
38 CreationError(String),
40}
41
42#[must_use = "the connection will be closed when the task is dropped"]
44pub struct WebSocketTask {
45 ws: WebSocket,
46 notification: Callback<WebSocketStatus>,
47 #[cfg(feature = "web_sys")]
48 #[allow(dead_code)]
49 listeners: [EventListener; 4],
50}
51
52#[cfg(feature = "web_sys")]
53impl WebSocketTask {
54 fn new(
55 ws: WebSocket,
56 notification: Callback<WebSocketStatus>,
57 listener_0: EventListener,
58 listeners: [EventListener; 3],
59 ) -> WebSocketTask {
60 let [listener_1, listener_2, listener_3] = listeners;
61 WebSocketTask {
62 ws,
63 notification,
64 listeners: [listener_0, listener_1, listener_2, listener_3],
65 }
66 }
67}
68
69impl fmt::Debug for WebSocketTask {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 f.write_str("WebSocketTask")
72 }
73}
74
75#[derive(Default, Debug)]
77pub struct WebSocketService {}
78
79impl WebSocketService {
80 pub fn connect<OUT: 'static>(
83 url: &str,
84 callback: Callback<OUT>,
85 notification: Callback<WebSocketStatus>,
86 ) -> Result<WebSocketTask, WebSocketError>
87 where
88 OUT: From<Text> + From<Binary>,
89 {
90 cfg_match! {
91 feature = "std_web" => ({
92 let ws = Self::connect_common(url, ¬ification)?.0;
93 ws.add_event_listener(move |event: SocketMessageEvent| {
94 process_both(&event, &callback);
95 });
96 Ok(WebSocketTask { ws, notification })
97 }),
98 feature = "web_sys" => ({
99 let ConnectCommon(ws, listeners) = Self::connect_common(url, ¬ification)?;
100 let listener = EventListener::new(&ws, "message", move |event: &Event| {
101 let event = event.dyn_ref::<MessageEvent>().unwrap();
102 process_both(&event, &callback);
103 });
104 Ok(WebSocketTask::new(ws, notification, listener, listeners))
105 }),
106 }
107 }
108
109 pub fn connect_binary<OUT: 'static>(
114 url: &str,
115 callback: Callback<OUT>,
116 notification: Callback<WebSocketStatus>,
117 ) -> Result<WebSocketTask, WebSocketError>
118 where
119 OUT: From<Binary>,
120 {
121 cfg_match! {
122 feature = "std_web" => ({
123 let ws = Self::connect_common(url, ¬ification)?.0;
124 ws.add_event_listener(move |event: SocketMessageEvent| {
125 process_binary(&event, &callback);
126 });
127 Ok(WebSocketTask { ws, notification })
128 }),
129 feature = "web_sys" => ({
130 let ConnectCommon(ws, listeners) = Self::connect_common(url, ¬ification)?;
131 let listener = EventListener::new(&ws, "message", move |event: &Event| {
132 let event = event.dyn_ref::<MessageEvent>().unwrap();
133 process_binary(&event, &callback);
134 });
135 Ok(WebSocketTask::new(ws, notification, listener, listeners))
136 }),
137 }
138 }
139
140 pub fn connect_text<OUT: 'static>(
145 url: &str,
146 callback: Callback<OUT>,
147 notification: Callback<WebSocketStatus>,
148 ) -> Result<WebSocketTask, WebSocketError>
149 where
150 OUT: From<Text>,
151 {
152 cfg_match! {
153 feature = "std_web" => ({
154 let ws = Self::connect_common(url, ¬ification)?.0;
155 ws.add_event_listener(move |event: SocketMessageEvent| {
156 process_text(&event, &callback);
157 });
158 Ok(WebSocketTask { ws, notification })
159 }),
160 feature = "web_sys" => ({
161 let ConnectCommon(ws, listeners) = Self::connect_common(url, ¬ification)?;
162 let listener = EventListener::new(&ws, "message", move |event: &Event| {
163 let event = event.dyn_ref::<MessageEvent>().unwrap();
164 process_text(&event, &callback);
165 });
166 Ok(WebSocketTask::new(ws, notification, listener, listeners))
167 }),
168 }
169 }
170
171 fn connect_common(
172 url: &str,
173 notification: &Callback<WebSocketStatus>,
174 ) -> Result<ConnectCommon, WebSocketError> {
175 let ws = WebSocket::new(url);
176
177 let ws = ws.map_err(
178 #[cfg(feature = "std_web")]
179 |_| WebSocketError::CreationError("Error opening a WebSocket connection.".to_string()),
180 #[cfg(feature = "web_sys")]
181 |ws_error| {
182 WebSocketError::CreationError(
183 ws_error
184 .unchecked_into::<js_sys::Error>()
185 .to_string()
186 .as_string()
187 .unwrap(),
188 )
189 },
190 )?;
191
192 cfg_match! {
193 feature = "std_web" => ws.set_binary_type(SocketBinaryType::ArrayBuffer),
194 feature = "web_sys" => ws.set_binary_type(BinaryType::Arraybuffer),
195 };
196 let notify = notification.clone();
197 let listener_open =
198 move |#[cfg(feature = "std_web")] _: SocketOpenEvent,
199 #[cfg(feature = "web_sys")] _: &Event| {
200 notify.emit(WebSocketStatus::Opened);
201 };
202 let notify = notification.clone();
203 let listener_close =
204 move |#[cfg(feature = "std_web")] _: SocketCloseEvent,
205 #[cfg(feature = "web_sys")] _: &Event| {
206 notify.emit(WebSocketStatus::Closed);
207 };
208 let notify = notification.clone();
209 let listener_error =
210 move |#[cfg(feature = "std_web")] _: SocketErrorEvent,
211 #[cfg(feature = "web_sys")] _: &Event| {
212 notify.emit(WebSocketStatus::Error);
213 };
214 #[cfg_attr(feature = "std_web", allow(clippy::let_unit_value, unused_variables))]
215 {
216 let listeners = cfg_match! {
217 feature = "std_web" => ({
218 ws.add_event_listener(listener_open);
219 ws.add_event_listener(listener_close);
220 ws.add_event_listener(listener_error);
221 }),
222 feature = "web_sys" => [
223 EventListener::new(&ws, "open", listener_open),
224 EventListener::new(&ws, "close", listener_close),
225 EventListener::new(&ws, "error", listener_error),
226 ],
227 };
228 Ok(ConnectCommon(
229 ws,
230 #[cfg(feature = "web_sys")]
231 listeners,
232 ))
233 }
234 }
235}
236
237struct ConnectCommon(WebSocket, #[cfg(feature = "web_sys")] [EventListener; 3]);
238
239fn process_binary<OUT: 'static>(
240 #[cfg(feature = "std_web")] event: &SocketMessageEvent,
241 #[cfg(feature = "web_sys")] event: &MessageEvent,
242 callback: &Callback<OUT>,
243) where
244 OUT: From<Binary>,
245{
246 #[cfg(feature = "std_web")]
247 let bytes = event.data().into_array_buffer();
248
249 #[cfg(feature = "web_sys")]
250 let bytes = if !event.data().is_string() {
251 Some(event.data())
252 } else {
253 None
254 };
255
256 let data = if let Some(bytes) = bytes {
257 let bytes: Vec<u8> = cfg_match! {
258 feature = "std_web" => bytes.into(),
259 feature = "web_sys" => Uint8Array::new(&bytes).to_vec(),
260 };
261 Ok(bytes)
262 } else {
263 Err(FormatError::ReceivedTextForBinary.into())
264 };
265
266 let out = OUT::from(data);
267 callback.emit(out);
268}
269
270fn process_text<OUT: 'static>(
271 #[cfg(feature = "std_web")] event: &SocketMessageEvent,
272 #[cfg(feature = "web_sys")] event: &MessageEvent,
273 callback: &Callback<OUT>,
274) where
275 OUT: From<Text>,
276{
277 let text = cfg_match! {
278 feature = "std_web" => event.data().into_text(),
279 feature = "web_sys" => event.data().as_string(),
280 };
281
282 let data = if let Some(text) = text {
283 Ok(text)
284 } else {
285 Err(FormatError::ReceivedBinaryForText.into())
286 };
287
288 let out = OUT::from(data);
289 callback.emit(out);
290}
291
292fn process_both<OUT: 'static>(
293 #[cfg(feature = "std_web")] event: &SocketMessageEvent,
294 #[cfg(feature = "web_sys")] event: &MessageEvent,
295 callback: &Callback<OUT>,
296) where
297 OUT: From<Text> + From<Binary>,
298{
299 #[cfg(feature = "std_web")]
300 let is_text = event.data().into_text().is_some();
301
302 #[cfg(feature = "web_sys")]
303 let is_text = event.data().is_string();
304
305 if is_text {
306 process_text(event, callback);
307 } else {
308 process_binary(event, callback);
309 }
310}
311
312impl WebSocketTask {
313 pub fn send<IN>(&mut self, data: IN)
315 where
316 IN: Into<Text>,
317 {
318 if let Ok(body) = data.into() {
319 let result = cfg_match! {
320 feature = "std_web" => self.ws.send_text(&body),
321 feature = "web_sys" => self.ws.send_with_str(&body),
322 };
323
324 if result.is_err() {
325 self.notification.emit(WebSocketStatus::Error);
326 }
327 }
328 }
329
330 pub fn send_binary<IN>(&mut self, data: IN)
332 where
333 IN: Into<Binary>,
334 {
335 if let Ok(body) = data.into() {
336 let result = cfg_match! {
337 feature = "std_web" => self.ws.send_bytes(&body),
338 feature = "web_sys" => self.ws.send_with_u8_array(&body),
339 };
340
341 if result.is_err() {
342 self.notification.emit(WebSocketStatus::Error);
343 }
344 }
345 }
346}
347
348impl Task for WebSocketTask {
349 fn is_active(&self) -> bool {
350 cfg_match! {
351 feature = "std_web" => matches!(self.ws.ready_state(), SocketReadyState::Connecting | SocketReadyState::Open),
352 feature = "web_sys" => matches!(self.ws.ready_state(), WebSocket::CONNECTING | WebSocket::OPEN),
353 }
354 }
355}
356
357impl Drop for WebSocketTask {
358 fn drop(&mut self) {
359 if self.is_active() {
360 cfg_match! {
361 feature = "std_web" => self.ws.close(),
362 feature = "web_sys" => self.ws.close().ok(),
363 };
364 }
365 }
366}
367
368#[cfg(test)]
369#[cfg(all(feature = "wasm_test", feature = "echo_server_test"))]
370mod tests {
371 use super::*;
372 use crate::callback::{test_util::CallbackFuture, Callback};
373 use crate::format::{FormatError, Json};
374 use crate::services::TimeoutService;
375 use serde::{Deserialize, Serialize};
376 use std::time::Duration;
377 use wasm_bindgen_test::{wasm_bindgen_test as test, wasm_bindgen_test_configure};
378
379 wasm_bindgen_test_configure!(run_in_browser);
380
381 const fn echo_server_url() -> &'static str {
382 env!("ECHO_SERVER_URL")
384 }
385
386 async fn ignore_first_message<T>(cb_future: &CallbackFuture<T>) {
388 let sleep_future = CallbackFuture::<()>::default();
389 let _sleep_task =
390 TimeoutService::spawn(Duration::from_millis(10), sleep_future.clone().into());
391 sleep_future.await;
392 cb_future.ready();
393 }
394
395 #[derive(Serialize, Deserialize, Debug, PartialEq)]
396 struct Message {
397 test: String,
398 }
399
400 #[test]
401 async fn connect() {
402 let url = echo_server_url();
403 let cb_future = CallbackFuture::<Json<Result<Message, anyhow::Error>>>::default();
404 let callback: Callback<_> = cb_future.clone().into();
405 let status_future = CallbackFuture::<WebSocketStatus>::default();
406 let notification: Callback<_> = status_future.clone().into();
407
408 let mut task = WebSocketService::connect(url, callback, notification).unwrap();
409 assert_eq!(status_future.await, WebSocketStatus::Opened);
410 ignore_first_message(&cb_future).await;
411
412 let msg = Message {
413 test: String::from("hello"),
414 };
415
416 task.send(Json(&msg));
417 match cb_future.clone().await {
418 Json(Ok(received)) => assert_eq!(received, msg),
419 Json(Err(err)) => assert!(false, err),
420 }
421
422 task.send_binary(Json(&msg));
423 match cb_future.await {
424 Json(Ok(received)) => assert_eq!(received, msg),
425 Json(Err(err)) => assert!(false, err),
426 }
427 }
428
429 #[test]
430 #[cfg(feature = "web_sys")]
431 async fn test_invalid_url_error() {
432 let url = "syntactically-invalid";
433 let cb_future = CallbackFuture::<Json<Result<Message, anyhow::Error>>>::default();
434 let callback = cb_future.clone().into();
435 let status_future = CallbackFuture::<WebSocketStatus>::default();
436 let notification: Callback<_> = status_future.clone().into();
437 let task = WebSocketService::connect_text(url, callback, notification);
438 assert!(task.is_err());
439 if let Err(err) = task {
440 #[allow(irrefutable_let_patterns)]
441 if let WebSocketError::CreationError(creation_err) = err {
442 assert!(creation_err.starts_with("SyntaxError:"));
443 } else {
444 assert!(false);
445 }
446 }
447 }
448
449 #[test]
450 async fn connect_text() {
451 let url = echo_server_url();
452 let cb_future = CallbackFuture::<Json<Result<Message, anyhow::Error>>>::default();
453 let callback: Callback<_> = cb_future.clone().into();
454 let status_future = CallbackFuture::<WebSocketStatus>::default();
455 let notification: Callback<_> = status_future.clone().into();
456
457 let mut task = WebSocketService::connect_text(url, callback, notification).unwrap();
458 assert_eq!(status_future.await, WebSocketStatus::Opened);
459 ignore_first_message(&cb_future).await;
460
461 let msg = Message {
462 test: String::from("hello"),
463 };
464
465 task.send(Json(&msg));
466 match cb_future.clone().await {
467 Json(Ok(received)) => assert_eq!(received, msg),
468 Json(Err(err)) => assert!(false, err),
469 }
470
471 task.send_binary(Json(&msg));
472 match cb_future.await {
473 Json(Ok(received)) => assert!(false, received),
474 Json(Err(err)) => assert_eq!(
475 err.to_string(),
476 FormatError::ReceivedBinaryForText.to_string()
477 ),
478 }
479 }
480
481 #[test]
482 async fn connect_binary() {
483 let url = echo_server_url();
484 let cb_future = CallbackFuture::<Json<Result<Message, anyhow::Error>>>::default();
485 let callback: Callback<_> = cb_future.clone().into();
486 let status_future = CallbackFuture::<WebSocketStatus>::default();
487 let notification: Callback<_> = status_future.clone().into();
488
489 let mut task = WebSocketService::connect_binary(url, callback, notification).unwrap();
490 assert_eq!(status_future.await, WebSocketStatus::Opened);
491 ignore_first_message(&cb_future).await;
492
493 let msg = Message {
494 test: String::from("hello"),
495 };
496
497 task.send_binary(Json(&msg));
498 match cb_future.clone().await {
499 Json(Ok(received)) => assert_eq!(received, msg),
500 Json(Err(err)) => assert!(false, err),
501 }
502
503 task.send(Json(&msg));
504 match cb_future.await {
505 Json(Ok(received)) => assert!(false, received),
506 Json(Err(err)) => assert_eq!(
507 err.to_string(),
508 FormatError::ReceivedTextForBinary.to_string()
509 ),
510 }
511 }
512
513 #[test]
514 #[cfg(feature = "web_sys")]
515 async fn is_active_while_connecting() {
516 let url = echo_server_url();
517 let cb_future = CallbackFuture::<Json<Result<Message, anyhow::Error>>>::default();
518 let callback: Callback<_> = cb_future.clone().into();
519 let status_future = CallbackFuture::<WebSocketStatus>::default();
520 let notification: Callback<_> = status_future.clone().into();
521
522 let task = WebSocketService::connect_text(url, callback, notification).unwrap();
523
524 assert!(task.is_active());
531
532 assert_eq!(status_future.await, WebSocketStatus::Opened);
533 }
534
535 #[test]
536 #[cfg(feature = "web_sys")]
537 async fn drop_while_still_connecting() {
538 let url = echo_server_url();
539 let cb_future = CallbackFuture::<Json<Result<Message, anyhow::Error>>>::default();
540 let callback: Callback<_> = cb_future.clone().into();
541 let status_future = CallbackFuture::<WebSocketStatus>::default();
542 let notification: Callback<_> = status_future.clone().into();
543
544 let task = WebSocketService::connect_text(url, callback, notification).unwrap();
545 let ws = task.ws.clone();
546
547 drop(task);
554
555 let ws_ready_state = ws.ready_state();
556 assert!(cfg_match! {
557 feature = "std_web" => matches!(ws_ready_state, SocketReadyState::Closing | SocketReadyState::Closed),
558 feature = "web_sys" => matches!(ws_ready_state, WebSocket::CLOSING | WebSocket::CLOSED),
559 });
560 }
561}