websocket_base/ws/util/
header.rs

1//! Utility functions for reading and writing data frame headers.
2
3use crate::result::{WebSocketError, WebSocketResult};
4use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
5use std::io::{Read, Write};
6
7bitflags! {
8	/// Flags relevant to a WebSocket data frame.
9	pub struct DataFrameFlags: u8 {
10		/// Marks this dataframe as the last dataframe
11		const FIN = 0x80;
12		/// First reserved bit
13		const RSV1 = 0x40;
14		/// Second reserved bit
15		const RSV2 = 0x20;
16		/// Third reserved bit
17		const RSV3 = 0x10;
18	}
19}
20
21/// Represents a data frame header.
22#[derive(Debug, Clone, Copy, PartialEq)]
23pub struct DataFrameHeader {
24	/// The bit flags for the first byte of the header.
25	pub flags: DataFrameFlags,
26	/// The opcode of the header - must be <= 16.
27	pub opcode: u8,
28	/// The masking key, if any.
29	pub mask: Option<[u8; 4]>,
30	/// The length of the payload.
31	pub len: u64,
32}
33
34/// Writes a data frame header.
35pub fn write_header(writer: &mut dyn Write, header: DataFrameHeader) -> WebSocketResult<()> {
36	if header.opcode > 0xF {
37		return Err(WebSocketError::DataFrameError("Invalid data frame opcode"));
38	}
39	if header.opcode >= 8 && header.len >= 126 {
40		return Err(WebSocketError::DataFrameError(
41			"Control frame length too long",
42		));
43	}
44
45	// Write 'FIN', 'RSV1', 'RSV2', 'RSV3' and 'opcode'
46	writer.write_u8((header.flags.bits) | header.opcode)?;
47
48	writer.write_u8(
49		// Write the 'MASK'
50		if header.mask.is_some() { 0x80 } else { 0x00 } |
51		// Write the 'Payload len'
52		if header.len <= 125 { header.len as u8 }
53		else if header.len <= 65535 { 126 }
54		else { 127 },
55	)?;
56
57	// Write 'Extended payload length'
58	if header.len >= 126 && header.len <= 65535 {
59		writer.write_u16::<BigEndian>(header.len as u16)?;
60	} else if header.len > 65535 {
61		writer.write_u64::<BigEndian>(header.len)?;
62	}
63
64	// Write 'Masking-key'
65	if let Some(mask) = header.mask {
66		writer.write_all(&mask)?
67	}
68
69	Ok(())
70}
71
72/// Reads a data frame header.
73pub fn read_header<R>(reader: &mut R) -> WebSocketResult<DataFrameHeader>
74where
75	R: Read,
76{
77	let byte0 = reader.read_u8()?;
78	let byte1 = reader.read_u8()?;
79
80	let flags = DataFrameFlags::from_bits_truncate(byte0);
81	let opcode = byte0 & 0x0F;
82
83	let len = match byte1 & 0x7F {
84		0..=125 => u64::from(byte1 & 0x7F),
85		126 => {
86			let len = u64::from(reader.read_u16::<BigEndian>()?);
87			if len <= 125 {
88				return Err(WebSocketError::DataFrameError("Invalid data frame length"));
89			}
90			len
91		}
92		127 => {
93			let len = reader.read_u64::<BigEndian>()?;
94			if len <= 65535 {
95				return Err(WebSocketError::DataFrameError("Invalid data frame length"));
96			}
97			len
98		}
99		_ => unreachable!(),
100	};
101
102	if opcode >= 8 {
103		if len >= 126 {
104			return Err(WebSocketError::DataFrameError(
105				"Control frame length too long",
106			));
107		}
108		if !flags.contains(DataFrameFlags::FIN) {
109			return Err(WebSocketError::ProtocolError(
110				"Illegal fragmented control frame",
111			));
112		}
113	}
114
115	let mask = if byte1 & 0x80 == 0x80 {
116		Some([
117			reader.read_u8()?,
118			reader.read_u8()?,
119			reader.read_u8()?,
120			reader.read_u8()?,
121		])
122	} else {
123		None
124	};
125
126	Ok(DataFrameHeader {
127		flags,
128		opcode,
129		mask,
130		len,
131	})
132}
133
134#[cfg(all(feature = "nightly", test))]
135mod tests {
136	use super::*;
137	use test;
138
139	#[test]
140	fn test_read_header_simple() {
141		let header = [0x81, 0x2B];
142		let obtained = read_header(&mut &header[..]).unwrap();
143		let expected = DataFrameHeader {
144			flags: DataFrameFlags::FIN,
145			opcode: 1,
146			mask: None,
147			len: 43,
148		};
149		assert_eq!(obtained, expected);
150	}
151
152	#[test]
153	fn test_write_header_simple() {
154		let header = DataFrameHeader {
155			flags: DataFrameFlags::FIN,
156			opcode: 1,
157			mask: None,
158			len: 43,
159		};
160		let expected = [0x81, 0x2B];
161		let mut obtained = Vec::with_capacity(2);
162		write_header(&mut obtained, header).unwrap();
163
164		assert_eq!(&obtained[..], &expected[..]);
165	}
166
167	#[test]
168	fn test_read_header_complex() {
169		let header = [0x42, 0xFE, 0x02, 0x00, 0x02, 0x04, 0x08, 0x10];
170		let obtained = read_header(&mut &header[..]).unwrap();
171		let expected = DataFrameHeader {
172			flags: DataFrameFlags::RSV1,
173			opcode: 2,
174			mask: Some([2, 4, 8, 16]),
175			len: 512,
176		};
177		assert_eq!(obtained, expected);
178	}
179
180	#[test]
181	fn test_write_header_complex() {
182		let header = DataFrameHeader {
183			flags: DataFrameFlags::RSV1,
184			opcode: 2,
185			mask: Some([2, 4, 8, 16]),
186			len: 512,
187		};
188		let expected = [0x42, 0xFE, 0x02, 0x00, 0x02, 0x04, 0x08, 0x10];
189		let mut obtained = Vec::with_capacity(8);
190		write_header(&mut obtained, header).unwrap();
191
192		assert_eq!(&obtained[..], &expected[..]);
193	}
194
195	#[bench]
196	fn bench_read_header(b: &mut test::Bencher) {
197		let header = vec![0x42u8, 0xFE, 0x02, 0x00, 0x02, 0x04, 0x08, 0x10];
198		b.iter(|| {
199			read_header(&mut &header[..]).unwrap();
200		});
201	}
202
203	#[bench]
204	fn bench_write_header(b: &mut test::Bencher) {
205		let header = DataFrameHeader {
206			flags: DataFrameFlags::RSV1,
207			opcode: 2,
208			mask: Some([2, 4, 8, 16]),
209			len: 512,
210		};
211		let mut writer = Vec::with_capacity(8);
212		b.iter(|| {
213			write_header(&mut writer, header).unwrap();
214		});
215	}
216}