1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use super::{Decoder, Encoder};
use bytes::{Buf, BytesMut};
#[allow(missing_docs)]
pub trait SkipAheadHandler: Sized + std::fmt::Debug {
fn continue_skipping(self, src: &[u8]) -> anyhow::Result<(usize, Option<Self>)>;
}
impl SkipAheadHandler for () {
fn continue_skipping(self, _: &[u8]) -> anyhow::Result<(usize, Option<Self>)> {
Ok((0, None))
}
}
#[allow(missing_docs)]
pub trait DecoderWithSkipAhead: Decoder {
type Handler: SkipAheadHandler;
fn prepare_skip_ahead(&mut self, src: &mut BytesMut) -> Self::Handler;
}
#[derive(Debug)]
pub struct LimitCodec<C: DecoderWithSkipAhead> {
inner: C,
max_frame_size: usize,
skip_ahead_state: Option<<C as DecoderWithSkipAhead>::Handler>,
decoder_defunct: bool,
}
impl<C> LimitCodec<C>
where
C: DecoderWithSkipAhead,
{
#[allow(missing_docs)]
pub fn new(inner: C, max_frame_size: usize) -> Self {
Self {
inner,
max_frame_size,
skip_ahead_state: None,
decoder_defunct: false,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum LimitError<E: std::error::Error + 'static> {
#[error("frame size limit exceeded (detected at {0} bytes)")]
LimitExceeded(usize),
#[error("codec couldn't recover from invalid or too big frame")]
Defunct,
#[error(transparent)]
Inner(#[from] E),
}
impl<C> Encoder for LimitCodec<C>
where
C: Encoder + DecoderWithSkipAhead,
{
type Error = LimitError<<C as Encoder>::Error>;
type Item = <C as Encoder>::Item;
fn encode(&mut self, src: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
let mut tmp_dst = dst.split_off(dst.len());
self.inner.encode(src, &mut tmp_dst)?;
if tmp_dst.len() > self.max_frame_size {
return Err(LimitError::LimitExceeded(tmp_dst.len()));
}
dst.unsplit(tmp_dst);
Ok(())
}
}
impl<C> Decoder for LimitCodec<C>
where
C: DecoderWithSkipAhead,
{
type Error = LimitError<<C as Decoder>::Error>;
type Item = <C as Decoder>::Item;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
while let Some(sas) = self.skip_ahead_state.take() {
match sas.continue_skipping(src) {
Ok((amount, next)) => {
self.skip_ahead_state = next;
debug_assert!(amount <= src.len());
src.advance(amount);
debug_assert!(amount != 0 || self.skip_ahead_state.is_none());
if src.is_empty() {
return Ok(None);
}
},
Err(_err) => {
self.decoder_defunct = true;
},
}
}
if self.decoder_defunct {
src.clear();
return Err(LimitError::Defunct);
}
match self.inner.decode(src) {
Ok(None) if src.len() > self.max_frame_size => {
self.skip_ahead_state = Some(self.inner.prepare_skip_ahead(src));
Err(LimitError::LimitExceeded(src.len()))
},
Ok(x) => Ok(x),
Err(x) => Err(LimitError::Inner(x)),
}
}
}