async_graphql/http/
websocket.rs

1//! WebSocket transport for subscription
2
3use std::{
4    collections::HashMap,
5    future::Future,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9    time::{Duration, Instant},
10};
11
12use futures_timer::Delay;
13use futures_util::{
14    future::{BoxFuture, Ready},
15    stream::Stream,
16    FutureExt, StreamExt,
17};
18use pin_project_lite::pin_project;
19use serde::{Deserialize, Serialize};
20
21use crate::{Data, Error, Executor, Request, Response, Result};
22
23/// All known protocols based on WebSocket.
24pub const ALL_WEBSOCKET_PROTOCOLS: [&str; 2] = ["graphql-transport-ws", "graphql-ws"];
25
26/// An enum representing the various forms of a WebSocket message.
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum WsMessage {
29    /// A text WebSocket message
30    Text(String),
31
32    /// A close message with the close frame.
33    Close(u16, String),
34}
35
36impl WsMessage {
37    /// Returns the contained [WsMessage::Text] value, consuming the `self`
38    /// value.
39    ///
40    /// Because this function may panic, its use is generally discouraged.
41    ///
42    /// # Panics
43    ///
44    /// Panics if the self value not equals [WsMessage::Text].
45    pub fn unwrap_text(self) -> String {
46        match self {
47            Self::Text(text) => text,
48            Self::Close(_, _) => panic!("Not a text message"),
49        }
50    }
51
52    /// Returns the contained [WsMessage::Close] value, consuming the `self`
53    /// value.
54    ///
55    /// Because this function may panic, its use is generally discouraged.
56    ///
57    /// # Panics
58    ///
59    /// Panics if the self value not equals [WsMessage::Close].
60    pub fn unwrap_close(self) -> (u16, String) {
61        match self {
62            Self::Close(code, msg) => (code, msg),
63            Self::Text(_) => panic!("Not a close message"),
64        }
65    }
66}
67
68struct Timer {
69    interval: Duration,
70    delay: Delay,
71}
72
73impl Timer {
74    #[inline]
75    fn new(interval: Duration) -> Self {
76        Self {
77            interval,
78            delay: Delay::new(interval),
79        }
80    }
81
82    #[inline]
83    fn reset(&mut self) {
84        self.delay.reset(self.interval);
85    }
86}
87
88impl Stream for Timer {
89    type Item = ();
90
91    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
92        let this = &mut *self;
93        match this.delay.poll_unpin(cx) {
94            Poll::Ready(_) => {
95                this.delay.reset(this.interval);
96                Poll::Ready(Some(()))
97            }
98            Poll::Pending => Poll::Pending,
99        }
100    }
101}
102
103pin_project! {
104    /// A GraphQL connection over websocket.
105    ///
106    /// # References
107    ///
108    /// - [subscriptions-transport-ws](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md)
109    /// - [graphql-ws](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md)
110    pub struct WebSocket<S, E, OnInit, OnPing> {
111        on_connection_init: Option<OnInit>,
112        on_ping: OnPing,
113        init_fut: Option<BoxFuture<'static, Result<Data>>>,
114        ping_fut: Option<BoxFuture<'static, Result<Option<serde_json::Value>>>>,
115        connection_data: Option<Data>,
116        data: Option<Arc<Data>>,
117        executor: E,
118        streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
119        #[pin]
120        stream: S,
121        protocol: Protocols,
122        last_msg_at: Instant,
123        keepalive_timer: Option<Timer>,
124        close: bool,
125    }
126}
127
128type MessageMapStream<S> =
129    futures_util::stream::Map<S, fn(<S as Stream>::Item) -> serde_json::Result<ClientMessage>>;
130
131/// Default connection initializer type.
132pub type DefaultOnConnInitType = fn(serde_json::Value) -> Ready<Result<Data>>;
133
134/// Default ping handler type.
135pub type DefaultOnPingType =
136    fn(Option<&Data>, Option<serde_json::Value>) -> Ready<Result<Option<serde_json::Value>>>;
137
138/// Default connection initializer function.
139pub fn default_on_connection_init(_: serde_json::Value) -> Ready<Result<Data>> {
140    futures_util::future::ready(Ok(Data::default()))
141}
142
143/// Default ping handler function.
144pub fn default_on_ping(
145    _: Option<&Data>,
146    _: Option<serde_json::Value>,
147) -> Ready<Result<Option<serde_json::Value>>> {
148    futures_util::future::ready(Ok(None))
149}
150
151impl<S, E> WebSocket<S, E, DefaultOnConnInitType, DefaultOnPingType>
152where
153    E: Executor,
154    S: Stream<Item = serde_json::Result<ClientMessage>>,
155{
156    /// Create a new websocket from [`ClientMessage`] stream.
157    pub fn from_message_stream(executor: E, stream: S, protocol: Protocols) -> Self {
158        WebSocket {
159            on_connection_init: Some(default_on_connection_init),
160            on_ping: default_on_ping,
161            init_fut: None,
162            ping_fut: None,
163            connection_data: None,
164            data: None,
165            executor,
166            streams: HashMap::new(),
167            stream,
168            protocol,
169            last_msg_at: Instant::now(),
170            keepalive_timer: None,
171            close: false,
172        }
173    }
174}
175
176impl<S, E> WebSocket<MessageMapStream<S>, E, DefaultOnConnInitType, DefaultOnPingType>
177where
178    E: Executor,
179    S: Stream,
180    S::Item: AsRef<[u8]>,
181{
182    /// Create a new websocket from bytes stream.
183    pub fn new(executor: E, stream: S, protocol: Protocols) -> Self {
184        let stream = stream
185            .map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result<ClientMessage>);
186        WebSocket::from_message_stream(executor, stream, protocol)
187    }
188}
189
190impl<S, E, OnInit, OnPing> WebSocket<S, E, OnInit, OnPing>
191where
192    E: Executor,
193    S: Stream<Item = serde_json::Result<ClientMessage>>,
194{
195    /// Specify a connection data.
196    ///
197    /// This data usually comes from HTTP requests.
198    /// When the `GQL_CONNECTION_INIT` message is received, this data will be
199    /// merged with the data returned by the closure specified by
200    /// `with_initializer` into the final subscription context data.
201    #[must_use]
202    pub fn connection_data(mut self, data: Data) -> Self {
203        self.connection_data = Some(data);
204        self
205    }
206
207    /// Specify a connection initialize callback function.
208    ///
209    /// This function if present, will be called with the data sent by the
210    /// client in the [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
211    /// From that point on the returned data will be accessible to all requests.
212    #[must_use]
213    pub fn on_connection_init<F, R>(self, callback: F) -> WebSocket<S, E, F, OnPing>
214    where
215        F: FnOnce(serde_json::Value) -> R + Send + 'static,
216        R: Future<Output = Result<Data>> + Send + 'static,
217    {
218        WebSocket {
219            on_connection_init: Some(callback),
220            on_ping: self.on_ping,
221            init_fut: self.init_fut,
222            ping_fut: self.ping_fut,
223            connection_data: self.connection_data,
224            data: self.data,
225            executor: self.executor,
226            streams: self.streams,
227            stream: self.stream,
228            protocol: self.protocol,
229            last_msg_at: self.last_msg_at,
230            keepalive_timer: self.keepalive_timer,
231            close: self.close,
232        }
233    }
234
235    /// Specify a ping callback function.
236    ///
237    /// This function if present, will be called with the data sent by the
238    /// client in the [`Ping` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping).
239    ///
240    /// The function should return the data to be sent in the [`Pong` message](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong).
241    ///
242    /// NOTE: Only used for the `graphql-ws` protocol.
243    #[must_use]
244    pub fn on_ping<F, R>(self, callback: F) -> WebSocket<S, E, OnInit, F>
245    where
246        F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Send + Clone + 'static,
247        R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
248    {
249        WebSocket {
250            on_connection_init: self.on_connection_init,
251            on_ping: callback,
252            init_fut: self.init_fut,
253            ping_fut: self.ping_fut,
254            connection_data: self.connection_data,
255            data: self.data,
256            executor: self.executor,
257            streams: self.streams,
258            stream: self.stream,
259            protocol: self.protocol,
260            last_msg_at: self.last_msg_at,
261            keepalive_timer: self.keepalive_timer,
262            close: self.close,
263        }
264    }
265
266    /// Sets a timeout for receiving an acknowledgement of the keep-alive ping.
267    ///
268    /// If the ping is not acknowledged within the timeout, the connection will
269    /// be closed.
270    ///
271    /// NOTE: Only used for the `graphql-ws` protocol.
272    #[must_use]
273    pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
274        Self {
275            keepalive_timer: timeout.into().map(Timer::new),
276            ..self
277        }
278    }
279}
280
281impl<S, E, OnInit, InitFut, OnPing, PingFut> Stream for WebSocket<S, E, OnInit, OnPing>
282where
283    E: Executor,
284    S: Stream<Item = serde_json::Result<ClientMessage>>,
285    OnInit: FnOnce(serde_json::Value) -> InitFut + Send + 'static,
286    InitFut: Future<Output = Result<Data>> + Send + 'static,
287    OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> PingFut + Clone + Send + 'static,
288    PingFut: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
289{
290    type Item = WsMessage;
291
292    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
293        let mut this = self.project();
294
295        if *this.close {
296            return Poll::Ready(None);
297        }
298
299        if let Some(keepalive_timer) = this.keepalive_timer {
300            if let Poll::Ready(Some(())) = keepalive_timer.poll_next_unpin(cx) {
301                return match this.protocol {
302                    Protocols::SubscriptionsTransportWS => {
303                        *this.close = true;
304                        Poll::Ready(Some(WsMessage::Text(
305                            serde_json::to_string(&ServerMessage::ConnectionError {
306                                payload: Error::new("timeout"),
307                            })
308                            .unwrap(),
309                        )))
310                    }
311                    Protocols::GraphQLWS => {
312                        *this.close = true;
313                        Poll::Ready(Some(WsMessage::Close(3008, "timeout".to_string())))
314                    }
315                };
316            }
317        }
318
319        if this.init_fut.is_none() && this.ping_fut.is_none() {
320            while let Poll::Ready(message) = Pin::new(&mut this.stream).poll_next(cx) {
321                let message = match message {
322                    Some(message) => message,
323                    None => return Poll::Ready(None),
324                };
325
326                let message: ClientMessage = match message {
327                    Ok(message) => message,
328                    Err(err) => {
329                        *this.close = true;
330                        return Poll::Ready(Some(WsMessage::Close(1002, err.to_string())));
331                    }
332                };
333
334                *this.last_msg_at = Instant::now();
335                if let Some(keepalive_timer) = this.keepalive_timer {
336                    keepalive_timer.reset();
337                }
338
339                match message {
340                    ClientMessage::ConnectionInit { payload } => {
341                        if let Some(on_connection_init) = this.on_connection_init.take() {
342                            *this.init_fut = Some(Box::pin(async move {
343                                on_connection_init(payload.unwrap_or_default()).await
344                            }));
345                            break;
346                        } else {
347                            *this.close = true;
348                            match this.protocol {
349                                Protocols::SubscriptionsTransportWS => {
350                                    return Poll::Ready(Some(WsMessage::Text(
351                                        serde_json::to_string(&ServerMessage::ConnectionError {
352                                            payload: Error::new(
353                                                "Too many initialisation requests.",
354                                            ),
355                                        })
356                                        .unwrap(),
357                                    )));
358                                }
359                                Protocols::GraphQLWS => {
360                                    return Poll::Ready(Some(WsMessage::Close(
361                                        4429,
362                                        "Too many initialisation requests.".to_string(),
363                                    )));
364                                }
365                            }
366                        }
367                    }
368                    ClientMessage::Start {
369                        id,
370                        payload: request,
371                    } => {
372                        if let Some(data) = this.data.clone() {
373                            this.streams.insert(
374                                id,
375                                Box::pin(this.executor.execute_stream(request, Some(data))),
376                            );
377                        } else {
378                            *this.close = true;
379                            return Poll::Ready(Some(WsMessage::Close(
380                                1011,
381                                "The handshake is not completed.".to_string(),
382                            )));
383                        }
384                    }
385                    ClientMessage::Stop { id } => {
386                        if this.streams.remove(&id).is_some() {
387                            return Poll::Ready(Some(WsMessage::Text(
388                                serde_json::to_string(&ServerMessage::Complete { id: &id })
389                                    .unwrap(),
390                            )));
391                        }
392                    }
393                    // Note: in the revised `graphql-ws` spec, there is no equivalent to the
394                    // `CONNECTION_TERMINATE` `client -> server` message; rather, disconnection is
395                    // handled by disconnecting the websocket
396                    ClientMessage::ConnectionTerminate => {
397                        *this.close = true;
398                        return Poll::Ready(None);
399                    }
400                    // Pong must be sent in response from the receiving party as soon as possible.
401                    ClientMessage::Ping { payload } => {
402                        let on_ping = this.on_ping.clone();
403                        let data = this.data.clone();
404                        *this.ping_fut =
405                            Some(Box::pin(
406                                async move { on_ping(data.as_deref(), payload).await },
407                            ));
408                        break;
409                    }
410                    ClientMessage::Pong { .. } => {
411                        // Do nothing...
412                    }
413                }
414            }
415        }
416
417        if let Some(init_fut) = this.init_fut {
418            return init_fut.poll_unpin(cx).map(|res| {
419                *this.init_fut = None;
420                match res {
421                    Ok(data) => {
422                        let mut ctx_data = this.connection_data.take().unwrap_or_default();
423                        ctx_data.merge(data);
424                        *this.data = Some(Arc::new(ctx_data));
425                        Some(WsMessage::Text(
426                            serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
427                        ))
428                    }
429                    Err(err) => {
430                        *this.close = true;
431                        match this.protocol {
432                            Protocols::SubscriptionsTransportWS => Some(WsMessage::Text(
433                                serde_json::to_string(&ServerMessage::ConnectionError {
434                                    payload: Error::new(err.message),
435                                })
436                                .unwrap(),
437                            )),
438                            Protocols::GraphQLWS => Some(WsMessage::Close(1002, err.message)),
439                        }
440                    }
441                }
442            });
443        }
444
445        if let Some(ping_fut) = this.ping_fut {
446            return ping_fut.poll_unpin(cx).map(|res| {
447                *this.ping_fut = None;
448                match res {
449                    Ok(payload) => Some(WsMessage::Text(
450                        serde_json::to_string(&ServerMessage::Pong { payload }).unwrap(),
451                    )),
452                    Err(err) => {
453                        *this.close = true;
454                        match this.protocol {
455                            Protocols::SubscriptionsTransportWS => Some(WsMessage::Text(
456                                serde_json::to_string(&ServerMessage::ConnectionError {
457                                    payload: Error::new(err.message),
458                                })
459                                .unwrap(),
460                            )),
461                            Protocols::GraphQLWS => Some(WsMessage::Close(1002, err.message)),
462                        }
463                    }
464                }
465            });
466        }
467
468        for (id, stream) in &mut *this.streams {
469            match Pin::new(stream).poll_next(cx) {
470                Poll::Ready(Some(payload)) => {
471                    return Poll::Ready(Some(WsMessage::Text(
472                        serde_json::to_string(&this.protocol.next_message(id, payload)).unwrap(),
473                    )));
474                }
475                Poll::Ready(None) => {
476                    let id = id.clone();
477                    this.streams.remove(&id);
478                    return Poll::Ready(Some(WsMessage::Text(
479                        serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(),
480                    )));
481                }
482                Poll::Pending => {}
483            }
484        }
485
486        Poll::Pending
487    }
488}
489
490/// Specification of which GraphQL Over WebSockets protocol is being utilized
491#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
492pub enum Protocols {
493    /// [subscriptions-transport-ws protocol](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md).
494    SubscriptionsTransportWS,
495    /// [graphql-ws protocol](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md).
496    GraphQLWS,
497}
498
499impl Protocols {
500    /// Returns the `Sec-WebSocket-Protocol` header value for the protocol
501    pub fn sec_websocket_protocol(&self) -> &'static str {
502        match self {
503            Protocols::SubscriptionsTransportWS => "graphql-ws",
504            Protocols::GraphQLWS => "graphql-transport-ws",
505        }
506    }
507
508    #[inline]
509    fn next_message<'s>(&self, id: &'s str, payload: Response) -> ServerMessage<'s> {
510        match self {
511            Protocols::SubscriptionsTransportWS => ServerMessage::Data { id, payload },
512            Protocols::GraphQLWS => ServerMessage::Next { id, payload },
513        }
514    }
515}
516
517impl std::str::FromStr for Protocols {
518    type Err = Error;
519
520    fn from_str(protocol: &str) -> Result<Self, Self::Err> {
521        if protocol.eq_ignore_ascii_case("graphql-ws") {
522            Ok(Protocols::SubscriptionsTransportWS)
523        } else if protocol.eq_ignore_ascii_case("graphql-transport-ws") {
524            Ok(Protocols::GraphQLWS)
525        } else {
526            Err(Error::new(format!(
527                "Unsupported Sec-WebSocket-Protocol: {}",
528                protocol
529            )))
530        }
531    }
532}
533
534/// A websocket message received from the client
535#[derive(Deserialize)]
536#[serde(tag = "type", rename_all = "snake_case")]
537#[allow(clippy::large_enum_variant)] // Request is at fault
538pub enum ClientMessage {
539    /// A new connection
540    ConnectionInit {
541        /// Optional init payload from the client
542        payload: Option<serde_json::Value>,
543    },
544    /// The start of a Websocket subscription
545    #[serde(alias = "subscribe")]
546    Start {
547        /// Message ID
548        id: String,
549        /// The GraphQL Request - this can be modified by protocol implementors
550        /// to add files uploads.
551        payload: Request,
552    },
553    /// The end of a Websocket subscription
554    #[serde(alias = "complete")]
555    Stop {
556        /// Message ID
557        id: String,
558    },
559    /// Connection terminated by the client
560    ConnectionTerminate,
561    /// Useful for detecting failed connections, displaying latency metrics or
562    /// other types of network probing.
563    ///
564    /// Reference: <https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping>
565    Ping {
566        /// Additional details about the ping.
567        payload: Option<serde_json::Value>,
568    },
569    /// The response to the Ping message.
570    ///
571    /// Reference: <https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong>
572    Pong {
573        /// Additional details about the pong.
574        payload: Option<serde_json::Value>,
575    },
576}
577
578impl ClientMessage {
579    /// Creates a ClientMessage from an array of bytes
580    pub fn from_bytes<T>(message: T) -> serde_json::Result<Self>
581    where
582        T: AsRef<[u8]>,
583    {
584        serde_json::from_slice(message.as_ref())
585    }
586}
587
588#[derive(Serialize)]
589#[serde(tag = "type", rename_all = "snake_case")]
590enum ServerMessage<'a> {
591    ConnectionError {
592        payload: Error,
593    },
594    ConnectionAck,
595    /// subscriptions-transport-ws protocol next payload
596    Data {
597        id: &'a str,
598        payload: Response,
599    },
600    /// graphql-ws protocol next payload
601    Next {
602        id: &'a str,
603        payload: Response,
604    },
605    // Not used by this library, as it's not necessary to send
606    // Error {
607    //     id: &'a str,
608    //     payload: serde_json::Value,
609    // },
610    Complete {
611        id: &'a str,
612    },
613    /// The response to the Ping message.
614    ///
615    /// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong
616    Pong {
617        #[serde(skip_serializing_if = "Option::is_none")]
618        payload: Option<serde_json::Value>,
619    },
620    // Not used by this library
621    // #[serde(rename = "ka")]
622    // KeepAlive
623}