tokio_util/codec/
any_delimiter_codec.rs

1use crate::codec::decoder::Decoder;
2use crate::codec::encoder::Encoder;
3
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5use std::{cmp, fmt, io, str};
6
7const DEFAULT_SEEK_DELIMITERS: &[u8] = b",;\n\r";
8const DEFAULT_SEQUENCE_WRITER: &[u8] = b",";
9/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into chunks based on any character in the given delimiter string.
10///
11/// [`Decoder`]: crate::codec::Decoder
12/// [`Encoder`]: crate::codec::Encoder
13///
14/// # Example
15/// Decode string of bytes containing various different delimiters.
16///
17/// [`BytesMut`]: bytes::BytesMut
18/// [`Error`]: std::io::Error
19///
20/// ```
21/// use tokio_util::codec::{AnyDelimiterCodec, Decoder};
22/// use bytes::{BufMut, BytesMut};
23///
24/// #
25/// # #[tokio::main(flavor = "current_thread")]
26/// # async fn main() -> Result<(), std::io::Error> {
27/// let mut codec = AnyDelimiterCodec::new(b",;\r\n".to_vec(),b";".to_vec());
28/// let buf = &mut BytesMut::new();
29/// buf.reserve(200);
30/// buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r");
31/// assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap());
32/// assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap());
33/// assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap());
34/// assert_eq!("", codec.decode(buf).unwrap().unwrap());
35/// assert_eq!(None, codec.decode(buf).unwrap());
36/// # Ok(())
37/// # }
38/// ```
39///
40#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
41pub struct AnyDelimiterCodec {
42    // Stored index of the next index to examine for the delimiter character.
43    // This is used to optimize searching.
44    // For example, if `decode` was called with `abc` and the delimiter is '{}', it would hold `3`,
45    // because that is the next index to examine.
46    // The next time `decode` is called with `abcde}`, the method will
47    // only look at `de}` before returning.
48    next_index: usize,
49
50    /// The maximum length for a given chunk. If `usize::MAX`, chunks will be
51    /// read until a delimiter character is reached.
52    max_length: usize,
53
54    /// Are we currently discarding the remainder of a chunk which was over
55    /// the length limit?
56    is_discarding: bool,
57
58    /// The bytes that are using for search during decode
59    seek_delimiters: Vec<u8>,
60
61    /// The bytes that are using for encoding
62    sequence_writer: Vec<u8>,
63}
64
65impl AnyDelimiterCodec {
66    /// Returns a `AnyDelimiterCodec` for splitting up data into chunks.
67    ///
68    /// # Note
69    ///
70    /// The returned `AnyDelimiterCodec` will not have an upper bound on the length
71    /// of a buffered chunk. See the documentation for [`new_with_max_length`]
72    /// for information on why this could be a potential security risk.
73    ///
74    /// [`new_with_max_length`]: crate::codec::AnyDelimiterCodec::new_with_max_length()
75    pub fn new(seek_delimiters: Vec<u8>, sequence_writer: Vec<u8>) -> AnyDelimiterCodec {
76        AnyDelimiterCodec {
77            next_index: 0,
78            max_length: usize::MAX,
79            is_discarding: false,
80            seek_delimiters,
81            sequence_writer,
82        }
83    }
84
85    /// Returns a `AnyDelimiterCodec` with a maximum chunk length limit.
86    ///
87    /// If this is set, calls to `AnyDelimiterCodec::decode` will return a
88    /// [`AnyDelimiterCodecError`] when a chunk exceeds the length limit. Subsequent calls
89    /// will discard up to `limit` bytes from that chunk until a delimiter
90    /// character is reached, returning `None` until the delimiter over the limit
91    /// has been fully discarded. After that point, calls to `decode` will
92    /// function as normal.
93    ///
94    /// # Note
95    ///
96    /// Setting a length limit is highly recommended for any `AnyDelimiterCodec` which
97    /// will be exposed to untrusted input. Otherwise, the size of the buffer
98    /// that holds the chunk currently being read is unbounded. An attacker could
99    /// exploit this unbounded buffer by sending an unbounded amount of input
100    /// without any delimiter characters, causing unbounded memory consumption.
101    ///
102    /// [`AnyDelimiterCodecError`]: crate::codec::AnyDelimiterCodecError
103    pub fn new_with_max_length(
104        seek_delimiters: Vec<u8>,
105        sequence_writer: Vec<u8>,
106        max_length: usize,
107    ) -> Self {
108        AnyDelimiterCodec {
109            max_length,
110            ..AnyDelimiterCodec::new(seek_delimiters, sequence_writer)
111        }
112    }
113
114    /// Returns the maximum chunk length when decoding.
115    ///
116    /// ```
117    /// use std::usize;
118    /// use tokio_util::codec::AnyDelimiterCodec;
119    ///
120    /// let codec = AnyDelimiterCodec::new(b",;\n".to_vec(), b";".to_vec());
121    /// assert_eq!(codec.max_length(), usize::MAX);
122    /// ```
123    /// ```
124    /// use tokio_util::codec::AnyDelimiterCodec;
125    ///
126    /// let codec = AnyDelimiterCodec::new_with_max_length(b",;\n".to_vec(), b";".to_vec(), 256);
127    /// assert_eq!(codec.max_length(), 256);
128    /// ```
129    pub fn max_length(&self) -> usize {
130        self.max_length
131    }
132}
133
134impl Decoder for AnyDelimiterCodec {
135    type Item = Bytes;
136    type Error = AnyDelimiterCodecError;
137
138    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, AnyDelimiterCodecError> {
139        loop {
140            // Determine how far into the buffer we'll search for a delimiter. If
141            // there's no max_length set, we'll read to the end of the buffer.
142            let read_to = cmp::min(self.max_length.saturating_add(1), buf.len());
143
144            let new_chunk_offset = buf[self.next_index..read_to].iter().position(|b| {
145                self.seek_delimiters
146                    .iter()
147                    .any(|delimiter| *b == *delimiter)
148            });
149
150            match (self.is_discarding, new_chunk_offset) {
151                (true, Some(offset)) => {
152                    // If we found a new chunk, discard up to that offset and
153                    // then stop discarding. On the next iteration, we'll try
154                    // to read a chunk normally.
155                    buf.advance(offset + self.next_index + 1);
156                    self.is_discarding = false;
157                    self.next_index = 0;
158                }
159                (true, None) => {
160                    // Otherwise, we didn't find a new chunk, so we'll discard
161                    // everything we read. On the next iteration, we'll continue
162                    // discarding up to max_len bytes unless we find a new chunk.
163                    buf.advance(read_to);
164                    self.next_index = 0;
165                    if buf.is_empty() {
166                        return Ok(None);
167                    }
168                }
169                (false, Some(offset)) => {
170                    // Found a chunk!
171                    let new_chunk_index = offset + self.next_index;
172                    self.next_index = 0;
173                    let mut chunk = buf.split_to(new_chunk_index + 1);
174                    chunk.truncate(chunk.len() - 1);
175                    let chunk = chunk.freeze();
176                    return Ok(Some(chunk));
177                }
178                (false, None) if buf.len() > self.max_length => {
179                    // Reached the maximum length without finding a
180                    // new chunk, return an error and start discarding on the
181                    // next call.
182                    self.is_discarding = true;
183                    return Err(AnyDelimiterCodecError::MaxChunkLengthExceeded);
184                }
185                (false, None) => {
186                    // We didn't find a chunk or reach the length limit, so the next
187                    // call will resume searching at the current offset.
188                    self.next_index = read_to;
189                    return Ok(None);
190                }
191            }
192        }
193    }
194
195    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, AnyDelimiterCodecError> {
196        Ok(match self.decode(buf)? {
197            Some(frame) => Some(frame),
198            None => {
199                // return remaining data, if any
200                if buf.is_empty() {
201                    None
202                } else {
203                    let chunk = buf.split_to(buf.len());
204                    self.next_index = 0;
205                    Some(chunk.freeze())
206                }
207            }
208        })
209    }
210}
211
212impl<T> Encoder<T> for AnyDelimiterCodec
213where
214    T: AsRef<str>,
215{
216    type Error = AnyDelimiterCodecError;
217
218    fn encode(&mut self, chunk: T, buf: &mut BytesMut) -> Result<(), AnyDelimiterCodecError> {
219        let chunk = chunk.as_ref();
220        buf.reserve(chunk.len() + 1);
221        buf.put(chunk.as_bytes());
222        buf.put(self.sequence_writer.as_ref());
223
224        Ok(())
225    }
226}
227
228impl Default for AnyDelimiterCodec {
229    fn default() -> Self {
230        Self::new(
231            DEFAULT_SEEK_DELIMITERS.to_vec(),
232            DEFAULT_SEQUENCE_WRITER.to_vec(),
233        )
234    }
235}
236
237/// An error occurred while encoding or decoding a chunk.
238#[derive(Debug)]
239pub enum AnyDelimiterCodecError {
240    /// The maximum chunk length was exceeded.
241    MaxChunkLengthExceeded,
242    /// An IO error occurred.
243    Io(io::Error),
244}
245
246impl fmt::Display for AnyDelimiterCodecError {
247    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248        match self {
249            AnyDelimiterCodecError::MaxChunkLengthExceeded => {
250                write!(f, "max chunk length exceeded")
251            }
252            AnyDelimiterCodecError::Io(e) => write!(f, "{e}"),
253        }
254    }
255}
256
257impl From<io::Error> for AnyDelimiterCodecError {
258    fn from(e: io::Error) -> AnyDelimiterCodecError {
259        AnyDelimiterCodecError::Io(e)
260    }
261}
262
263impl std::error::Error for AnyDelimiterCodecError {}