1extern crate bytes;
11extern crate tokio_codec;
12
13use std::borrow::Borrow;
14use std::io::Cursor;
15use std::marker::PhantomData;
16use std::mem;
17
18use self::bytes::BufMut;
19use self::bytes::BytesMut;
20use self::tokio_codec::Decoder;
21use self::tokio_codec::Encoder;
22
23use crate::dataframe::DataFrame;
24use crate::message::OwnedMessage;
25use crate::result::WebSocketError;
26use crate::ws::dataframe::DataFrame as DataFrameTrait;
27use crate::ws::message::Message as MessageTrait;
28use crate::ws::util::header::read_header;
29
30const DEFAULT_MAX_DATAFRAME_SIZE : usize = 1024*1024*100;
31const DEFAULT_MAX_MESSAGE_SIZE : usize = 1024*1024*200;
32const MAX_DATAFRAMES_IN_ONE_MESSAGE: usize = 1024*1024;
33const PER_DATAFRAME_OVERHEAD : usize = 64;
34
35#[derive(Clone, Copy, PartialEq, Eq, Debug)]
43pub enum Context {
44 Server,
47 Client,
50}
51
52pub struct DataFrameCodec<D> {
68 is_server: bool,
69 frame_type: PhantomData<D>,
70 max_dataframe_size: u32,
71}
72
73impl DataFrameCodec<DataFrame> {
74 pub fn default(context: Context) -> Self {
80 DataFrameCodec::new(context)
81 }
82}
83
84impl<D> DataFrameCodec<D> {
85 pub fn new(context: Context) -> DataFrameCodec<D> {
94 DataFrameCodec::new_with_limits(context, DEFAULT_MAX_DATAFRAME_SIZE)
95 }
96
97 pub fn new_with_limits(context: Context, max_dataframe_size: usize) -> DataFrameCodec<D> {
98 let max_dataframe_size: u32 = max_dataframe_size.min(u32::MAX as usize) as u32;
99 DataFrameCodec {
100 is_server: context == Context::Server,
101 frame_type: PhantomData,
102 max_dataframe_size,
103 }
104 }
105}
106
107impl<D> Decoder for DataFrameCodec<D> {
108 type Item = DataFrame;
109 type Error = WebSocketError;
110
111 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
113 let (header, bytes_read) = {
114 let mut reader = Cursor::new(src.as_ref());
116
117 let header = match read_header(&mut reader) {
119 Ok(head) => head,
120 Err(WebSocketError::NoDataAvailable) => return Ok(None),
121 Err(e) => return Err(e),
122 };
123
124 (header, reader.position())
125 };
126
127 if header.len > self.max_dataframe_size as u64 {
128 return Err(WebSocketError::ProtocolError(
129 "Exceeded maximum incoming DataFrame size",
130 ));
131 }
132
133 if header.len + bytes_read > src.len() as u64 {
135 return Ok(None);
136 }
137
138 let _ = src.split_to(bytes_read as usize);
140 let body = src.split_to(header.len as usize).to_vec();
141
142 Ok(Some(DataFrame::read_dataframe_body(
144 header,
145 body,
146 self.is_server,
147 )?))
148 }
149}
150
151impl<D> Encoder for DataFrameCodec<D>
152where
153 D: Borrow<dyn DataFrameTrait>,
154{
155 type Item = D;
156 type Error = WebSocketError;
157
158 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
159 let masked = !self.is_server;
160 let frame_size = item.borrow().frame_size(masked);
161 if frame_size > dst.remaining_mut() {
162 dst.reserve(frame_size);
163 }
164 item.borrow().write_to(&mut dst.writer(), masked)
165 }
166}
167
168pub struct MessageCodec<M>
220where
221 M: MessageTrait,
222{
223 buffer: Vec<DataFrame>,
224 dataframe_codec: DataFrameCodec<DataFrame>,
225 message_type: PhantomData<fn(M)>,
226 max_message_size: u32,
227}
228
229impl MessageCodec<OwnedMessage> {
230 pub fn default(context: Context) -> Self {
240 Self::new(context)
241 }
242}
243
244impl<M> MessageCodec<M>
245where
246 M: MessageTrait,
247{
248 pub fn new(context: Context) -> MessageCodec<M> {
257 MessageCodec::new_with_limits(context, DEFAULT_MAX_DATAFRAME_SIZE, DEFAULT_MAX_MESSAGE_SIZE)
258 }
259
260 pub fn new_with_limits(context: Context, max_dataframe_size: usize, max_message_size: usize) -> MessageCodec<M> {
261 let max_message_size: u32 = max_message_size.min(u32::MAX as usize) as u32;
262 MessageCodec {
263 buffer: Vec::new(),
264 dataframe_codec: DataFrameCodec::new_with_limits(context, max_dataframe_size),
265 message_type: PhantomData,
266 max_message_size,
267 }
268 }
269}
270
271impl<M> Decoder for MessageCodec<M>
272where
273 M: MessageTrait,
274{
275 type Item = OwnedMessage;
276 type Error = WebSocketError;
277
278 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
279 let mut current_message_length : usize = self.buffer.iter().map(|x|x.data.len()).sum();
280 while let Some(frame) = self.dataframe_codec.decode(src)? {
281 let is_first = self.buffer.is_empty();
282 let finished = frame.finished;
283
284 match frame.opcode as u8 {
285 0 if is_first => {
287 return Err(WebSocketError::ProtocolError(
288 "Unexpected continuation data frame opcode",
289 ));
290 }
291 8..=15 => {
293 return Ok(Some(OwnedMessage::from_dataframes(vec![frame])?));
294 }
295 1..=7 if !is_first => {
297 return Err(WebSocketError::ProtocolError(
298 "Unexpected data frame opcode",
299 ));
300 }
301 _ => {
303 current_message_length += frame.data.len() + PER_DATAFRAME_OVERHEAD;
304 self.buffer.push(frame);
305 }
306 };
307
308 if finished {
309 let buffer = mem::replace(&mut self.buffer, Vec::new());
310 return Ok(Some(OwnedMessage::from_dataframes(buffer)?));
311 } else {
312 if self.buffer.len() >= MAX_DATAFRAMES_IN_ONE_MESSAGE {
313 return Err(WebSocketError::ProtocolError(
314 "Exceeded count of data frames in one WebSocket message",
315 ));
316 }
317 if current_message_length > self.max_message_size as usize {
318 return Err(WebSocketError::ProtocolError(
319 "Exceeded maximum WebSocket message size",
320 ));
321 }
322 }
323 }
324
325 Ok(None)
326 }
327}
328
329impl<M> Encoder for MessageCodec<M>
330where
331 M: MessageTrait,
332{
333 type Item = M;
334 type Error = WebSocketError;
335
336 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
337 let masked = !self.dataframe_codec.is_server;
338 let frame_size = item.message_size(masked);
339 if frame_size > dst.remaining_mut() {
340 dst.reserve(frame_size);
341 }
342 item.serialize(&mut dst.writer(), masked)
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 extern crate tokio;
349 use super::*;
350 use crate::message::CloseData;
351 use crate::message::Message;
352 use crate::stream::ReadWritePair;
353 use futures::{Future, Sink, Stream};
354 use std::io::Cursor;
355
356 #[test]
357 fn owned_message_predicts_size() {
358 let messages = vec![
359 OwnedMessage::Text("nilbog".to_string()),
360 OwnedMessage::Binary(vec![1, 2, 3, 4]),
361 OwnedMessage::Binary(vec![42; 256]),
362 OwnedMessage::Binary(vec![42; 65535]),
363 OwnedMessage::Binary(vec![42; 65555]),
364 OwnedMessage::Ping("beep".to_string().into_bytes()),
365 OwnedMessage::Pong("boop".to_string().into_bytes()),
366 OwnedMessage::Close(None),
367 OwnedMessage::Close(Some(CloseData {
368 status_code: 64,
369 reason: "because".to_string(),
370 })),
371 ];
372
373 for message in messages.into_iter() {
374 let masked_predicted = message.message_size(true);
375 let mut masked_buf = Vec::new();
376 message.serialize(&mut masked_buf, true).unwrap();
377 assert_eq!(masked_buf.len(), masked_predicted);
378
379 let unmasked_predicted = message.message_size(false);
380 let mut unmasked_buf = Vec::new();
381 message.serialize(&mut unmasked_buf, false).unwrap();
382 assert_eq!(unmasked_buf.len(), unmasked_predicted);
383 }
384 }
385
386 #[test]
387 fn cow_message_predicts_size() {
388 let messages = vec![
389 Message::binary(vec![1, 2, 3, 4]),
390 Message::binary(vec![42; 256]),
391 Message::binary(vec![42; 65535]),
392 Message::binary(vec![42; 65555]),
393 Message::text("nilbog".to_string()),
394 Message::ping("beep".to_string().into_bytes()),
395 Message::pong("boop".to_string().into_bytes()),
396 Message::close(),
397 Message::close_because(64, "because"),
398 ];
399
400 for message in messages.iter() {
401 let masked_predicted = message.message_size(true);
402 let mut masked_buf = Vec::new();
403 message.serialize(&mut masked_buf, true).unwrap();
404 assert_eq!(masked_buf.len(), masked_predicted);
405
406 let unmasked_predicted = message.message_size(false);
407 let mut unmasked_buf = Vec::new();
408 message.serialize(&mut unmasked_buf, false).unwrap();
409 assert_eq!(unmasked_buf.len(), unmasked_predicted);
410 }
411 }
412
413 #[test]
414 fn message_codec_client_send_receive() {
415 let mut input = Vec::new();
416 Message::text("50 schmeckels")
417 .serialize(&mut input, false)
418 .unwrap();
419
420 let f = MessageCodec::new(Context::Client)
421 .framed(ReadWritePair(Cursor::new(input), Cursor::new(vec![])))
422 .into_future()
423 .map_err(|e| e.0)
424 .map(|(m, s)| {
425 assert_eq!(m, Some(OwnedMessage::Text("50 schmeckels".to_string())));
426 s
427 })
428 .and_then(|s| s.send(Message::text("ethan bradberry")))
429 .and_then(|s| {
430 let mut stream = s.into_parts().io;
431 stream.1.set_position(0);
432 println!("buffer: {:?}", stream.1);
433 MessageCodec::default(Context::Server)
434 .framed(ReadWritePair(stream.1, stream.0))
435 .into_future()
436 .map_err(|e| e.0)
437 .map(|(message, _)| {
438 assert_eq!(message, Some(Message::text("ethan bradberry").into()))
439 })
440 });
441
442 tokio::runtime::Builder::new()
443 .build()
444 .unwrap()
445 .block_on(f)
446 .unwrap();
447 }
448
449 #[test]
450 fn message_codec_server_send_receive() {
451 let mut runtime = tokio::runtime::Builder::new().build().unwrap();
452 let mut input = Vec::new();
453 Message::text("50 schmeckels")
454 .serialize(&mut input, true)
455 .unwrap();
456
457 let f = MessageCodec::new(Context::Server)
458 .framed(ReadWritePair(Cursor::new(input), Cursor::new(vec![])))
459 .into_future()
460 .map_err(|e| e.0)
461 .map(|(m, s)| {
462 assert_eq!(m, Some(OwnedMessage::Text("50 schmeckels".to_string())));
463 s
464 })
465 .and_then(|s| s.send(Message::text("ethan bradberry")))
466 .map(|s| {
467 let mut written = vec![];
468 Message::text("ethan bradberry")
469 .serialize(&mut written, false)
470 .unwrap();
471 assert_eq!(written, s.into_parts().io.1.into_inner());
472 });
473
474 runtime.block_on(f).unwrap();
475 }
476}