extern crate bytes;
extern crate tokio_codec;
use std::borrow::Borrow;
use std::io::Cursor;
use std::marker::PhantomData;
use std::mem;
use self::bytes::BufMut;
use self::bytes::BytesMut;
use self::tokio_codec::Decoder;
use self::tokio_codec::Encoder;
use crate::dataframe::DataFrame;
use crate::message::OwnedMessage;
use crate::result::WebSocketError;
use crate::ws::dataframe::DataFrame as DataFrameTrait;
use crate::ws::message::Message as MessageTrait;
use crate::ws::util::header::read_header;
const DEFAULT_MAX_DATAFRAME_SIZE : usize = 1024*1024*100;
const DEFAULT_MAX_MESSAGE_SIZE : usize = 1024*1024*200;
const MAX_DATAFRAMES_IN_ONE_MESSAGE: usize = 1024*1024;
const PER_DATAFRAME_OVERHEAD : usize = 64;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum Context {
Server,
Client,
}
pub struct DataFrameCodec<D> {
is_server: bool,
frame_type: PhantomData<D>,
max_dataframe_size: u32,
}
impl DataFrameCodec<DataFrame> {
pub fn default(context: Context) -> Self {
DataFrameCodec::new(context)
}
}
impl<D> DataFrameCodec<D> {
pub fn new(context: Context) -> DataFrameCodec<D> {
DataFrameCodec::new_with_limits(context, DEFAULT_MAX_DATAFRAME_SIZE)
}
pub fn new_with_limits(context: Context, max_dataframe_size: usize) -> DataFrameCodec<D> {
let max_dataframe_size: u32 = max_dataframe_size.min(u32::MAX as usize) as u32;
DataFrameCodec {
is_server: context == Context::Server,
frame_type: PhantomData,
max_dataframe_size,
}
}
}
impl<D> Decoder for DataFrameCodec<D> {
type Item = DataFrame;
type Error = WebSocketError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let (header, bytes_read) = {
let mut reader = Cursor::new(src.as_ref());
let header = match read_header(&mut reader) {
Ok(head) => head,
Err(WebSocketError::NoDataAvailable) => return Ok(None),
Err(e) => return Err(e),
};
(header, reader.position())
};
if header.len > self.max_dataframe_size as u64 {
return Err(WebSocketError::ProtocolError(
"Exceeded maximum incoming DataFrame size",
));
}
if header.len + bytes_read > src.len() as u64 {
return Ok(None);
}
let _ = src.split_to(bytes_read as usize);
let body = src.split_to(header.len as usize).to_vec();
Ok(Some(DataFrame::read_dataframe_body(
header,
body,
self.is_server,
)?))
}
}
impl<D> Encoder for DataFrameCodec<D>
where
D: Borrow<dyn DataFrameTrait>,
{
type Item = D;
type Error = WebSocketError;
fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
let masked = !self.is_server;
let frame_size = item.borrow().frame_size(masked);
if frame_size > dst.remaining_mut() {
dst.reserve(frame_size);
}
item.borrow().write_to(&mut dst.writer(), masked)
}
}
pub struct MessageCodec<M>
where
M: MessageTrait,
{
buffer: Vec<DataFrame>,
dataframe_codec: DataFrameCodec<DataFrame>,
message_type: PhantomData<fn(M)>,
max_message_size: u32,
}
impl MessageCodec<OwnedMessage> {
pub fn default(context: Context) -> Self {
Self::new(context)
}
}
impl<M> MessageCodec<M>
where
M: MessageTrait,
{
pub fn new(context: Context) -> MessageCodec<M> {
MessageCodec::new_with_limits(context, DEFAULT_MAX_DATAFRAME_SIZE, DEFAULT_MAX_MESSAGE_SIZE)
}
pub fn new_with_limits(context: Context, max_dataframe_size: usize, max_message_size: usize) -> MessageCodec<M> {
let max_message_size: u32 = max_message_size.min(u32::MAX as usize) as u32;
MessageCodec {
buffer: Vec::new(),
dataframe_codec: DataFrameCodec::new_with_limits(context, max_dataframe_size),
message_type: PhantomData,
max_message_size,
}
}
}
impl<M> Decoder for MessageCodec<M>
where
M: MessageTrait,
{
type Item = OwnedMessage;
type Error = WebSocketError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let mut current_message_length : usize = self.buffer.iter().map(|x|x.data.len()).sum();
while let Some(frame) = self.dataframe_codec.decode(src)? {
let is_first = self.buffer.is_empty();
let finished = frame.finished;
match frame.opcode as u8 {
0 if is_first => {
return Err(WebSocketError::ProtocolError(
"Unexpected continuation data frame opcode",
));
}
8..=15 => {
return Ok(Some(OwnedMessage::from_dataframes(vec![frame])?));
}
1..=7 if !is_first => {
return Err(WebSocketError::ProtocolError(
"Unexpected data frame opcode",
));
}
_ => {
current_message_length += frame.data.len() + PER_DATAFRAME_OVERHEAD;
self.buffer.push(frame);
}
};
if finished {
let buffer = mem::replace(&mut self.buffer, Vec::new());
return Ok(Some(OwnedMessage::from_dataframes(buffer)?));
} else {
if self.buffer.len() >= MAX_DATAFRAMES_IN_ONE_MESSAGE {
return Err(WebSocketError::ProtocolError(
"Exceeded count of data frames in one WebSocket message",
));
}
if current_message_length > self.max_message_size as usize {
return Err(WebSocketError::ProtocolError(
"Exceeded maximum WebSocket message size",
));
}
}
}
Ok(None)
}
}
impl<M> Encoder for MessageCodec<M>
where
M: MessageTrait,
{
type Item = M;
type Error = WebSocketError;
fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
let masked = !self.dataframe_codec.is_server;
let frame_size = item.message_size(masked);
if frame_size > dst.remaining_mut() {
dst.reserve(frame_size);
}
item.serialize(&mut dst.writer(), masked)
}
}
#[cfg(test)]
mod tests {
extern crate tokio;
use super::*;
use crate::message::CloseData;
use crate::message::Message;
use crate::stream::ReadWritePair;
use futures::{Future, Sink, Stream};
use std::io::Cursor;
#[test]
fn owned_message_predicts_size() {
let messages = vec![
OwnedMessage::Text("nilbog".to_string()),
OwnedMessage::Binary(vec![1, 2, 3, 4]),
OwnedMessage::Binary(vec![42; 256]),
OwnedMessage::Binary(vec![42; 65535]),
OwnedMessage::Binary(vec![42; 65555]),
OwnedMessage::Ping("beep".to_string().into_bytes()),
OwnedMessage::Pong("boop".to_string().into_bytes()),
OwnedMessage::Close(None),
OwnedMessage::Close(Some(CloseData {
status_code: 64,
reason: "because".to_string(),
})),
];
for message in messages.into_iter() {
let masked_predicted = message.message_size(true);
let mut masked_buf = Vec::new();
message.serialize(&mut masked_buf, true).unwrap();
assert_eq!(masked_buf.len(), masked_predicted);
let unmasked_predicted = message.message_size(false);
let mut unmasked_buf = Vec::new();
message.serialize(&mut unmasked_buf, false).unwrap();
assert_eq!(unmasked_buf.len(), unmasked_predicted);
}
}
#[test]
fn cow_message_predicts_size() {
let messages = vec![
Message::binary(vec![1, 2, 3, 4]),
Message::binary(vec![42; 256]),
Message::binary(vec![42; 65535]),
Message::binary(vec![42; 65555]),
Message::text("nilbog".to_string()),
Message::ping("beep".to_string().into_bytes()),
Message::pong("boop".to_string().into_bytes()),
Message::close(),
Message::close_because(64, "because"),
];
for message in messages.iter() {
let masked_predicted = message.message_size(true);
let mut masked_buf = Vec::new();
message.serialize(&mut masked_buf, true).unwrap();
assert_eq!(masked_buf.len(), masked_predicted);
let unmasked_predicted = message.message_size(false);
let mut unmasked_buf = Vec::new();
message.serialize(&mut unmasked_buf, false).unwrap();
assert_eq!(unmasked_buf.len(), unmasked_predicted);
}
}
#[test]
fn message_codec_client_send_receive() {
let mut input = Vec::new();
Message::text("50 schmeckels")
.serialize(&mut input, false)
.unwrap();
let f = MessageCodec::new(Context::Client)
.framed(ReadWritePair(Cursor::new(input), Cursor::new(vec![])))
.into_future()
.map_err(|e| e.0)
.map(|(m, s)| {
assert_eq!(m, Some(OwnedMessage::Text("50 schmeckels".to_string())));
s
})
.and_then(|s| s.send(Message::text("ethan bradberry")))
.and_then(|s| {
let mut stream = s.into_parts().io;
stream.1.set_position(0);
println!("buffer: {:?}", stream.1);
MessageCodec::default(Context::Server)
.framed(ReadWritePair(stream.1, stream.0))
.into_future()
.map_err(|e| e.0)
.map(|(message, _)| {
assert_eq!(message, Some(Message::text("ethan bradberry").into()))
})
});
tokio::runtime::Builder::new()
.build()
.unwrap()
.block_on(f)
.unwrap();
}
#[test]
fn message_codec_server_send_receive() {
let mut runtime = tokio::runtime::Builder::new().build().unwrap();
let mut input = Vec::new();
Message::text("50 schmeckels")
.serialize(&mut input, true)
.unwrap();
let f = MessageCodec::new(Context::Server)
.framed(ReadWritePair(Cursor::new(input), Cursor::new(vec![])))
.into_future()
.map_err(|e| e.0)
.map(|(m, s)| {
assert_eq!(m, Some(OwnedMessage::Text("50 schmeckels".to_string())));
s
})
.and_then(|s| s.send(Message::text("ethan bradberry")))
.map(|s| {
let mut written = vec![];
Message::text("ethan bradberry")
.serialize(&mut written, false)
.unwrap();
assert_eq!(written, s.into_parts().io.1.into_inner());
});
runtime.block_on(f).unwrap();
}
}