fedimint_server/net/
framed.rs

1//! Adapter that implements a message based protocol on top of a stream based
2//! one
3use std::fmt::Debug;
4use std::io::{Read, Write};
5use std::marker::PhantomData;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use bytes::{Buf, BufMut, BytesMut};
10use fedimint_logging::LOG_NET_PEER;
11use futures::{Sink, Stream};
12use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
13use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
14use tokio_util::codec::{FramedRead, FramedWrite};
15use tracing::{error, trace};
16
17/// Owned [`FramedTransport`] trait object
18pub type AnyFramedTransport<M> = Box<dyn FramedTransport<M> + Send + Unpin + 'static>;
19
20/// A bidirectional framed transport adapter that can be split into its read and
21/// write half
22pub trait FramedTransport<T>:
23    Sink<T, Error = anyhow::Error> + Stream<Item = Result<T, anyhow::Error>>
24{
25    /// Split the framed transport into read and write half
26    fn borrow_split(
27        &mut self,
28    ) -> (
29        &'_ mut (dyn Sink<T, Error = anyhow::Error> + Send + Unpin),
30        &'_ mut (dyn Stream<Item = Result<T, anyhow::Error>> + Send + Unpin),
31    );
32
33    /// Transforms concrete `FramedTransport` object into an owned trait object
34    fn into_dyn(self) -> AnyFramedTransport<T>
35    where
36        Self: Sized + Send + Unpin + 'static,
37    {
38        Box::new(self)
39    }
40}
41
42/// Special case for tokio [`TcpStream`](tokio::net::TcpStream) based
43/// [`BidiFramed`] instances
44pub type TcpBidiFramed<T> = BidiFramed<T, OwnedWriteHalf, OwnedReadHalf>;
45
46/// Sink (sending) half of [`BidiFramed`]
47pub type FramedSink<S, T> = FramedWrite<S, BincodeCodec<T>>;
48/// Stream (receiving) half of [`BidiFramed`]
49pub type FramedStream<S, T> = FramedRead<S, BincodeCodec<T>>;
50
51/// Framed transport codec for streams
52///
53/// Wraps a stream `S` and allows sending packetized data of type `T` over it.
54/// Data items are encoded using [`bincode`] and the bytes are sent over the
55/// stream prepended with a length field. `BidiFramed` implements `Sink<T>` and
56/// `Stream<Item=Result<T, _>>`.
57#[derive(Debug)]
58pub struct BidiFramed<T, WH, RH> {
59    sink: FramedSink<WH, T>,
60    stream: FramedStream<RH, T>,
61}
62
63/// Framed codec that uses [`bincode`] to encode structs with [`serde`] support
64#[derive(Debug)]
65pub struct BincodeCodec<T> {
66    _pd: PhantomData<T>,
67}
68
69impl<T, WH, RH> BidiFramed<T, WH, RH>
70where
71    WH: AsyncWrite,
72    RH: AsyncRead,
73    T: serde::Serialize + serde::de::DeserializeOwned,
74{
75    /// Builds a new `BidiFramed` codec around a stream `stream`.
76    ///
77    /// See [`TcpBidiFramed::new_from_tcp`] for a more efficient version in case
78    /// the stream is a tokio TCP stream.
79    pub fn new<S>(stream: S) -> BidiFramed<T, WriteHalf<S>, ReadHalf<S>>
80    where
81        S: AsyncRead + AsyncWrite,
82    {
83        let (read, write) = tokio::io::split(stream);
84        BidiFramed {
85            sink: FramedSink::new(write, BincodeCodec::new()),
86            stream: FramedStream::new(read, BincodeCodec::new()),
87        }
88    }
89
90    /// Splits the codec in its sending and receiving parts
91    ///
92    /// This can be useful in cases where potentially simultaneous read and
93    /// write operations are required. Otherwise a we would need a mutex to
94    /// guard access.
95    pub fn borrow_parts(&mut self) -> (&mut FramedSink<WH, T>, &mut FramedStream<RH, T>) {
96        (&mut self.sink, &mut self.stream)
97    }
98}
99
100impl<T> TcpBidiFramed<T>
101where
102    T: serde::Serialize + serde::de::DeserializeOwned,
103{
104    /// Special constructor for tokio TCP connections.
105    ///
106    /// Tokio [`TcpStream`](tokio::net::TcpStream) implements an efficient
107    /// method of splitting the stream into a read and a write half this
108    /// constructor takes advantage of.
109    pub fn new_from_tcp(stream: tokio::net::TcpStream) -> TcpBidiFramed<T> {
110        let (read, write) = stream.into_split();
111        BidiFramed {
112            sink: FramedSink::new(write, BincodeCodec::new()),
113            stream: FramedStream::new(read, BincodeCodec::new()),
114        }
115    }
116}
117
118impl<T, WH, RH> Sink<T> for BidiFramed<T, WH, RH>
119where
120    WH: tokio::io::AsyncWrite + Unpin,
121    RH: Unpin,
122    T: Debug + serde::Serialize,
123{
124    type Error = anyhow::Error;
125
126    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
127        Sink::poll_ready(Pin::new(&mut self.sink), cx)
128    }
129
130    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
131        Sink::start_send(Pin::new(&mut self.sink), item)
132    }
133
134    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
135        Sink::poll_flush(Pin::new(&mut self.sink), cx)
136    }
137
138    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
139        Sink::poll_close(Pin::new(&mut self.sink), cx)
140    }
141}
142
143impl<T, WH, RH> Stream for BidiFramed<T, WH, RH>
144where
145    T: serde::de::DeserializeOwned,
146    WH: Unpin,
147    RH: tokio::io::AsyncRead + Unpin,
148{
149    type Item = Result<T, anyhow::Error>;
150
151    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
152        Stream::poll_next(Pin::new(&mut self.stream), cx)
153    }
154}
155
156impl<T, WH, RH> FramedTransport<T> for BidiFramed<T, WH, RH>
157where
158    T: Debug + serde::Serialize + serde::de::DeserializeOwned + Send,
159    WH: tokio::io::AsyncWrite + Send + Unpin,
160    RH: tokio::io::AsyncRead + Send + Unpin,
161{
162    fn borrow_split(
163        &mut self,
164    ) -> (
165        &'_ mut (dyn Sink<T, Error = anyhow::Error> + Send + Unpin),
166        &'_ mut (dyn Stream<Item = Result<T, anyhow::Error>> + Send + Unpin),
167    ) {
168        let (sink, stream) = self.borrow_parts();
169        (&mut *sink, &mut *stream)
170    }
171}
172
173impl<T> BincodeCodec<T> {
174    fn new() -> BincodeCodec<T> {
175        BincodeCodec { _pd: PhantomData }
176    }
177}
178
179impl<T> tokio_util::codec::Encoder<T> for BincodeCodec<T>
180where
181    T: serde::Serialize + Debug,
182{
183    type Error = anyhow::Error;
184
185    fn encode(&mut self, item: T, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
186        // First, write a dummy length field and remember its position
187        let old_len = dst.len();
188        dst.writer().write_all(&[0u8; 8]).unwrap();
189        assert_eq!(dst.len(), old_len + 8);
190
191        // Then we serialize the message into the buffer
192        bincode::serialize_into(dst.writer(), &item).inspect_err(|_e| {
193            error!(
194                target: LOG_NET_PEER,
195                "Serializing message failed: {:?}", item
196            );
197        })?;
198
199        // Lastly we update the length field by counting how many bytes have been
200        // written
201        let new_len = dst.len();
202        let encoded_len = new_len - old_len - 8;
203        dst[old_len..old_len + 8].copy_from_slice(&encoded_len.to_be_bytes()[..]);
204
205        Ok(())
206    }
207}
208
209impl<T> tokio_util::codec::Decoder for BincodeCodec<T>
210where
211    T: serde::de::DeserializeOwned,
212{
213    type Item = T;
214    type Error = anyhow::Error;
215
216    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
217        if src.len() < 8 {
218            return Ok(None);
219        }
220
221        let length = u64::from_be_bytes(src[0..8].try_into().expect("correct length"));
222        if src.len() < (length as usize) + 8 {
223            trace!(length, buffern_len = src.len(), "Received partial message");
224            return Ok(None);
225        }
226        trace!(length, "Received full message");
227
228        src.reader()
229            .read_exact(&mut [0u8; 8][..])
230            .expect("minimum length checked");
231
232        Ok(bincode::deserialize_from(src.reader()).map(Option::Some)?)
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use std::time::Duration;
239
240    use futures::{SinkExt, StreamExt};
241    use serde::{Deserialize, Serialize};
242    use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf};
243
244    use crate::net::framed::BidiFramed;
245
246    #[tokio::test]
247    async fn test_roundtrip() {
248        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
249        enum TestEnum {
250            Foo,
251            Bar(u64),
252        }
253
254        let input = vec![TestEnum::Foo, TestEnum::Bar(42), TestEnum::Foo];
255        let (sender, recipient) = tokio::io::duplex(1024);
256
257        let mut framed_sender =
258            BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(sender);
259
260        let mut framed_recipient =
261            BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(recipient);
262
263        for item in &input {
264            framed_sender.send(item.clone()).await.unwrap();
265        }
266
267        for item in &input {
268            let received = framed_recipient.next().await.unwrap().unwrap();
269            assert_eq!(&received, item);
270        }
271        drop(framed_sender);
272
273        assert!(framed_recipient.next().await.is_none());
274    }
275
276    #[tokio::test]
277    async fn test_not_try_parse_partial() {
278        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
279        enum TestEnum {
280            Foo,
281            Bar(u64),
282        }
283
284        let (sender_src, mut recipient_src) = tokio::io::duplex(1024);
285        let (mut sender_dst, recipient_dst) = tokio::io::duplex(1024);
286
287        let mut framed_sender =
288            BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(
289                sender_src,
290            );
291        let mut framed_recipient =
292            BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(
293                recipient_dst,
294            );
295
296        framed_sender
297            .send(TestEnum::Bar(0x4242_4242_4242_4242))
298            .await
299            .unwrap();
300
301        // Simulate a partial send
302        let mut buf = [0u8; 3];
303        recipient_src.read_exact(&mut buf).await.unwrap();
304        sender_dst.write_all(&buf).await.unwrap();
305
306        // Try to read, should not return an error but block
307        let received = tokio::time::timeout(Duration::from_secs(1), framed_recipient.next()).await;
308
309        assert!(received.is_err());
310    }
311}