x11rb_protocol/
packet_reader.rs

1//! Collects X11 data into "packets" to be parsed by a display.
2
3use core::fmt;
4use core::mem::replace;
5
6use alloc::{vec, vec::Vec};
7
8/// Minimal length of an X11 packet.
9const MINIMAL_PACKET_LENGTH: usize = 32;
10
11/// A wrapper around a buffer used to read X11 packets.
12pub struct PacketReader {
13    /// A partially-read packet.
14    pending_packet: Vec<u8>,
15
16    /// The point at which the packet is already read.
17    already_read: usize,
18}
19
20impl fmt::Debug for PacketReader {
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        f.debug_tuple("PacketReader")
23            .field(&format_args!(
24                "{}/{}",
25                self.already_read,
26                self.pending_packet.len()
27            ))
28            .finish()
29    }
30}
31
32impl Default for PacketReader {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38impl PacketReader {
39    /// Create a new, empty `PacketReader`.
40    ///
41    /// # Example
42    ///
43    /// ```rust
44    /// # use x11rb_protocol::packet_reader::PacketReader;
45    /// let reader = PacketReader::new();
46    /// ```
47    pub fn new() -> Self {
48        Self {
49            pending_packet: vec![0; MINIMAL_PACKET_LENGTH],
50            already_read: 0,
51        }
52    }
53
54    /// Get the buffer that the reader should fill with data.
55    ///
56    /// # Example
57    ///
58    /// ```rust
59    /// # use x11rb_protocol::packet_reader::PacketReader;
60    /// # use x11rb_protocol::protocol::xproto::{GetInputFocusReply, InputFocus, Window};
61    /// let mut reader = PacketReader::new();
62    /// let buffer: [u8; 32] = read_in_buffer();
63    ///
64    /// reader.buffer().copy_from_slice(&buffer);
65    ///
66    /// # fn read_in_buffer() -> [u8; 32] { [0; 32] }
67    /// ```
68    pub fn buffer(&mut self) -> &mut [u8] {
69        &mut self.pending_packet[self.already_read..]
70    }
71
72    /// The remaining capacity that needs to be filled.
73    pub fn remaining_capacity(&self) -> usize {
74        self.pending_packet.len() - self.already_read
75    }
76
77    /// Advance this buffer by the given amount.
78    ///
79    /// This will return the packet that was read, if enough bytes were read in order
80    /// to form a complete packet.
81    pub fn advance(&mut self, amount: usize) -> Option<Vec<u8>> {
82        self.already_read += amount;
83        debug_assert!(self.already_read <= self.pending_packet.len());
84
85        if self.already_read == MINIMAL_PACKET_LENGTH {
86            // we've read in the minimal packet, compute the amount of data we need to read
87            // to form a complete packet
88            let extra_length = extra_length(&self.pending_packet);
89
90            // tell if we need to read more
91            if extra_length > 0 {
92                let total_length = MINIMAL_PACKET_LENGTH + extra_length;
93                self.pending_packet.resize(total_length, 0);
94                return None;
95            }
96        } else if self.already_read != self.pending_packet.len() {
97            // we haven't read the full packet yet, return
98            return None;
99        }
100
101        // we've read in the full packet, return it
102        self.already_read = 0;
103        Some(replace(
104            &mut self.pending_packet,
105            vec![0; MINIMAL_PACKET_LENGTH],
106        ))
107    }
108}
109
110/// Compute the length of the data we need to read, beyond the `MINIMAL_PACKET_LENGTH`.
111fn extra_length(buffer: &[u8]) -> usize {
112    use crate::protocol::xproto::GE_GENERIC_EVENT;
113    const REPLY: u8 = 1;
114
115    let response_type = buffer[0];
116
117    if response_type == REPLY || response_type & 0x7f == GE_GENERIC_EVENT {
118        let length_field = buffer[4..8].try_into().unwrap();
119        let length_field = u32::from_ne_bytes(length_field) as usize;
120        4 * length_field
121    } else {
122        // Fixed size packet: error or event that is not GE_GENERIC_EVENT
123        0
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::PacketReader;
130    use alloc::{vec, vec::Vec};
131
132    fn test_packets(packets: Vec<Vec<u8>>) {
133        // Combine all packet data into one big chunk and test that the packet reader splits things
134        let mut all_data = packets.iter().flatten().copied().collect::<Vec<u8>>();
135
136        let mut reader = PacketReader::default();
137        for (i, packet) in packets.into_iter().enumerate() {
138            std::println!("Checking packet {i}");
139            loop {
140                let buffer = reader.buffer();
141                let amount = std::cmp::min(buffer.len(), all_data.len());
142                buffer.copy_from_slice(&all_data[..amount]);
143                let _ = all_data.drain(..amount);
144
145                if let Some(read_packet) = reader.advance(amount) {
146                    assert_eq!(read_packet, packet);
147                    break;
148                }
149            }
150        }
151    }
152
153    fn make_reply_with_length(len: usize) -> Vec<u8> {
154        let mut packet = vec![0; len];
155        let len = (len - 32) / 4;
156
157        // write "len" to bytes 4..8 in the packet
158        let len_bytes = (len as u32).to_ne_bytes();
159        packet[4..8].copy_from_slice(&len_bytes);
160        packet[0] = 1;
161
162        packet
163    }
164
165    #[test]
166    fn fixed_size_packet() {
167        // packet with a fixed size
168        let packet = vec![0; 32];
169        test_packets(vec![packet]);
170    }
171
172    #[test]
173    fn variable_size_packet() {
174        // packet with a variable size
175        let packet = make_reply_with_length(1200);
176        test_packets(vec![packet]);
177    }
178
179    #[test]
180    fn test_many_fixed_size_packets() {
181        let mut packets = vec![];
182        for _ in 0..100 {
183            packets.push(vec![0; 32]);
184        }
185        test_packets(packets);
186    }
187
188    #[test]
189    fn test_many_variable_size_packets() {
190        let mut packets = vec![];
191        for i in 0..100 {
192            // for maximum variation, increase packet size in a curved parabola
193            // defined by -1/25 (x - 50)^2 + 100
194            let variation = ((i - 50) * (i - 50)) as f32;
195            let variation = -1.0 / 25.0 * variation + 100.0;
196            let variation = variation as usize;
197            // round to a multiple of 4
198            let variation = variation / 4 * 4;
199
200            let mut len = 1200 + variation;
201            let mut packet = vec![0; len];
202            assert_eq!(0, len % 4);
203            len = (len - 32) / 4;
204
205            // write "len" to bytes 4..8 in the packet
206            let len_bytes = (len as u32).to_ne_bytes();
207            packet[4..8].copy_from_slice(&len_bytes);
208            packet[0] = 1;
209
210            packets.push(packet);
211        }
212        test_packets(packets);
213    }
214
215    #[test]
216    fn test_many_size_packets_mixed() {
217        let mut packets = vec![];
218        for i in 0..100 {
219            // on odds, do a varsize packet
220            let mut len = if i & 1 == 1 {
221                // for maximum variation, increase packet size in a curved parabola
222                // defined by -1/25 (x - 50)^2 + 100
223                let variation = ((i - 50) * (i - 50)) as f32;
224                let variation = -1.0 / 25.0 * variation + 100.0;
225                let variation = variation as usize;
226                // round to a multiple of 4
227                let variation = variation / 4 * 4;
228
229                1200 + variation
230            } else {
231                32
232            };
233            assert_eq!(0, len % 4);
234            let mut packet = vec![0; len];
235            len = (len - 32) / 4;
236
237            // write "len" to bytes 4..8 in the packet
238            let len_bytes = (len as u32).to_ne_bytes();
239            packet[4..8].copy_from_slice(&len_bytes);
240            packet[0] = 1;
241
242            packets.push(packet);
243        }
244        test_packets(packets);
245    }
246
247    #[test]
248    fn test_debug_fixed_size_packet() {
249        // The debug output includes the length of the packet of the packet and how much was
250        // already read
251        let mut reader = PacketReader::new();
252        assert_eq!(std::format!("{:?}", reader), "PacketReader(0/32)");
253
254        let _ = reader.advance(15);
255        assert_eq!(std::format!("{:?}", reader), "PacketReader(15/32)");
256
257        let _ = reader.advance(15);
258        assert_eq!(std::format!("{:?}", reader), "PacketReader(30/32)");
259
260        let _ = reader.advance(2);
261        assert_eq!(std::format!("{:?}", reader), "PacketReader(0/32)");
262    }
263
264    #[test]
265    fn test_debug_variable_size_packet() {
266        let packet = make_reply_with_length(1200);
267        let mut reader = PacketReader::new();
268
269        let first_len = 32;
270        let second_len = 3;
271
272        reader.buffer()[..first_len].copy_from_slice(&packet[..first_len]);
273        let _ = reader.advance(first_len);
274
275        reader.buffer()[..second_len].copy_from_slice(&packet[..second_len]);
276        let _ = reader.advance(second_len);
277
278        assert_eq!(std::format!("{:?}", reader), "PacketReader(35/1200)");
279    }
280}