1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use super::{Decoder, Encoder};
use bytes::{Buf, BytesMut};

#[allow(missing_docs)]
pub trait SkipAheadHandler: Sized + std::fmt::Debug {
    fn continue_skipping(self, src: &[u8]) -> anyhow::Result<(usize, Option<Self>)>;
}

impl SkipAheadHandler for () {
    fn continue_skipping(self, _: &[u8]) -> anyhow::Result<(usize, Option<Self>)> {
        Ok((0, None))
    }
}

#[allow(missing_docs)]
pub trait DecoderWithSkipAhead: Decoder {
    type Handler: SkipAheadHandler;

    fn prepare_skip_ahead(&mut self, src: &mut BytesMut) -> Self::Handler;
}

#[derive(Debug)]
pub struct LimitCodec<C: DecoderWithSkipAhead> {
    inner: C,
    max_frame_size: usize,
    skip_ahead_state: Option<<C as DecoderWithSkipAhead>::Handler>,
    decoder_defunct: bool,
}

impl<C> LimitCodec<C>
where
    C: DecoderWithSkipAhead,
{
    #[allow(missing_docs)]
    pub fn new(inner: C, max_frame_size: usize) -> Self {
        Self {
            inner,
            max_frame_size,
            skip_ahead_state: None,
            decoder_defunct: false,
        }
    }
}

#[derive(Debug, thiserror::Error)]
pub enum LimitError<E: std::error::Error + 'static> {
    #[error("frame size limit exceeded (detected at {0} bytes)")]
    LimitExceeded(usize),
    #[error("codec couldn't recover from invalid or too big frame")]
    Defunct,
    #[error(transparent)]
    Inner(#[from] E),
}

impl<C> Encoder for LimitCodec<C>
where
    C: Encoder + DecoderWithSkipAhead,
{
    type Error = LimitError<<C as Encoder>::Error>;
    type Item = <C as Encoder>::Item;

    fn encode(&mut self, src: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
        let mut tmp_dst = dst.split_off(dst.len());
        self.inner.encode(src, &mut tmp_dst)?;

        if tmp_dst.len() > self.max_frame_size {
            return Err(LimitError::LimitExceeded(tmp_dst.len()));
        }

        dst.unsplit(tmp_dst);
        Ok(())
    }
}

impl<C> Decoder for LimitCodec<C>
where
    C: DecoderWithSkipAhead,
{
    type Error = LimitError<<C as Decoder>::Error>;
    type Item = <C as Decoder>::Item;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        while let Some(sas) = self.skip_ahead_state.take() {
            match sas.continue_skipping(src) {
                Ok((amount, next)) => {
                    self.skip_ahead_state = next;
                    debug_assert!(amount <= src.len());
                    src.advance(amount);
                    debug_assert!(amount != 0 || self.skip_ahead_state.is_none());
                    if src.is_empty() {
                        return Ok(None);
                    }
                },
                Err(_err) => {
                    self.decoder_defunct = true;
                },
            }
        }

        if self.decoder_defunct {
            src.clear();
            return Err(LimitError::Defunct);
        }
        match self.inner.decode(src) {
            Ok(None) if src.len() > self.max_frame_size => {
                self.skip_ahead_state = Some(self.inner.prepare_skip_ahead(src));
                Err(LimitError::LimitExceeded(src.len()))
            },
            Ok(x) => Ok(x),
            Err(x) => Err(LimitError::Inner(x)),
        }
    }
}