fastwebsockets/
lib.rs

1// Copyright 2023 Divy Srivastava <dj.srivastava23@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! _fastwebsockets_ is a minimal, fast WebSocket server implementation.
16//!
17//! [https://github.com/denoland/fastwebsockets](https://github.com/denoland/fastwebsockets)
18//!
19//! Passes the _Autobahn|TestSuite_ and fuzzed with LLVM's _libfuzzer_.
20//!
21//! You can use it as a raw websocket frame parser and deal with spec compliance yourself, or you can use it as a full-fledged websocket server.
22//!
23//! # Example
24//!
25//! ```
26//! use tokio::net::TcpStream;
27//! use fastwebsockets::{WebSocket, OpCode, Role};
28//! use anyhow::Result;
29//!
30//! async fn handle(
31//!   socket: TcpStream,
32//! ) -> Result<()> {
33//!   let mut ws = WebSocket::after_handshake(socket, Role::Server);
34//!   ws.set_writev(false);
35//!   ws.set_auto_close(true);
36//!   ws.set_auto_pong(true);
37//!
38//!   loop {
39//!     let frame = ws.read_frame().await?;
40//!     match frame.opcode {
41//!       OpCode::Close => break,
42//!       OpCode::Text | OpCode::Binary => {
43//!         ws.write_frame(frame).await?;
44//!       }
45//!       _ => {}
46//!     }
47//!   }
48//!   Ok(())
49//! }
50//! ```
51//!
52//! ## Fragmentation
53//!
54//! By default, fastwebsockets will give the application raw frames with FIN set. Other
55//! crates like tungstenite which will give you a single message with all the frames
56//! concatenated.
57//!
58//! For concanated frames, use `FragmentCollector`:
59//! ```
60//! use fastwebsockets::{FragmentCollector, WebSocket, Role};
61//! use tokio::net::TcpStream;
62//! use anyhow::Result;
63//!
64//! async fn handle(
65//!   socket: TcpStream,
66//! ) -> Result<()> {
67//!   let mut ws = WebSocket::after_handshake(socket, Role::Server);
68//!   let mut ws = FragmentCollector::new(ws);
69//!   let incoming = ws.read_frame().await?;
70//!   // Always returns full messages
71//!   assert!(incoming.fin);
72//!   Ok(())
73//! }
74//! ```
75//!
76//! _permessage-deflate is not supported yet._
77//!
78//! ## HTTP Upgrades
79//!
80//! Enable the `upgrade` feature to do server-side upgrades and client-side
81//! handshakes.
82//!
83//! This feature is powered by [hyper](https://docs.rs/hyper).
84//!
85//! ```
86//! use fastwebsockets::upgrade::upgrade;
87//! use http_body_util::Empty;
88//! use hyper::{Request, body::{Incoming, Bytes}, Response};
89//! use anyhow::Result;
90//!
91//! async fn server_upgrade(
92//!   mut req: Request<Incoming>,
93//! ) -> Result<Response<Empty<Bytes>>> {
94//!   let (response, fut) = upgrade(&mut req)?;
95//!
96//!   tokio::spawn(async move {
97//!     let ws = fut.await;
98//!     // Do something with the websocket
99//!   });
100//!
101//!   Ok(response)
102//! }
103//! ```
104//!
105//! Use the `handshake` module for client-side handshakes.
106//!
107//! ```
108//! use fastwebsockets::handshake;
109//! use fastwebsockets::FragmentCollector;
110//! use hyper::{Request, body::Bytes, upgrade::Upgraded, header::{UPGRADE, CONNECTION}};
111//! use http_body_util::Empty;
112//! use hyper_util::rt::TokioIo;
113//! use tokio::net::TcpStream;
114//! use std::future::Future;
115//! use anyhow::Result;
116//!
117//! async fn connect() -> Result<FragmentCollector<TokioIo<Upgraded>>> {
118//!   let stream = TcpStream::connect("localhost:9001").await?;
119//!
120//!   let req = Request::builder()
121//!     .method("GET")
122//!     .uri("http://localhost:9001/")
123//!     .header("Host", "localhost:9001")
124//!     .header(UPGRADE, "websocket")
125//!     .header(CONNECTION, "upgrade")
126//!     .header(
127//!       "Sec-WebSocket-Key",
128//!       fastwebsockets::handshake::generate_key(),
129//!     )
130//!     .header("Sec-WebSocket-Version", "13")
131//!     .body(Empty::<Bytes>::new())?;
132//!
133//!   let (ws, _) = handshake::client(&SpawnExecutor, req, stream).await?;
134//!   Ok(FragmentCollector::new(ws))
135//! }
136//!
137//! // Tie hyper's executor to tokio runtime
138//! struct SpawnExecutor;
139//!
140//! impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
141//! where
142//!   Fut: Future + Send + 'static,
143//!   Fut::Output: Send + 'static,
144//! {
145//!   fn execute(&self, fut: Fut) {
146//!     tokio::task::spawn(fut);
147//!   }
148//! }
149//! ```
150
151#![cfg_attr(docsrs, feature(doc_cfg))]
152
153mod close;
154mod error;
155mod fragment;
156mod frame;
157/// Client handshake.
158#[cfg(feature = "upgrade")]
159#[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))]
160pub mod handshake;
161mod mask;
162/// HTTP upgrades.
163#[cfg(feature = "upgrade")]
164#[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))]
165pub mod upgrade;
166
167use bytes::Buf;
168
169use bytes::BytesMut;
170#[cfg(feature = "unstable-split")]
171use std::future::Future;
172
173use tokio::io::AsyncRead;
174use tokio::io::AsyncReadExt;
175use tokio::io::AsyncWrite;
176use tokio::io::AsyncWriteExt;
177
178pub use crate::close::CloseCode;
179pub use crate::error::WebSocketError;
180pub use crate::fragment::FragmentCollector;
181#[cfg(feature = "unstable-split")]
182pub use crate::fragment::FragmentCollectorRead;
183pub use crate::frame::Frame;
184pub use crate::frame::OpCode;
185pub use crate::frame::Payload;
186pub use crate::mask::unmask;
187
188#[derive(Copy, Clone, PartialEq)]
189pub enum Role {
190  Server,
191  Client,
192}
193
194pub(crate) struct WriteHalf {
195  role: Role,
196  closed: bool,
197  vectored: bool,
198  auto_apply_mask: bool,
199  writev_threshold: usize,
200  write_buffer: Vec<u8>,
201}
202
203pub(crate) struct ReadHalf {
204  role: Role,
205  auto_apply_mask: bool,
206  auto_close: bool,
207  auto_pong: bool,
208  writev_threshold: usize,
209  max_message_size: usize,
210  buffer: BytesMut,
211}
212
213#[cfg(feature = "unstable-split")]
214pub struct WebSocketRead<S> {
215  stream: S,
216  read_half: ReadHalf,
217}
218
219#[cfg(feature = "unstable-split")]
220pub struct WebSocketWrite<S> {
221  stream: S,
222  write_half: WriteHalf,
223}
224
225#[cfg(feature = "unstable-split")]
226/// Create a split `WebSocketRead`/`WebSocketWrite` pair from a stream that has already completed the WebSocket handshake.
227pub fn after_handshake_split<R, W>(
228  read: R,
229  write: W,
230  role: Role,
231) -> (WebSocketRead<R>, WebSocketWrite<W>)
232where
233  R: AsyncRead + Unpin,
234  W: AsyncWrite + Unpin,
235{
236  (
237    WebSocketRead {
238      stream: read,
239      read_half: ReadHalf::after_handshake(role),
240    },
241    WebSocketWrite {
242      stream: write,
243      write_half: WriteHalf::after_handshake(role),
244    },
245  )
246}
247
248#[cfg(feature = "unstable-split")]
249impl<'f, S> WebSocketRead<S> {
250  /// Consumes the `WebSocketRead` and returns the underlying stream.
251  #[inline]
252  pub(crate) fn into_parts_internal(self) -> (S, ReadHalf) {
253    (self.stream, self.read_half)
254  }
255
256  pub fn set_writev_threshold(&mut self, threshold: usize) {
257    self.read_half.writev_threshold = threshold;
258  }
259
260  /// Sets whether to automatically close the connection when a close frame is received. When set to `false`, the application will have to manually send close frames.
261  ///
262  /// Default: `true`
263  pub fn set_auto_close(&mut self, auto_close: bool) {
264    self.read_half.auto_close = auto_close;
265  }
266
267  /// Sets whether to automatically send a pong frame when a ping frame is received.
268  ///
269  /// Default: `true`
270  pub fn set_auto_pong(&mut self, auto_pong: bool) {
271    self.read_half.auto_pong = auto_pong;
272  }
273
274  /// Sets the maximum message size in bytes. If a message is received that is larger than this, the connection will be closed.
275  ///
276  /// Default: 64 MiB
277  pub fn set_max_message_size(&mut self, max_message_size: usize) {
278    self.read_half.max_message_size = max_message_size;
279  }
280
281  /// Sets whether to automatically apply the mask to the frame payload.
282  ///
283  /// Default: `true`
284  pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
285    self.read_half.auto_apply_mask = auto_apply_mask;
286  }
287
288  /// Reads a frame from the stream.
289  pub async fn read_frame<R, E>(
290    &mut self,
291    send_fn: &mut impl FnMut(Frame<'f>) -> R,
292  ) -> Result<Frame, WebSocketError>
293  where
294    S: AsyncRead + Unpin,
295    E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
296    R: Future<Output = Result<(), E>>,
297  {
298    loop {
299      let (res, obligated_send) =
300        self.read_half.read_frame_inner(&mut self.stream).await;
301      if let Some(frame) = obligated_send {
302        let res = send_fn(frame).await;
303        res.map_err(|e| WebSocketError::SendError(e.into()))?;
304      }
305      if let Some(frame) = res? {
306        break Ok(frame);
307      }
308    }
309  }
310}
311
312#[cfg(feature = "unstable-split")]
313impl<'f, S> WebSocketWrite<S> {
314  /// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used.
315  ///
316  /// Default: `true`
317  pub fn set_writev(&mut self, vectored: bool) {
318    self.write_half.vectored = vectored;
319  }
320
321  pub fn set_writev_threshold(&mut self, threshold: usize) {
322    self.write_half.writev_threshold = threshold;
323  }
324
325  /// Sets whether to automatically apply the mask to the frame payload.
326  ///
327  /// Default: `true`
328  pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
329    self.write_half.auto_apply_mask = auto_apply_mask;
330  }
331
332  pub fn is_closed(&self) -> bool {
333    self.write_half.closed
334  }
335
336  pub async fn write_frame(
337    &mut self,
338    frame: Frame<'f>,
339  ) -> Result<(), WebSocketError>
340  where
341    S: AsyncWrite + Unpin,
342  {
343    self.write_half.write_frame(&mut self.stream, frame).await
344  }
345
346  pub async fn flush(&mut self) -> Result<(), WebSocketError>
347  where
348    S: AsyncWrite + Unpin,
349  {
350    flush(&mut self.stream).await
351  }
352}
353
354#[inline]
355async fn flush<S>(stream: &mut S) -> Result<(), WebSocketError>
356where
357  S: AsyncWrite + Unpin,
358{
359  stream.flush().await.map_err(WebSocketError::IoError)
360}
361
362/// WebSocket protocol implementation over an async stream.
363pub struct WebSocket<S> {
364  stream: S,
365  write_half: WriteHalf,
366  read_half: ReadHalf,
367}
368
369impl<'f, S> WebSocket<S> {
370  /// Creates a new `WebSocket` from a stream that has already completed the WebSocket handshake.
371  ///
372  /// Use the `upgrade` feature to handle server upgrades and client handshakes.
373  ///
374  /// # Example
375  ///
376  /// ```
377  /// use tokio::net::TcpStream;
378  /// use fastwebsockets::{WebSocket, OpCode, Role};
379  /// use anyhow::Result;
380  ///
381  /// async fn handle_client(
382  ///   socket: TcpStream,
383  /// ) -> Result<()> {
384  ///   let mut ws = WebSocket::after_handshake(socket, Role::Server);
385  ///   // ...
386  ///   Ok(())
387  /// }
388  /// ```
389  pub fn after_handshake(stream: S, role: Role) -> Self
390  where
391    S: AsyncRead + AsyncWrite + Unpin,
392  {
393    Self {
394      stream,
395      write_half: WriteHalf::after_handshake(role),
396      read_half: ReadHalf::after_handshake(role),
397    }
398  }
399
400  /// Split a [`WebSocket`] into a [`WebSocketRead`] and [`WebSocketWrite`] half. Note that the split version does not
401  /// handle fragmented packets and you may wish to create a [`FragmentCollectorRead`] over top of the read half that
402  /// is returned.
403  #[cfg(feature = "unstable-split")]
404  pub fn split<R, W>(
405    self,
406    split_fn: impl Fn(S) -> (R, W),
407  ) -> (WebSocketRead<R>, WebSocketWrite<W>)
408  where
409    S: AsyncRead + AsyncWrite + Unpin,
410    R: AsyncRead + Unpin,
411    W: AsyncWrite + Unpin,
412  {
413    let (stream, read, write) = self.into_parts_internal();
414    let (r, w) = split_fn(stream);
415    (
416      WebSocketRead {
417        stream: r,
418        read_half: read,
419      },
420      WebSocketWrite {
421        stream: w,
422        write_half: write,
423      },
424    )
425  }
426
427  /// Consumes the `WebSocket` and returns the underlying stream.
428  #[inline]
429  pub fn into_inner(self) -> S {
430    // self.write_half.into_inner().stream
431    self.stream
432  }
433
434  /// Consumes the `WebSocket` and returns the underlying stream.
435  #[inline]
436  pub(crate) fn into_parts_internal(self) -> (S, ReadHalf, WriteHalf) {
437    (self.stream, self.read_half, self.write_half)
438  }
439
440  /// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used.
441  ///
442  /// Default: `true`
443  pub fn set_writev(&mut self, vectored: bool) {
444    self.write_half.vectored = vectored;
445  }
446
447  pub fn set_writev_threshold(&mut self, threshold: usize) {
448    self.read_half.writev_threshold = threshold;
449    self.write_half.writev_threshold = threshold;
450  }
451
452  /// Sets whether to automatically close the connection when a close frame is received. When set to `false`, the application will have to manually send close frames.
453  ///
454  /// Default: `true`
455  pub fn set_auto_close(&mut self, auto_close: bool) {
456    self.read_half.auto_close = auto_close;
457  }
458
459  /// Sets whether to automatically send a pong frame when a ping frame is received.
460  ///
461  /// Default: `true`
462  pub fn set_auto_pong(&mut self, auto_pong: bool) {
463    self.read_half.auto_pong = auto_pong;
464  }
465
466  /// Sets the maximum message size in bytes. If a message is received that is larger than this, the connection will be closed.
467  ///
468  /// Default: 64 MiB
469  pub fn set_max_message_size(&mut self, max_message_size: usize) {
470    self.read_half.max_message_size = max_message_size;
471  }
472
473  /// Sets whether to automatically apply the mask to the frame payload.
474  ///
475  /// Default: `true`
476  pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
477    self.read_half.auto_apply_mask = auto_apply_mask;
478    self.write_half.auto_apply_mask = auto_apply_mask;
479  }
480
481  pub fn is_closed(&self) -> bool {
482    self.write_half.closed
483  }
484
485  /// Writes a frame to the stream.
486  ///
487  /// # Example
488  ///
489  /// ```
490  /// use fastwebsockets::{WebSocket, Frame, OpCode};
491  /// use tokio::net::TcpStream;
492  /// use anyhow::Result;
493  ///
494  /// async fn send(
495  ///   ws: &mut WebSocket<TcpStream>
496  /// ) -> Result<()> {
497  ///   let mut frame = Frame::binary(vec![0x01, 0x02, 0x03].into());
498  ///   ws.write_frame(frame).await?;
499  ///   Ok(())
500  /// }
501  /// ```
502  pub async fn write_frame(
503    &mut self,
504    frame: Frame<'f>,
505  ) -> Result<(), WebSocketError>
506  where
507    S: AsyncRead + AsyncWrite + Unpin,
508  {
509    self.write_half.write_frame(&mut self.stream, frame).await?;
510    Ok(())
511  }
512
513  /// Flushes the data from the underlying stream.
514  ///
515  /// if the underlying stream is buffered (i.e: TlsStream<TcpStream>), it is needed to call flush
516  /// to be sure that the written frame are correctly pushed down to the bottom stream/channel.
517  ///
518  pub async fn flush(&mut self) -> Result<(), WebSocketError>
519  where
520    S: AsyncWrite + Unpin,
521  {
522    flush(&mut self.stream).await
523  }
524
525  /// Reads a frame from the stream.
526  ///
527  /// This method will unmask the frame payload. For fragmented frames, use `FragmentCollector::read_frame`.
528  ///
529  /// Text frames payload is guaranteed to be valid UTF-8.
530  ///
531  /// # Example
532  ///
533  /// ```
534  /// use fastwebsockets::{OpCode, WebSocket, Frame};
535  /// use tokio::net::TcpStream;
536  /// use anyhow::Result;
537  ///
538  /// async fn echo(
539  ///   ws: &mut WebSocket<TcpStream>
540  /// ) -> Result<()> {
541  ///   let frame = ws.read_frame().await?;
542  ///   match frame.opcode {
543  ///     OpCode::Text | OpCode::Binary => {
544  ///       ws.write_frame(frame).await?;
545  ///     }
546  ///     _ => {}
547  ///   }
548  ///   Ok(())
549  /// }
550  /// ```
551  pub async fn read_frame(&mut self) -> Result<Frame<'f>, WebSocketError>
552  where
553    S: AsyncRead + AsyncWrite + Unpin,
554  {
555    loop {
556      let (res, obligated_send) =
557        self.read_half.read_frame_inner(&mut self.stream).await;
558      let is_closed = self.write_half.closed;
559      if let Some(frame) = obligated_send {
560        if !is_closed {
561          self.write_half.write_frame(&mut self.stream, frame).await?;
562        }
563      }
564      if let Some(frame) = res? {
565        if is_closed && frame.opcode != OpCode::Close {
566          return Err(WebSocketError::ConnectionClosed);
567        }
568        break Ok(frame);
569      }
570    }
571  }
572}
573
574const MAX_HEADER_SIZE: usize = 14;
575
576impl ReadHalf {
577  pub fn after_handshake(role: Role) -> Self {
578    let buffer = BytesMut::with_capacity(8192);
579
580    Self {
581      role,
582      auto_apply_mask: true,
583      auto_close: true,
584      auto_pong: true,
585      writev_threshold: 1024,
586      max_message_size: 64 << 20,
587      buffer,
588    }
589  }
590
591  /// Attempt to read a single frame from from the incoming stream, returning any send obligations if
592  /// `auto_close` or `auto_pong` are enabled. Callers to this function are obligated to send the
593  /// frame in the latter half of the tuple if one is specified, unless the write half of this socket
594  /// has been closed.
595  ///
596  /// XXX: Do not expose this method to the public API.
597  pub(crate) async fn read_frame_inner<'f, S>(
598    &mut self,
599    stream: &mut S,
600  ) -> (Result<Option<Frame<'f>>, WebSocketError>, Option<Frame<'f>>)
601  where
602    S: AsyncRead + Unpin,
603  {
604    let mut frame = match self.parse_frame_header(stream).await {
605      Ok(frame) => frame,
606      Err(e) => return (Err(e), None),
607    };
608
609    if self.role == Role::Server && self.auto_apply_mask {
610      frame.unmask()
611    };
612
613    match frame.opcode {
614      OpCode::Close if self.auto_close => {
615        match frame.payload.len() {
616          0 => {}
617          1 => return (Err(WebSocketError::InvalidCloseFrame), None),
618          _ => {
619            let code = close::CloseCode::from(u16::from_be_bytes(
620              frame.payload[0..2].try_into().unwrap(),
621            ));
622
623            #[cfg(feature = "simd")]
624            if simdutf8::basic::from_utf8(&frame.payload[2..]).is_err() {
625              return (Err(WebSocketError::InvalidUTF8), None);
626            };
627
628            #[cfg(not(feature = "simd"))]
629            if std::str::from_utf8(&frame.payload[2..]).is_err() {
630              return (Err(WebSocketError::InvalidUTF8), None);
631            };
632
633            if !code.is_allowed() {
634              return (
635                Err(WebSocketError::InvalidCloseCode),
636                Some(Frame::close(1002, &frame.payload[2..])),
637              );
638            }
639          }
640        };
641
642        let obligated_send = Frame::close_raw(frame.payload.to_owned().into());
643        (Ok(Some(frame)), Some(obligated_send))
644      }
645      OpCode::Ping if self.auto_pong => {
646        (Ok(None), Some(Frame::pong(frame.payload)))
647      }
648      OpCode::Text => {
649        if frame.fin && !frame.is_utf8() {
650          (Err(WebSocketError::InvalidUTF8), None)
651        } else {
652          (Ok(Some(frame)), None)
653        }
654      }
655      _ => (Ok(Some(frame)), None),
656    }
657  }
658
659  async fn parse_frame_header<'a, S>(
660    &mut self,
661    stream: &mut S,
662  ) -> Result<Frame<'a>, WebSocketError>
663  where
664    S: AsyncRead + Unpin,
665  {
666    macro_rules! eof {
667      ($n:expr) => {{
668        if $n == 0 {
669          return Err(WebSocketError::UnexpectedEOF);
670        }
671      }};
672    }
673
674    // Read the first two bytes
675    while self.buffer.remaining() < 2 {
676      eof!(stream.read_buf(&mut self.buffer).await?);
677    }
678
679    let fin = self.buffer[0] & 0b10000000 != 0;
680    let rsv1 = self.buffer[0] & 0b01000000 != 0;
681    let rsv2 = self.buffer[0] & 0b00100000 != 0;
682    let rsv3 = self.buffer[0] & 0b00010000 != 0;
683
684    if rsv1 || rsv2 || rsv3 {
685      return Err(WebSocketError::ReservedBitsNotZero);
686    }
687
688    let opcode = frame::OpCode::try_from(self.buffer[0] & 0b00001111)?;
689    let masked = self.buffer[1] & 0b10000000 != 0;
690
691    let length_code = self.buffer[1] & 0x7F;
692    let extra = match length_code {
693      126 => 2,
694      127 => 8,
695      _ => 0,
696    };
697
698    self.buffer.advance(2);
699    while self.buffer.remaining() < extra + masked as usize * 4 {
700      eof!(stream.read_buf(&mut self.buffer).await?);
701    }
702
703    let payload_len: usize = match extra {
704      0 => usize::from(length_code),
705      2 => self.buffer.get_u16() as usize,
706      #[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))]
707      8 => self.buffer.get_u64() as usize,
708      // On 32bit systems, usize is only 4bytes wide so we must check for usize overflowing
709      #[cfg(any(
710        target_pointer_width = "8",
711        target_pointer_width = "16",
712        target_pointer_width = "32"
713      ))]
714      8 => match usize::try_from(self.buffer.get_u64()) {
715        Ok(length) => length,
716        Err(_) => return Err(WebSocketError::FrameTooLarge),
717      },
718      _ => unreachable!(),
719    };
720
721    let mask = if masked {
722      Some(self.buffer.get_u32().to_be_bytes())
723    } else {
724      None
725    };
726
727    if frame::is_control(opcode) && !fin {
728      return Err(WebSocketError::ControlFrameFragmented);
729    }
730
731    if opcode == OpCode::Ping && payload_len > 125 {
732      return Err(WebSocketError::PingFrameTooLarge);
733    }
734
735    if payload_len >= self.max_message_size {
736      return Err(WebSocketError::FrameTooLarge);
737    }
738
739    // Reserve a bit more to try to get next frame header and avoid a syscall to read it next time
740    self.buffer.reserve(payload_len + MAX_HEADER_SIZE);
741    while payload_len > self.buffer.remaining() {
742      eof!(stream.read_buf(&mut self.buffer).await?);
743    }
744
745    // if we read too much it will stay in the buffer, for the next call to this method
746    let payload = self.buffer.split_to(payload_len);
747    let frame = Frame::new(fin, opcode, mask, Payload::Bytes(payload));
748    Ok(frame)
749  }
750}
751
752impl WriteHalf {
753  pub fn after_handshake(role: Role) -> Self {
754    Self {
755      role,
756      closed: false,
757      auto_apply_mask: true,
758      vectored: true,
759      writev_threshold: 1024,
760      write_buffer: Vec::with_capacity(2),
761    }
762  }
763
764  /// Writes a frame to the provided stream.
765  pub async fn write_frame<'a, S>(
766    &'a mut self,
767    stream: &mut S,
768    mut frame: Frame<'a>,
769  ) -> Result<(), WebSocketError>
770  where
771    S: AsyncWrite + Unpin,
772  {
773    if self.role == Role::Client && self.auto_apply_mask {
774      frame.mask();
775    }
776
777    if frame.opcode == OpCode::Close {
778      self.closed = true;
779    } else if self.closed {
780      return Err(WebSocketError::ConnectionClosed);
781    }
782
783    if self.vectored && frame.payload.len() > self.writev_threshold {
784      frame.writev(stream).await?;
785    } else {
786      let text = frame.write(&mut self.write_buffer);
787      stream.write_all(text).await?;
788    }
789
790    Ok(())
791  }
792}
793
794#[cfg(test)]
795mod tests {
796  use super::*;
797
798  const _: () = {
799    const fn assert_unsync<S>() {
800      // Generic trait with a blanket impl over `()` for all types.
801      trait AmbiguousIfImpl<A> {
802        // Required for actually being able to reference the trait.
803        fn some_item() {}
804      }
805
806      impl<T: ?Sized> AmbiguousIfImpl<()> for T {}
807
808      // Used for the specialized impl when *all* traits in
809      // `$($t)+` are implemented.
810      #[allow(dead_code)]
811      struct Invalid;
812
813      impl<T: ?Sized + Sync> AmbiguousIfImpl<Invalid> for T {}
814
815      // If there is only one specialized trait impl, type inference with
816      // `_` can be resolved and this can compile. Fails to compile if
817      // `$x` implements `AmbiguousIfImpl<Invalid>`.
818      let _ = <S as AmbiguousIfImpl<_>>::some_item;
819    }
820    assert_unsync::<WebSocket<tokio::net::TcpStream>>();
821  };
822}