quick_protobuf_codec/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
2
3use asynchronous_codec::{Decoder, Encoder};
4use bytes::{Buf, BufMut, BytesMut};
5use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer, WriterBackend};
6use std::io;
7use std::marker::PhantomData;
8
9mod generated;
10
11#[doc(hidden)] // NOT public API. Do not use.
12pub use generated::test as proto;
13
14/// [`Codec`] implements [`Encoder`] and [`Decoder`], uses [`unsigned_varint`]
15/// to prefix messages with their length and uses [`quick_protobuf`] and a provided
16/// `struct` implementing [`MessageRead`] and [`MessageWrite`] to do the encoding.
17pub struct Codec<In, Out = In> {
18    max_message_len_bytes: usize,
19    phantom: PhantomData<(In, Out)>,
20}
21
22impl<In, Out> Codec<In, Out> {
23    /// Create new [`Codec`].
24    ///
25    /// Parameter `max_message_len_bytes` determines the maximum length of the
26    /// Protobuf message. The limit does not include the bytes needed for the
27    /// [`unsigned_varint`].
28    pub fn new(max_message_len_bytes: usize) -> Self {
29        Self {
30            max_message_len_bytes,
31            phantom: PhantomData,
32        }
33    }
34}
35
36impl<In: MessageWrite, Out> Encoder for Codec<In, Out> {
37    type Item<'a> = In;
38    type Error = Error;
39
40    fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
41        write_length(&item, dst);
42        write_message(&item, dst)?;
43
44        Ok(())
45    }
46}
47
48/// Write the message's length (i.e. `size`) to `dst` as a variable-length integer.
49fn write_length(message: &impl MessageWrite, dst: &mut BytesMut) {
50    let message_length = message.get_size();
51
52    let mut uvi_buf = unsigned_varint::encode::usize_buffer();
53    let encoded_length = unsigned_varint::encode::usize(message_length, &mut uvi_buf);
54
55    dst.extend_from_slice(encoded_length);
56}
57
58/// Write the message itself to `dst`.
59fn write_message(item: &impl MessageWrite, dst: &mut BytesMut) -> io::Result<()> {
60    let mut writer = Writer::new(BytesMutWriterBackend::new(dst));
61    item.write_message(&mut writer)
62        .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
63
64    Ok(())
65}
66
67impl<In, Out> Decoder for Codec<In, Out>
68where
69    Out: for<'a> MessageRead<'a>,
70{
71    type Item = Out;
72    type Error = Error;
73
74    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
75        let (message_length, remaining) = match unsigned_varint::decode::usize(src) {
76            Ok((len, remaining)) => (len, remaining),
77            Err(unsigned_varint::decode::Error::Insufficient) => return Ok(None),
78            Err(e) => return Err(Error(io::Error::new(io::ErrorKind::InvalidData, e))),
79        };
80
81        if message_length > self.max_message_len_bytes {
82            return Err(Error(io::Error::new(
83                io::ErrorKind::PermissionDenied,
84                format!(
85                    "message with {message_length}b exceeds maximum of {}b",
86                    self.max_message_len_bytes
87                ),
88            )));
89        }
90
91        // Compute how many bytes the varint itself consumed.
92        let varint_length = src.len() - remaining.len();
93
94        // Ensure we can read an entire message.
95        if src.len() < (message_length + varint_length) {
96            return Ok(None);
97        }
98
99        // Safe to advance buffer now.
100        src.advance(varint_length);
101
102        let message = src.split_to(message_length);
103
104        let mut reader = BytesReader::from_bytes(&message);
105        let message = Self::Item::from_reader(&mut reader, &message)
106            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
107
108        Ok(Some(message))
109    }
110}
111
112struct BytesMutWriterBackend<'a> {
113    dst: &'a mut BytesMut,
114}
115
116impl<'a> BytesMutWriterBackend<'a> {
117    fn new(dst: &'a mut BytesMut) -> Self {
118        Self { dst }
119    }
120}
121
122impl<'a> WriterBackend for BytesMutWriterBackend<'a> {
123    fn pb_write_u8(&mut self, x: u8) -> quick_protobuf::Result<()> {
124        self.dst.put_u8(x);
125
126        Ok(())
127    }
128
129    fn pb_write_u32(&mut self, x: u32) -> quick_protobuf::Result<()> {
130        self.dst.put_u32_le(x);
131
132        Ok(())
133    }
134
135    fn pb_write_i32(&mut self, x: i32) -> quick_protobuf::Result<()> {
136        self.dst.put_i32_le(x);
137
138        Ok(())
139    }
140
141    fn pb_write_f32(&mut self, x: f32) -> quick_protobuf::Result<()> {
142        self.dst.put_f32_le(x);
143
144        Ok(())
145    }
146
147    fn pb_write_u64(&mut self, x: u64) -> quick_protobuf::Result<()> {
148        self.dst.put_u64_le(x);
149
150        Ok(())
151    }
152
153    fn pb_write_i64(&mut self, x: i64) -> quick_protobuf::Result<()> {
154        self.dst.put_i64_le(x);
155
156        Ok(())
157    }
158
159    fn pb_write_f64(&mut self, x: f64) -> quick_protobuf::Result<()> {
160        self.dst.put_f64_le(x);
161
162        Ok(())
163    }
164
165    fn pb_write_all(&mut self, buf: &[u8]) -> quick_protobuf::Result<()> {
166        self.dst.put_slice(buf);
167
168        Ok(())
169    }
170}
171
172#[derive(thiserror::Error, Debug)]
173#[error("Failed to encode/decode message")]
174pub struct Error(#[from] io::Error);
175
176impl From<Error> for io::Error {
177    fn from(e: Error) -> Self {
178        e.0
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use crate::proto;
186    use asynchronous_codec::FramedRead;
187    use futures::io::Cursor;
188    use futures::{FutureExt, StreamExt};
189    use quickcheck::{Arbitrary, Gen, QuickCheck};
190    use std::error::Error;
191
192    #[test]
193    fn honors_max_message_length() {
194        let codec = Codec::<Dummy>::new(1);
195        let mut src = varint_zeroes(100);
196
197        let mut read = FramedRead::new(Cursor::new(&mut src), codec);
198        let err = read.next().now_or_never().unwrap().unwrap().unwrap_err();
199
200        assert_eq!(
201            err.source().unwrap().to_string(),
202            "message with 100b exceeds maximum of 1b"
203        )
204    }
205
206    #[test]
207    fn empty_bytes_mut_does_not_panic() {
208        let mut codec = Codec::<Dummy>::new(100);
209
210        let mut src = varint_zeroes(100);
211        src.truncate(50);
212
213        let result = codec.decode(&mut src);
214
215        assert!(result.unwrap().is_none());
216        assert_eq!(
217            src.len(),
218            50,
219            "to not modify `src` if we cannot read a full message"
220        )
221    }
222
223    #[test]
224    fn only_partial_message_in_bytes_mut_does_not_panic() {
225        let mut codec = Codec::<Dummy>::new(100);
226
227        let result = codec.decode(&mut BytesMut::new());
228
229        assert!(result.unwrap().is_none());
230    }
231
232    #[test]
233    fn handles_arbitrary_initial_capacity() {
234        fn prop(message: proto::Message, initial_capacity: u16) {
235            let mut buffer = BytesMut::with_capacity(initial_capacity as usize);
236            let mut codec = Codec::<proto::Message>::new(u32::MAX as usize);
237
238            codec.encode(message.clone(), &mut buffer).unwrap();
239            let decoded = codec.decode(&mut buffer).unwrap().unwrap();
240
241            assert_eq!(message, decoded);
242        }
243
244        QuickCheck::new().quickcheck(prop as fn(_, _) -> _)
245    }
246
247    /// Constructs a [`BytesMut`] of the provided length where the message is all zeros.
248    fn varint_zeroes(length: usize) -> BytesMut {
249        let mut buf = unsigned_varint::encode::usize_buffer();
250        let encoded_length = unsigned_varint::encode::usize(length, &mut buf);
251
252        let mut src = BytesMut::new();
253        src.extend_from_slice(encoded_length);
254        src.extend(std::iter::repeat(0).take(length));
255        src
256    }
257
258    impl Arbitrary for proto::Message {
259        fn arbitrary(g: &mut Gen) -> Self {
260            Self {
261                data: Vec::arbitrary(g),
262            }
263        }
264    }
265
266    #[derive(Debug)]
267    struct Dummy;
268
269    impl<'a> MessageRead<'a> for Dummy {
270        fn from_reader(_: &mut BytesReader, _: &'a [u8]) -> quick_protobuf::Result<Self> {
271            todo!()
272        }
273    }
274}