1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
use std::{
collections::VecDeque,
future::poll_fn,
io, mem,
pin::Pin,
task::{Context, Poll},
};
use actix_codec::{Decoder, Encoder};
use actix_http::{
ws::{Codec, Frame, Message, ProtocolError},
Payload,
};
use actix_web::{
web::{Bytes, BytesMut},
Error,
};
use bytestring::ByteString;
use futures_core::stream::Stream;
use tokio::sync::mpsc::Receiver;
use crate::AggregatedMessageStream;
/// Response body for a WebSocket.
pub struct StreamingBody {
session_rx: Receiver<Message>,
messages: VecDeque<Message>,
buf: BytesMut,
codec: Codec,
closing: bool,
}
impl StreamingBody {
pub(super) fn new(session_rx: Receiver<Message>) -> Self {
StreamingBody {
session_rx,
messages: VecDeque::new(),
buf: BytesMut::new(),
codec: Codec::new(),
closing: false,
}
}
}
/// Stream of messages from a WebSocket client.
pub struct MessageStream {
payload: Payload,
messages: VecDeque<Message>,
buf: BytesMut,
codec: Codec,
closing: bool,
}
impl MessageStream {
pub(super) fn new(payload: Payload) -> Self {
MessageStream {
payload,
messages: VecDeque::new(),
buf: BytesMut::new(),
codec: Codec::new(),
closing: false,
}
}
/// Sets the maximum permitted size for received WebSocket frames, in bytes.
///
/// By default, up to 64KiB is allowed.
///
/// Any received frames larger than the permitted value will return
/// `Err(ProtocolError::Overflow)` instead.
///
/// ```no_run
/// # use actix_ws::MessageStream;
/// # fn test(stream: MessageStream) {
/// // increase permitted frame size from 64KB to 1MB
/// let stream = stream.max_frame_size(1024 * 1024);
/// # }
/// ```
#[must_use]
pub fn max_frame_size(mut self, max_size: usize) -> Self {
self.codec = self.codec.max_size(max_size);
self
}
/// Returns a stream wrapper that collects continuation frames into their equivalent aggregated
/// forms, i.e., binary or text.
///
/// By default, continuations will be aggregated up to 1MiB in size (customizable with
/// [`AggregatedMessageStream::max_continuation_size()`]). The stream implementation returns an
/// error if this size is exceeded.
#[must_use]
pub fn aggregate_continuations(self) -> AggregatedMessageStream {
AggregatedMessageStream::new(self)
}
/// Waits for the next item from the message stream
///
/// This is a convenience for calling the [`Stream`](Stream::poll_next()) implementation.
///
/// ```no_run
/// # use actix_ws::MessageStream;
/// # async fn test(mut stream: MessageStream) {
/// while let Some(Ok(msg)) = stream.recv().await {
/// // handle message
/// }
/// # }
/// ```
#[must_use]
pub async fn recv(&mut self) -> Option<Result<Message, ProtocolError>> {
poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
}
}
impl Stream for StreamingBody {
type Item = Result<Bytes, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.closing {
return Poll::Ready(None);
}
loop {
match Pin::new(&mut this.session_rx).poll_recv(cx) {
Poll::Ready(Some(msg)) => {
this.messages.push_back(msg);
}
Poll::Ready(None) => {
this.closing = true;
break;
}
Poll::Pending => break,
}
}
while let Some(msg) = this.messages.pop_front() {
if let Err(err) = this.codec.encode(msg, &mut this.buf) {
return Poll::Ready(Some(Err(err.into())));
}
}
if !this.buf.is_empty() {
return Poll::Ready(Some(Ok(mem::take(&mut this.buf).freeze())));
}
Poll::Pending
}
}
impl Stream for MessageStream {
type Item = Result<Message, ProtocolError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
// Return the first message in the queue if one exists
//
// This is faster than polling and parsing
if let Some(msg) = this.messages.pop_front() {
return Poll::Ready(Some(Ok(msg)));
}
if !this.closing {
// Read in bytes until there's nothing left to read
loop {
match Pin::new(&mut this.payload).poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
this.buf.extend_from_slice(&bytes);
}
Poll::Ready(Some(Err(err))) => {
return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::other(err)))));
}
Poll::Ready(None) => {
this.closing = true;
break;
}
Poll::Pending => break,
}
}
}
// Create messages until there's no more bytes left
while let Some(frame) = this.codec.decode(&mut this.buf)? {
let message = match frame {
Frame::Text(bytes) => {
ByteString::try_from(bytes)
.map(Message::Text)
.map_err(|err| {
ProtocolError::Io(io::Error::new(io::ErrorKind::InvalidData, err))
})?
}
Frame::Binary(bytes) => Message::Binary(bytes),
Frame::Ping(bytes) => Message::Ping(bytes),
Frame::Pong(bytes) => Message::Pong(bytes),
Frame::Close(reason) => Message::Close(reason),
Frame::Continuation(item) => Message::Continuation(item),
};
this.messages.push_back(message);
}
// Return the first message in the queue
if let Some(msg) = this.messages.pop_front() {
return Poll::Ready(Some(Ok(msg)));
}
// If we've exhausted our message queue and we're closing, close the stream
if this.closing {
return Poll::Ready(None);
}
Poll::Pending
}
}