leptos_use/
use_websocket.rs

1#![cfg_attr(feature = "ssr", allow(unused_variables, unused_imports, dead_code))]
2
3use crate::{core::ConnectionReadyState, use_interval_fn, ReconnectLimit};
4use cfg_if::cfg_if;
5use codee::{CodecError, Decoder, Encoder, HybridCoderError, HybridDecoder, HybridEncoder};
6use default_struct_builder::DefaultBuilder;
7use js_sys::Array;
8use leptos::{leptos_dom::helpers::TimeoutHandle, prelude::*};
9use std::marker::PhantomData;
10use std::sync::{atomic::AtomicBool, Arc};
11use std::time::Duration;
12use thiserror::Error;
13use wasm_bindgen::prelude::*;
14use web_sys::{BinaryType, CloseEvent, Event, MessageEvent, WebSocket};
15
16#[allow(rustdoc::bare_urls)]
17/// Creating and managing a [Websocket](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket) connection.
18///
19/// ## Demo
20///
21/// [Link to Demo](https://github.com/Synphonyte/leptos-use/tree/main/examples/use_websocket)
22///
23/// ## Usage
24///
25/// Values are (en)decoded via the given codec. You can use any of the codecs, string or binary.
26///
27/// > Please check [the codec chapter](https://leptos-use.rs/codecs.html) to see what codecs are
28/// > available and what feature flags they require.
29///
30/// ```
31/// # use leptos::prelude::*;
32/// # use codee::string::FromToStringCodec;
33/// # use leptos_use::{use_websocket, UseWebSocketReturn};
34/// # use leptos_use::core::ConnectionReadyState;
35/// #
36/// # #[component]
37/// # fn Demo() -> impl IntoView {
38/// let UseWebSocketReturn {
39///     ready_state,
40///     message,
41///     send,
42///     open,
43///     close,
44///     ..
45/// } = use_websocket::<String, String, FromToStringCodec>("wss://echo.websocket.events/");
46///
47/// let send_message = move |_| {
48///     send(&"Hello, world!".to_string());
49/// };
50///
51/// let status = move || ready_state.get().to_string();
52///
53/// let connected = move || ready_state.get() == ConnectionReadyState::Open;
54///
55/// let open_connection = move |_| {
56///     open();
57/// };
58///
59/// let close_connection = move |_| {
60///     close();
61/// };
62///
63/// view! {
64///     <div>
65///         <p>"status: " {status}</p>
66///
67///         <button on:click=send_message disabled=move || !connected()>"Send"</button>
68///         <button on:click=open_connection disabled=connected>"Open"</button>
69///         <button on:click=close_connection disabled=move || !connected()>"Close"</button>
70///
71///         <p>"Receive message: " {move || format!("{:?}", message.get())}</p>
72///     </div>
73/// }
74/// # }
75/// ```
76///
77/// Here is another example using `msgpack` for encoding and decoding. This means that only binary
78/// messages can be sent or received. For this to work you have to enable the **`msgpack_serde` feature** flag.
79///
80/// ```
81/// # use leptos::*;
82/// # use codee::binary::MsgpackSerdeCodec;
83/// # use leptos_use::{use_websocket, UseWebSocketReturn};
84/// # use serde::{Deserialize, Serialize};
85/// #
86/// # #[component]
87/// # fn Demo() -> impl IntoView {
88/// #[derive(Serialize, Deserialize)]
89/// struct SomeData {
90///     name: String,
91///     count: i32,
92/// }
93///
94/// let UseWebSocketReturn {
95///     message,
96///     send,
97///     ..
98/// } = use_websocket::<SomeData, SomeData, MsgpackSerdeCodec>("wss://some.websocket.server/");
99///
100/// let send_data = move || {
101///     send(&SomeData {
102///         name: "John Doe".to_string(),
103///         count: 42,
104///     });
105/// };
106/// #
107/// # view! {}
108/// }
109/// ```
110///
111/// ### Heartbeats
112///
113/// Heartbeats can be configured by the `heartbeat` option. You have to provide a heartbeat
114/// type, that implements the `Default` trait and an `Encoder` for it. This encoder doesn't have
115/// to be the same as the one used for the other websocket messages.
116///
117/// ```
118/// # use leptos::*;
119/// # use codee::string::FromToStringCodec;
120/// # use leptos_use::{use_websocket_with_options, UseWebSocketOptions, UseWebSocketReturn};
121/// # use serde::{Deserialize, Serialize};
122/// #
123/// # #[component]
124/// # fn Demo() -> impl IntoView {
125/// #[derive(Default)]
126/// struct Heartbeat;
127///
128/// // Simple example for usage with `FromToStringCodec`
129/// impl std::fmt::Display for Heartbeat {
130///     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131///         write!(f, "<Heartbeat>")
132///     }
133/// }
134///
135/// let UseWebSocketReturn {
136///     send,
137///     message,
138///     ..
139/// } = use_websocket_with_options::<String, String, FromToStringCodec, _, _>(
140///     "wss://echo.websocket.events/",
141///     UseWebSocketOptions::default()
142///         // Enable heartbeats every 10 seconds. In this case we use the same codec as for the
143///         // other messages. But this is not necessary.
144///         .heartbeat::<Heartbeat, FromToStringCodec>(10_000),
145/// );
146/// #
147/// # view! {}
148/// }
149/// ```
150///
151/// ## Relative Paths
152///
153/// If the provided `url` is relative, it will be resolved relative to the current page.
154/// Urls will be resolved like this the following. Please note that the protocol (http vs https) will
155/// be taken into account as well.
156///
157/// | Current Page                   | Relative Url             | Resolved Url                        |
158/// |--------------------------------|--------------------------|-------------------------------------|
159/// | http://example.com/some/where  | /api/ws                  | ws://example.com/api/ws             |
160/// | https://example.com/some/where | /api/ws                  | wss://example.com/api/ws            |
161/// | https://example.com/some/where | api/ws                   | wss://example.com/some/where/api/ws |
162/// | https://example.com/some/where | //otherdomain.com/api/ws | wss://otherdomain.com/api/ws        |
163///
164///
165/// ## Usage with `provide_context`
166///
167/// The return value of `use_websocket` utilizes several type parameters which can make it
168/// cumbersome to use with `provide_context` + `expect_context`.
169/// The following example shows how to avoid type parameters with dynamic dispatch.
170/// This sacrifices a little bit of performance for the sake of ergonomics. However,
171/// compared to network transmission speeds this loss of performance is negligible.
172///
173/// First we define the `struct` that is going to be passed around as context.
174///
175/// ```
176/// # use leptos::prelude::*;
177/// use std::sync::Arc;
178///
179/// #[derive(Clone)]
180/// pub struct WebsocketContext {
181///     pub message: Signal<Option<String>>,
182///     send: Arc<dyn Fn(&String)>,  // use Arc to make it easily cloneable
183/// }
184///
185/// impl WebsocketContext {
186///     pub fn new(message: Signal<Option<String>>, send: Arc<dyn Fn(&String)>) -> Self {
187///         Self {
188///             message,
189///             send,
190///         }
191///     }
192///
193///     // create a method to avoid having to use parantheses around the field
194///     #[inline(always)]
195///     pub fn send(&self, message: &str) {
196///         (self.send)(&message.to_string())
197///     }
198/// }
199/// ```
200///
201/// Now you can provide the context like the following.
202///
203/// ```
204/// # use leptos::prelude::*;
205/// # use codee::string::FromToStringCodec;
206/// # use leptos_use::{use_websocket, UseWebSocketReturn};
207/// # use std::sync::Arc;
208/// # #[derive(Clone)]
209/// # pub struct WebsocketContext {
210/// #     pub message: Signal<Option<String>>,
211/// #     send: Arc<dyn Fn(&String) + Send + Sync>,
212/// # }
213/// #
214/// # impl WebsocketContext {
215/// #     pub fn new(message: Signal<Option<String>>, send: Arc<dyn Fn(&String) + Send + Sync>) -> Self {
216/// #         Self {
217/// #             message,
218/// #             send,
219/// #         }
220/// #     }
221/// # }
222///
223/// # #[component]
224/// # fn Demo() -> impl IntoView {
225/// let UseWebSocketReturn {
226///     message,
227///     send,
228///     ..
229/// } = use_websocket::<String, String, FromToStringCodec>("ws:://some.websocket.io");
230///
231/// provide_context(WebsocketContext::new(message, Arc::new(send.clone())));
232/// #
233/// # view! {}
234/// # }
235/// ```
236///
237/// Finally let's use the context:
238///
239/// ```
240/// # use leptos::prelude::*;
241/// # use leptos_use::{use_websocket, UseWebSocketReturn};
242/// # use std::sync::Arc;
243/// # #[derive(Clone)]
244/// # pub struct WebsocketContext {
245/// #     pub message: Signal<Option<String>>,
246/// #     send: Arc<dyn Fn(&String)>,
247/// # }
248/// #
249/// # impl WebsocketContext {
250/// #     #[inline(always)]
251/// #     pub fn send(&self, message: &str) {
252/// #         (self.send)(&message.to_string())
253/// #     }
254/// # }
255///
256/// # #[component]
257/// # fn Demo() -> impl IntoView {
258/// let websocket = expect_context::<WebsocketContext>();
259///
260/// websocket.send("Hello World!");
261/// #
262/// # view! {}
263/// # }
264/// ```
265///
266/// ## Server-Side Rendering
267///
268/// On the server the returned functions amount to no-ops.
269pub fn use_websocket<Tx, Rx, C>(
270    url: &str,
271) -> UseWebSocketReturn<
272    Tx,
273    Rx,
274    impl Fn() + Clone + Send + Sync + 'static,
275    impl Fn() + Clone + Send + Sync + 'static,
276    impl Fn(&Tx) + Clone + Send + Sync + 'static,
277>
278where
279    Tx: Send + Sync + 'static,
280    Rx: Send + Sync + 'static,
281    C: Encoder<Tx> + Decoder<Rx>,
282    C: HybridEncoder<Tx, <C as Encoder<Tx>>::Encoded, Error = <C as Encoder<Tx>>::Error>,
283    C: HybridDecoder<Rx, <C as Decoder<Rx>>::Encoded, Error = <C as Decoder<Rx>>::Error>,
284{
285    use_websocket_with_options::<Tx, Rx, C, (), DummyEncoder>(url, UseWebSocketOptions::default())
286}
287
288/// Version of [`use_websocket`] that takes `UseWebSocketOptions`. See [`use_websocket`] for how to use.
289#[allow(clippy::type_complexity)]
290pub fn use_websocket_with_options<Tx, Rx, C, Hb, HbCodec>(
291    url: &str,
292    options: UseWebSocketOptions<
293        Rx,
294        HybridCoderError<<C as Encoder<Tx>>::Error>,
295        HybridCoderError<<C as Decoder<Rx>>::Error>,
296        Hb,
297        HbCodec,
298    >,
299) -> UseWebSocketReturn<
300    Tx,
301    Rx,
302    impl Fn() + Clone + Send + Sync + 'static,
303    impl Fn() + Clone + Send + Sync + 'static,
304    impl Fn(&Tx) + Clone + Send + Sync + 'static,
305>
306where
307    Tx: Send + Sync + 'static,
308    Rx: Send + Sync + 'static,
309    C: Encoder<Tx> + Decoder<Rx>,
310    C: HybridEncoder<Tx, <C as Encoder<Tx>>::Encoded, Error = <C as Encoder<Tx>>::Error>,
311    C: HybridDecoder<Rx, <C as Decoder<Rx>>::Encoded, Error = <C as Decoder<Rx>>::Error>,
312    Hb: Default + Send + Sync + 'static,
313    HbCodec: Encoder<Hb> + Send + Sync,
314    HbCodec: HybridEncoder<
315        Hb,
316        <HbCodec as Encoder<Hb>>::Encoded,
317        Error = <HbCodec as Encoder<Hb>>::Error,
318    >,
319    <HbCodec as Encoder<Hb>>::Error: std::fmt::Debug,
320{
321    let url = normalize_url(url);
322
323    let UseWebSocketOptions {
324        on_open,
325        on_message,
326        on_message_raw,
327        on_message_raw_bytes,
328        on_error,
329        on_close,
330        reconnect_limit,
331        reconnect_interval,
332        immediate,
333        protocols,
334        heartbeat,
335    } = options;
336
337    let (ready_state, set_ready_state) = signal(ConnectionReadyState::Closed);
338    let (message, set_message) = signal(None);
339    let ws_signal = RwSignal::new_local(None::<WebSocket>);
340
341    let reconnect_timer_ref: StoredValue<Option<TimeoutHandle>> = StoredValue::new(None);
342
343    let reconnect_times_ref: StoredValue<u64> = StoredValue::new(0);
344    let manually_closed_ref: StoredValue<bool> = StoredValue::new(false);
345
346    let unmounted = Arc::new(AtomicBool::new(false));
347
348    let connect_ref: StoredValue<Option<Arc<dyn Fn() + Send + Sync>>> = StoredValue::new(None);
349
350    let send_str = move |data: &str| {
351        if ready_state.get_untracked() == ConnectionReadyState::Open {
352            if let Some(web_socket) = ws_signal.get_untracked() {
353                let _ = web_socket.send_with_str(data);
354            }
355        }
356    };
357
358    let send_bytes = move |data: &[u8]| {
359        if ready_state.get_untracked() == ConnectionReadyState::Open {
360            if let Some(web_socket) = ws_signal.get_untracked() {
361                let _ = web_socket.send_with_u8_array(data);
362            }
363        }
364    };
365
366    let send = {
367        let on_error = Arc::clone(&on_error);
368
369        move |value: &Tx| {
370            let on_error = Arc::clone(&on_error);
371
372            send_with_codec::<Tx, C>(value, send_str, send_bytes, move |err| {
373                on_error(UseWebSocketError::Codec(CodecError::Encode(err)));
374            });
375        }
376    };
377
378    let heartbeat_interval_ref = StoredValue::new_local(None::<(Arc<dyn Fn()>, Arc<dyn Fn()>)>);
379
380    let stop_heartbeat = move || {
381        if let Some((pause, _)) = heartbeat_interval_ref.get_value() {
382            pause();
383        }
384    };
385
386    #[cfg(not(feature = "ssr"))]
387    {
388        use crate::utils::Pausable;
389
390        let start_heartbeat = {
391            let on_error = Arc::clone(&on_error);
392
393            move || {
394                if let Some(heartbeat) = &heartbeat {
395                    if let Some((pause, resume)) = heartbeat_interval_ref.get_value() {
396                        pause();
397                        resume();
398                    } else {
399                        let on_error = Arc::clone(&on_error);
400
401                        let Pausable { pause, resume, .. } = use_interval_fn(
402                            move || {
403                                send_with_codec::<Hb, HbCodec>(
404                                    &Hb::default(),
405                                    send_str,
406                                    send_bytes,
407                                    {
408                                        let on_error = Arc::clone(&on_error);
409
410                                        move |err| {
411                                            on_error(UseWebSocketError::HeartbeatCodec(format!(
412                                                "Failed to encode heartbeat data: {err:?}"
413                                            )))
414                                        }
415                                    },
416                                )
417                            },
418                            heartbeat.interval,
419                        );
420
421                        heartbeat_interval_ref.set_value(Some((Arc::new(pause), Arc::new(resume))));
422                    }
423                }
424            }
425        };
426
427        let reconnect_ref: StoredValue<Option<Arc<dyn Fn() + Send + Sync>>> =
428            StoredValue::new(None);
429        reconnect_ref.set_value({
430            let unmounted = Arc::clone(&unmounted);
431
432            Some(Arc::new(move || {
433                let unmounted = Arc::clone(&unmounted);
434
435                if !manually_closed_ref.get_value()
436                    && !reconnect_limit.is_exceeded_by(reconnect_times_ref.get_value())
437                    && ws_signal
438                        .get_untracked()
439                        .is_some_and(|ws: WebSocket| ws.ready_state() != WebSocket::OPEN)
440                    && reconnect_timer_ref.get_value().is_none()
441                {
442                    reconnect_timer_ref.set_value(
443                        set_timeout_with_handle(
444                            move || {
445                                if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
446                                    return;
447                                }
448                                if let Some(connect) = connect_ref.get_value() {
449                                    connect();
450                                    reconnect_times_ref.update_value(|current| *current += 1);
451                                }
452                            },
453                            Duration::from_millis(reconnect_interval),
454                        )
455                        .ok(),
456                    );
457                }
458            }))
459        });
460
461        connect_ref.set_value({
462            let unmounted = Arc::clone(&unmounted);
463            let on_error = Arc::clone(&on_error);
464
465            Some(Arc::new(move || {
466                if let Some(reconnect_timer) = reconnect_timer_ref.get_value() {
467                    reconnect_timer.clear();
468                    reconnect_timer_ref.set_value(None);
469                }
470
471                if let Some(web_socket) = ws_signal.get_untracked() {
472                    let _ = web_socket.close();
473                }
474
475                let web_socket = {
476                    protocols.with_untracked(|protocols| {
477                        protocols.as_ref().map_or_else(
478                            || WebSocket::new(&url).unwrap_throw(),
479                            |protocols| {
480                                let array = protocols
481                                    .iter()
482                                    .map(|p| JsValue::from(p.clone()))
483                                    .collect::<Array>();
484                                WebSocket::new_with_str_sequence(&url, &JsValue::from(&array))
485                                    .unwrap_throw()
486                            },
487                        )
488                    })
489                };
490                web_socket.set_binary_type(BinaryType::Arraybuffer);
491                set_ready_state.set(ConnectionReadyState::Connecting);
492
493                // onopen handler
494                {
495                    let unmounted = Arc::clone(&unmounted);
496                    let on_open = Arc::clone(&on_open);
497
498                    let onopen_closure = Closure::wrap(Box::new({
499                        let start_heartbeat = start_heartbeat.clone();
500
501                        move |e: Event| {
502                            if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
503                                return;
504                            }
505
506                            #[cfg(debug_assertions)]
507                            let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
508
509                            on_open(e);
510
511                            #[cfg(debug_assertions)]
512                            drop(zone);
513
514                            set_ready_state.set(ConnectionReadyState::Open);
515
516                            start_heartbeat();
517                        }
518                    })
519                        as Box<dyn FnMut(Event)>);
520                    web_socket.set_onopen(Some(onopen_closure.as_ref().unchecked_ref()));
521                    // Forget the closure to keep it alive
522                    onopen_closure.forget();
523                }
524
525                // onmessage handler
526                {
527                    let unmounted = Arc::clone(&unmounted);
528                    let on_message = Arc::clone(&on_message);
529                    let on_message_raw = Arc::clone(&on_message_raw);
530                    let on_message_raw_bytes = Arc::clone(&on_message_raw_bytes);
531                    let on_error = Arc::clone(&on_error);
532
533                    let onmessage_closure = Closure::wrap(Box::new(move |e: MessageEvent| {
534                        if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
535                            return;
536                        }
537
538                        e.data().dyn_into::<js_sys::ArrayBuffer>().map_or_else(
539                            |_| {
540                                e.data().dyn_into::<js_sys::JsString>().map_or_else(
541                                    |_| {
542                                        unreachable!(
543                                            "message event, received Unknown: {:?}",
544                                            e.data()
545                                        );
546                                    },
547                                    |txt| {
548                                        let txt = String::from(&txt);
549
550                                        #[cfg(debug_assertions)]
551                                        let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
552
553                                        on_message_raw(&txt);
554
555                                        #[cfg(debug_assertions)]
556                                        drop(zone);
557
558                                        match C::decode_str(&txt) {
559                                            Ok(val) => {
560                                                #[cfg(debug_assertions)]
561                                                let prev = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
562
563                                                on_message(&val);
564
565                                                #[cfg(debug_assertions)]
566                                                drop(prev);
567
568                                                set_message.set(Some(val));
569                                            }
570                                            Err(err) => {
571                                                on_error(CodecError::Decode(err).into());
572                                            }
573                                        }
574                                    },
575                                );
576                            },
577                            |array_buffer| {
578                                let array = js_sys::Uint8Array::new(&array_buffer);
579                                let array = array.to_vec();
580
581                                #[cfg(debug_assertions)]
582                                let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
583
584                                on_message_raw_bytes(&array);
585
586                                #[cfg(debug_assertions)]
587                                drop(zone);
588
589                                match C::decode_bin(array.as_slice()) {
590                                    Ok(val) => {
591                                        #[cfg(debug_assertions)]
592                                        let prev = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
593
594                                        on_message(&val);
595
596                                        #[cfg(debug_assertions)]
597                                        drop(prev);
598
599                                        set_message.set(Some(val));
600                                    }
601                                    Err(err) => {
602                                        on_error(CodecError::Decode(err).into());
603                                    }
604                                }
605                            },
606                        );
607                    })
608                        as Box<dyn FnMut(MessageEvent)>);
609                    web_socket.set_onmessage(Some(onmessage_closure.as_ref().unchecked_ref()));
610                    onmessage_closure.forget();
611                }
612
613                // onerror handler
614                {
615                    let unmounted = Arc::clone(&unmounted);
616                    let on_error = Arc::clone(&on_error);
617
618                    let onerror_closure = Closure::wrap(Box::new(move |e: Event| {
619                        if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
620                            return;
621                        }
622
623                        stop_heartbeat();
624
625                        if let Some(reconnect) = &reconnect_ref.get_value() {
626                            reconnect();
627                        }
628
629                        #[cfg(debug_assertions)]
630                        let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
631
632                        on_error(UseWebSocketError::Event(e));
633
634                        #[cfg(debug_assertions)]
635                        drop(zone);
636
637                        set_ready_state.set(ConnectionReadyState::Closed);
638                    })
639                        as Box<dyn FnMut(Event)>);
640                    web_socket.set_onerror(Some(onerror_closure.as_ref().unchecked_ref()));
641                    onerror_closure.forget();
642                }
643
644                // onclose handler
645                {
646                    let unmounted = Arc::clone(&unmounted);
647                    let on_close = Arc::clone(&on_close);
648
649                    let onclose_closure = Closure::wrap(Box::new(move |e: CloseEvent| {
650                        if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
651                            return;
652                        }
653
654                        stop_heartbeat();
655
656                        if let Some(reconnect) = &reconnect_ref.get_value() {
657                            reconnect();
658                        }
659
660                        #[cfg(debug_assertions)]
661                        let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
662
663                        on_close(e);
664
665                        #[cfg(debug_assertions)]
666                        drop(zone);
667
668                        set_ready_state.set(ConnectionReadyState::Closed);
669                    })
670                        as Box<dyn FnMut(CloseEvent)>);
671                    web_socket.set_onclose(Some(onclose_closure.as_ref().unchecked_ref()));
672                    onclose_closure.forget();
673                }
674
675                ws_signal.set(Some(web_socket));
676            }))
677        });
678    }
679
680    // Open connection
681    let open = move || {
682        reconnect_times_ref.set_value(0);
683        if let Some(connect) = connect_ref.get_value() {
684            connect();
685        }
686    };
687
688    // Close connection
689    let close = {
690        reconnect_timer_ref.set_value(None);
691
692        move || {
693            stop_heartbeat();
694            manually_closed_ref.set_value(true);
695            if let Some(web_socket) = ws_signal.get_untracked() {
696                let _ = web_socket.close();
697            }
698        }
699    };
700
701    // Open connection (not called if option `manual` is true)
702    Effect::new(move |_| {
703        if immediate {
704            open();
705        }
706    });
707
708    // clean up (unmount)
709    on_cleanup(move || {
710        unmounted.store(true, std::sync::atomic::Ordering::Relaxed);
711        close();
712    });
713
714    UseWebSocketReturn {
715        ready_state: ready_state.into(),
716        message: message.into(),
717        ws: ws_signal.into(),
718        open,
719        close,
720        send,
721        _marker: PhantomData,
722    }
723}
724
725fn send_with_codec<T, Codec>(
726    value: &T,
727    send_str: impl Fn(&str),
728    send_bytes: impl Fn(&[u8]),
729    on_error: impl Fn(HybridCoderError<<Codec as Encoder<T>>::Error>),
730) where
731    Codec: Encoder<T>,
732    Codec: HybridEncoder<T, <Codec as Encoder<T>>::Encoded, Error = <Codec as Encoder<T>>::Error>,
733{
734    if Codec::is_binary_encoder() {
735        match Codec::encode_bin(value) {
736            Ok(val) => send_bytes(&val),
737            Err(err) => on_error(err),
738        }
739    } else {
740        match Codec::encode_str(value) {
741            Ok(val) => send_str(&val),
742            Err(err) => on_error(err),
743        }
744    }
745}
746
747type ArcFnBytes = Arc<dyn Fn(&[u8]) + Send + Sync>;
748
749/// Options for [`use_websocket_with_options`].
750#[derive(DefaultBuilder)]
751pub struct UseWebSocketOptions<Rx, E, D, Hb, HbCodec>
752where
753    Rx: ?Sized,
754    Hb: Default + Send + Sync + 'static,
755    HbCodec: Encoder<Hb>,
756    HbCodec: HybridEncoder<
757        Hb,
758        <HbCodec as Encoder<Hb>>::Encoded,
759        Error = <HbCodec as Encoder<Hb>>::Error,
760    >,
761{
762    /// Heartbeat options
763    #[builder(skip)]
764    heartbeat: Option<HeartbeatOptions<Hb, HbCodec>>,
765    /// `WebSocket` connect callback.
766    on_open: Arc<dyn Fn(Event) + Send + Sync>,
767    /// `WebSocket` message callback for typed message decoded by codec.
768    #[builder(skip)]
769    on_message: Arc<dyn Fn(&Rx) + Send + Sync>,
770    /// `WebSocket` message callback for text.
771    on_message_raw: Arc<dyn Fn(&str) + Send + Sync>,
772    /// `WebSocket` message callback for binary.
773    on_message_raw_bytes: ArcFnBytes,
774    /// `WebSocket` error callback.
775    #[builder(skip)]
776    on_error: Arc<dyn Fn(UseWebSocketError<E, D>) + Send + Sync>,
777    /// `WebSocket` close callback.
778    on_close: Arc<dyn Fn(CloseEvent) + Send + Sync>,
779    /// Retry times. Defaults to `ReconnectLimit::Limited(3)`. Use `ReconnectLimit::Infinite` for
780    /// infinite retries.
781    reconnect_limit: ReconnectLimit,
782    /// Retry interval in ms. Defaults to 3000.
783    reconnect_interval: u64,
784    /// If `true` the `WebSocket` connection will immediately be opened when calling this function.
785    /// If `false` you have to manually call the `open` function.
786    /// Defaults to `true`.
787    immediate: bool,
788    /// Sub protocols. See [MDN Docs](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/WebSocket#protocols).
789    ///
790    /// Can be set as a signal to support protocols only available after the initial render.
791    ///
792    /// Note that protocols are only updated on the next websocket open() call, not whenever the signal is updated.
793    /// Therefore "lazy" protocols should use the `immediate(false)` option and manually call `open()`.
794    #[builder(into)]
795    protocols: Signal<Option<Vec<String>>>,
796}
797
798impl<Rx: ?Sized, E, D, Hb, HbCodec> UseWebSocketOptions<Rx, E, D, Hb, HbCodec>
799where
800    Hb: Default + Send + Sync + 'static,
801    HbCodec: Encoder<Hb>,
802    HbCodec: HybridEncoder<
803        Hb,
804        <HbCodec as Encoder<Hb>>::Encoded,
805        Error = <HbCodec as Encoder<Hb>>::Error,
806    >,
807{
808    /// `WebSocket` error callback.
809    pub fn on_error<F>(self, handler: F) -> Self
810    where
811        F: Fn(UseWebSocketError<E, D>) + Send + Sync + 'static,
812    {
813        Self {
814            on_error: Arc::new(handler),
815            ..self
816        }
817    }
818
819    /// `WebSocket` message callback for typed message decoded by codec.
820    pub fn on_message<F>(self, handler: F) -> Self
821    where
822        F: Fn(&Rx) + Send + Sync + 'static,
823    {
824        Self {
825            on_message: Arc::new(handler),
826            ..self
827        }
828    }
829
830    /// Set the data, codec and interval at which the heartbeat is sent. The heartbeat
831    /// is the default value of the `NewHb` type.
832    pub fn heartbeat<NewHb, NewHbCodec>(
833        self,
834        interval: u64,
835    ) -> UseWebSocketOptions<Rx, E, D, NewHb, NewHbCodec>
836    where
837        NewHb: Default + Send + Sync + 'static,
838        NewHbCodec: Encoder<NewHb>,
839        NewHbCodec: HybridEncoder<
840            NewHb,
841            <NewHbCodec as Encoder<NewHb>>::Encoded,
842            Error = <NewHbCodec as Encoder<NewHb>>::Error,
843        >,
844    {
845        UseWebSocketOptions {
846            heartbeat: Some(HeartbeatOptions {
847                data: PhantomData::<NewHb>,
848                interval,
849                codec: PhantomData::<NewHbCodec>,
850            }),
851            on_open: self.on_open,
852            on_message: self.on_message,
853            on_message_raw: self.on_message_raw,
854            on_message_raw_bytes: self.on_message_raw_bytes,
855            on_close: self.on_close,
856            on_error: self.on_error,
857            reconnect_limit: self.reconnect_limit,
858            reconnect_interval: self.reconnect_interval,
859            immediate: self.immediate,
860            protocols: self.protocols,
861        }
862    }
863}
864
865impl<Rx: ?Sized, E, D> Default for UseWebSocketOptions<Rx, E, D, (), DummyEncoder> {
866    fn default() -> Self {
867        Self {
868            heartbeat: None,
869            on_open: Arc::new(|_| {}),
870            on_message: Arc::new(|_| {}),
871            on_message_raw: Arc::new(|_| {}),
872            on_message_raw_bytes: Arc::new(|_| {}),
873            on_error: Arc::new(|_| {}),
874            on_close: Arc::new(|_| {}),
875            reconnect_limit: ReconnectLimit::default(),
876            reconnect_interval: 3000,
877            immediate: true,
878            protocols: Default::default(),
879        }
880    }
881}
882
883pub struct DummyEncoder;
884
885impl Encoder<()> for DummyEncoder {
886    type Encoded = String;
887    type Error = ();
888
889    fn encode(_: &()) -> Result<Self::Encoded, Self::Error> {
890        Ok("".to_string())
891    }
892}
893
894/// Options for heartbeats
895pub struct HeartbeatOptions<Hb, HbCodec>
896where
897    Hb: Default + Send + Sync + 'static,
898    HbCodec: Encoder<Hb>,
899    HbCodec: HybridEncoder<
900        Hb,
901        <HbCodec as Encoder<Hb>>::Encoded,
902        Error = <HbCodec as Encoder<Hb>>::Error,
903    >,
904{
905    /// Heartbeat data that will be sent to the server
906    data: PhantomData<Hb>,
907    /// Heartbeat interval in ms. A heartbeat will be sent every `interval` ms.
908    interval: u64,
909    /// Codec used to encode the heartbeat data
910    codec: PhantomData<HbCodec>,
911}
912
913impl<Hb, HbCodec> Clone for HeartbeatOptions<Hb, HbCodec>
914where
915    Hb: Default + Send + Sync + 'static,
916    HbCodec: Encoder<Hb>,
917    HbCodec: HybridEncoder<
918        Hb,
919        <HbCodec as Encoder<Hb>>::Encoded,
920        Error = <HbCodec as Encoder<Hb>>::Error,
921    >,
922{
923    fn clone(&self) -> Self {
924        *self
925    }
926}
927
928impl<Hb, HbCodec> Copy for HeartbeatOptions<Hb, HbCodec>
929where
930    Hb: Default + Send + Sync + 'static,
931    HbCodec: Encoder<Hb>,
932    HbCodec: HybridEncoder<
933        Hb,
934        <HbCodec as Encoder<Hb>>::Encoded,
935        Error = <HbCodec as Encoder<Hb>>::Error,
936    >,
937{
938}
939
940/// Return type of [`use_websocket`].
941#[derive(Clone)]
942pub struct UseWebSocketReturn<Tx, Rx, OpenFn, CloseFn, SendFn>
943where
944    Tx: Send + Sync + 'static,
945    Rx: Send + Sync + 'static,
946    OpenFn: Fn() + Clone + Send + Sync + 'static,
947    CloseFn: Fn() + Clone + Send + Sync + 'static,
948    SendFn: Fn(&Tx) + Clone + Send + Sync + 'static,
949{
950    /// The current state of the `WebSocket` connection.
951    pub ready_state: Signal<ConnectionReadyState>,
952    /// Latest message received from `WebSocket`.
953    pub message: Signal<Option<Rx>>,
954    /// The `WebSocket` instance.
955    pub ws: Signal<Option<WebSocket>, LocalStorage>,
956    /// Opens the `WebSocket` connection
957    pub open: OpenFn,
958    /// Closes the `WebSocket` connection
959    pub close: CloseFn,
960    /// Sends data through the socket
961    pub send: SendFn,
962
963    _marker: PhantomData<Tx>,
964}
965
966#[derive(Error, Debug)]
967pub enum UseWebSocketError<E, D> {
968    #[error("WebSocket error event")]
969    Event(Event),
970    #[error("WebSocket codec error: {0}")]
971    Codec(#[from] CodecError<E, D>),
972    #[error("WebSocket heartbeat codec error: {0}")]
973    HeartbeatCodec(String),
974}
975
976fn normalize_url(url: &str) -> String {
977    cfg_if! { if #[cfg(feature = "ssr")] {
978        url.to_string()
979    } else {
980        if url.starts_with("ws://") || url.starts_with("wss://") {
981            url.to_string()
982        } else if url.starts_with("//") {
983            format!("{}{}", detect_protocol(), url)
984        } else if url.starts_with('/') {
985            format!(
986                "{}//{}{}",
987                detect_protocol(),
988                window().location().host().expect("Host not found"),
989                url
990            )
991        } else {
992            let mut path = window().location().pathname().expect("Pathname not found");
993            if !path.ends_with('/') {
994                path.push('/')
995            }
996            format!(
997                "{}//{}{}{}",
998                detect_protocol(),
999                window().location().host().expect("Host not found"),
1000                path,
1001                url
1002            )
1003        }
1004    }}
1005}
1006
1007fn detect_protocol() -> String {
1008    cfg_if! { if #[cfg(feature = "ssr")] {
1009        "ws".to_string()
1010    } else {
1011        window().location().protocol().expect("Protocol not found").replace("http", "ws")
1012    }}
1013}