fedimint_server/net/
connect.rs

1//! Provides an abstract network connection interface and multiple
2//! implementations
3
4use std::collections::BTreeMap;
5use std::fmt::Debug;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use anyhow::format_err;
11use async_trait::async_trait;
12use fedimint_core::util::SafeUrl;
13use fedimint_core::PeerId;
14use futures::Stream;
15use tokio::io::{ReadHalf, WriteHalf};
16use tokio::net::{TcpListener, TcpStream};
17use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient;
18use tokio_rustls::rustls::RootCertStore;
19use tokio_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream};
20
21use crate::net::framed::{AnyFramedTransport, BidiFramed, FramedTransport};
22
23/// Shared [`Connector`] trait object
24pub type SharedAnyConnector<M> = Arc<dyn Connector<M> + Send + Sync + Unpin + 'static>;
25
26/// Owned [`Connector`] trait object
27pub type AnyConnector<M> = Box<dyn Connector<M> + Send + Sync + Unpin + 'static>;
28
29/// Result of a connection opening future
30pub type ConnectResult<M> = Result<(PeerId, AnyFramedTransport<M>), anyhow::Error>;
31
32/// Owned trait object type for incoming connection listeners
33pub type ConnectionListener<M> =
34    Pin<Box<dyn Stream<Item = ConnectResult<M>> + Send + Unpin + 'static>>;
35
36/// Allows to connect to peers and to listen for incoming connections
37///
38/// Connections are message based ([`FramedTransport`]) and should be
39/// authenticated and encrypted for production deployments.
40#[async_trait]
41pub trait Connector<M> {
42    /// Connect to a `destination`
43    async fn connect_framed(&self, destination: SafeUrl, peer: PeerId) -> ConnectResult<M>;
44
45    /// Listen for incoming connections on `bind_addr`
46    async fn listen(&self, bind_addr: SocketAddr) -> Result<ConnectionListener<M>, anyhow::Error>;
47
48    /// Transform this concrete `Connector` into an owned trait object version
49    /// of itself
50    fn into_dyn(self) -> AnyConnector<M>
51    where
52        Self: Sized + Send + Sync + Unpin + 'static,
53    {
54        Box::new(self)
55    }
56}
57
58/// TCP connector with encryption and authentication
59#[derive(Debug)]
60pub struct TlsTcpConnector {
61    our_certificate: rustls::Certificate,
62    our_private_key: rustls::PrivateKey,
63    peer_certs: Arc<PeerCertStore>,
64    /// Copy of the certs from `peer_certs`, but in a format that `tokio_rustls`
65    /// understands
66    cert_store: RootCertStore,
67    peer_names: BTreeMap<PeerId, String>,
68}
69
70#[derive(Debug, Clone)]
71pub struct TlsConfig {
72    pub our_private_key: rustls::PrivateKey,
73    pub peer_certs: BTreeMap<PeerId, rustls::Certificate>,
74    pub peer_names: BTreeMap<PeerId, String>,
75}
76
77#[derive(Debug, Clone)]
78pub struct PeerCertStore {
79    peer_certificates: Vec<(PeerId, rustls::Certificate)>,
80}
81
82impl TlsTcpConnector {
83    pub fn new(cfg: TlsConfig, our_id: PeerId) -> TlsTcpConnector {
84        let mut cert_store = RootCertStore::empty();
85        for cert in cfg.peer_certs.values() {
86            cert_store
87                .add(cert)
88                .expect("Could not add peer certificate");
89        }
90
91        TlsTcpConnector {
92            our_certificate: cfg.peer_certs.get(&our_id).expect("exists").clone(),
93            our_private_key: cfg.our_private_key,
94            peer_certs: Arc::new(PeerCertStore::new(cfg.peer_certs)),
95            cert_store,
96            peer_names: cfg.peer_names,
97        }
98    }
99}
100
101impl PeerCertStore {
102    fn new(certs: impl IntoIterator<Item = (PeerId, rustls::Certificate)>) -> PeerCertStore {
103        PeerCertStore {
104            peer_certificates: certs.into_iter().collect(),
105        }
106    }
107
108    fn get_peer_by_cert(&self, cert: &rustls::Certificate) -> Option<PeerId> {
109        self.peer_certificates
110            .iter()
111            .find_map(|(peer, peer_cert)| if peer_cert == cert { Some(*peer) } else { None })
112    }
113
114    fn authenticate_peer(
115        &self,
116        received: Option<&[rustls::Certificate]>,
117    ) -> Result<PeerId, anyhow::Error> {
118        let cert_chain =
119            received.ok_or_else(|| anyhow::anyhow!("Peer did not authenticate itself"))?;
120
121        if cert_chain.len() != 1 {
122            return Err(anyhow::anyhow!(
123                "Received certificate chain of len={}, expected=1",
124                cert_chain.len()
125            ));
126        }
127
128        let received_cert = cert_chain.first().expect("Checked above");
129
130        self.get_peer_by_cert(received_cert)
131            .ok_or_else(|| anyhow::anyhow!("Unknown certificate"))
132    }
133
134    async fn accept_connection<M>(
135        &self,
136        listener: &mut TcpListener,
137        acceptor: &TlsAcceptor,
138    ) -> Result<(PeerId, AnyFramedTransport<M>), anyhow::Error>
139    where
140        M: Debug + serde::Serialize + serde::de::DeserializeOwned + Send + Unpin + 'static,
141    {
142        let (connection, _) = listener.accept().await?;
143        let tls_conn = acceptor.accept(connection).await?;
144
145        let (_, tls_session) = tls_conn.get_ref();
146        let auth_peer = self.authenticate_peer(tls_session.peer_certificates())?;
147
148        let framed =
149            BidiFramed::<_, WriteHalf<TlsStream<TcpStream>>, ReadHalf<TlsStream<TcpStream>>>::new(
150                tls_conn,
151            )
152            .into_dyn();
153        Ok((auth_peer, framed))
154    }
155}
156
157#[async_trait]
158impl<M> Connector<M> for TlsTcpConnector
159where
160    M: Debug + serde::Serialize + serde::de::DeserializeOwned + Send + Unpin + 'static,
161{
162    async fn connect_framed(&self, destination: SafeUrl, peer: PeerId) -> ConnectResult<M> {
163        let cfg = rustls::ClientConfig::builder()
164            .with_safe_defaults()
165            .with_root_certificates(self.cert_store.clone())
166            .with_client_auth_cert(
167                vec![self.our_certificate.clone()],
168                self.our_private_key.clone(),
169            )
170            .expect("Failed to create TLS config");
171
172        let fake_domain =
173            rustls::ServerName::try_from(dns_sanitize(&self.peer_names[&peer]).as_str())
174                .expect("Always a valid DNS name");
175
176        let connector = TlsConnector::from(Arc::new(cfg));
177        let tls_conn = connector
178            .connect(
179                fake_domain,
180                TcpStream::connect(parse_host_port(&destination)?).await?,
181            )
182            .await?;
183
184        let (_, tls_session) = tls_conn.get_ref();
185        let auth_peer = self
186            .peer_certs
187            .authenticate_peer(tls_session.peer_certificates())?;
188
189        if auth_peer != peer {
190            return Err(anyhow::anyhow!("Connected to unexpected peer"));
191        }
192
193        let framed =
194            BidiFramed::<_, WriteHalf<TlsStream<TcpStream>>, ReadHalf<TlsStream<TcpStream>>>::new(
195                tls_conn,
196            )
197            .into_dyn();
198
199        Ok((peer, framed))
200    }
201
202    async fn listen(&self, bind_addr: SocketAddr) -> Result<ConnectionListener<M>, anyhow::Error> {
203        let verifier = AllowAnyAuthenticatedClient::new(self.cert_store.clone());
204        let config = rustls::ServerConfig::builder()
205            .with_safe_defaults()
206            .with_client_cert_verifier(Arc::from(verifier))
207            .with_single_cert(
208                vec![self.our_certificate.clone()],
209                self.our_private_key.clone(),
210            )
211            .unwrap();
212        let listener = TcpListener::bind(bind_addr).await?;
213        let peer_certs = self.peer_certs.clone();
214
215        let stream = futures::stream::unfold(listener, move |mut listener| {
216            let acceptor = TlsAcceptor::from(Arc::new(config.clone()));
217            let peer_certs = peer_certs.clone();
218
219            Box::pin(async move {
220                let res = peer_certs.accept_connection(&mut listener, &acceptor).await;
221                Some((res, listener))
222            })
223        });
224        Ok(Box::pin(stream))
225    }
226}
227
228/// Sanitizes name as valid domain name
229pub fn dns_sanitize(name: &str) -> String {
230    let sanitized = name.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
231    format!("peer{sanitized}")
232}
233
234/// Parses the host and port from a url
235pub fn parse_host_port(url: &SafeUrl) -> anyhow::Result<String> {
236    let host = url
237        .host_str()
238        .ok_or_else(|| format_err!("Missing host in {url}"))?;
239    let port = url
240        .port_or_known_default()
241        .ok_or_else(|| format_err!("Missing port in {url}"))?;
242
243    Ok(format!("{host}:{port}"))
244}
245
246/// Fake network stack used in tests
247#[allow(unused_imports)]
248pub mod mock {
249    use std::collections::HashMap;
250    use std::fmt::Debug;
251    use std::future::Future;
252    use std::net::SocketAddr;
253    use std::pin::Pin;
254    use std::sync::atomic::{AtomicBool, Ordering};
255    use std::sync::Arc;
256    use std::time::Duration;
257
258    use anyhow::{anyhow, Error};
259    use fedimint_core::runtime::spawn;
260    use fedimint_core::task::sleep;
261    use fedimint_core::util::SafeUrl;
262    use fedimint_core::{task, PeerId};
263    use futures::{pin_mut, FutureExt, SinkExt, Stream, StreamExt};
264    use rand::Rng;
265    use tokio::io::{
266        AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf,
267    };
268    use tokio::sync::mpsc::Sender;
269    use tokio::sync::Mutex;
270    use tokio_util::sync::CancellationToken;
271    use tracing::error;
272
273    use crate::net::connect::{parse_host_port, ConnectResult, Connector};
274    use crate::net::framed::{BidiFramed, FramedTransport};
275
276    struct UnreliableDuplexStream {
277        inner: DuplexStream,
278        broken: CancellationToken,
279        read_generator: Option<UnreliabilityGenerator>,
280        write_generator: Option<UnreliabilityGenerator>,
281        flush_generator: Option<UnreliabilityGenerator>,
282        shutdown_generator: Option<UnreliabilityGenerator>,
283    }
284
285    impl UnreliableDuplexStream {
286        fn new(inner: DuplexStream, reliability: StreamReliability) -> UnreliableDuplexStream {
287            match reliability {
288                StreamReliability::FullyReliable => Self {
289                    inner,
290                    broken: CancellationToken::new(),
291                    read_generator: None,
292                    write_generator: None,
293                    flush_generator: None,
294                    shutdown_generator: None,
295                },
296                StreamReliability::RandomlyUnreliable {
297                    read_failure_rate,
298                    write_failure_rate,
299                    flush_failure_rate,
300                    shutdown_failure_rate,
301                    read_latency,
302                    write_latency,
303                    flush_latency,
304                    shutdown_latency,
305                } => Self {
306                    inner,
307                    broken: CancellationToken::new(),
308                    read_generator: Some(UnreliabilityGenerator::new(
309                        read_latency,
310                        read_failure_rate,
311                    )),
312                    write_generator: Some(UnreliabilityGenerator::new(
313                        write_latency,
314                        write_failure_rate,
315                    )),
316                    flush_generator: Some(UnreliabilityGenerator::new(
317                        flush_latency,
318                        flush_failure_rate,
319                    )),
320                    shutdown_generator: Some(UnreliabilityGenerator::new(
321                        shutdown_latency,
322                        shutdown_failure_rate,
323                    )),
324                },
325            }
326        }
327
328        fn poll_broken(&self, cx: &mut std::task::Context<'_>) -> bool {
329            let await_cancellation = self.broken.cancelled();
330            pin_mut!(await_cancellation);
331            await_cancellation.poll(cx).is_ready()
332        }
333    }
334
335    impl Debug for UnreliableDuplexStream {
336        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337            f.debug_struct("UnreliableDuplexStream").finish()
338        }
339    }
340
341    struct UnreliabilityGenerator {
342        latency: LatencyInterval,
343        failure_rate: FailureRate,
344        sleep_future: Option<Pin<Box<tokio::time::Sleep>>>,
345        successes: u64,
346    }
347
348    impl UnreliabilityGenerator {
349        fn new(latency: LatencyInterval, failure_rate: FailureRate) -> UnreliabilityGenerator {
350            Self {
351                latency,
352                failure_rate,
353                sleep_future: None,
354                successes: 0,
355            }
356        }
357
358        pub fn generate(
359            &mut self,
360            cx: &mut std::task::Context<'_>,
361        ) -> std::task::Poll<std::io::Result<()>> {
362            let sleep = self.sleep_future.get_or_insert_with(|| {
363                Box::pin(
364                    // nosemgrep: ban-tokio-sleep
365                    tokio::time::sleep(self.latency.random()),
366                )
367            });
368            match sleep.poll_unpin(cx) {
369                std::task::Poll::Ready(()) => {
370                    self.sleep_future = None;
371                }
372                std::task::Poll::Pending => return std::task::Poll::Pending,
373            }
374            if self.failure_rate.random_fail() {
375                tracing::debug!(
376                    "Returning random error on unreliable stream after {} successes",
377                    self.successes
378                );
379                std::task::Poll::Ready(Err(std::io::Error::new(
380                    std::io::ErrorKind::Other,
381                    "Randomly failed",
382                )))
383            } else {
384                self.successes += 1;
385                std::task::Poll::Ready(Ok(()))
386            }
387        }
388    }
389
390    impl AsyncRead for UnreliableDuplexStream {
391        fn poll_read(
392            mut self: Pin<&mut Self>,
393            cx: &mut std::task::Context<'_>,
394            buf: &mut tokio::io::ReadBuf<'_>,
395        ) -> std::task::Poll<std::io::Result<()>> {
396            if self.poll_broken(cx) {
397                return std::task::Poll::Ready(Err(std::io::Error::new(
398                    std::io::ErrorKind::Other,
399                    "Stream is broken",
400                )));
401            }
402
403            match self.read_generator.as_mut().map(|g| g.generate(cx)) {
404                Some(std::task::Poll::Ready(Err(e))) => {
405                    self.broken.cancel();
406                    std::task::Poll::Ready(Err(e))
407                }
408                Some(std::task::Poll::Pending) => std::task::Poll::Pending,
409                Some(std::task::Poll::Ready(Ok(()))) | None => {
410                    Pin::new(&mut self.inner).poll_read(cx, buf)
411                }
412            }
413        }
414    }
415
416    impl AsyncWrite for UnreliableDuplexStream {
417        fn poll_write(
418            mut self: Pin<&mut Self>,
419            cx: &mut std::task::Context<'_>,
420            buf: &[u8],
421        ) -> std::task::Poll<Result<usize, std::io::Error>> {
422            if self.poll_broken(cx) {
423                return std::task::Poll::Ready(Err(std::io::Error::new(
424                    std::io::ErrorKind::Other,
425                    "Stream is broken",
426                )));
427            }
428
429            match self.write_generator.as_mut().map(|g| g.generate(cx)) {
430                Some(std::task::Poll::Ready(Err(e))) => {
431                    self.broken.cancel();
432                    std::task::Poll::Ready(Err(e))
433                }
434                Some(std::task::Poll::Pending) => std::task::Poll::Pending,
435                Some(std::task::Poll::Ready(Ok(()))) | None => {
436                    Pin::new(&mut self.inner).poll_write(cx, buf)
437                }
438            }
439        }
440
441        fn poll_flush(
442            mut self: Pin<&mut Self>,
443            cx: &mut std::task::Context<'_>,
444        ) -> std::task::Poll<Result<(), std::io::Error>> {
445            if self.poll_broken(cx) {
446                return std::task::Poll::Ready(Err(std::io::Error::new(
447                    std::io::ErrorKind::Other,
448                    "Stream is broken",
449                )));
450            }
451
452            match self.flush_generator.as_mut().map(|g| g.generate(cx)) {
453                Some(std::task::Poll::Ready(Err(e))) => {
454                    self.broken.cancel();
455                    std::task::Poll::Ready(Err(e))
456                }
457                Some(std::task::Poll::Pending) => std::task::Poll::Pending,
458                Some(std::task::Poll::Ready(Ok(()))) | None => {
459                    Pin::new(&mut self.inner).poll_flush(cx)
460                }
461            }
462        }
463
464        fn poll_shutdown(
465            mut self: Pin<&mut Self>,
466            cx: &mut std::task::Context<'_>,
467        ) -> std::task::Poll<Result<(), std::io::Error>> {
468            if self.poll_broken(cx) {
469                return std::task::Poll::Ready(Err(std::io::Error::new(
470                    std::io::ErrorKind::Other,
471                    "Stream is broken",
472                )));
473            }
474
475            match self.shutdown_generator.as_mut().map(|g| g.generate(cx)) {
476                Some(std::task::Poll::Ready(Err(e))) => {
477                    self.broken.cancel();
478                    std::task::Poll::Ready(Err(e))
479                }
480                Some(std::task::Poll::Pending) => std::task::Poll::Pending,
481                Some(std::task::Poll::Ready(Ok(()))) | None => {
482                    Pin::new(&mut self.inner).poll_shutdown(cx)
483                }
484            }
485        }
486    }
487
488    pub struct MockNetwork {
489        clients: Arc<Mutex<HashMap<String, Sender<UnreliableDuplexStream>>>>,
490    }
491
492    pub struct MockConnector {
493        id: PeerId,
494        clients: Arc<Mutex<HashMap<String, Sender<UnreliableDuplexStream>>>>,
495        reliability: StreamReliability,
496    }
497
498    impl MockNetwork {
499        #[allow(clippy::new_without_default)]
500        pub fn new() -> MockNetwork {
501            MockNetwork {
502                clients: Arc::new(Mutex::new(HashMap::new())),
503            }
504        }
505
506        pub fn connector(&self, id: PeerId, reliability: StreamReliability) -> MockConnector {
507            MockConnector {
508                id,
509                clients: self.clients.clone(),
510                reliability,
511            }
512        }
513    }
514
515    #[derive(Debug, Copy, Clone, PartialEq, Eq)]
516    pub struct LatencyInterval {
517        min_millis: u64,
518        max_millis: u64,
519    }
520
521    impl LatencyInterval {
522        const ZERO: LatencyInterval = LatencyInterval {
523            min_millis: 0,
524            max_millis: 0,
525        };
526
527        pub fn new(min: Duration, max: Duration) -> LatencyInterval {
528            assert!(min <= max);
529            LatencyInterval {
530                min_millis: min
531                    .as_millis()
532                    .try_into()
533                    .expect("min duration as millis to fit in a u64"),
534                max_millis: max
535                    .as_millis()
536                    .try_into()
537                    .expect("max duration as millis to fit in a u64"),
538            }
539        }
540
541        pub fn random(&self) -> Duration {
542            let mut rng = rand::thread_rng();
543            Duration::from_millis(rng.gen_range(self.min_millis..=self.max_millis))
544        }
545    }
546
547    #[derive(Debug, Copy, Clone)]
548    pub struct FailureRate(f64);
549    impl FailureRate {
550        const MAX: FailureRate = FailureRate(1.0);
551        pub fn new(failure_rate: f64) -> Self {
552            assert!((0.0..=1.0).contains(&failure_rate));
553            Self(failure_rate)
554        }
555
556        pub fn random_fail(&self) -> bool {
557            let mut rng = rand::thread_rng();
558            rng.gen_range(0.0..1.0) < self.0
559        }
560    }
561
562    #[derive(Debug, Copy, Clone)]
563    pub enum StreamReliability {
564        FullyReliable,
565        RandomlyUnreliable {
566            read_failure_rate: FailureRate,
567            write_failure_rate: FailureRate,
568            flush_failure_rate: FailureRate,
569            shutdown_failure_rate: FailureRate,
570            read_latency: LatencyInterval,
571            write_latency: LatencyInterval,
572            flush_latency: LatencyInterval,
573            shutdown_latency: LatencyInterval,
574        },
575    }
576
577    impl StreamReliability {
578        pub const MILDLY_UNRELIABLE: StreamReliability = {
579            let failure_rate = FailureRate(0.0);
580            let latency = LatencyInterval {
581                min_millis: 1,
582                max_millis: 10,
583            };
584            Self::RandomlyUnreliable {
585                read_failure_rate: failure_rate,
586                write_failure_rate: failure_rate,
587                flush_failure_rate: failure_rate,
588                shutdown_failure_rate: failure_rate,
589                read_latency: latency,
590                write_latency: latency,
591                flush_latency: latency,
592                shutdown_latency: latency,
593            }
594        };
595
596        pub const INTEGRATION_TEST: StreamReliability = {
597            // Based on empirical testing: creates errors without causing tests to take
598            // additional time compared to StreamReliability::FullyReliable
599            // If an order of magnitude higher, tests may take unreasonable amounts of time.
600            // If an order of magnitude lower, a test may run without any error actually
601            // happening
602            let failure_rate_base = 0.0;
603            let latency = LatencyInterval {
604                min_millis: 1,
605                max_millis: 10,
606            };
607            Self::RandomlyUnreliable {
608                // Try to make read_failure_rate = write_failure_rate + flush_failure_rate
609                read_failure_rate: FailureRate(failure_rate_base * 2.0),
610                write_failure_rate: FailureRate(failure_rate_base),
611                flush_failure_rate: FailureRate(failure_rate_base),
612                shutdown_failure_rate: FailureRate(failure_rate_base),
613                read_latency: latency,
614                write_latency: latency,
615                flush_latency: latency,
616                shutdown_latency: latency,
617            }
618        };
619
620        pub const BROKEN: StreamReliability = {
621            Self::RandomlyUnreliable {
622                read_failure_rate: FailureRate::MAX,
623                write_failure_rate: FailureRate::MAX,
624                flush_failure_rate: FailureRate::MAX,
625                shutdown_failure_rate: FailureRate::MAX,
626                read_latency: LatencyInterval::ZERO,
627                write_latency: LatencyInterval::ZERO,
628                flush_latency: LatencyInterval::ZERO,
629                shutdown_latency: LatencyInterval::ZERO,
630            }
631        };
632    }
633
634    #[async_trait::async_trait]
635    impl<M> Connector<M> for MockConnector
636    where
637        M: Debug + serde::Serialize + serde::de::DeserializeOwned + Send + Unpin + 'static,
638    {
639        async fn connect_framed(&self, destination: SafeUrl, _peer: PeerId) -> ConnectResult<M> {
640            let mut clients_lock = self.clients.try_lock().map_err(|e| {
641                anyhow!("Mock network mutex busy or poisoned, the network stack will re-try anyway: {e:?}")
642            })?;
643            if let Some(client) = clients_lock.get_mut(&parse_host_port(&destination)?) {
644                let (stream_our, stream_theirs) = tokio::io::duplex(43_689);
645                let mut stream_our = UnreliableDuplexStream::new(stream_our, self.reliability);
646                let stream_theirs = UnreliableDuplexStream::new(stream_theirs, self.reliability);
647                client.send(stream_theirs).await?;
648                let peer = do_handshake(self.id, &mut stream_our).await?;
649                let framed = BidiFramed::<
650                    M,
651                    WriteHalf<UnreliableDuplexStream>,
652                    ReadHalf<UnreliableDuplexStream>,
653                >::new(stream_our)
654                .into_dyn();
655                Ok((peer, framed))
656            } else {
657                return Err(anyhow::anyhow!("can't connect"));
658            }
659        }
660
661        async fn listen(
662            &self,
663            bind_addr: SocketAddr,
664        ) -> Result<Pin<Box<dyn Stream<Item = ConnectResult<M>> + Send + Unpin + 'static>>, Error>
665        {
666            let (send, receive) = tokio::sync::mpsc::channel(16);
667
668            if self
669                .clients
670                .lock()
671                .await
672                .insert(bind_addr.to_string(), send)
673                .is_some()
674            {
675                return Err(anyhow::anyhow!("Address already bound"));
676            }
677
678            let our_id = self.id;
679            let stream = futures::stream::unfold(receive, move |mut receive| {
680                Box::pin(async move {
681                    let mut connection = receive.recv().await.unwrap();
682                    let peer = match do_handshake(our_id, &mut connection).await {
683                        Ok(peer) => peer,
684                        Err(e) => {
685                            tracing::debug!("Error during handshake: {e:?}");
686                            return Some((Err(e), receive));
687                        }
688                    };
689                    let framed =
690                        BidiFramed::<M, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(
691                            connection,
692                        )
693                        .into_dyn();
694
695                    Some((Ok((peer, framed)), receive))
696                })
697            });
698            Ok(Box::pin(stream))
699        }
700    }
701
702    async fn do_handshake<S>(our_id: PeerId, stream: &mut S) -> Result<PeerId, anyhow::Error>
703    where
704        S: AsyncRead + AsyncWrite + Unpin,
705    {
706        // Send our id
707        let our_id = our_id.to_usize() as u16;
708        stream.write_all(&our_id.to_be_bytes()[..]).await?;
709
710        // Receive peer id
711        let mut peer_id = [0u8; 2];
712        stream.read_exact(&mut peer_id[..]).await?;
713        Ok(PeerId::from(u16::from_be_bytes(peer_id)))
714    }
715
716    #[tokio::test]
717    async fn test_mock_network() {
718        let bind_addr: SocketAddr = "127.0.0.1:7000".parse().unwrap();
719        let url: SafeUrl = "ws://127.0.0.1:7000".parse().unwrap();
720        let peer_a = PeerId::from(1);
721        let peer_b = PeerId::from(2);
722
723        let net = MockNetwork::new();
724        let conn_a = net.connector(peer_a, StreamReliability::FullyReliable);
725        let conn_b = net.connector(peer_b, StreamReliability::FullyReliable);
726
727        let mut listener = Connector::<u64>::listen(&conn_a, bind_addr).await.unwrap();
728        let conn_a_fut = spawn("listener next await", async move {
729            listener.next().await.unwrap().unwrap()
730        });
731
732        let (auth_peer_b, mut conn_b) = Connector::<u64>::connect_framed(&conn_b, url, peer_a)
733            .await
734            .unwrap();
735        let (auth_peer_a, mut conn_a) = conn_a_fut.await.unwrap();
736
737        assert_eq!(auth_peer_a, peer_b);
738        assert_eq!(auth_peer_b, peer_a);
739
740        conn_a.send(42).await.unwrap();
741        conn_b.send(21).await.unwrap();
742
743        assert_eq!(conn_a.next().await.unwrap().unwrap(), 21);
744        assert_eq!(conn_b.next().await.unwrap().unwrap(), 42);
745    }
746
747    #[tokio::test]
748    async fn test_unreliable_components() {
749        assert!(!FailureRate::new(0f64).random_fail());
750        assert!(FailureRate::new(1f64).random_fail());
751
752        let good_interval = (0..=3).contains(
753            &LatencyInterval::new(Duration::from_millis(0), Duration::from_millis(3))
754                .random()
755                .as_millis(),
756        );
757        assert!(good_interval);
758
759        let (a, b) = tokio::io::duplex(43_689);
760        let mut a_stream = UnreliableDuplexStream::new(a, StreamReliability::FullyReliable);
761        let mut b_stream = UnreliableDuplexStream::new(b, StreamReliability::FullyReliable);
762        assert!(a_stream.write(&[1, 2, 3]).await.is_ok());
763        assert!(a_stream.flush().await.is_ok());
764        assert_eq!(b_stream.read_u8().await.unwrap(), 1);
765        assert_eq!(b_stream.read_u8().await.unwrap(), 2);
766        assert_eq!(b_stream.read_u8().await.unwrap(), 3);
767
768        let (a, b) = tokio::io::duplex(43_689);
769        let mut a_stream = UnreliableDuplexStream::new(a, StreamReliability::FullyReliable);
770        let mut b_stream = UnreliableDuplexStream::new(b, StreamReliability::BROKEN);
771        assert!(a_stream.write(&[1, 2, 3]).await.is_ok());
772        assert!(a_stream.flush().await.is_ok());
773        assert!(b_stream.read_u8().await.is_err());
774
775        let (a, b) = tokio::io::duplex(43_689);
776        let mut a_stream = UnreliableDuplexStream::new(a, StreamReliability::BROKEN);
777        let mut _b_stream = UnreliableDuplexStream::new(b, StreamReliability::FullyReliable);
778        assert!(a_stream.write(&[1, 2, 3]).await.is_err());
779        // a read on _b_stream would block...
780    }
781
782    #[allow(dead_code)]
783    async fn timeout<F, T>(f: F) -> Option<T>
784    where
785        F: Future<Output = T>,
786    {
787        tokio::time::timeout(Duration::from_secs(1), f).await.ok()
788    }
789
790    #[tokio::test]
791    async fn test_large_messages() {
792        let bind_addr: SocketAddr = "127.0.0.1:7000".parse().unwrap();
793        let url: SafeUrl = "ws://127.0.0.1:7000".parse().unwrap();
794        let peer_a = PeerId::from(1);
795        let peer_b = PeerId::from(2);
796
797        let net = MockNetwork::new();
798        let conn_a = net.connector(peer_a, StreamReliability::FullyReliable);
799        let conn_b = net.connector(peer_b, StreamReliability::FullyReliable);
800
801        let mut listener = Connector::<Vec<u8>>::listen(&conn_a, bind_addr)
802            .await
803            .unwrap();
804        let conn_a_fut = spawn("listener next await", async move {
805            listener.next().await.unwrap().unwrap()
806        });
807
808        let (auth_peer_b, mut conn_b) = Connector::<Vec<u8>>::connect_framed(&conn_b, url, peer_a)
809            .await
810            .unwrap();
811        let (auth_peer_a, mut conn_a) = conn_a_fut.await.unwrap();
812
813        assert_eq!(auth_peer_a, peer_b);
814        assert_eq!(auth_peer_b, peer_a);
815
816        let send_future = async {
817            conn_a.send(vec![42; 16000]).await.unwrap();
818        }
819        .boxed();
820        let receive_future = async {
821            assert_eq!(
822                timeout(conn_b.next()).await.unwrap().unwrap().unwrap(),
823                vec![42; 16000]
824            );
825        }
826        .boxed();
827
828        tokio::join!(send_future, receive_future);
829    }
830}
831
832#[cfg(test)]
833mod tests {
834    use std::net::SocketAddr;
835
836    use fedimint_core::runtime::spawn;
837    use fedimint_core::util::SafeUrl;
838    use fedimint_core::PeerId;
839    use futures::{SinkExt, StreamExt};
840
841    use crate::config::gen_cert_and_key;
842    use crate::net::connect::{ConnectionListener, Connector, TlsConfig};
843    use crate::net::framed::AnyFramedTransport;
844    use crate::TlsTcpConnector;
845
846    fn gen_connector_config(count: usize) -> Vec<TlsConfig> {
847        let peer_keys = (0..count)
848            .map(|id| {
849                let peer = PeerId::from(id as u16);
850                gen_cert_and_key(&format!("peer-{}", peer.to_usize())).unwrap()
851            })
852            .collect::<Vec<_>>();
853
854        peer_keys
855            .iter()
856            .map(|(_cert, key)| TlsConfig {
857                our_private_key: key.clone(),
858                peer_certs: peer_keys
859                    .iter()
860                    .enumerate()
861                    .map(|(peer, (cert, _))| (PeerId::from(peer as u16), cert.clone()))
862                    .collect(),
863                peer_names: peer_keys
864                    .iter()
865                    .enumerate()
866                    .map(|(peer, (_, _))| (PeerId::from(peer as u16), format!("peer-{peer}")))
867                    .collect(),
868            })
869            .collect()
870    }
871
872    #[tokio::test]
873    async fn connect_success() {
874        // FIXME: don't actually bind here, probably requires yet another Box<dyn Trait>
875        // layer :(
876        let bind_addr: SocketAddr = "127.0.0.1:7000".parse().unwrap();
877        let url: SafeUrl = "ws://127.0.0.1:7000".parse().unwrap();
878        let connectors = gen_connector_config(5)
879            .into_iter()
880            .enumerate()
881            .map(|(id, cfg)| TlsTcpConnector::new(cfg, PeerId::from(id as u16)))
882            .collect::<Vec<_>>();
883
884        let mut server: ConnectionListener<u64> = connectors[0].listen(bind_addr).await.unwrap();
885
886        let server_task = spawn("server next await", async move {
887            let (peer, mut conn) = server.next().await.unwrap().unwrap();
888            assert_eq!(peer.to_usize(), 2);
889            let received = conn.next().await.unwrap().unwrap();
890            assert_eq!(received, 42);
891            conn.send(21).await.unwrap();
892            assert!(conn.next().await.unwrap().is_err());
893        });
894
895        let (peer_of_a, mut client_a): (_, AnyFramedTransport<u64>) = connectors[2]
896            .connect_framed(url.clone(), PeerId::from(0))
897            .await
898            .unwrap();
899        assert_eq!(peer_of_a.to_usize(), 0);
900        client_a.send(42).await.unwrap();
901        let received = client_a.next().await.unwrap().unwrap();
902        assert_eq!(received, 21);
903        drop(client_a);
904
905        server_task.await.unwrap();
906    }
907
908    #[tokio::test]
909    async fn connect_reject() {
910        let bind_addr: SocketAddr = "127.0.0.1:7001".parse().unwrap();
911        let url: SafeUrl = "wss://127.0.0.1:7001".parse().unwrap();
912        let cfg = gen_connector_config(5);
913
914        let honest = TlsTcpConnector::new(cfg[0].clone(), PeerId::from(0));
915
916        let mut malicious_wrong_key_cfg = cfg[1].clone();
917        malicious_wrong_key_cfg.our_private_key = cfg[2].our_private_key.clone();
918        let malicious_wrong_key = TlsTcpConnector::new(malicious_wrong_key_cfg, PeerId::from(1));
919
920        // Honest server, malicious client with wrong private key
921        {
922            let mut server: ConnectionListener<u64> = honest.listen(bind_addr).await.unwrap();
923
924            let server_task = spawn("server next await", async move {
925                let conn_res = server.next().await.unwrap();
926                assert_eq!(
927                    conn_res.err().unwrap().to_string().as_str(),
928                    "invalid peer certificate: BadSignature"
929                );
930            });
931
932            let err_anytime = async {
933                let (_peer, mut conn): (_, AnyFramedTransport<u64>) = malicious_wrong_key
934                    .connect_framed(url.clone(), PeerId::from(0))
935                    .await?;
936
937                conn.send(42).await?;
938                conn.flush().await?;
939                conn.next().await.unwrap()?;
940
941                Result::<_, anyhow::Error>::Ok(())
942            };
943
944            let conn_res = err_anytime.await;
945            assert_eq!(
946                conn_res.err().unwrap().to_string().as_str(),
947                "received fatal alert: DecryptError"
948            );
949
950            server_task.await.unwrap();
951        }
952
953        // Malicious server with wrong key, honest client
954        {
955            let mut server: ConnectionListener<u64> =
956                malicious_wrong_key.listen(bind_addr).await.unwrap();
957
958            let server_task = spawn("server next await", async move {
959                let conn_res = server.next().await.unwrap();
960                assert_eq!(
961                    conn_res.err().unwrap().to_string().as_str(),
962                    "received fatal alert: DecryptError"
963                );
964            });
965
966            let err_anytime = async {
967                let (_peer, mut conn): (_, AnyFramedTransport<u64>) =
968                    honest.connect_framed(url.clone(), PeerId::from(1)).await?;
969
970                conn.send(42).await?;
971                conn.flush().await?;
972                conn.next().await.unwrap()?;
973
974                Result::<_, anyhow::Error>::Ok(())
975            };
976
977            let conn_res = err_anytime.await;
978            assert_eq!(
979                conn_res.err().unwrap().to_string().as_str(),
980                "invalid peer certificate: BadSignature"
981            );
982
983            server_task.await.unwrap();
984        }
985
986        // Server with wrong certificate, honest client
987        {
988            let mut server: ConnectionListener<u64> =
989                TlsTcpConnector::new(cfg[2].clone(), PeerId::from(2))
990                    .listen(bind_addr)
991                    .await
992                    .unwrap();
993
994            let server_task = spawn("server next await", async move {
995                let conn_res = server.next().await.unwrap();
996                assert_eq!(
997                    conn_res.err().unwrap().to_string().as_str(),
998                    "received fatal alert: BadCertificate"
999                );
1000            });
1001
1002            let err_anytime = async {
1003                let (_peer, mut conn): (_, AnyFramedTransport<u64>) =
1004                    honest.connect_framed(url.clone(), PeerId::from(0)).await?;
1005
1006                conn.send(42).await?;
1007                conn.flush().await?;
1008                conn.next().await.unwrap()?;
1009
1010                Result::<_, anyhow::Error>::Ok(())
1011            };
1012
1013            let conn_res = err_anytime.await;
1014            assert_eq!(
1015                conn_res.err().unwrap().to_string().as_str(),
1016                "invalid peer certificate: NotValidForName"
1017            );
1018
1019            server_task.await.unwrap();
1020        }
1021    }
1022}