async_codec_lite/codec/
limit.rs

1use super::{Decoder, Encoder};
2use bytes::{Buf, BytesMut};
3
4#[allow(missing_docs)]
5pub trait SkipAheadHandler: Sized + std::fmt::Debug {
6    fn continue_skipping(self, src: &[u8]) -> anyhow::Result<(usize, Option<Self>)>;
7}
8
9impl SkipAheadHandler for () {
10    fn continue_skipping(self, _: &[u8]) -> anyhow::Result<(usize, Option<Self>)> {
11        Ok((0, None))
12    }
13}
14
15#[allow(missing_docs)]
16pub trait DecoderWithSkipAhead: Decoder {
17    type Handler: SkipAheadHandler;
18
19    fn prepare_skip_ahead(&mut self, src: &mut BytesMut) -> Self::Handler;
20}
21
22#[derive(Debug)]
23pub struct LimitCodec<C: DecoderWithSkipAhead> {
24    inner: C,
25    max_frame_size: usize,
26    skip_ahead_state: Option<<C as DecoderWithSkipAhead>::Handler>,
27    decoder_defunct: bool,
28}
29
30impl<C> LimitCodec<C>
31where
32    C: DecoderWithSkipAhead,
33{
34    #[allow(missing_docs)]
35    pub fn new(inner: C, max_frame_size: usize) -> Self {
36        Self {
37            inner,
38            max_frame_size,
39            skip_ahead_state: None,
40            decoder_defunct: false,
41        }
42    }
43}
44
45#[derive(Debug, thiserror::Error)]
46pub enum LimitError<E: std::error::Error + 'static> {
47    #[error("frame size limit exceeded (detected at {0} bytes)")]
48    LimitExceeded(usize),
49    #[error("codec couldn't recover from invalid or too big frame")]
50    Defunct,
51    #[error(transparent)]
52    Inner(#[from] E),
53}
54
55impl<C> Encoder for LimitCodec<C>
56where
57    C: Encoder + DecoderWithSkipAhead,
58{
59    type Error = LimitError<<C as Encoder>::Error>;
60    type Item = <C as Encoder>::Item;
61
62    fn encode(&mut self, src: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
63        let mut tmp_dst = dst.split_off(dst.len());
64        self.inner.encode(src, &mut tmp_dst)?;
65
66        if tmp_dst.len() > self.max_frame_size {
67            return Err(LimitError::LimitExceeded(tmp_dst.len()));
68        }
69
70        dst.unsplit(tmp_dst);
71        Ok(())
72    }
73}
74
75impl<C> Decoder for LimitCodec<C>
76where
77    C: DecoderWithSkipAhead,
78{
79    type Error = LimitError<<C as Decoder>::Error>;
80    type Item = <C as Decoder>::Item;
81
82    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
83        while let Some(sas) = self.skip_ahead_state.take() {
84            match sas.continue_skipping(src) {
85                Ok((amount, next)) => {
86                    self.skip_ahead_state = next;
87                    debug_assert!(amount <= src.len());
88                    src.advance(amount);
89                    debug_assert!(amount != 0 || self.skip_ahead_state.is_none());
90                    if src.is_empty() {
91                        return Ok(None);
92                    }
93                },
94                Err(_err) => {
95                    self.decoder_defunct = true;
96                },
97            }
98        }
99
100        if self.decoder_defunct {
101            src.clear();
102            return Err(LimitError::Defunct);
103        }
104        match self.inner.decode(src) {
105            Ok(None) if src.len() > self.max_frame_size => {
106                self.skip_ahead_state = Some(self.inner.prepare_skip_ahead(src));
107                Err(LimitError::LimitExceeded(src.len()))
108            },
109            Ok(x) => Ok(x),
110            Err(x) => Err(LimitError::Inner(x)),
111        }
112    }
113}