prost_codec/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
2
3use asynchronous_codec::{Decoder, Encoder};
4use bytes::BytesMut;
5use prost::Message;
6use std::io::Cursor;
7use std::marker::PhantomData;
8use unsigned_varint::codec::UviBytes;
9
10/// [`Codec`] implements [`Encoder`] and [`Decoder`], uses [`unsigned_varint`]
11/// to prefix messages with their length and uses [`prost`] and a provided
12/// `struct` implementing [`Message`] to do the encoding.
13pub struct Codec<In, Out = In> {
14    uvi: UviBytes,
15    phantom: PhantomData<(In, Out)>,
16}
17
18impl<In, Out> Codec<In, Out> {
19    /// Create new [`Codec`].
20    ///
21    /// Parameter `max_message_len_bytes` determines the maximum length of the
22    /// Protobuf message. The limit does not include the bytes needed for the
23    /// [`unsigned_varint`].
24    pub fn new(max_message_len_bytes: usize) -> Self {
25        let mut uvi = UviBytes::default();
26        uvi.set_max_len(max_message_len_bytes);
27        Self {
28            uvi,
29            phantom: PhantomData::default(),
30        }
31    }
32}
33
34impl<In: Message, Out> Encoder for Codec<In, Out> {
35    type Item = In;
36    type Error = Error;
37
38    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
39        let mut encoded_msg = BytesMut::new();
40        item.encode(&mut encoded_msg)
41            .expect("BytesMut to have sufficient capacity.");
42        self.uvi.encode(encoded_msg.freeze(), dst)?;
43
44        Ok(())
45    }
46}
47
48impl<In, Out: Message + Default> Decoder for Codec<In, Out> {
49    type Item = Out;
50    type Error = Error;
51
52    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
53        let msg = match self.uvi.decode(src)? {
54            None => return Ok(None),
55            Some(msg) => msg,
56        };
57
58        let message = Message::decode(Cursor::new(msg))
59            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
60
61        Ok(Some(message))
62    }
63}
64
65#[derive(thiserror::Error, Debug)]
66#[error("Failed to encode/decode message")]
67pub struct Error(#[from] std::io::Error);
68
69impl From<Error> for std::io::Error {
70    fn from(e: Error) -> Self {
71        e.0
72    }
73}