1use std::fmt::Debug;
4use std::io::{Read, Write};
5use std::marker::PhantomData;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use bytes::{Buf, BufMut, BytesMut};
10use fedimint_logging::LOG_NET_PEER;
11use futures::{Sink, Stream};
12use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
13use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
14use tokio_util::codec::{FramedRead, FramedWrite};
15use tracing::{error, trace};
16
17pub type AnyFramedTransport<M> = Box<dyn FramedTransport<M> + Send + Unpin + 'static>;
19
20pub trait FramedTransport<T>:
23 Sink<T, Error = anyhow::Error> + Stream<Item = Result<T, anyhow::Error>>
24{
25 fn borrow_split(
27 &mut self,
28 ) -> (
29 &'_ mut (dyn Sink<T, Error = anyhow::Error> + Send + Unpin),
30 &'_ mut (dyn Stream<Item = Result<T, anyhow::Error>> + Send + Unpin),
31 );
32
33 fn into_dyn(self) -> AnyFramedTransport<T>
35 where
36 Self: Sized + Send + Unpin + 'static,
37 {
38 Box::new(self)
39 }
40}
41
42pub type TcpBidiFramed<T> = BidiFramed<T, OwnedWriteHalf, OwnedReadHalf>;
45
46pub type FramedSink<S, T> = FramedWrite<S, BincodeCodec<T>>;
48pub type FramedStream<S, T> = FramedRead<S, BincodeCodec<T>>;
50
51#[derive(Debug)]
58pub struct BidiFramed<T, WH, RH> {
59 sink: FramedSink<WH, T>,
60 stream: FramedStream<RH, T>,
61}
62
63#[derive(Debug)]
65pub struct BincodeCodec<T> {
66 _pd: PhantomData<T>,
67}
68
69impl<T, WH, RH> BidiFramed<T, WH, RH>
70where
71 WH: AsyncWrite,
72 RH: AsyncRead,
73 T: serde::Serialize + serde::de::DeserializeOwned,
74{
75 pub fn new<S>(stream: S) -> BidiFramed<T, WriteHalf<S>, ReadHalf<S>>
80 where
81 S: AsyncRead + AsyncWrite,
82 {
83 let (read, write) = tokio::io::split(stream);
84 BidiFramed {
85 sink: FramedSink::new(write, BincodeCodec::new()),
86 stream: FramedStream::new(read, BincodeCodec::new()),
87 }
88 }
89
90 pub fn borrow_parts(&mut self) -> (&mut FramedSink<WH, T>, &mut FramedStream<RH, T>) {
96 (&mut self.sink, &mut self.stream)
97 }
98}
99
100impl<T> TcpBidiFramed<T>
101where
102 T: serde::Serialize + serde::de::DeserializeOwned,
103{
104 pub fn new_from_tcp(stream: tokio::net::TcpStream) -> TcpBidiFramed<T> {
110 let (read, write) = stream.into_split();
111 BidiFramed {
112 sink: FramedSink::new(write, BincodeCodec::new()),
113 stream: FramedStream::new(read, BincodeCodec::new()),
114 }
115 }
116}
117
118impl<T, WH, RH> Sink<T> for BidiFramed<T, WH, RH>
119where
120 WH: tokio::io::AsyncWrite + Unpin,
121 RH: Unpin,
122 T: Debug + serde::Serialize,
123{
124 type Error = anyhow::Error;
125
126 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
127 Sink::poll_ready(Pin::new(&mut self.sink), cx)
128 }
129
130 fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
131 Sink::start_send(Pin::new(&mut self.sink), item)
132 }
133
134 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
135 Sink::poll_flush(Pin::new(&mut self.sink), cx)
136 }
137
138 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
139 Sink::poll_close(Pin::new(&mut self.sink), cx)
140 }
141}
142
143impl<T, WH, RH> Stream for BidiFramed<T, WH, RH>
144where
145 T: serde::de::DeserializeOwned,
146 WH: Unpin,
147 RH: tokio::io::AsyncRead + Unpin,
148{
149 type Item = Result<T, anyhow::Error>;
150
151 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
152 Stream::poll_next(Pin::new(&mut self.stream), cx)
153 }
154}
155
156impl<T, WH, RH> FramedTransport<T> for BidiFramed<T, WH, RH>
157where
158 T: Debug + serde::Serialize + serde::de::DeserializeOwned + Send,
159 WH: tokio::io::AsyncWrite + Send + Unpin,
160 RH: tokio::io::AsyncRead + Send + Unpin,
161{
162 fn borrow_split(
163 &mut self,
164 ) -> (
165 &'_ mut (dyn Sink<T, Error = anyhow::Error> + Send + Unpin),
166 &'_ mut (dyn Stream<Item = Result<T, anyhow::Error>> + Send + Unpin),
167 ) {
168 let (sink, stream) = self.borrow_parts();
169 (&mut *sink, &mut *stream)
170 }
171}
172
173impl<T> BincodeCodec<T> {
174 fn new() -> BincodeCodec<T> {
175 BincodeCodec { _pd: PhantomData }
176 }
177}
178
179impl<T> tokio_util::codec::Encoder<T> for BincodeCodec<T>
180where
181 T: serde::Serialize + Debug,
182{
183 type Error = anyhow::Error;
184
185 fn encode(&mut self, item: T, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
186 let old_len = dst.len();
188 dst.writer().write_all(&[0u8; 8]).unwrap();
189 assert_eq!(dst.len(), old_len + 8);
190
191 bincode::serialize_into(dst.writer(), &item).inspect_err(|_e| {
193 error!(
194 target: LOG_NET_PEER,
195 "Serializing message failed: {:?}", item
196 );
197 })?;
198
199 let new_len = dst.len();
202 let encoded_len = new_len - old_len - 8;
203 dst[old_len..old_len + 8].copy_from_slice(&encoded_len.to_be_bytes()[..]);
204
205 Ok(())
206 }
207}
208
209impl<T> tokio_util::codec::Decoder for BincodeCodec<T>
210where
211 T: serde::de::DeserializeOwned,
212{
213 type Item = T;
214 type Error = anyhow::Error;
215
216 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
217 if src.len() < 8 {
218 return Ok(None);
219 }
220
221 let length = u64::from_be_bytes(src[0..8].try_into().expect("correct length"));
222 if src.len() < (length as usize) + 8 {
223 trace!(length, buffern_len = src.len(), "Received partial message");
224 return Ok(None);
225 }
226 trace!(length, "Received full message");
227
228 src.reader()
229 .read_exact(&mut [0u8; 8][..])
230 .expect("minimum length checked");
231
232 Ok(bincode::deserialize_from(src.reader()).map(Option::Some)?)
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use std::time::Duration;
239
240 use futures::{SinkExt, StreamExt};
241 use serde::{Deserialize, Serialize};
242 use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf};
243
244 use crate::net::framed::BidiFramed;
245
246 #[tokio::test]
247 async fn test_roundtrip() {
248 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
249 enum TestEnum {
250 Foo,
251 Bar(u64),
252 }
253
254 let input = vec![TestEnum::Foo, TestEnum::Bar(42), TestEnum::Foo];
255 let (sender, recipient) = tokio::io::duplex(1024);
256
257 let mut framed_sender =
258 BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(sender);
259
260 let mut framed_recipient =
261 BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(recipient);
262
263 for item in &input {
264 framed_sender.send(item.clone()).await.unwrap();
265 }
266
267 for item in &input {
268 let received = framed_recipient.next().await.unwrap().unwrap();
269 assert_eq!(&received, item);
270 }
271 drop(framed_sender);
272
273 assert!(framed_recipient.next().await.is_none());
274 }
275
276 #[tokio::test]
277 async fn test_not_try_parse_partial() {
278 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
279 enum TestEnum {
280 Foo,
281 Bar(u64),
282 }
283
284 let (sender_src, mut recipient_src) = tokio::io::duplex(1024);
285 let (mut sender_dst, recipient_dst) = tokio::io::duplex(1024);
286
287 let mut framed_sender =
288 BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(
289 sender_src,
290 );
291 let mut framed_recipient =
292 BidiFramed::<TestEnum, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(
293 recipient_dst,
294 );
295
296 framed_sender
297 .send(TestEnum::Bar(0x4242_4242_4242_4242))
298 .await
299 .unwrap();
300
301 let mut buf = [0u8; 3];
303 recipient_src.read_exact(&mut buf).await.unwrap();
304 sender_dst.write_all(&buf).await.unwrap();
305
306 let received = tokio::time::timeout(Duration::from_secs(1), framed_recipient.next()).await;
308
309 assert!(received.is_err());
310 }
311}