1#![cfg_attr(docsrs, feature(doc_cfg))]
152
153mod close;
154mod error;
155mod fragment;
156mod frame;
157#[cfg(feature = "upgrade")]
159#[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))]
160pub mod handshake;
161mod mask;
162#[cfg(feature = "upgrade")]
164#[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))]
165pub mod upgrade;
166
167use bytes::Buf;
168
169use bytes::BytesMut;
170#[cfg(feature = "unstable-split")]
171use std::future::Future;
172
173use tokio::io::AsyncRead;
174use tokio::io::AsyncReadExt;
175use tokio::io::AsyncWrite;
176use tokio::io::AsyncWriteExt;
177
178pub use crate::close::CloseCode;
179pub use crate::error::WebSocketError;
180pub use crate::fragment::FragmentCollector;
181#[cfg(feature = "unstable-split")]
182pub use crate::fragment::FragmentCollectorRead;
183pub use crate::frame::Frame;
184pub use crate::frame::OpCode;
185pub use crate::frame::Payload;
186pub use crate::mask::unmask;
187
188#[derive(Copy, Clone, PartialEq)]
189pub enum Role {
190 Server,
191 Client,
192}
193
194pub(crate) struct WriteHalf {
195 role: Role,
196 closed: bool,
197 vectored: bool,
198 auto_apply_mask: bool,
199 writev_threshold: usize,
200 write_buffer: Vec<u8>,
201}
202
203pub(crate) struct ReadHalf {
204 role: Role,
205 auto_apply_mask: bool,
206 auto_close: bool,
207 auto_pong: bool,
208 writev_threshold: usize,
209 max_message_size: usize,
210 buffer: BytesMut,
211}
212
213#[cfg(feature = "unstable-split")]
214pub struct WebSocketRead<S> {
215 stream: S,
216 read_half: ReadHalf,
217}
218
219#[cfg(feature = "unstable-split")]
220pub struct WebSocketWrite<S> {
221 stream: S,
222 write_half: WriteHalf,
223}
224
225#[cfg(feature = "unstable-split")]
226pub fn after_handshake_split<R, W>(
228 read: R,
229 write: W,
230 role: Role,
231) -> (WebSocketRead<R>, WebSocketWrite<W>)
232where
233 R: AsyncRead + Unpin,
234 W: AsyncWrite + Unpin,
235{
236 (
237 WebSocketRead {
238 stream: read,
239 read_half: ReadHalf::after_handshake(role),
240 },
241 WebSocketWrite {
242 stream: write,
243 write_half: WriteHalf::after_handshake(role),
244 },
245 )
246}
247
248#[cfg(feature = "unstable-split")]
249impl<'f, S> WebSocketRead<S> {
250 #[inline]
252 pub(crate) fn into_parts_internal(self) -> (S, ReadHalf) {
253 (self.stream, self.read_half)
254 }
255
256 pub fn set_writev_threshold(&mut self, threshold: usize) {
257 self.read_half.writev_threshold = threshold;
258 }
259
260 pub fn set_auto_close(&mut self, auto_close: bool) {
264 self.read_half.auto_close = auto_close;
265 }
266
267 pub fn set_auto_pong(&mut self, auto_pong: bool) {
271 self.read_half.auto_pong = auto_pong;
272 }
273
274 pub fn set_max_message_size(&mut self, max_message_size: usize) {
278 self.read_half.max_message_size = max_message_size;
279 }
280
281 pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
285 self.read_half.auto_apply_mask = auto_apply_mask;
286 }
287
288 pub async fn read_frame<R, E>(
290 &mut self,
291 send_fn: &mut impl FnMut(Frame<'f>) -> R,
292 ) -> Result<Frame, WebSocketError>
293 where
294 S: AsyncRead + Unpin,
295 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
296 R: Future<Output = Result<(), E>>,
297 {
298 loop {
299 let (res, obligated_send) =
300 self.read_half.read_frame_inner(&mut self.stream).await;
301 if let Some(frame) = obligated_send {
302 let res = send_fn(frame).await;
303 res.map_err(|e| WebSocketError::SendError(e.into()))?;
304 }
305 if let Some(frame) = res? {
306 break Ok(frame);
307 }
308 }
309 }
310}
311
312#[cfg(feature = "unstable-split")]
313impl<'f, S> WebSocketWrite<S> {
314 pub fn set_writev(&mut self, vectored: bool) {
318 self.write_half.vectored = vectored;
319 }
320
321 pub fn set_writev_threshold(&mut self, threshold: usize) {
322 self.write_half.writev_threshold = threshold;
323 }
324
325 pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
329 self.write_half.auto_apply_mask = auto_apply_mask;
330 }
331
332 pub fn is_closed(&self) -> bool {
333 self.write_half.closed
334 }
335
336 pub async fn write_frame(
337 &mut self,
338 frame: Frame<'f>,
339 ) -> Result<(), WebSocketError>
340 where
341 S: AsyncWrite + Unpin,
342 {
343 self.write_half.write_frame(&mut self.stream, frame).await
344 }
345
346 pub async fn flush(&mut self) -> Result<(), WebSocketError>
347 where
348 S: AsyncWrite + Unpin,
349 {
350 flush(&mut self.stream).await
351 }
352}
353
354#[inline]
355async fn flush<S>(stream: &mut S) -> Result<(), WebSocketError>
356where
357 S: AsyncWrite + Unpin,
358{
359 stream.flush().await.map_err(WebSocketError::IoError)
360}
361
362pub struct WebSocket<S> {
364 stream: S,
365 write_half: WriteHalf,
366 read_half: ReadHalf,
367}
368
369impl<'f, S> WebSocket<S> {
370 pub fn after_handshake(stream: S, role: Role) -> Self
390 where
391 S: AsyncRead + AsyncWrite + Unpin,
392 {
393 Self {
394 stream,
395 write_half: WriteHalf::after_handshake(role),
396 read_half: ReadHalf::after_handshake(role),
397 }
398 }
399
400 #[cfg(feature = "unstable-split")]
404 pub fn split<R, W>(
405 self,
406 split_fn: impl Fn(S) -> (R, W),
407 ) -> (WebSocketRead<R>, WebSocketWrite<W>)
408 where
409 S: AsyncRead + AsyncWrite + Unpin,
410 R: AsyncRead + Unpin,
411 W: AsyncWrite + Unpin,
412 {
413 let (stream, read, write) = self.into_parts_internal();
414 let (r, w) = split_fn(stream);
415 (
416 WebSocketRead {
417 stream: r,
418 read_half: read,
419 },
420 WebSocketWrite {
421 stream: w,
422 write_half: write,
423 },
424 )
425 }
426
427 #[inline]
429 pub fn into_inner(self) -> S {
430 self.stream
432 }
433
434 #[inline]
436 pub(crate) fn into_parts_internal(self) -> (S, ReadHalf, WriteHalf) {
437 (self.stream, self.read_half, self.write_half)
438 }
439
440 pub fn set_writev(&mut self, vectored: bool) {
444 self.write_half.vectored = vectored;
445 }
446
447 pub fn set_writev_threshold(&mut self, threshold: usize) {
448 self.read_half.writev_threshold = threshold;
449 self.write_half.writev_threshold = threshold;
450 }
451
452 pub fn set_auto_close(&mut self, auto_close: bool) {
456 self.read_half.auto_close = auto_close;
457 }
458
459 pub fn set_auto_pong(&mut self, auto_pong: bool) {
463 self.read_half.auto_pong = auto_pong;
464 }
465
466 pub fn set_max_message_size(&mut self, max_message_size: usize) {
470 self.read_half.max_message_size = max_message_size;
471 }
472
473 pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
477 self.read_half.auto_apply_mask = auto_apply_mask;
478 self.write_half.auto_apply_mask = auto_apply_mask;
479 }
480
481 pub fn is_closed(&self) -> bool {
482 self.write_half.closed
483 }
484
485 pub async fn write_frame(
503 &mut self,
504 frame: Frame<'f>,
505 ) -> Result<(), WebSocketError>
506 where
507 S: AsyncRead + AsyncWrite + Unpin,
508 {
509 self.write_half.write_frame(&mut self.stream, frame).await?;
510 Ok(())
511 }
512
513 pub async fn flush(&mut self) -> Result<(), WebSocketError>
519 where
520 S: AsyncWrite + Unpin,
521 {
522 flush(&mut self.stream).await
523 }
524
525 pub async fn read_frame(&mut self) -> Result<Frame<'f>, WebSocketError>
552 where
553 S: AsyncRead + AsyncWrite + Unpin,
554 {
555 loop {
556 let (res, obligated_send) =
557 self.read_half.read_frame_inner(&mut self.stream).await;
558 let is_closed = self.write_half.closed;
559 if let Some(frame) = obligated_send {
560 if !is_closed {
561 self.write_half.write_frame(&mut self.stream, frame).await?;
562 }
563 }
564 if let Some(frame) = res? {
565 if is_closed && frame.opcode != OpCode::Close {
566 return Err(WebSocketError::ConnectionClosed);
567 }
568 break Ok(frame);
569 }
570 }
571 }
572}
573
574const MAX_HEADER_SIZE: usize = 14;
575
576impl ReadHalf {
577 pub fn after_handshake(role: Role) -> Self {
578 let buffer = BytesMut::with_capacity(8192);
579
580 Self {
581 role,
582 auto_apply_mask: true,
583 auto_close: true,
584 auto_pong: true,
585 writev_threshold: 1024,
586 max_message_size: 64 << 20,
587 buffer,
588 }
589 }
590
591 pub(crate) async fn read_frame_inner<'f, S>(
598 &mut self,
599 stream: &mut S,
600 ) -> (Result<Option<Frame<'f>>, WebSocketError>, Option<Frame<'f>>)
601 where
602 S: AsyncRead + Unpin,
603 {
604 let mut frame = match self.parse_frame_header(stream).await {
605 Ok(frame) => frame,
606 Err(e) => return (Err(e), None),
607 };
608
609 if self.role == Role::Server && self.auto_apply_mask {
610 frame.unmask()
611 };
612
613 match frame.opcode {
614 OpCode::Close if self.auto_close => {
615 match frame.payload.len() {
616 0 => {}
617 1 => return (Err(WebSocketError::InvalidCloseFrame), None),
618 _ => {
619 let code = close::CloseCode::from(u16::from_be_bytes(
620 frame.payload[0..2].try_into().unwrap(),
621 ));
622
623 #[cfg(feature = "simd")]
624 if simdutf8::basic::from_utf8(&frame.payload[2..]).is_err() {
625 return (Err(WebSocketError::InvalidUTF8), None);
626 };
627
628 #[cfg(not(feature = "simd"))]
629 if std::str::from_utf8(&frame.payload[2..]).is_err() {
630 return (Err(WebSocketError::InvalidUTF8), None);
631 };
632
633 if !code.is_allowed() {
634 return (
635 Err(WebSocketError::InvalidCloseCode),
636 Some(Frame::close(1002, &frame.payload[2..])),
637 );
638 }
639 }
640 };
641
642 let obligated_send = Frame::close_raw(frame.payload.to_owned().into());
643 (Ok(Some(frame)), Some(obligated_send))
644 }
645 OpCode::Ping if self.auto_pong => {
646 (Ok(None), Some(Frame::pong(frame.payload)))
647 }
648 OpCode::Text => {
649 if frame.fin && !frame.is_utf8() {
650 (Err(WebSocketError::InvalidUTF8), None)
651 } else {
652 (Ok(Some(frame)), None)
653 }
654 }
655 _ => (Ok(Some(frame)), None),
656 }
657 }
658
659 async fn parse_frame_header<'a, S>(
660 &mut self,
661 stream: &mut S,
662 ) -> Result<Frame<'a>, WebSocketError>
663 where
664 S: AsyncRead + Unpin,
665 {
666 macro_rules! eof {
667 ($n:expr) => {{
668 if $n == 0 {
669 return Err(WebSocketError::UnexpectedEOF);
670 }
671 }};
672 }
673
674 while self.buffer.remaining() < 2 {
676 eof!(stream.read_buf(&mut self.buffer).await?);
677 }
678
679 let fin = self.buffer[0] & 0b10000000 != 0;
680 let rsv1 = self.buffer[0] & 0b01000000 != 0;
681 let rsv2 = self.buffer[0] & 0b00100000 != 0;
682 let rsv3 = self.buffer[0] & 0b00010000 != 0;
683
684 if rsv1 || rsv2 || rsv3 {
685 return Err(WebSocketError::ReservedBitsNotZero);
686 }
687
688 let opcode = frame::OpCode::try_from(self.buffer[0] & 0b00001111)?;
689 let masked = self.buffer[1] & 0b10000000 != 0;
690
691 let length_code = self.buffer[1] & 0x7F;
692 let extra = match length_code {
693 126 => 2,
694 127 => 8,
695 _ => 0,
696 };
697
698 self.buffer.advance(2);
699 while self.buffer.remaining() < extra + masked as usize * 4 {
700 eof!(stream.read_buf(&mut self.buffer).await?);
701 }
702
703 let payload_len: usize = match extra {
704 0 => usize::from(length_code),
705 2 => self.buffer.get_u16() as usize,
706 #[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))]
707 8 => self.buffer.get_u64() as usize,
708 #[cfg(any(
710 target_pointer_width = "8",
711 target_pointer_width = "16",
712 target_pointer_width = "32"
713 ))]
714 8 => match usize::try_from(self.buffer.get_u64()) {
715 Ok(length) => length,
716 Err(_) => return Err(WebSocketError::FrameTooLarge),
717 },
718 _ => unreachable!(),
719 };
720
721 let mask = if masked {
722 Some(self.buffer.get_u32().to_be_bytes())
723 } else {
724 None
725 };
726
727 if frame::is_control(opcode) && !fin {
728 return Err(WebSocketError::ControlFrameFragmented);
729 }
730
731 if opcode == OpCode::Ping && payload_len > 125 {
732 return Err(WebSocketError::PingFrameTooLarge);
733 }
734
735 if payload_len >= self.max_message_size {
736 return Err(WebSocketError::FrameTooLarge);
737 }
738
739 self.buffer.reserve(payload_len + MAX_HEADER_SIZE);
741 while payload_len > self.buffer.remaining() {
742 eof!(stream.read_buf(&mut self.buffer).await?);
743 }
744
745 let payload = self.buffer.split_to(payload_len);
747 let frame = Frame::new(fin, opcode, mask, Payload::Bytes(payload));
748 Ok(frame)
749 }
750}
751
752impl WriteHalf {
753 pub fn after_handshake(role: Role) -> Self {
754 Self {
755 role,
756 closed: false,
757 auto_apply_mask: true,
758 vectored: true,
759 writev_threshold: 1024,
760 write_buffer: Vec::with_capacity(2),
761 }
762 }
763
764 pub async fn write_frame<'a, S>(
766 &'a mut self,
767 stream: &mut S,
768 mut frame: Frame<'a>,
769 ) -> Result<(), WebSocketError>
770 where
771 S: AsyncWrite + Unpin,
772 {
773 if self.role == Role::Client && self.auto_apply_mask {
774 frame.mask();
775 }
776
777 if frame.opcode == OpCode::Close {
778 self.closed = true;
779 } else if self.closed {
780 return Err(WebSocketError::ConnectionClosed);
781 }
782
783 if self.vectored && frame.payload.len() > self.writev_threshold {
784 frame.writev(stream).await?;
785 } else {
786 let text = frame.write(&mut self.write_buffer);
787 stream.write_all(text).await?;
788 }
789
790 Ok(())
791 }
792}
793
794#[cfg(test)]
795mod tests {
796 use super::*;
797
798 const _: () = {
799 const fn assert_unsync<S>() {
800 trait AmbiguousIfImpl<A> {
802 fn some_item() {}
804 }
805
806 impl<T: ?Sized> AmbiguousIfImpl<()> for T {}
807
808 #[allow(dead_code)]
811 struct Invalid;
812
813 impl<T: ?Sized + Sync> AmbiguousIfImpl<Invalid> for T {}
814
815 let _ = <S as AmbiguousIfImpl<_>>::some_item;
819 }
820 assert_unsync::<WebSocket<tokio::net::TcpStream>>();
821 };
822}