quick_protobuf_codec/
lib.rs1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
2
3use asynchronous_codec::{Decoder, Encoder};
4use bytes::{Buf, BufMut, BytesMut};
5use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer, WriterBackend};
6use std::io;
7use std::marker::PhantomData;
8
9mod generated;
10
11#[doc(hidden)] pub use generated::test as proto;
13
14pub struct Codec<In, Out = In> {
18 max_message_len_bytes: usize,
19 phantom: PhantomData<(In, Out)>,
20}
21
22impl<In, Out> Codec<In, Out> {
23 pub fn new(max_message_len_bytes: usize) -> Self {
29 Self {
30 max_message_len_bytes,
31 phantom: PhantomData,
32 }
33 }
34}
35
36impl<In: MessageWrite, Out> Encoder for Codec<In, Out> {
37 type Item<'a> = In;
38 type Error = Error;
39
40 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
41 write_length(&item, dst);
42 write_message(&item, dst)?;
43
44 Ok(())
45 }
46}
47
48fn write_length(message: &impl MessageWrite, dst: &mut BytesMut) {
50 let message_length = message.get_size();
51
52 let mut uvi_buf = unsigned_varint::encode::usize_buffer();
53 let encoded_length = unsigned_varint::encode::usize(message_length, &mut uvi_buf);
54
55 dst.extend_from_slice(encoded_length);
56}
57
58fn write_message(item: &impl MessageWrite, dst: &mut BytesMut) -> io::Result<()> {
60 let mut writer = Writer::new(BytesMutWriterBackend::new(dst));
61 item.write_message(&mut writer)
62 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
63
64 Ok(())
65}
66
67impl<In, Out> Decoder for Codec<In, Out>
68where
69 Out: for<'a> MessageRead<'a>,
70{
71 type Item = Out;
72 type Error = Error;
73
74 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
75 let (message_length, remaining) = match unsigned_varint::decode::usize(src) {
76 Ok((len, remaining)) => (len, remaining),
77 Err(unsigned_varint::decode::Error::Insufficient) => return Ok(None),
78 Err(e) => return Err(Error(io::Error::new(io::ErrorKind::InvalidData, e))),
79 };
80
81 if message_length > self.max_message_len_bytes {
82 return Err(Error(io::Error::new(
83 io::ErrorKind::PermissionDenied,
84 format!(
85 "message with {message_length}b exceeds maximum of {}b",
86 self.max_message_len_bytes
87 ),
88 )));
89 }
90
91 let varint_length = src.len() - remaining.len();
93
94 if src.len() < (message_length + varint_length) {
96 return Ok(None);
97 }
98
99 src.advance(varint_length);
101
102 let message = src.split_to(message_length);
103
104 let mut reader = BytesReader::from_bytes(&message);
105 let message = Self::Item::from_reader(&mut reader, &message)
106 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
107
108 Ok(Some(message))
109 }
110}
111
112struct BytesMutWriterBackend<'a> {
113 dst: &'a mut BytesMut,
114}
115
116impl<'a> BytesMutWriterBackend<'a> {
117 fn new(dst: &'a mut BytesMut) -> Self {
118 Self { dst }
119 }
120}
121
122impl<'a> WriterBackend for BytesMutWriterBackend<'a> {
123 fn pb_write_u8(&mut self, x: u8) -> quick_protobuf::Result<()> {
124 self.dst.put_u8(x);
125
126 Ok(())
127 }
128
129 fn pb_write_u32(&mut self, x: u32) -> quick_protobuf::Result<()> {
130 self.dst.put_u32_le(x);
131
132 Ok(())
133 }
134
135 fn pb_write_i32(&mut self, x: i32) -> quick_protobuf::Result<()> {
136 self.dst.put_i32_le(x);
137
138 Ok(())
139 }
140
141 fn pb_write_f32(&mut self, x: f32) -> quick_protobuf::Result<()> {
142 self.dst.put_f32_le(x);
143
144 Ok(())
145 }
146
147 fn pb_write_u64(&mut self, x: u64) -> quick_protobuf::Result<()> {
148 self.dst.put_u64_le(x);
149
150 Ok(())
151 }
152
153 fn pb_write_i64(&mut self, x: i64) -> quick_protobuf::Result<()> {
154 self.dst.put_i64_le(x);
155
156 Ok(())
157 }
158
159 fn pb_write_f64(&mut self, x: f64) -> quick_protobuf::Result<()> {
160 self.dst.put_f64_le(x);
161
162 Ok(())
163 }
164
165 fn pb_write_all(&mut self, buf: &[u8]) -> quick_protobuf::Result<()> {
166 self.dst.put_slice(buf);
167
168 Ok(())
169 }
170}
171
172#[derive(thiserror::Error, Debug)]
173#[error("Failed to encode/decode message")]
174pub struct Error(#[from] io::Error);
175
176impl From<Error> for io::Error {
177 fn from(e: Error) -> Self {
178 e.0
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use crate::proto;
186 use asynchronous_codec::FramedRead;
187 use futures::io::Cursor;
188 use futures::{FutureExt, StreamExt};
189 use quickcheck::{Arbitrary, Gen, QuickCheck};
190 use std::error::Error;
191
192 #[test]
193 fn honors_max_message_length() {
194 let codec = Codec::<Dummy>::new(1);
195 let mut src = varint_zeroes(100);
196
197 let mut read = FramedRead::new(Cursor::new(&mut src), codec);
198 let err = read.next().now_or_never().unwrap().unwrap().unwrap_err();
199
200 assert_eq!(
201 err.source().unwrap().to_string(),
202 "message with 100b exceeds maximum of 1b"
203 )
204 }
205
206 #[test]
207 fn empty_bytes_mut_does_not_panic() {
208 let mut codec = Codec::<Dummy>::new(100);
209
210 let mut src = varint_zeroes(100);
211 src.truncate(50);
212
213 let result = codec.decode(&mut src);
214
215 assert!(result.unwrap().is_none());
216 assert_eq!(
217 src.len(),
218 50,
219 "to not modify `src` if we cannot read a full message"
220 )
221 }
222
223 #[test]
224 fn only_partial_message_in_bytes_mut_does_not_panic() {
225 let mut codec = Codec::<Dummy>::new(100);
226
227 let result = codec.decode(&mut BytesMut::new());
228
229 assert!(result.unwrap().is_none());
230 }
231
232 #[test]
233 fn handles_arbitrary_initial_capacity() {
234 fn prop(message: proto::Message, initial_capacity: u16) {
235 let mut buffer = BytesMut::with_capacity(initial_capacity as usize);
236 let mut codec = Codec::<proto::Message>::new(u32::MAX as usize);
237
238 codec.encode(message.clone(), &mut buffer).unwrap();
239 let decoded = codec.decode(&mut buffer).unwrap().unwrap();
240
241 assert_eq!(message, decoded);
242 }
243
244 QuickCheck::new().quickcheck(prop as fn(_, _) -> _)
245 }
246
247 fn varint_zeroes(length: usize) -> BytesMut {
249 let mut buf = unsigned_varint::encode::usize_buffer();
250 let encoded_length = unsigned_varint::encode::usize(length, &mut buf);
251
252 let mut src = BytesMut::new();
253 src.extend_from_slice(encoded_length);
254 src.extend(std::iter::repeat(0).take(length));
255 src
256 }
257
258 impl Arbitrary for proto::Message {
259 fn arbitrary(g: &mut Gen) -> Self {
260 Self {
261 data: Vec::arbitrary(g),
262 }
263 }
264 }
265
266 #[derive(Debug)]
267 struct Dummy;
268
269 impl<'a> MessageRead<'a> for Dummy {
270 fn from_reader(_: &mut BytesReader, _: &'a [u8]) -> quick_protobuf::Result<Self> {
271 todo!()
272 }
273 }
274}