1use std::{
4 collections::HashMap,
5 future::Future,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9 time::{Duration, Instant},
10};
11
12use futures_timer::Delay;
13use futures_util::{
14 future::{BoxFuture, Ready},
15 stream::Stream,
16 FutureExt, StreamExt,
17};
18use pin_project_lite::pin_project;
19use serde::{Deserialize, Serialize};
20
21use crate::{Data, Error, Executor, Request, Response, Result};
22
23pub const ALL_WEBSOCKET_PROTOCOLS: [&str; 2] = ["graphql-transport-ws", "graphql-ws"];
25
26#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum WsMessage {
29 Text(String),
31
32 Close(u16, String),
34}
35
36impl WsMessage {
37 pub fn unwrap_text(self) -> String {
46 match self {
47 Self::Text(text) => text,
48 Self::Close(_, _) => panic!("Not a text message"),
49 }
50 }
51
52 pub fn unwrap_close(self) -> (u16, String) {
61 match self {
62 Self::Close(code, msg) => (code, msg),
63 Self::Text(_) => panic!("Not a close message"),
64 }
65 }
66}
67
68struct Timer {
69 interval: Duration,
70 delay: Delay,
71}
72
73impl Timer {
74 #[inline]
75 fn new(interval: Duration) -> Self {
76 Self {
77 interval,
78 delay: Delay::new(interval),
79 }
80 }
81
82 #[inline]
83 fn reset(&mut self) {
84 self.delay.reset(self.interval);
85 }
86}
87
88impl Stream for Timer {
89 type Item = ();
90
91 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
92 let this = &mut *self;
93 match this.delay.poll_unpin(cx) {
94 Poll::Ready(_) => {
95 this.delay.reset(this.interval);
96 Poll::Ready(Some(()))
97 }
98 Poll::Pending => Poll::Pending,
99 }
100 }
101}
102
103pin_project! {
104 pub struct WebSocket<S, E, OnInit, OnPing> {
111 on_connection_init: Option<OnInit>,
112 on_ping: OnPing,
113 init_fut: Option<BoxFuture<'static, Result<Data>>>,
114 ping_fut: Option<BoxFuture<'static, Result<Option<serde_json::Value>>>>,
115 connection_data: Option<Data>,
116 data: Option<Arc<Data>>,
117 executor: E,
118 streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
119 #[pin]
120 stream: S,
121 protocol: Protocols,
122 last_msg_at: Instant,
123 keepalive_timer: Option<Timer>,
124 close: bool,
125 }
126}
127
128type MessageMapStream<S> =
129 futures_util::stream::Map<S, fn(<S as Stream>::Item) -> serde_json::Result<ClientMessage>>;
130
131pub type DefaultOnConnInitType = fn(serde_json::Value) -> Ready<Result<Data>>;
133
134pub type DefaultOnPingType =
136 fn(Option<&Data>, Option<serde_json::Value>) -> Ready<Result<Option<serde_json::Value>>>;
137
138pub fn default_on_connection_init(_: serde_json::Value) -> Ready<Result<Data>> {
140 futures_util::future::ready(Ok(Data::default()))
141}
142
143pub fn default_on_ping(
145 _: Option<&Data>,
146 _: Option<serde_json::Value>,
147) -> Ready<Result<Option<serde_json::Value>>> {
148 futures_util::future::ready(Ok(None))
149}
150
151impl<S, E> WebSocket<S, E, DefaultOnConnInitType, DefaultOnPingType>
152where
153 E: Executor,
154 S: Stream<Item = serde_json::Result<ClientMessage>>,
155{
156 pub fn from_message_stream(executor: E, stream: S, protocol: Protocols) -> Self {
158 WebSocket {
159 on_connection_init: Some(default_on_connection_init),
160 on_ping: default_on_ping,
161 init_fut: None,
162 ping_fut: None,
163 connection_data: None,
164 data: None,
165 executor,
166 streams: HashMap::new(),
167 stream,
168 protocol,
169 last_msg_at: Instant::now(),
170 keepalive_timer: None,
171 close: false,
172 }
173 }
174}
175
176impl<S, E> WebSocket<MessageMapStream<S>, E, DefaultOnConnInitType, DefaultOnPingType>
177where
178 E: Executor,
179 S: Stream,
180 S::Item: AsRef<[u8]>,
181{
182 pub fn new(executor: E, stream: S, protocol: Protocols) -> Self {
184 let stream = stream
185 .map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result<ClientMessage>);
186 WebSocket::from_message_stream(executor, stream, protocol)
187 }
188}
189
190impl<S, E, OnInit, OnPing> WebSocket<S, E, OnInit, OnPing>
191where
192 E: Executor,
193 S: Stream<Item = serde_json::Result<ClientMessage>>,
194{
195 #[must_use]
202 pub fn connection_data(mut self, data: Data) -> Self {
203 self.connection_data = Some(data);
204 self
205 }
206
207 #[must_use]
213 pub fn on_connection_init<F, R>(self, callback: F) -> WebSocket<S, E, F, OnPing>
214 where
215 F: FnOnce(serde_json::Value) -> R + Send + 'static,
216 R: Future<Output = Result<Data>> + Send + 'static,
217 {
218 WebSocket {
219 on_connection_init: Some(callback),
220 on_ping: self.on_ping,
221 init_fut: self.init_fut,
222 ping_fut: self.ping_fut,
223 connection_data: self.connection_data,
224 data: self.data,
225 executor: self.executor,
226 streams: self.streams,
227 stream: self.stream,
228 protocol: self.protocol,
229 last_msg_at: self.last_msg_at,
230 keepalive_timer: self.keepalive_timer,
231 close: self.close,
232 }
233 }
234
235 #[must_use]
244 pub fn on_ping<F, R>(self, callback: F) -> WebSocket<S, E, OnInit, F>
245 where
246 F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Send + Clone + 'static,
247 R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
248 {
249 WebSocket {
250 on_connection_init: self.on_connection_init,
251 on_ping: callback,
252 init_fut: self.init_fut,
253 ping_fut: self.ping_fut,
254 connection_data: self.connection_data,
255 data: self.data,
256 executor: self.executor,
257 streams: self.streams,
258 stream: self.stream,
259 protocol: self.protocol,
260 last_msg_at: self.last_msg_at,
261 keepalive_timer: self.keepalive_timer,
262 close: self.close,
263 }
264 }
265
266 #[must_use]
273 pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
274 Self {
275 keepalive_timer: timeout.into().map(Timer::new),
276 ..self
277 }
278 }
279}
280
281impl<S, E, OnInit, InitFut, OnPing, PingFut> Stream for WebSocket<S, E, OnInit, OnPing>
282where
283 E: Executor,
284 S: Stream<Item = serde_json::Result<ClientMessage>>,
285 OnInit: FnOnce(serde_json::Value) -> InitFut + Send + 'static,
286 InitFut: Future<Output = Result<Data>> + Send + 'static,
287 OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> PingFut + Clone + Send + 'static,
288 PingFut: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
289{
290 type Item = WsMessage;
291
292 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
293 let mut this = self.project();
294
295 if *this.close {
296 return Poll::Ready(None);
297 }
298
299 if let Some(keepalive_timer) = this.keepalive_timer {
300 if let Poll::Ready(Some(())) = keepalive_timer.poll_next_unpin(cx) {
301 return match this.protocol {
302 Protocols::SubscriptionsTransportWS => {
303 *this.close = true;
304 Poll::Ready(Some(WsMessage::Text(
305 serde_json::to_string(&ServerMessage::ConnectionError {
306 payload: Error::new("timeout"),
307 })
308 .unwrap(),
309 )))
310 }
311 Protocols::GraphQLWS => {
312 *this.close = true;
313 Poll::Ready(Some(WsMessage::Close(3008, "timeout".to_string())))
314 }
315 };
316 }
317 }
318
319 if this.init_fut.is_none() && this.ping_fut.is_none() {
320 while let Poll::Ready(message) = Pin::new(&mut this.stream).poll_next(cx) {
321 let message = match message {
322 Some(message) => message,
323 None => return Poll::Ready(None),
324 };
325
326 let message: ClientMessage = match message {
327 Ok(message) => message,
328 Err(err) => {
329 *this.close = true;
330 return Poll::Ready(Some(WsMessage::Close(1002, err.to_string())));
331 }
332 };
333
334 *this.last_msg_at = Instant::now();
335 if let Some(keepalive_timer) = this.keepalive_timer {
336 keepalive_timer.reset();
337 }
338
339 match message {
340 ClientMessage::ConnectionInit { payload } => {
341 if let Some(on_connection_init) = this.on_connection_init.take() {
342 *this.init_fut = Some(Box::pin(async move {
343 on_connection_init(payload.unwrap_or_default()).await
344 }));
345 break;
346 } else {
347 *this.close = true;
348 match this.protocol {
349 Protocols::SubscriptionsTransportWS => {
350 return Poll::Ready(Some(WsMessage::Text(
351 serde_json::to_string(&ServerMessage::ConnectionError {
352 payload: Error::new(
353 "Too many initialisation requests.",
354 ),
355 })
356 .unwrap(),
357 )));
358 }
359 Protocols::GraphQLWS => {
360 return Poll::Ready(Some(WsMessage::Close(
361 4429,
362 "Too many initialisation requests.".to_string(),
363 )));
364 }
365 }
366 }
367 }
368 ClientMessage::Start {
369 id,
370 payload: request,
371 } => {
372 if let Some(data) = this.data.clone() {
373 this.streams.insert(
374 id,
375 Box::pin(this.executor.execute_stream(request, Some(data))),
376 );
377 } else {
378 *this.close = true;
379 return Poll::Ready(Some(WsMessage::Close(
380 1011,
381 "The handshake is not completed.".to_string(),
382 )));
383 }
384 }
385 ClientMessage::Stop { id } => {
386 if this.streams.remove(&id).is_some() {
387 return Poll::Ready(Some(WsMessage::Text(
388 serde_json::to_string(&ServerMessage::Complete { id: &id })
389 .unwrap(),
390 )));
391 }
392 }
393 ClientMessage::ConnectionTerminate => {
397 *this.close = true;
398 return Poll::Ready(None);
399 }
400 ClientMessage::Ping { payload } => {
402 let on_ping = this.on_ping.clone();
403 let data = this.data.clone();
404 *this.ping_fut =
405 Some(Box::pin(
406 async move { on_ping(data.as_deref(), payload).await },
407 ));
408 break;
409 }
410 ClientMessage::Pong { .. } => {
411 }
413 }
414 }
415 }
416
417 if let Some(init_fut) = this.init_fut {
418 return init_fut.poll_unpin(cx).map(|res| {
419 *this.init_fut = None;
420 match res {
421 Ok(data) => {
422 let mut ctx_data = this.connection_data.take().unwrap_or_default();
423 ctx_data.merge(data);
424 *this.data = Some(Arc::new(ctx_data));
425 Some(WsMessage::Text(
426 serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
427 ))
428 }
429 Err(err) => {
430 *this.close = true;
431 match this.protocol {
432 Protocols::SubscriptionsTransportWS => Some(WsMessage::Text(
433 serde_json::to_string(&ServerMessage::ConnectionError {
434 payload: Error::new(err.message),
435 })
436 .unwrap(),
437 )),
438 Protocols::GraphQLWS => Some(WsMessage::Close(1002, err.message)),
439 }
440 }
441 }
442 });
443 }
444
445 if let Some(ping_fut) = this.ping_fut {
446 return ping_fut.poll_unpin(cx).map(|res| {
447 *this.ping_fut = None;
448 match res {
449 Ok(payload) => Some(WsMessage::Text(
450 serde_json::to_string(&ServerMessage::Pong { payload }).unwrap(),
451 )),
452 Err(err) => {
453 *this.close = true;
454 match this.protocol {
455 Protocols::SubscriptionsTransportWS => Some(WsMessage::Text(
456 serde_json::to_string(&ServerMessage::ConnectionError {
457 payload: Error::new(err.message),
458 })
459 .unwrap(),
460 )),
461 Protocols::GraphQLWS => Some(WsMessage::Close(1002, err.message)),
462 }
463 }
464 }
465 });
466 }
467
468 for (id, stream) in &mut *this.streams {
469 match Pin::new(stream).poll_next(cx) {
470 Poll::Ready(Some(payload)) => {
471 return Poll::Ready(Some(WsMessage::Text(
472 serde_json::to_string(&this.protocol.next_message(id, payload)).unwrap(),
473 )));
474 }
475 Poll::Ready(None) => {
476 let id = id.clone();
477 this.streams.remove(&id);
478 return Poll::Ready(Some(WsMessage::Text(
479 serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(),
480 )));
481 }
482 Poll::Pending => {}
483 }
484 }
485
486 Poll::Pending
487 }
488}
489
490#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
492pub enum Protocols {
493 SubscriptionsTransportWS,
495 GraphQLWS,
497}
498
499impl Protocols {
500 pub fn sec_websocket_protocol(&self) -> &'static str {
502 match self {
503 Protocols::SubscriptionsTransportWS => "graphql-ws",
504 Protocols::GraphQLWS => "graphql-transport-ws",
505 }
506 }
507
508 #[inline]
509 fn next_message<'s>(&self, id: &'s str, payload: Response) -> ServerMessage<'s> {
510 match self {
511 Protocols::SubscriptionsTransportWS => ServerMessage::Data { id, payload },
512 Protocols::GraphQLWS => ServerMessage::Next { id, payload },
513 }
514 }
515}
516
517impl std::str::FromStr for Protocols {
518 type Err = Error;
519
520 fn from_str(protocol: &str) -> Result<Self, Self::Err> {
521 if protocol.eq_ignore_ascii_case("graphql-ws") {
522 Ok(Protocols::SubscriptionsTransportWS)
523 } else if protocol.eq_ignore_ascii_case("graphql-transport-ws") {
524 Ok(Protocols::GraphQLWS)
525 } else {
526 Err(Error::new(format!(
527 "Unsupported Sec-WebSocket-Protocol: {}",
528 protocol
529 )))
530 }
531 }
532}
533
534#[derive(Deserialize)]
536#[serde(tag = "type", rename_all = "snake_case")]
537#[allow(clippy::large_enum_variant)] pub enum ClientMessage {
539 ConnectionInit {
541 payload: Option<serde_json::Value>,
543 },
544 #[serde(alias = "subscribe")]
546 Start {
547 id: String,
549 payload: Request,
552 },
553 #[serde(alias = "complete")]
555 Stop {
556 id: String,
558 },
559 ConnectionTerminate,
561 Ping {
566 payload: Option<serde_json::Value>,
568 },
569 Pong {
573 payload: Option<serde_json::Value>,
575 },
576}
577
578impl ClientMessage {
579 pub fn from_bytes<T>(message: T) -> serde_json::Result<Self>
581 where
582 T: AsRef<[u8]>,
583 {
584 serde_json::from_slice(message.as_ref())
585 }
586}
587
588#[derive(Serialize)]
589#[serde(tag = "type", rename_all = "snake_case")]
590enum ServerMessage<'a> {
591 ConnectionError {
592 payload: Error,
593 },
594 ConnectionAck,
595 Data {
597 id: &'a str,
598 payload: Response,
599 },
600 Next {
602 id: &'a str,
603 payload: Response,
604 },
605 Complete {
611 id: &'a str,
612 },
613 Pong {
617 #[serde(skip_serializing_if = "Option::is_none")]
618 payload: Option<serde_json::Value>,
619 },
620 }