use crate::result::{WebSocketError, WebSocketResult};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use std::io::{Read, Write};
bitflags! {
pub struct DataFrameFlags: u8 {
const FIN = 0x80;
const RSV1 = 0x40;
const RSV2 = 0x20;
const RSV3 = 0x10;
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct DataFrameHeader {
pub flags: DataFrameFlags,
pub opcode: u8,
pub mask: Option<[u8; 4]>,
pub len: u64,
}
pub fn write_header(writer: &mut dyn Write, header: DataFrameHeader) -> WebSocketResult<()> {
if header.opcode > 0xF {
return Err(WebSocketError::DataFrameError("Invalid data frame opcode"));
}
if header.opcode >= 8 && header.len >= 126 {
return Err(WebSocketError::DataFrameError(
"Control frame length too long",
));
}
writer.write_u8((header.flags.bits) | header.opcode)?;
writer.write_u8(
if header.mask.is_some() { 0x80 } else { 0x00 } |
if header.len <= 125 { header.len as u8 }
else if header.len <= 65535 { 126 }
else { 127 },
)?;
if header.len >= 126 && header.len <= 65535 {
writer.write_u16::<BigEndian>(header.len as u16)?;
} else if header.len > 65535 {
writer.write_u64::<BigEndian>(header.len)?;
}
if let Some(mask) = header.mask {
writer.write_all(&mask)?
}
Ok(())
}
pub fn read_header<R>(reader: &mut R) -> WebSocketResult<DataFrameHeader>
where
R: Read,
{
let byte0 = reader.read_u8()?;
let byte1 = reader.read_u8()?;
let flags = DataFrameFlags::from_bits_truncate(byte0);
let opcode = byte0 & 0x0F;
let len = match byte1 & 0x7F {
0..=125 => u64::from(byte1 & 0x7F),
126 => {
let len = u64::from(reader.read_u16::<BigEndian>()?);
if len <= 125 {
return Err(WebSocketError::DataFrameError("Invalid data frame length"));
}
len
}
127 => {
let len = reader.read_u64::<BigEndian>()?;
if len <= 65535 {
return Err(WebSocketError::DataFrameError("Invalid data frame length"));
}
len
}
_ => unreachable!(),
};
if opcode >= 8 {
if len >= 126 {
return Err(WebSocketError::DataFrameError(
"Control frame length too long",
));
}
if !flags.contains(DataFrameFlags::FIN) {
return Err(WebSocketError::ProtocolError(
"Illegal fragmented control frame",
));
}
}
let mask = if byte1 & 0x80 == 0x80 {
Some([
reader.read_u8()?,
reader.read_u8()?,
reader.read_u8()?,
reader.read_u8()?,
])
} else {
None
};
Ok(DataFrameHeader {
flags,
opcode,
mask,
len,
})
}
#[cfg(all(feature = "nightly", test))]
mod tests {
use super::*;
use test;
#[test]
fn test_read_header_simple() {
let header = [0x81, 0x2B];
let obtained = read_header(&mut &header[..]).unwrap();
let expected = DataFrameHeader {
flags: DataFrameFlags::FIN,
opcode: 1,
mask: None,
len: 43,
};
assert_eq!(obtained, expected);
}
#[test]
fn test_write_header_simple() {
let header = DataFrameHeader {
flags: DataFrameFlags::FIN,
opcode: 1,
mask: None,
len: 43,
};
let expected = [0x81, 0x2B];
let mut obtained = Vec::with_capacity(2);
write_header(&mut obtained, header).unwrap();
assert_eq!(&obtained[..], &expected[..]);
}
#[test]
fn test_read_header_complex() {
let header = [0x42, 0xFE, 0x02, 0x00, 0x02, 0x04, 0x08, 0x10];
let obtained = read_header(&mut &header[..]).unwrap();
let expected = DataFrameHeader {
flags: DataFrameFlags::RSV1,
opcode: 2,
mask: Some([2, 4, 8, 16]),
len: 512,
};
assert_eq!(obtained, expected);
}
#[test]
fn test_write_header_complex() {
let header = DataFrameHeader {
flags: DataFrameFlags::RSV1,
opcode: 2,
mask: Some([2, 4, 8, 16]),
len: 512,
};
let expected = [0x42, 0xFE, 0x02, 0x00, 0x02, 0x04, 0x08, 0x10];
let mut obtained = Vec::with_capacity(8);
write_header(&mut obtained, header).unwrap();
assert_eq!(&obtained[..], &expected[..]);
}
#[bench]
fn bench_read_header(b: &mut test::Bencher) {
let header = vec![0x42u8, 0xFE, 0x02, 0x00, 0x02, 0x04, 0x08, 0x10];
b.iter(|| {
read_header(&mut &header[..]).unwrap();
});
}
#[bench]
fn bench_write_header(b: &mut test::Bencher) {
let header = DataFrameHeader {
flags: DataFrameFlags::RSV1,
opcode: 2,
mask: Some([2, 4, 8, 16]),
len: 512,
};
let mut writer = Vec::with_capacity(8);
b.iter(|| {
write_header(&mut writer, header).unwrap();
});
}
}