1use crate::js_to_js_error;
31use crate::websocket::{events::CloseEvent, Message, State, WebSocketError};
32use futures_channel::mpsc;
33use futures_core::{ready, Stream};
34use futures_sink::Sink;
35use gloo_utils::errors::JsError;
36use pin_project::{pin_project, pinned_drop};
37use std::cell::RefCell;
38use std::pin::Pin;
39use std::rc::Rc;
40use std::task::{Context, Poll, Waker};
41use wasm_bindgen::prelude::*;
42use wasm_bindgen::JsCast;
43use web_sys::{BinaryType, MessageEvent};
44
45#[allow(missing_debug_implementations)]
47#[pin_project(PinnedDrop)]
48pub struct WebSocket {
49 ws: web_sys::WebSocket,
50 sink_waker: Rc<RefCell<Option<Waker>>>,
51 #[pin]
52 message_receiver: mpsc::UnboundedReceiver<StreamMessage>,
53 #[allow(clippy::type_complexity)]
54 closures: (
55 Closure<dyn FnMut()>,
56 Closure<dyn FnMut(MessageEvent)>,
57 Closure<dyn FnMut(web_sys::Event)>,
58 Closure<dyn FnMut(web_sys::CloseEvent)>,
59 ),
60 #[cfg(feature = "io-util")]
64 pub(super) read_pending_bytes: Option<Vec<u8>>, }
66
67impl WebSocket {
68 pub fn open(url: &str) -> Result<Self, JsError> {
78 Self::setup(web_sys::WebSocket::new(url))
79 }
80
81 pub fn open_with_protocol(url: &str, protocol: &str) -> Result<Self, JsError> {
92 Self::setup(web_sys::WebSocket::new_with_str(url, protocol))
93 }
94
95 #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
109 #[cfg(feature = "json")]
110 pub fn open_with_protocols<S: AsRef<str> + serde::Serialize>(
111 url: &str,
112 protocols: &[S],
113 ) -> Result<Self, JsError> {
114 let json = <JsValue as gloo_utils::format::JsValueSerdeExt>::from_serde(protocols)
115 .map_err(|err| {
116 js_sys::Error::new(&format!(
117 "Failed to convert protocols to Javascript value: {err}"
118 ))
119 })?;
120 Self::setup(web_sys::WebSocket::new_with_str_sequence(url, &json))
121 }
122
123 fn setup(ws: Result<web_sys::WebSocket, JsValue>) -> Result<Self, JsError> {
124 let waker: Rc<RefCell<Option<Waker>>> = Rc::new(RefCell::new(None));
125 let ws = ws.map_err(js_to_js_error)?;
126
127 ws.set_binary_type(BinaryType::Arraybuffer);
131
132 let (sender, receiver) = mpsc::unbounded();
133
134 let open_callback: Closure<dyn FnMut()> = {
135 let waker = Rc::clone(&waker);
136 Closure::wrap(Box::new(move || {
137 if let Some(waker) = waker.borrow_mut().take() {
138 waker.wake();
139 }
140 }) as Box<dyn FnMut()>)
141 };
142
143 ws.add_event_listener_with_callback_and_add_event_listener_options(
144 "open",
145 open_callback.as_ref().unchecked_ref(),
146 web_sys::AddEventListenerOptions::new().once(true),
147 )
148 .map_err(js_to_js_error)?;
149
150 let message_callback: Closure<dyn FnMut(MessageEvent)> = {
151 let sender = sender.clone();
152 Closure::wrap(Box::new(move |e: MessageEvent| {
153 let msg = parse_message(e);
154 let _ = sender.unbounded_send(StreamMessage::Message(msg));
155 }) as Box<dyn FnMut(MessageEvent)>)
156 };
157
158 ws.add_event_listener_with_callback("message", message_callback.as_ref().unchecked_ref())
159 .map_err(js_to_js_error)?;
160
161 let error_callback: Closure<dyn FnMut(web_sys::Event)> = {
162 let sender = sender.clone();
163 let waker = Rc::clone(&waker);
164 Closure::wrap(Box::new(move |_e: web_sys::Event| {
165 if let Some(waker) = waker.borrow_mut().take() {
166 waker.wake();
167 }
168 let _ = sender.unbounded_send(StreamMessage::ErrorEvent);
169 }) as Box<dyn FnMut(web_sys::Event)>)
170 };
171
172 ws.add_event_listener_with_callback("error", error_callback.as_ref().unchecked_ref())
173 .map_err(js_to_js_error)?;
174
175 let close_callback: Closure<dyn FnMut(web_sys::CloseEvent)> = {
176 Closure::wrap(Box::new(move |e: web_sys::CloseEvent| {
177 let close_event = CloseEvent {
178 code: e.code(),
179 reason: e.reason(),
180 was_clean: e.was_clean(),
181 };
182 let _ = sender.unbounded_send(StreamMessage::CloseEvent(close_event));
183 let _ = sender.unbounded_send(StreamMessage::ConnectionClose);
184 }) as Box<dyn FnMut(web_sys::CloseEvent)>)
185 };
186
187 ws.add_event_listener_with_callback_and_add_event_listener_options(
188 "close",
189 close_callback.as_ref().unchecked_ref(),
190 web_sys::AddEventListenerOptions::new().once(true),
191 )
192 .map_err(js_to_js_error)?;
193
194 Ok(Self {
195 ws,
196 sink_waker: waker,
197 message_receiver: receiver,
198 closures: (
199 open_callback,
200 message_callback,
201 error_callback,
202 close_callback,
203 ),
204 #[cfg(feature = "io-util")]
205 read_pending_bytes: None,
206 })
207 }
208
209 pub fn close(self, code: Option<u16>, reason: Option<&str>) -> Result<(), JsError> {
214 let result = match (code, reason) {
215 (None, None) => self.ws.close(),
216 (Some(code), None) => self.ws.close_with_code(code),
217 (Some(code), Some(reason)) => self.ws.close_with_code_and_reason(code, reason),
218 (None, Some(reason)) => self.ws.close_with_code_and_reason(1005, reason),
221 };
222 result.map_err(js_to_js_error)
223 }
224
225 pub fn state(&self) -> State {
227 let ready_state = self.ws.ready_state();
228 match ready_state {
229 0 => State::Connecting,
230 1 => State::Open,
231 2 => State::Closing,
232 3 => State::Closed,
233 _ => unreachable!(),
234 }
235 }
236
237 pub fn extensions(&self) -> String {
239 self.ws.extensions()
240 }
241
242 pub fn protocol(&self) -> String {
244 self.ws.protocol()
245 }
246}
247
248impl TryFrom<web_sys::WebSocket> for WebSocket {
249 type Error = JsError;
250
251 fn try_from(ws: web_sys::WebSocket) -> Result<Self, Self::Error> {
252 Self::setup(Ok(ws))
253 }
254}
255
256#[derive(Clone)]
257enum StreamMessage {
258 ErrorEvent,
259 CloseEvent(CloseEvent),
260 Message(Message),
261 ConnectionClose,
262}
263
264fn parse_message(event: MessageEvent) -> Message {
265 if let Ok(array_buffer) = event.data().dyn_into::<js_sys::ArrayBuffer>() {
266 let array = js_sys::Uint8Array::new(&array_buffer);
267 Message::Bytes(array.to_vec())
268 } else if let Ok(txt) = event.data().dyn_into::<js_sys::JsString>() {
269 Message::Text(String::from(&txt))
270 } else {
271 unreachable!("message event, received Unknown: {:?}", event.data());
272 }
273}
274
275impl Sink<Message> for WebSocket {
276 type Error = WebSocketError;
277
278 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
279 let ready_state = self.ws.ready_state();
280 if ready_state == 0 {
281 *self.sink_waker.borrow_mut() = Some(cx.waker().clone());
282 Poll::Pending
283 } else {
284 Poll::Ready(Ok(()))
285 }
286 }
287
288 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
289 let result = match item {
290 Message::Bytes(bytes) => self.ws.send_with_u8_array(&bytes),
291 Message::Text(message) => self.ws.send_with_str(&message),
292 };
293 match result {
294 Ok(_) => Ok(()),
295 Err(e) => Err(WebSocketError::MessageSendError(js_to_js_error(e))),
296 }
297 }
298
299 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
300 Poll::Ready(Ok(()))
301 }
302
303 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
304 Poll::Ready(Ok(()))
305 }
306}
307
308impl Stream for WebSocket {
309 type Item = Result<Message, WebSocketError>;
310
311 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
312 let msg = ready!(self.project().message_receiver.poll_next(cx));
313 match msg {
314 Some(StreamMessage::Message(msg)) => Poll::Ready(Some(Ok(msg))),
315 Some(StreamMessage::ErrorEvent) => {
316 Poll::Ready(Some(Err(WebSocketError::ConnectionError)))
317 }
318 Some(StreamMessage::CloseEvent(e)) => {
319 Poll::Ready(Some(Err(WebSocketError::ConnectionClose(e))))
320 }
321 Some(StreamMessage::ConnectionClose) => Poll::Ready(None),
322 None => Poll::Ready(None),
323 }
324 }
325}
326
327#[pinned_drop]
328impl PinnedDrop for WebSocket {
329 fn drop(self: Pin<&mut Self>) {
330 self.ws.close().unwrap();
331
332 for (ty, cb) in [
333 ("open", self.closures.0.as_ref()),
334 ("message", self.closures.1.as_ref()),
335 ("error", self.closures.2.as_ref()),
336 ] {
337 let _ = self
338 .ws
339 .remove_event_listener_with_callback(ty, cb.unchecked_ref());
340 }
341
342 if let Ok(close_event) = web_sys::CloseEvent::new_with_event_init_dict(
343 "close",
344 web_sys::CloseEventInit::new()
345 .code(1000)
346 .reason("client dropped"),
347 ) {
348 let _ = self.ws.dispatch_event(&close_event);
349 }
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use futures::{SinkExt, StreamExt};
357 use wasm_bindgen_test::*;
358
359 wasm_bindgen_test_configure!(run_in_browser);
360
361 #[wasm_bindgen_test]
362 async fn websocket_works() {
363 let ws_echo_server_url =
364 option_env!("WS_ECHO_SERVER_URL").expect("Did you set WS_ECHO_SERVER_URL?");
365
366 let ws = WebSocket::open(ws_echo_server_url).unwrap();
367 let (mut sender, mut receiver) = ws.split();
368
369 sender
370 .send(Message::Text(String::from("test 1")))
371 .await
372 .unwrap();
373 sender
374 .send(Message::Text(String::from("test 2")))
375 .await
376 .unwrap();
377
378 let _ = receiver.next().await;
381
382 assert_eq!(
383 receiver.next().await.unwrap().unwrap(),
384 Message::Text("test 1".to_string())
385 );
386 assert_eq!(
387 receiver.next().await.unwrap().unwrap(),
388 Message::Text("test 2".to_string())
389 );
390 }
391}