websocket_base/codec/
ws.rs

1//! Send websocket messages and dataframes asynchronously.
2//!
3//! This module provides codecs that can be be used with `tokio` to create
4//! asynchronous streams that can serialize/deserialize websocket messages
5//! (and dataframes for users that want low level control).
6//!
7//! For websocket messages, see the documentation for `MessageCodec`, for
8//! dataframes see the documentation for `DataFrameCodec`
9
10extern crate bytes;
11extern crate tokio_codec;
12
13use std::borrow::Borrow;
14use std::io::Cursor;
15use std::marker::PhantomData;
16use std::mem;
17
18use self::bytes::BufMut;
19use self::bytes::BytesMut;
20use self::tokio_codec::Decoder;
21use self::tokio_codec::Encoder;
22
23use crate::dataframe::DataFrame;
24use crate::message::OwnedMessage;
25use crate::result::WebSocketError;
26use crate::ws::dataframe::DataFrame as DataFrameTrait;
27use crate::ws::message::Message as MessageTrait;
28use crate::ws::util::header::read_header;
29
30const DEFAULT_MAX_DATAFRAME_SIZE : usize = 1024*1024*100;
31const DEFAULT_MAX_MESSAGE_SIZE : usize = 1024*1024*200;
32const MAX_DATAFRAMES_IN_ONE_MESSAGE: usize = 1024*1024;
33const PER_DATAFRAME_OVERHEAD : usize = 64;
34
35/// Even though a websocket connection may look perfectly symmetrical
36/// in reality there are small differences between clients and servers.
37/// This type is passed to the codecs to inform them of what role they are in
38/// (i.e. that of a Client or Server).
39///
40/// For those familiar with the protocol, this decides whether the data should be
41/// masked or not.
42#[derive(Clone, Copy, PartialEq, Eq, Debug)]
43pub enum Context {
44	/// Set the codec to act in `Server` mode, used when
45	/// implementing a websocket server.
46	Server,
47	/// Set the codec to act in `Client` mode, used when
48	/// implementing a websocket client.
49	Client,
50}
51
52/**************
53 * Dataframes *
54 **************/
55
56/// A codec for decoding and encoding websocket dataframes.
57///
58/// This codec decodes dataframes into the crates default implementation
59/// of `Dataframe` but can encode and send any struct that implements the
60/// `ws::Dataframe` trait. The type of struct to encode is given by the `D`
61/// type parameter in the struct.
62///
63/// Using dataframes directly is meant for users who want low-level access to the
64/// connection. If you don't want to do anything low-level please use the
65/// `MessageCodec` codec instead, or better yet use the `ClientBuilder` to make
66/// clients and the `Server` to make servers.
67pub struct DataFrameCodec<D> {
68	is_server: bool,
69	frame_type: PhantomData<D>,
70	max_dataframe_size: u32,
71}
72
73impl DataFrameCodec<DataFrame> {
74	/// Create a new `DataFrameCodec` struct using the crate's implementation
75	/// of dataframes for reading and writing dataframes.
76	///
77	/// Use this method if you don't want to provide a custom implementation
78	/// for your dataframes.
79	pub fn default(context: Context) -> Self {
80		DataFrameCodec::new(context)
81	}
82}
83
84impl<D> DataFrameCodec<D> {
85	/// Create a new `DataFrameCodec` struct using any implementation of
86	/// `ws::Dataframe` you want. This is useful if you want to manipulate
87	/// the websocket layer very specifically.
88	///
89	/// If you only want to be able to send and receive the crate's
90	/// `DataFrame` struct use `.default(Context)` instead.
91	/// 
92	/// There is a default dataframe size limit imposed. Use `new_with_limits` to override it
93	pub fn new(context: Context) -> DataFrameCodec<D> {
94		DataFrameCodec::new_with_limits(context, DEFAULT_MAX_DATAFRAME_SIZE)
95	}
96
97	pub fn new_with_limits(context: Context, max_dataframe_size: usize) ->  DataFrameCodec<D> {
98		let max_dataframe_size: u32 = max_dataframe_size.min(u32::MAX as usize) as u32;
99		DataFrameCodec {
100			is_server: context == Context::Server,
101			frame_type: PhantomData,
102			max_dataframe_size,
103		}
104	}
105}
106
107impl<D> Decoder for DataFrameCodec<D> {
108	type Item = DataFrame;
109	type Error = WebSocketError;
110
111	// TODO: do not retry to read the header on each new data (keep a buffer)
112	fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
113		let (header, bytes_read) = {
114			// we'll make a fake reader and keep track of the bytes read
115			let mut reader = Cursor::new(src.as_ref());
116
117			// read header to get the size, bail if not enough
118			let header = match read_header(&mut reader) {
119				Ok(head) => head,
120				Err(WebSocketError::NoDataAvailable) => return Ok(None),
121				Err(e) => return Err(e),
122			};
123
124			(header, reader.position())
125		};
126
127		if header.len > self.max_dataframe_size as u64 {
128			return Err(WebSocketError::ProtocolError(
129				"Exceeded maximum incoming DataFrame size",
130			));
131		}
132
133		// check if we have enough bytes to continue
134		if header.len + bytes_read > src.len() as u64 {
135			return Ok(None);
136		}
137
138		// TODO: using usize is not the right thing here (can be larger)
139		let _ = src.split_to(bytes_read as usize);
140		let body = src.split_to(header.len as usize).to_vec();
141
142		// construct a dataframe
143		Ok(Some(DataFrame::read_dataframe_body(
144			header,
145			body,
146			self.is_server,
147		)?))
148	}
149}
150
151impl<D> Encoder for DataFrameCodec<D>
152where
153	D: Borrow<dyn DataFrameTrait>,
154{
155	type Item = D;
156	type Error = WebSocketError;
157
158	fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
159		let masked = !self.is_server;
160		let frame_size = item.borrow().frame_size(masked);
161		if frame_size > dst.remaining_mut() {
162			dst.reserve(frame_size);
163		}
164		item.borrow().write_to(&mut dst.writer(), masked)
165	}
166}
167
168/************
169 * Messages *
170 ************/
171
172/// A codec for asynchronously decoding and encoding websocket messages.
173///
174/// This codec decodes messages into the `OwnedMessage` struct, so using this
175/// the user will receive messages as `OwnedMessage`s. However it can encode
176/// any type of message that implements the `ws::Message` trait (that type is
177/// decided by the `M` type parameter) like `OwnedMessage` and `Message`.
178///
179/// Warning: if you don't know what your doing or want a simple websocket connection
180/// please use the `ClientBuilder` or the `Server` structs. You should only use this
181/// after a websocket handshake has already been completed on the stream you are
182/// using.
183///
184///# Example (for the high-level `websocket` crate)
185///
186///```rust,ignore
187///# extern crate tokio;
188///# extern crate websocket;
189///# extern crate hyper;
190///# use std::io::{self, Cursor};
191///use websocket::async::{MessageCodec, MsgCodecCtx};
192///# use websocket::{Message, OwnedMessage};
193///# use websocket::ws::Message as MessageTrait;
194///# use websocket::stream::ReadWritePair;
195///# use websocket::async::futures::{Future, Sink, Stream};
196///# use hyper::http::h1::Incoming;
197///# use hyper::version::HttpVersion;
198///# use hyper::header::Headers;
199///# use hyper::method::Method;
200///# use hyper::uri::RequestUri;
201///# use hyper::status::StatusCode;
202///# use tokio::codec::Decoder;
203///# fn main() {
204///
205///let mut runtime = tokio::runtime::Builder::new().build().unwrap();
206///let mut input = Vec::new();
207///Message::text("50 schmeckels").serialize(&mut input, false);
208///
209///let f = MessageCodec::default(MsgCodecCtx::Client)
210///    .framed(ReadWritePair(Cursor::new(input), Cursor::new(vec![])))
211///    .into_future()
212///    .map_err(|e| e.0)
213///    .map(|(m, _)| {
214///        assert_eq!(m, Some(OwnedMessage::Text("50 schmeckels".to_string())));
215///    });
216///
217///runtime.block_on(f).unwrap();
218///# }
219pub struct MessageCodec<M>
220where
221	M: MessageTrait,
222{
223	buffer: Vec<DataFrame>,
224	dataframe_codec: DataFrameCodec<DataFrame>,
225	message_type: PhantomData<fn(M)>,
226	max_message_size: u32,
227}
228
229impl MessageCodec<OwnedMessage> {
230	/// Create a new `MessageCodec` with a role of `context` (either `Client`
231	/// or `Server`) to read and write messages asynchronously.
232	///
233	/// This will create the crate's default codec which sends and receives
234	/// `OwnedMessage` structs. The message data has to be sent to an intermediate
235	/// buffer anyway so sending owned data is preferable.
236	///
237	/// If you have your own implementation of websocket messages, you can
238	/// use the `new` method to create a codec for that implementation.
239	pub fn default(context: Context) -> Self {
240		Self::new(context)
241	}
242}
243
244impl<M> MessageCodec<M>
245where
246	M: MessageTrait,
247{
248	/// Creates a codec that can encode a custom implementation of a websocket
249	/// message.
250	///
251	/// If you just want to use a normal codec without a specific implementation
252	/// of a websocket message, take a look at `MessageCodec::default`.
253	/// 
254	/// The codec automatically imposes default limits on message and data frame size.
255	/// Use `new_with_limits` to override them.
256	pub fn new(context: Context) -> MessageCodec<M> {
257		MessageCodec::new_with_limits(context, DEFAULT_MAX_DATAFRAME_SIZE, DEFAULT_MAX_MESSAGE_SIZE)
258	}
259
260	pub fn new_with_limits(context: Context, max_dataframe_size: usize, max_message_size: usize) -> MessageCodec<M> {
261		let max_message_size: u32 = max_message_size.min(u32::MAX as usize) as u32;
262		MessageCodec {
263			buffer: Vec::new(),
264			dataframe_codec: DataFrameCodec::new_with_limits(context, max_dataframe_size),
265			message_type: PhantomData,
266			max_message_size,
267		}
268	}
269}
270
271impl<M> Decoder for MessageCodec<M>
272where
273	M: MessageTrait,
274{
275	type Item = OwnedMessage;
276	type Error = WebSocketError;
277
278	fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
279		let mut current_message_length : usize = self.buffer.iter().map(|x|x.data.len()).sum();
280		while let Some(frame) = self.dataframe_codec.decode(src)? {
281			let is_first = self.buffer.is_empty();
282			let finished = frame.finished;
283
284			match frame.opcode as u8 {
285				// continuation code
286				0 if is_first => {
287					return Err(WebSocketError::ProtocolError(
288						"Unexpected continuation data frame opcode",
289					));
290				}
291				// control frame
292				8..=15 => {
293					return Ok(Some(OwnedMessage::from_dataframes(vec![frame])?));
294				}
295				// data frame
296				1..=7 if !is_first => {
297					return Err(WebSocketError::ProtocolError(
298						"Unexpected data frame opcode",
299					));
300				}
301				// its good
302				_ => {
303					current_message_length += frame.data.len() + PER_DATAFRAME_OVERHEAD;
304					self.buffer.push(frame);
305				}
306			};
307
308			if finished {
309				let buffer = mem::replace(&mut self.buffer, Vec::new());
310				return Ok(Some(OwnedMessage::from_dataframes(buffer)?));
311			} else {
312				if self.buffer.len() >= MAX_DATAFRAMES_IN_ONE_MESSAGE {
313					return Err(WebSocketError::ProtocolError(
314						"Exceeded count of data frames in one WebSocket message",
315					));
316				}
317				if current_message_length > self.max_message_size as usize {
318					return Err(WebSocketError::ProtocolError(
319						"Exceeded maximum WebSocket message size",
320					));
321				}
322			}
323		}
324
325		Ok(None)
326	}
327}
328
329impl<M> Encoder for MessageCodec<M>
330where
331	M: MessageTrait,
332{
333	type Item = M;
334	type Error = WebSocketError;
335
336	fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
337		let masked = !self.dataframe_codec.is_server;
338		let frame_size = item.message_size(masked);
339		if frame_size > dst.remaining_mut() {
340			dst.reserve(frame_size);
341		}
342		item.serialize(&mut dst.writer(), masked)
343	}
344}
345
346#[cfg(test)]
347mod tests {
348	extern crate tokio;
349	use super::*;
350	use crate::message::CloseData;
351	use crate::message::Message;
352	use crate::stream::ReadWritePair;
353	use futures::{Future, Sink, Stream};
354	use std::io::Cursor;
355
356	#[test]
357	fn owned_message_predicts_size() {
358		let messages = vec![
359			OwnedMessage::Text("nilbog".to_string()),
360			OwnedMessage::Binary(vec![1, 2, 3, 4]),
361			OwnedMessage::Binary(vec![42; 256]),
362			OwnedMessage::Binary(vec![42; 65535]),
363			OwnedMessage::Binary(vec![42; 65555]),
364			OwnedMessage::Ping("beep".to_string().into_bytes()),
365			OwnedMessage::Pong("boop".to_string().into_bytes()),
366			OwnedMessage::Close(None),
367			OwnedMessage::Close(Some(CloseData {
368				status_code: 64,
369				reason: "because".to_string(),
370			})),
371		];
372
373		for message in messages.into_iter() {
374			let masked_predicted = message.message_size(true);
375			let mut masked_buf = Vec::new();
376			message.serialize(&mut masked_buf, true).unwrap();
377			assert_eq!(masked_buf.len(), masked_predicted);
378
379			let unmasked_predicted = message.message_size(false);
380			let mut unmasked_buf = Vec::new();
381			message.serialize(&mut unmasked_buf, false).unwrap();
382			assert_eq!(unmasked_buf.len(), unmasked_predicted);
383		}
384	}
385
386	#[test]
387	fn cow_message_predicts_size() {
388		let messages = vec![
389			Message::binary(vec![1, 2, 3, 4]),
390			Message::binary(vec![42; 256]),
391			Message::binary(vec![42; 65535]),
392			Message::binary(vec![42; 65555]),
393			Message::text("nilbog".to_string()),
394			Message::ping("beep".to_string().into_bytes()),
395			Message::pong("boop".to_string().into_bytes()),
396			Message::close(),
397			Message::close_because(64, "because"),
398		];
399
400		for message in messages.iter() {
401			let masked_predicted = message.message_size(true);
402			let mut masked_buf = Vec::new();
403			message.serialize(&mut masked_buf, true).unwrap();
404			assert_eq!(masked_buf.len(), masked_predicted);
405
406			let unmasked_predicted = message.message_size(false);
407			let mut unmasked_buf = Vec::new();
408			message.serialize(&mut unmasked_buf, false).unwrap();
409			assert_eq!(unmasked_buf.len(), unmasked_predicted);
410		}
411	}
412
413	#[test]
414	fn message_codec_client_send_receive() {
415		let mut input = Vec::new();
416		Message::text("50 schmeckels")
417			.serialize(&mut input, false)
418			.unwrap();
419
420		let f = MessageCodec::new(Context::Client)
421			.framed(ReadWritePair(Cursor::new(input), Cursor::new(vec![])))
422			.into_future()
423			.map_err(|e| e.0)
424			.map(|(m, s)| {
425				assert_eq!(m, Some(OwnedMessage::Text("50 schmeckels".to_string())));
426				s
427			})
428			.and_then(|s| s.send(Message::text("ethan bradberry")))
429			.and_then(|s| {
430				let mut stream = s.into_parts().io;
431				stream.1.set_position(0);
432				println!("buffer: {:?}", stream.1);
433				MessageCodec::default(Context::Server)
434					.framed(ReadWritePair(stream.1, stream.0))
435					.into_future()
436					.map_err(|e| e.0)
437					.map(|(message, _)| {
438						assert_eq!(message, Some(Message::text("ethan bradberry").into()))
439					})
440			});
441
442		tokio::runtime::Builder::new()
443			.build()
444			.unwrap()
445			.block_on(f)
446			.unwrap();
447	}
448
449	#[test]
450	fn message_codec_server_send_receive() {
451		let mut runtime = tokio::runtime::Builder::new().build().unwrap();
452		let mut input = Vec::new();
453		Message::text("50 schmeckels")
454			.serialize(&mut input, true)
455			.unwrap();
456
457		let f = MessageCodec::new(Context::Server)
458			.framed(ReadWritePair(Cursor::new(input), Cursor::new(vec![])))
459			.into_future()
460			.map_err(|e| e.0)
461			.map(|(m, s)| {
462				assert_eq!(m, Some(OwnedMessage::Text("50 schmeckels".to_string())));
463				s
464			})
465			.and_then(|s| s.send(Message::text("ethan bradberry")))
466			.map(|s| {
467				let mut written = vec![];
468				Message::text("ethan bradberry")
469					.serialize(&mut written, false)
470					.unwrap();
471				assert_eq!(written, s.into_parts().io.1.into_inner());
472			});
473
474		runtime.block_on(f).unwrap();
475	}
476}