1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use super::{Decoder, Encoder};
use bytes::{Buf, BytesMut};

#[allow(missing_docs)]
pub trait SkipAheadHandler: Sized + std::fmt::Debug {
    fn continue_skipping(self, src: &[u8]) -> anyhow::Result<(usize, Option<Self>)>;
}

impl SkipAheadHandler for () {
    fn continue_skipping(self, _: &[u8]) -> anyhow::Result<(usize, Option<Self>)> {
        Ok((0, None))
    }
}

#[allow(missing_docs)]
pub trait DecoderWithSkipAhead: Decoder {
    type Handler: SkipAheadHandler;

    fn prepare_skip_ahead(&mut self, src: &mut BytesMut) -> Self::Handler;
}

#[derive(Debug)]
pub struct LimitCodec<C: DecoderWithSkipAhead> {
    inner: C,
    max_frame_size: usize,
    skip_ahead_state: Option<<C as DecoderWithSkipAhead>::Handler>,
    decoder_defunct: bool,
}

impl<C> LimitCodec<C>
where
    C: DecoderWithSkipAhead,
{
    #[allow(missing_docs)]
    pub fn new(inner: C, max_frame_size: usize) -> Self {
        Self {
            inner,
            max_frame_size,
            skip_ahead_state: None,
            decoder_defunct: false,
        }
    }
}

#[derive(Debug, thiserror::Error)]
pub enum LimitError<E: std::error::Error + 'static> {
    #[error("frame size limit exceeded (detected at {0} bytes)")]
    LimitExceeded(usize),
    #[error("codec couldn't recover from invalid or too big frame")]
    Defunct,
    #[error(transparent)]
    Inner(#[from] E),
}

impl<C> Encoder for LimitCodec<C>
where
    C: Encoder + DecoderWithSkipAhead,
{
    type Error = LimitError<<C as Encoder>::Error>;
    type Item = <C as Encoder>::Item;

    fn encode(&mut self, src: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
        let mut tmp_dst = dst.split_off(dst.len());
        self.inner.encode(src, &mut tmp_dst)?;

        if tmp_dst.len() > self.max_frame_size {
            return Err(LimitError::LimitExceeded(tmp_dst.len()));
        }

        dst.unsplit(tmp_dst);
        Ok(())
    }
}

impl<C> Decoder for LimitCodec<C>
where
    C: DecoderWithSkipAhead,
{
    type Error = LimitError<<C as Decoder>::Error>;
    type Item = <C as Decoder>::Item;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        while let Some(sas) = self.skip_ahead_state.take() {
            match sas.continue_skipping(&src) {
                Ok((amount, next)) => {
                    self.skip_ahead_state = next;
                    debug_assert!(amount <= src.len());
                    src.advance(amount);
                    debug_assert!(amount != 0 || self.skip_ahead_state.is_none());
                    if src.is_empty() {
                        return Ok(None);
                    }
                },
                Err(_err) => {
                    self.decoder_defunct = true;
                },
            }
        }

        if self.decoder_defunct {
            src.clear();
            return Err(LimitError::Defunct);
        }
        match self.inner.decode(src) {
            Ok(None) if src.len() > self.max_frame_size => {
                self.skip_ahead_state = Some(self.inner.prepare_skip_ahead(src));
                Err(LimitError::LimitExceeded(src.len()))
            },
            Ok(x) => Ok(x),
            Err(x) => Err(LimitError::Inner(x)),
        }
    }
}