websocket_sans_io/
frame_decoding.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
use crate::{PayloadLength, Opcode, FrameInfo, masking};

use nonmax::NonMaxU8;

/// When large_frames` crate feature is on (by default), any bytes can be decoded, so no error possible.
#[cfg(feature="large_frames")]
pub type FrameDecoderError = core::convert::Infallible;

/// When large_frames` crate feature is off (like now), WebSocket frame headers denoting large frames
/// produce this error.
#[allow(missing_docs)]
#[cfg(not(feature="large_frames"))]
#[derive(Debug,PartialEq, Eq, PartialOrd, Ord,Hash,Clone, Copy)]
pub enum FrameDecoderError {
    ExceededFrameSize,
}

#[derive(Clone, Copy, Debug)]
struct SmallBufWithLen<const C: usize> {
    len: u8,
    data: [u8; C],
}

impl<const C: usize> SmallBufWithLen<C> {
    /// Take as much bytes as possible from the slice pointer, updating it in process
    fn slurp<'a, 'c>(&'c mut self, data: &'a mut [u8]) -> &'a mut [u8] {
        let offset = self.len as usize;
        let maxlen = (C - offset).min(data.len());
        self.data[offset..(offset+maxlen)].copy_from_slice(&data[..maxlen]);
        self.len += maxlen as u8;
        &mut data[maxlen..]
    }
    fn is_full(&self) -> bool {
        self.len as usize == C
    }
    const fn new() -> SmallBufWithLen<C> {
        SmallBufWithLen {
            len: 0,
            data: [0u8; C],
        }
    }
}

/// Represents what data is expected to come next
#[derive(Clone, Copy, Debug)]
enum FrameDecodingState {
    HeaderBeginning(SmallBufWithLen<2>),
    PayloadLength16(SmallBufWithLen<2>),
    #[cfg(feature="large_frames")]
    PayloadLength64(SmallBufWithLen<8>),
    MaskingKey(SmallBufWithLen<4>),
    PayloadData {
        phase: Option<NonMaxU8>,
        remaining: PayloadLength,
    },
}

impl Default for FrameDecodingState {
    fn default() -> Self {
        FrameDecodingState::HeaderBeginning(SmallBufWithLen::new())
    }
}

/// A low-level WebSocket frames decoder.
/// 
/// It is a push parser: you can add offer it bytes that come from a socket and it emites events.
/// 
/// You typically need two loops to process incoming data: outer loop reads chunks of data
/// from sockets, inner loop supplies this chunk to the decoder instance until no more events get emitted.
/// 
/// Example usage:
/// 
/// ```
#[doc=include_str!("../examples/decode_frame.rs")]
/// ```
/// 
/// Any sequence of bytes result in a some (sensial or not) [`WebsocketFrameEvent`]
/// sequence (exception: when `large_frames` crate feature is disabled).
/// 
/// You may want to validate it (e.g. using [`FrameInfo::is_reasonable`] method) before using.
#[derive(Clone, Copy, Debug, Default)]
pub struct WebsocketFrameDecoder {
    state: FrameDecodingState,
    mask: [u8; 4],
    basic_header: [u8; 2],
    payload_length: PayloadLength,
    original_opcode: Opcode,
}

/// Return value of [`WebsocketFrameDecoder::add_data`] call.
#[derive(Debug,Clone)]
pub struct WebsocketFrameDecoderAddDataResult {
    /// Indicates how many bytes were consumed and should not be supplied again to
    /// the subsequent invocation of [`WebsocketFrameDecoder::add_data`].
    /// 
    /// When `add_data` procudes [`WebsocketFrameEvent::PayloadChunk`], it also indicated how many
    /// of the bytes in the buffer (starting from 0) should be used as a part of payload.
    pub consumed_bytes: usize,
    /// Emitted event, if any.
    pub event: Option<WebsocketFrameEvent>,
}

#[allow(missing_docs)]
/// Information that [`WebsocketFrameDecoder`] gives in return to bytes being fed to it.
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum WebsocketFrameEvent {
    /// Indicates a frame is started.
    /// 
    /// `original_opcode` is the same as `frame_info.opcode`, except for
    /// [`Opcode::Continuation`] frames, for which it should refer to
    /// initial frame in sequence (i.e. [`Opcode::Text`] or [`Opcode::Binary`])
    Start{frame_info: FrameInfo, original_opcode: Opcode},

    /// Bytes which were supplied to [`WebsocketFrameDecoder::add_data`] are payload bytes,
    /// transformed for usage as a part of payload.
    /// 
    /// You should use [`WebsocketFrameDecoderAddDataResult::consumed_bytes`] to get actual
    /// buffer to be handled as content coming from the WebSocket.
    /// 
    /// Mind the `original_opcode` to avoid mixing content of control frames and data frames.
    PayloadChunk{ original_opcode: Opcode},

    /// Indicates that all `PayloadChunk`s for the given frame are delivered and the frame
    /// is ended.
    /// 
    /// You can watch for `frame_info.fin` together with checking `original_opcode` to know
    /// wnen WebSocket **message** (not just a frame) ends.
    /// 
    /// `frame_info` is the same as in [`WebsocketFrameEvent::Start`]'s `frame_info`.
    End{frame_info: FrameInfo, original_opcode: Opcode},
}

impl WebsocketFrameDecoder {
    fn get_opcode(&self) -> Opcode {
        use Opcode::*;
        match self.basic_header[0] & 0xF {
            0 => Continuation,
            1 => Text,
            2 => Binary,
            3 => ReservedData3,
            4 => ReservedData4,
            5 => ReservedData5,
            6 => ReservedData6,
            7 => ReservedData7,
            8 => ConnectionClose,
            9 => Ping,
            0xA => Pong,
            0xB => ReservedControlB,
            0xC => ReservedControlC,
            0xD => ReservedControlD,
            0xE => ReservedControlE,
            0xF => ReservedControlF,
            _ => unreachable!(),
        }
    }

    /// Get frame info and original opcode
    fn get_frame_info(&self, masked: bool) -> (FrameInfo, Opcode) {
        let fi = FrameInfo {
            opcode: self.get_opcode(),
            payload_length: self.payload_length,
            mask: if masked { Some(self.mask) } else { None },
            fin: self.basic_header[0] & 0x80 == 0x80,
            reserved: (self.basic_header[0] & 0x70) >> 4,
        };
        let mut original_opcode = fi.opcode;
        if original_opcode==Opcode::Continuation {
            original_opcode = self.original_opcode;
        }
        (fi, original_opcode)
    }

    /// Add some bytes to the decoder and return events, if any.
    /// 
    /// Call this function again if any of the following conditions are met:
    ///
    /// * When new incoming data is available on the socket
    /// * When previous invocation of `add_data` returned nonzero [`WebsocketFrameDecoderAddDataResult::consumed_bytes`].
    /// * When previous invocation of `add_data` returned non-`None` [`WebsocketFrameDecoderAddDataResult::event`].
    /// 
    /// You may need call it with empty `data` buffer to get some final [`WebsocketFrameEvent::End`].
    /// 
    /// Input buffer needs to be mutable because it is also used to transform (unmask)
    /// payload content chunks in-place.
    pub fn add_data<'a, 'b>(
        &'a mut self,
        mut data: &'b mut [u8],
    ) -> Result<WebsocketFrameDecoderAddDataResult, FrameDecoderError> {
        let original_data_len = data.len();
        loop {
            macro_rules! return_dummy {
                () => {
                    return Ok(WebsocketFrameDecoderAddDataResult {
                        consumed_bytes: original_data_len - data.len(),
                        event: None,
                    });
                };
            }
            if data.len() == 0 && ! matches!(self.state, FrameDecodingState::PayloadData{remaining: 0, ..}) {
                return_dummy!();
            }
            macro_rules! try_to_fill_buffer_or_return {
                ($v:ident) => {
                    data = $v.slurp(data);
                    if !$v.is_full() {
                        assert!(data.is_empty());
                        return_dummy!();
                    }
                    let $v = $v.data;
                };
            }
            let mut length_is_ready = false;
            match self.state {
                FrameDecodingState::HeaderBeginning(ref mut v) => {
                    try_to_fill_buffer_or_return!(v);
                    self.basic_header = v;
                    let opcode = self.get_opcode();
                    if opcode.is_data() && opcode != Opcode::Continuation {
                        self.original_opcode = opcode;
                    }
                    match self.basic_header[1] & 0x7F {
                        0x7E => {
                            self.state = FrameDecodingState::PayloadLength16(SmallBufWithLen::new())
                        }
                        #[cfg(feature="large_frames")]
                        0x7F => {
                            self.state = FrameDecodingState::PayloadLength64(SmallBufWithLen::new())
                        }
                        #[cfg(not(feature="large_frames"))] 0x7F => {
                            return Err(FrameDecoderError::ExceededFrameSize);
                        }
                        x => {
                            self.payload_length = x.into();
                            length_is_ready = true;
                        }
                    };
                }
                FrameDecodingState::PayloadLength16(ref mut v) => {
                    try_to_fill_buffer_or_return!(v);
                    self.payload_length = u16::from_be_bytes(v).into();
                    length_is_ready = true;
                }
                #[cfg(feature="large_frames")]
                FrameDecodingState::PayloadLength64(ref mut v) => {
                    try_to_fill_buffer_or_return!(v);
                    self.payload_length = u64::from_be_bytes(v);
                    length_is_ready = true;
                }
                FrameDecodingState::MaskingKey(ref mut v) => {
                    try_to_fill_buffer_or_return!(v);
                    self.mask = v;
                    self.state = FrameDecodingState::PayloadData {
                        phase: Some(NonMaxU8::default()),
                        remaining: self.payload_length,
                    };
                    let (frame_info, original_opcode) = self.get_frame_info(true);
                    return Ok(WebsocketFrameDecoderAddDataResult {
                        consumed_bytes: original_data_len - data.len(),
                        event: Some(WebsocketFrameEvent::Start{frame_info, original_opcode}),
                    });
                }
                FrameDecodingState::PayloadData {
                    phase,
                    remaining: 0,
                } => {
                    self.state = FrameDecodingState::HeaderBeginning(SmallBufWithLen::new());
                    let (fi, original_opcode) = self.get_frame_info(phase.is_some());
                    if fi.opcode.is_data() && fi.fin {
                        self.original_opcode = Opcode::Continuation;
                    }
                    return Ok(WebsocketFrameDecoderAddDataResult {
                        consumed_bytes: original_data_len - data.len(),
                        event: Some(WebsocketFrameEvent::End{frame_info: fi, original_opcode}
                            ),
                    });
                }
                FrameDecodingState::PayloadData {
                    ref mut phase,
                    ref mut remaining,
                } => {
                    let start_offset = original_data_len - data.len();
                    let mut max_len = data.len();
                    if let Ok(remaining_usize) = usize::try_from(*remaining) {
                        max_len = max_len.min(remaining_usize);
                    }
                    let (payload_chunk, _rest) = data.split_at_mut(max_len);

                    if let Some(phase) = phase {
                        let mut ph = phase.get();
                        masking::apply_mask(self.mask, payload_chunk, ph);
                        ph += payload_chunk.len() as u8;
                        *phase = NonMaxU8::new(ph & 0x03).unwrap();
                    }

                    *remaining -= max_len as PayloadLength;
                    let mut original_opcode = self.get_opcode();
                    if original_opcode == Opcode::Continuation {
                        original_opcode = self.original_opcode;
                    }
                    assert_eq!(start_offset, 0);
                    return Ok(WebsocketFrameDecoderAddDataResult {
                        consumed_bytes: max_len,
                        event: Some(WebsocketFrameEvent::PayloadChunk{original_opcode}),
                    });
                }
            }
            if length_is_ready {
                if self.basic_header[1] & 0x80 == 0x80 {
                    self.state = FrameDecodingState::MaskingKey(SmallBufWithLen::new());
                } else {
                    self.state = FrameDecodingState::PayloadData {
                        phase: None,
                        remaining: self.payload_length,
                    };
                    let (frame_info, original_opcode) = self.get_frame_info(false);
                    return Ok(WebsocketFrameDecoderAddDataResult {
                        consumed_bytes: original_data_len - data.len(),
                        event: Some(WebsocketFrameEvent::Start{frame_info, original_opcode}),
                    });
                }
            }
        }
    }

    /// There is no incomplete WebSocket frame at this moment and EOF is valid here.
    ///
    /// This method is not related to [`Opcode::ConnectionClose`] in any way.
    #[inline]
    pub fn eof_valid(&self) -> bool {
        matches!(self.state, FrameDecodingState::HeaderBeginning(..))
    }

    /// Create new instance.
    #[inline]
    pub const fn new() -> Self {
        WebsocketFrameDecoder {
            state: FrameDecodingState::HeaderBeginning(SmallBufWithLen::new()),
            mask: [0; 4],
            basic_header: [0; 2],
            payload_length: 0,
            original_opcode: Opcode::Continuation,
        }
    }
}