async_codec_lite/codec/
limit.rs1use super::{Decoder, Encoder};
2use bytes::{Buf, BytesMut};
3
4#[allow(missing_docs)]
5pub trait SkipAheadHandler: Sized + std::fmt::Debug {
6 fn continue_skipping(self, src: &[u8]) -> anyhow::Result<(usize, Option<Self>)>;
7}
8
9impl SkipAheadHandler for () {
10 fn continue_skipping(self, _: &[u8]) -> anyhow::Result<(usize, Option<Self>)> {
11 Ok((0, None))
12 }
13}
14
15#[allow(missing_docs)]
16pub trait DecoderWithSkipAhead: Decoder {
17 type Handler: SkipAheadHandler;
18
19 fn prepare_skip_ahead(&mut self, src: &mut BytesMut) -> Self::Handler;
20}
21
22#[derive(Debug)]
23pub struct LimitCodec<C: DecoderWithSkipAhead> {
24 inner: C,
25 max_frame_size: usize,
26 skip_ahead_state: Option<<C as DecoderWithSkipAhead>::Handler>,
27 decoder_defunct: bool,
28}
29
30impl<C> LimitCodec<C>
31where
32 C: DecoderWithSkipAhead,
33{
34 #[allow(missing_docs)]
35 pub fn new(inner: C, max_frame_size: usize) -> Self {
36 Self {
37 inner,
38 max_frame_size,
39 skip_ahead_state: None,
40 decoder_defunct: false,
41 }
42 }
43}
44
45#[derive(Debug, thiserror::Error)]
46pub enum LimitError<E: std::error::Error + 'static> {
47 #[error("frame size limit exceeded (detected at {0} bytes)")]
48 LimitExceeded(usize),
49 #[error("codec couldn't recover from invalid or too big frame")]
50 Defunct,
51 #[error(transparent)]
52 Inner(#[from] E),
53}
54
55impl<C> Encoder for LimitCodec<C>
56where
57 C: Encoder + DecoderWithSkipAhead,
58{
59 type Error = LimitError<<C as Encoder>::Error>;
60 type Item = <C as Encoder>::Item;
61
62 fn encode(&mut self, src: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
63 let mut tmp_dst = dst.split_off(dst.len());
64 self.inner.encode(src, &mut tmp_dst)?;
65
66 if tmp_dst.len() > self.max_frame_size {
67 return Err(LimitError::LimitExceeded(tmp_dst.len()));
68 }
69
70 dst.unsplit(tmp_dst);
71 Ok(())
72 }
73}
74
75impl<C> Decoder for LimitCodec<C>
76where
77 C: DecoderWithSkipAhead,
78{
79 type Error = LimitError<<C as Decoder>::Error>;
80 type Item = <C as Decoder>::Item;
81
82 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
83 while let Some(sas) = self.skip_ahead_state.take() {
84 match sas.continue_skipping(src) {
85 Ok((amount, next)) => {
86 self.skip_ahead_state = next;
87 debug_assert!(amount <= src.len());
88 src.advance(amount);
89 debug_assert!(amount != 0 || self.skip_ahead_state.is_none());
90 if src.is_empty() {
91 return Ok(None);
92 }
93 },
94 Err(_err) => {
95 self.decoder_defunct = true;
96 },
97 }
98 }
99
100 if self.decoder_defunct {
101 src.clear();
102 return Err(LimitError::Defunct);
103 }
104 match self.inner.decode(src) {
105 Ok(None) if src.len() > self.max_frame_size => {
106 self.skip_ahead_state = Some(self.inner.prepare_skip_ahead(src));
107 Err(LimitError::LimitExceeded(src.len()))
108 },
109 Ok(x) => Ok(x),
110 Err(x) => Err(LimitError::Inner(x)),
111 }
112 }
113}