1use crate::result::{WebSocketError, WebSocketResult};
4use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
5use std::io::{Read, Write};
6
7bitflags! {
8 pub struct DataFrameFlags: u8 {
10 const FIN = 0x80;
12 const RSV1 = 0x40;
14 const RSV2 = 0x20;
16 const RSV3 = 0x10;
18 }
19}
20
21#[derive(Debug, Clone, Copy, PartialEq)]
23pub struct DataFrameHeader {
24 pub flags: DataFrameFlags,
26 pub opcode: u8,
28 pub mask: Option<[u8; 4]>,
30 pub len: u64,
32}
33
34pub fn write_header(writer: &mut dyn Write, header: DataFrameHeader) -> WebSocketResult<()> {
36 if header.opcode > 0xF {
37 return Err(WebSocketError::DataFrameError("Invalid data frame opcode"));
38 }
39 if header.opcode >= 8 && header.len >= 126 {
40 return Err(WebSocketError::DataFrameError(
41 "Control frame length too long",
42 ));
43 }
44
45 writer.write_u8((header.flags.bits) | header.opcode)?;
47
48 writer.write_u8(
49 if header.mask.is_some() { 0x80 } else { 0x00 } |
51 if header.len <= 125 { header.len as u8 }
53 else if header.len <= 65535 { 126 }
54 else { 127 },
55 )?;
56
57 if header.len >= 126 && header.len <= 65535 {
59 writer.write_u16::<BigEndian>(header.len as u16)?;
60 } else if header.len > 65535 {
61 writer.write_u64::<BigEndian>(header.len)?;
62 }
63
64 if let Some(mask) = header.mask {
66 writer.write_all(&mask)?
67 }
68
69 Ok(())
70}
71
72pub fn read_header<R>(reader: &mut R) -> WebSocketResult<DataFrameHeader>
74where
75 R: Read,
76{
77 let byte0 = reader.read_u8()?;
78 let byte1 = reader.read_u8()?;
79
80 let flags = DataFrameFlags::from_bits_truncate(byte0);
81 let opcode = byte0 & 0x0F;
82
83 let len = match byte1 & 0x7F {
84 0..=125 => u64::from(byte1 & 0x7F),
85 126 => {
86 let len = u64::from(reader.read_u16::<BigEndian>()?);
87 if len <= 125 {
88 return Err(WebSocketError::DataFrameError("Invalid data frame length"));
89 }
90 len
91 }
92 127 => {
93 let len = reader.read_u64::<BigEndian>()?;
94 if len <= 65535 {
95 return Err(WebSocketError::DataFrameError("Invalid data frame length"));
96 }
97 len
98 }
99 _ => unreachable!(),
100 };
101
102 if opcode >= 8 {
103 if len >= 126 {
104 return Err(WebSocketError::DataFrameError(
105 "Control frame length too long",
106 ));
107 }
108 if !flags.contains(DataFrameFlags::FIN) {
109 return Err(WebSocketError::ProtocolError(
110 "Illegal fragmented control frame",
111 ));
112 }
113 }
114
115 let mask = if byte1 & 0x80 == 0x80 {
116 Some([
117 reader.read_u8()?,
118 reader.read_u8()?,
119 reader.read_u8()?,
120 reader.read_u8()?,
121 ])
122 } else {
123 None
124 };
125
126 Ok(DataFrameHeader {
127 flags,
128 opcode,
129 mask,
130 len,
131 })
132}
133
134#[cfg(all(feature = "nightly", test))]
135mod tests {
136 use super::*;
137 use test;
138
139 #[test]
140 fn test_read_header_simple() {
141 let header = [0x81, 0x2B];
142 let obtained = read_header(&mut &header[..]).unwrap();
143 let expected = DataFrameHeader {
144 flags: DataFrameFlags::FIN,
145 opcode: 1,
146 mask: None,
147 len: 43,
148 };
149 assert_eq!(obtained, expected);
150 }
151
152 #[test]
153 fn test_write_header_simple() {
154 let header = DataFrameHeader {
155 flags: DataFrameFlags::FIN,
156 opcode: 1,
157 mask: None,
158 len: 43,
159 };
160 let expected = [0x81, 0x2B];
161 let mut obtained = Vec::with_capacity(2);
162 write_header(&mut obtained, header).unwrap();
163
164 assert_eq!(&obtained[..], &expected[..]);
165 }
166
167 #[test]
168 fn test_read_header_complex() {
169 let header = [0x42, 0xFE, 0x02, 0x00, 0x02, 0x04, 0x08, 0x10];
170 let obtained = read_header(&mut &header[..]).unwrap();
171 let expected = DataFrameHeader {
172 flags: DataFrameFlags::RSV1,
173 opcode: 2,
174 mask: Some([2, 4, 8, 16]),
175 len: 512,
176 };
177 assert_eq!(obtained, expected);
178 }
179
180 #[test]
181 fn test_write_header_complex() {
182 let header = DataFrameHeader {
183 flags: DataFrameFlags::RSV1,
184 opcode: 2,
185 mask: Some([2, 4, 8, 16]),
186 len: 512,
187 };
188 let expected = [0x42, 0xFE, 0x02, 0x00, 0x02, 0x04, 0x08, 0x10];
189 let mut obtained = Vec::with_capacity(8);
190 write_header(&mut obtained, header).unwrap();
191
192 assert_eq!(&obtained[..], &expected[..]);
193 }
194
195 #[bench]
196 fn bench_read_header(b: &mut test::Bencher) {
197 let header = vec![0x42u8, 0xFE, 0x02, 0x00, 0x02, 0x04, 0x08, 0x10];
198 b.iter(|| {
199 read_header(&mut &header[..]).unwrap();
200 });
201 }
202
203 #[bench]
204 fn bench_write_header(b: &mut test::Bencher) {
205 let header = DataFrameHeader {
206 flags: DataFrameFlags::RSV1,
207 opcode: 2,
208 mask: Some([2, 4, 8, 16]),
209 len: 512,
210 };
211 let mut writer = Vec::with_capacity(8);
212 b.iter(|| {
213 write_header(&mut writer, header).unwrap();
214 });
215 }
216}