1use std::collections::BTreeMap;
5use std::fmt::Debug;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use anyhow::format_err;
11use async_trait::async_trait;
12use fedimint_core::util::SafeUrl;
13use fedimint_core::PeerId;
14use futures::Stream;
15use tokio::io::{ReadHalf, WriteHalf};
16use tokio::net::{TcpListener, TcpStream};
17use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient;
18use tokio_rustls::rustls::RootCertStore;
19use tokio_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream};
20
21use crate::net::framed::{AnyFramedTransport, BidiFramed, FramedTransport};
22
23pub type SharedAnyConnector<M> = Arc<dyn Connector<M> + Send + Sync + Unpin + 'static>;
25
26pub type AnyConnector<M> = Box<dyn Connector<M> + Send + Sync + Unpin + 'static>;
28
29pub type ConnectResult<M> = Result<(PeerId, AnyFramedTransport<M>), anyhow::Error>;
31
32pub type ConnectionListener<M> =
34 Pin<Box<dyn Stream<Item = ConnectResult<M>> + Send + Unpin + 'static>>;
35
36#[async_trait]
41pub trait Connector<M> {
42 async fn connect_framed(&self, destination: SafeUrl, peer: PeerId) -> ConnectResult<M>;
44
45 async fn listen(&self, bind_addr: SocketAddr) -> Result<ConnectionListener<M>, anyhow::Error>;
47
48 fn into_dyn(self) -> AnyConnector<M>
51 where
52 Self: Sized + Send + Sync + Unpin + 'static,
53 {
54 Box::new(self)
55 }
56}
57
58#[derive(Debug)]
60pub struct TlsTcpConnector {
61 our_certificate: rustls::Certificate,
62 our_private_key: rustls::PrivateKey,
63 peer_certs: Arc<PeerCertStore>,
64 cert_store: RootCertStore,
67 peer_names: BTreeMap<PeerId, String>,
68}
69
70#[derive(Debug, Clone)]
71pub struct TlsConfig {
72 pub our_private_key: rustls::PrivateKey,
73 pub peer_certs: BTreeMap<PeerId, rustls::Certificate>,
74 pub peer_names: BTreeMap<PeerId, String>,
75}
76
77#[derive(Debug, Clone)]
78pub struct PeerCertStore {
79 peer_certificates: Vec<(PeerId, rustls::Certificate)>,
80}
81
82impl TlsTcpConnector {
83 pub fn new(cfg: TlsConfig, our_id: PeerId) -> TlsTcpConnector {
84 let mut cert_store = RootCertStore::empty();
85 for cert in cfg.peer_certs.values() {
86 cert_store
87 .add(cert)
88 .expect("Could not add peer certificate");
89 }
90
91 TlsTcpConnector {
92 our_certificate: cfg.peer_certs.get(&our_id).expect("exists").clone(),
93 our_private_key: cfg.our_private_key,
94 peer_certs: Arc::new(PeerCertStore::new(cfg.peer_certs)),
95 cert_store,
96 peer_names: cfg.peer_names,
97 }
98 }
99}
100
101impl PeerCertStore {
102 fn new(certs: impl IntoIterator<Item = (PeerId, rustls::Certificate)>) -> PeerCertStore {
103 PeerCertStore {
104 peer_certificates: certs.into_iter().collect(),
105 }
106 }
107
108 fn get_peer_by_cert(&self, cert: &rustls::Certificate) -> Option<PeerId> {
109 self.peer_certificates
110 .iter()
111 .find_map(|(peer, peer_cert)| if peer_cert == cert { Some(*peer) } else { None })
112 }
113
114 fn authenticate_peer(
115 &self,
116 received: Option<&[rustls::Certificate]>,
117 ) -> Result<PeerId, anyhow::Error> {
118 let cert_chain =
119 received.ok_or_else(|| anyhow::anyhow!("Peer did not authenticate itself"))?;
120
121 if cert_chain.len() != 1 {
122 return Err(anyhow::anyhow!(
123 "Received certificate chain of len={}, expected=1",
124 cert_chain.len()
125 ));
126 }
127
128 let received_cert = cert_chain.first().expect("Checked above");
129
130 self.get_peer_by_cert(received_cert)
131 .ok_or_else(|| anyhow::anyhow!("Unknown certificate"))
132 }
133
134 async fn accept_connection<M>(
135 &self,
136 listener: &mut TcpListener,
137 acceptor: &TlsAcceptor,
138 ) -> Result<(PeerId, AnyFramedTransport<M>), anyhow::Error>
139 where
140 M: Debug + serde::Serialize + serde::de::DeserializeOwned + Send + Unpin + 'static,
141 {
142 let (connection, _) = listener.accept().await?;
143 let tls_conn = acceptor.accept(connection).await?;
144
145 let (_, tls_session) = tls_conn.get_ref();
146 let auth_peer = self.authenticate_peer(tls_session.peer_certificates())?;
147
148 let framed =
149 BidiFramed::<_, WriteHalf<TlsStream<TcpStream>>, ReadHalf<TlsStream<TcpStream>>>::new(
150 tls_conn,
151 )
152 .into_dyn();
153 Ok((auth_peer, framed))
154 }
155}
156
157#[async_trait]
158impl<M> Connector<M> for TlsTcpConnector
159where
160 M: Debug + serde::Serialize + serde::de::DeserializeOwned + Send + Unpin + 'static,
161{
162 async fn connect_framed(&self, destination: SafeUrl, peer: PeerId) -> ConnectResult<M> {
163 let cfg = rustls::ClientConfig::builder()
164 .with_safe_defaults()
165 .with_root_certificates(self.cert_store.clone())
166 .with_client_auth_cert(
167 vec![self.our_certificate.clone()],
168 self.our_private_key.clone(),
169 )
170 .expect("Failed to create TLS config");
171
172 let fake_domain =
173 rustls::ServerName::try_from(dns_sanitize(&self.peer_names[&peer]).as_str())
174 .expect("Always a valid DNS name");
175
176 let connector = TlsConnector::from(Arc::new(cfg));
177 let tls_conn = connector
178 .connect(
179 fake_domain,
180 TcpStream::connect(parse_host_port(&destination)?).await?,
181 )
182 .await?;
183
184 let (_, tls_session) = tls_conn.get_ref();
185 let auth_peer = self
186 .peer_certs
187 .authenticate_peer(tls_session.peer_certificates())?;
188
189 if auth_peer != peer {
190 return Err(anyhow::anyhow!("Connected to unexpected peer"));
191 }
192
193 let framed =
194 BidiFramed::<_, WriteHalf<TlsStream<TcpStream>>, ReadHalf<TlsStream<TcpStream>>>::new(
195 tls_conn,
196 )
197 .into_dyn();
198
199 Ok((peer, framed))
200 }
201
202 async fn listen(&self, bind_addr: SocketAddr) -> Result<ConnectionListener<M>, anyhow::Error> {
203 let verifier = AllowAnyAuthenticatedClient::new(self.cert_store.clone());
204 let config = rustls::ServerConfig::builder()
205 .with_safe_defaults()
206 .with_client_cert_verifier(Arc::from(verifier))
207 .with_single_cert(
208 vec![self.our_certificate.clone()],
209 self.our_private_key.clone(),
210 )
211 .unwrap();
212 let listener = TcpListener::bind(bind_addr).await?;
213 let peer_certs = self.peer_certs.clone();
214
215 let stream = futures::stream::unfold(listener, move |mut listener| {
216 let acceptor = TlsAcceptor::from(Arc::new(config.clone()));
217 let peer_certs = peer_certs.clone();
218
219 Box::pin(async move {
220 let res = peer_certs.accept_connection(&mut listener, &acceptor).await;
221 Some((res, listener))
222 })
223 });
224 Ok(Box::pin(stream))
225 }
226}
227
228pub fn dns_sanitize(name: &str) -> String {
230 let sanitized = name.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
231 format!("peer{sanitized}")
232}
233
234pub fn parse_host_port(url: &SafeUrl) -> anyhow::Result<String> {
236 let host = url
237 .host_str()
238 .ok_or_else(|| format_err!("Missing host in {url}"))?;
239 let port = url
240 .port_or_known_default()
241 .ok_or_else(|| format_err!("Missing port in {url}"))?;
242
243 Ok(format!("{host}:{port}"))
244}
245
246#[allow(unused_imports)]
248pub mod mock {
249 use std::collections::HashMap;
250 use std::fmt::Debug;
251 use std::future::Future;
252 use std::net::SocketAddr;
253 use std::pin::Pin;
254 use std::sync::atomic::{AtomicBool, Ordering};
255 use std::sync::Arc;
256 use std::time::Duration;
257
258 use anyhow::{anyhow, Error};
259 use fedimint_core::runtime::spawn;
260 use fedimint_core::task::sleep;
261 use fedimint_core::util::SafeUrl;
262 use fedimint_core::{task, PeerId};
263 use futures::{pin_mut, FutureExt, SinkExt, Stream, StreamExt};
264 use rand::Rng;
265 use tokio::io::{
266 AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf,
267 };
268 use tokio::sync::mpsc::Sender;
269 use tokio::sync::Mutex;
270 use tokio_util::sync::CancellationToken;
271 use tracing::error;
272
273 use crate::net::connect::{parse_host_port, ConnectResult, Connector};
274 use crate::net::framed::{BidiFramed, FramedTransport};
275
276 struct UnreliableDuplexStream {
277 inner: DuplexStream,
278 broken: CancellationToken,
279 read_generator: Option<UnreliabilityGenerator>,
280 write_generator: Option<UnreliabilityGenerator>,
281 flush_generator: Option<UnreliabilityGenerator>,
282 shutdown_generator: Option<UnreliabilityGenerator>,
283 }
284
285 impl UnreliableDuplexStream {
286 fn new(inner: DuplexStream, reliability: StreamReliability) -> UnreliableDuplexStream {
287 match reliability {
288 StreamReliability::FullyReliable => Self {
289 inner,
290 broken: CancellationToken::new(),
291 read_generator: None,
292 write_generator: None,
293 flush_generator: None,
294 shutdown_generator: None,
295 },
296 StreamReliability::RandomlyUnreliable {
297 read_failure_rate,
298 write_failure_rate,
299 flush_failure_rate,
300 shutdown_failure_rate,
301 read_latency,
302 write_latency,
303 flush_latency,
304 shutdown_latency,
305 } => Self {
306 inner,
307 broken: CancellationToken::new(),
308 read_generator: Some(UnreliabilityGenerator::new(
309 read_latency,
310 read_failure_rate,
311 )),
312 write_generator: Some(UnreliabilityGenerator::new(
313 write_latency,
314 write_failure_rate,
315 )),
316 flush_generator: Some(UnreliabilityGenerator::new(
317 flush_latency,
318 flush_failure_rate,
319 )),
320 shutdown_generator: Some(UnreliabilityGenerator::new(
321 shutdown_latency,
322 shutdown_failure_rate,
323 )),
324 },
325 }
326 }
327
328 fn poll_broken(&self, cx: &mut std::task::Context<'_>) -> bool {
329 let await_cancellation = self.broken.cancelled();
330 pin_mut!(await_cancellation);
331 await_cancellation.poll(cx).is_ready()
332 }
333 }
334
335 impl Debug for UnreliableDuplexStream {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 f.debug_struct("UnreliableDuplexStream").finish()
338 }
339 }
340
341 struct UnreliabilityGenerator {
342 latency: LatencyInterval,
343 failure_rate: FailureRate,
344 sleep_future: Option<Pin<Box<tokio::time::Sleep>>>,
345 successes: u64,
346 }
347
348 impl UnreliabilityGenerator {
349 fn new(latency: LatencyInterval, failure_rate: FailureRate) -> UnreliabilityGenerator {
350 Self {
351 latency,
352 failure_rate,
353 sleep_future: None,
354 successes: 0,
355 }
356 }
357
358 pub fn generate(
359 &mut self,
360 cx: &mut std::task::Context<'_>,
361 ) -> std::task::Poll<std::io::Result<()>> {
362 let sleep = self.sleep_future.get_or_insert_with(|| {
363 Box::pin(
364 tokio::time::sleep(self.latency.random()),
366 )
367 });
368 match sleep.poll_unpin(cx) {
369 std::task::Poll::Ready(()) => {
370 self.sleep_future = None;
371 }
372 std::task::Poll::Pending => return std::task::Poll::Pending,
373 }
374 if self.failure_rate.random_fail() {
375 tracing::debug!(
376 "Returning random error on unreliable stream after {} successes",
377 self.successes
378 );
379 std::task::Poll::Ready(Err(std::io::Error::new(
380 std::io::ErrorKind::Other,
381 "Randomly failed",
382 )))
383 } else {
384 self.successes += 1;
385 std::task::Poll::Ready(Ok(()))
386 }
387 }
388 }
389
390 impl AsyncRead for UnreliableDuplexStream {
391 fn poll_read(
392 mut self: Pin<&mut Self>,
393 cx: &mut std::task::Context<'_>,
394 buf: &mut tokio::io::ReadBuf<'_>,
395 ) -> std::task::Poll<std::io::Result<()>> {
396 if self.poll_broken(cx) {
397 return std::task::Poll::Ready(Err(std::io::Error::new(
398 std::io::ErrorKind::Other,
399 "Stream is broken",
400 )));
401 }
402
403 match self.read_generator.as_mut().map(|g| g.generate(cx)) {
404 Some(std::task::Poll::Ready(Err(e))) => {
405 self.broken.cancel();
406 std::task::Poll::Ready(Err(e))
407 }
408 Some(std::task::Poll::Pending) => std::task::Poll::Pending,
409 Some(std::task::Poll::Ready(Ok(()))) | None => {
410 Pin::new(&mut self.inner).poll_read(cx, buf)
411 }
412 }
413 }
414 }
415
416 impl AsyncWrite for UnreliableDuplexStream {
417 fn poll_write(
418 mut self: Pin<&mut Self>,
419 cx: &mut std::task::Context<'_>,
420 buf: &[u8],
421 ) -> std::task::Poll<Result<usize, std::io::Error>> {
422 if self.poll_broken(cx) {
423 return std::task::Poll::Ready(Err(std::io::Error::new(
424 std::io::ErrorKind::Other,
425 "Stream is broken",
426 )));
427 }
428
429 match self.write_generator.as_mut().map(|g| g.generate(cx)) {
430 Some(std::task::Poll::Ready(Err(e))) => {
431 self.broken.cancel();
432 std::task::Poll::Ready(Err(e))
433 }
434 Some(std::task::Poll::Pending) => std::task::Poll::Pending,
435 Some(std::task::Poll::Ready(Ok(()))) | None => {
436 Pin::new(&mut self.inner).poll_write(cx, buf)
437 }
438 }
439 }
440
441 fn poll_flush(
442 mut self: Pin<&mut Self>,
443 cx: &mut std::task::Context<'_>,
444 ) -> std::task::Poll<Result<(), std::io::Error>> {
445 if self.poll_broken(cx) {
446 return std::task::Poll::Ready(Err(std::io::Error::new(
447 std::io::ErrorKind::Other,
448 "Stream is broken",
449 )));
450 }
451
452 match self.flush_generator.as_mut().map(|g| g.generate(cx)) {
453 Some(std::task::Poll::Ready(Err(e))) => {
454 self.broken.cancel();
455 std::task::Poll::Ready(Err(e))
456 }
457 Some(std::task::Poll::Pending) => std::task::Poll::Pending,
458 Some(std::task::Poll::Ready(Ok(()))) | None => {
459 Pin::new(&mut self.inner).poll_flush(cx)
460 }
461 }
462 }
463
464 fn poll_shutdown(
465 mut self: Pin<&mut Self>,
466 cx: &mut std::task::Context<'_>,
467 ) -> std::task::Poll<Result<(), std::io::Error>> {
468 if self.poll_broken(cx) {
469 return std::task::Poll::Ready(Err(std::io::Error::new(
470 std::io::ErrorKind::Other,
471 "Stream is broken",
472 )));
473 }
474
475 match self.shutdown_generator.as_mut().map(|g| g.generate(cx)) {
476 Some(std::task::Poll::Ready(Err(e))) => {
477 self.broken.cancel();
478 std::task::Poll::Ready(Err(e))
479 }
480 Some(std::task::Poll::Pending) => std::task::Poll::Pending,
481 Some(std::task::Poll::Ready(Ok(()))) | None => {
482 Pin::new(&mut self.inner).poll_shutdown(cx)
483 }
484 }
485 }
486 }
487
488 pub struct MockNetwork {
489 clients: Arc<Mutex<HashMap<String, Sender<UnreliableDuplexStream>>>>,
490 }
491
492 pub struct MockConnector {
493 id: PeerId,
494 clients: Arc<Mutex<HashMap<String, Sender<UnreliableDuplexStream>>>>,
495 reliability: StreamReliability,
496 }
497
498 impl MockNetwork {
499 #[allow(clippy::new_without_default)]
500 pub fn new() -> MockNetwork {
501 MockNetwork {
502 clients: Arc::new(Mutex::new(HashMap::new())),
503 }
504 }
505
506 pub fn connector(&self, id: PeerId, reliability: StreamReliability) -> MockConnector {
507 MockConnector {
508 id,
509 clients: self.clients.clone(),
510 reliability,
511 }
512 }
513 }
514
515 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
516 pub struct LatencyInterval {
517 min_millis: u64,
518 max_millis: u64,
519 }
520
521 impl LatencyInterval {
522 const ZERO: LatencyInterval = LatencyInterval {
523 min_millis: 0,
524 max_millis: 0,
525 };
526
527 pub fn new(min: Duration, max: Duration) -> LatencyInterval {
528 assert!(min <= max);
529 LatencyInterval {
530 min_millis: min
531 .as_millis()
532 .try_into()
533 .expect("min duration as millis to fit in a u64"),
534 max_millis: max
535 .as_millis()
536 .try_into()
537 .expect("max duration as millis to fit in a u64"),
538 }
539 }
540
541 pub fn random(&self) -> Duration {
542 let mut rng = rand::thread_rng();
543 Duration::from_millis(rng.gen_range(self.min_millis..=self.max_millis))
544 }
545 }
546
547 #[derive(Debug, Copy, Clone)]
548 pub struct FailureRate(f64);
549 impl FailureRate {
550 const MAX: FailureRate = FailureRate(1.0);
551 pub fn new(failure_rate: f64) -> Self {
552 assert!((0.0..=1.0).contains(&failure_rate));
553 Self(failure_rate)
554 }
555
556 pub fn random_fail(&self) -> bool {
557 let mut rng = rand::thread_rng();
558 rng.gen_range(0.0..1.0) < self.0
559 }
560 }
561
562 #[derive(Debug, Copy, Clone)]
563 pub enum StreamReliability {
564 FullyReliable,
565 RandomlyUnreliable {
566 read_failure_rate: FailureRate,
567 write_failure_rate: FailureRate,
568 flush_failure_rate: FailureRate,
569 shutdown_failure_rate: FailureRate,
570 read_latency: LatencyInterval,
571 write_latency: LatencyInterval,
572 flush_latency: LatencyInterval,
573 shutdown_latency: LatencyInterval,
574 },
575 }
576
577 impl StreamReliability {
578 pub const MILDLY_UNRELIABLE: StreamReliability = {
579 let failure_rate = FailureRate(0.0);
580 let latency = LatencyInterval {
581 min_millis: 1,
582 max_millis: 10,
583 };
584 Self::RandomlyUnreliable {
585 read_failure_rate: failure_rate,
586 write_failure_rate: failure_rate,
587 flush_failure_rate: failure_rate,
588 shutdown_failure_rate: failure_rate,
589 read_latency: latency,
590 write_latency: latency,
591 flush_latency: latency,
592 shutdown_latency: latency,
593 }
594 };
595
596 pub const INTEGRATION_TEST: StreamReliability = {
597 let failure_rate_base = 0.0;
603 let latency = LatencyInterval {
604 min_millis: 1,
605 max_millis: 10,
606 };
607 Self::RandomlyUnreliable {
608 read_failure_rate: FailureRate(failure_rate_base * 2.0),
610 write_failure_rate: FailureRate(failure_rate_base),
611 flush_failure_rate: FailureRate(failure_rate_base),
612 shutdown_failure_rate: FailureRate(failure_rate_base),
613 read_latency: latency,
614 write_latency: latency,
615 flush_latency: latency,
616 shutdown_latency: latency,
617 }
618 };
619
620 pub const BROKEN: StreamReliability = {
621 Self::RandomlyUnreliable {
622 read_failure_rate: FailureRate::MAX,
623 write_failure_rate: FailureRate::MAX,
624 flush_failure_rate: FailureRate::MAX,
625 shutdown_failure_rate: FailureRate::MAX,
626 read_latency: LatencyInterval::ZERO,
627 write_latency: LatencyInterval::ZERO,
628 flush_latency: LatencyInterval::ZERO,
629 shutdown_latency: LatencyInterval::ZERO,
630 }
631 };
632 }
633
634 #[async_trait::async_trait]
635 impl<M> Connector<M> for MockConnector
636 where
637 M: Debug + serde::Serialize + serde::de::DeserializeOwned + Send + Unpin + 'static,
638 {
639 async fn connect_framed(&self, destination: SafeUrl, _peer: PeerId) -> ConnectResult<M> {
640 let mut clients_lock = self.clients.try_lock().map_err(|e| {
641 anyhow!("Mock network mutex busy or poisoned, the network stack will re-try anyway: {e:?}")
642 })?;
643 if let Some(client) = clients_lock.get_mut(&parse_host_port(&destination)?) {
644 let (stream_our, stream_theirs) = tokio::io::duplex(43_689);
645 let mut stream_our = UnreliableDuplexStream::new(stream_our, self.reliability);
646 let stream_theirs = UnreliableDuplexStream::new(stream_theirs, self.reliability);
647 client.send(stream_theirs).await?;
648 let peer = do_handshake(self.id, &mut stream_our).await?;
649 let framed = BidiFramed::<
650 M,
651 WriteHalf<UnreliableDuplexStream>,
652 ReadHalf<UnreliableDuplexStream>,
653 >::new(stream_our)
654 .into_dyn();
655 Ok((peer, framed))
656 } else {
657 return Err(anyhow::anyhow!("can't connect"));
658 }
659 }
660
661 async fn listen(
662 &self,
663 bind_addr: SocketAddr,
664 ) -> Result<Pin<Box<dyn Stream<Item = ConnectResult<M>> + Send + Unpin + 'static>>, Error>
665 {
666 let (send, receive) = tokio::sync::mpsc::channel(16);
667
668 if self
669 .clients
670 .lock()
671 .await
672 .insert(bind_addr.to_string(), send)
673 .is_some()
674 {
675 return Err(anyhow::anyhow!("Address already bound"));
676 }
677
678 let our_id = self.id;
679 let stream = futures::stream::unfold(receive, move |mut receive| {
680 Box::pin(async move {
681 let mut connection = receive.recv().await.unwrap();
682 let peer = match do_handshake(our_id, &mut connection).await {
683 Ok(peer) => peer,
684 Err(e) => {
685 tracing::debug!("Error during handshake: {e:?}");
686 return Some((Err(e), receive));
687 }
688 };
689 let framed =
690 BidiFramed::<M, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(
691 connection,
692 )
693 .into_dyn();
694
695 Some((Ok((peer, framed)), receive))
696 })
697 });
698 Ok(Box::pin(stream))
699 }
700 }
701
702 async fn do_handshake<S>(our_id: PeerId, stream: &mut S) -> Result<PeerId, anyhow::Error>
703 where
704 S: AsyncRead + AsyncWrite + Unpin,
705 {
706 let our_id = our_id.to_usize() as u16;
708 stream.write_all(&our_id.to_be_bytes()[..]).await?;
709
710 let mut peer_id = [0u8; 2];
712 stream.read_exact(&mut peer_id[..]).await?;
713 Ok(PeerId::from(u16::from_be_bytes(peer_id)))
714 }
715
716 #[tokio::test]
717 async fn test_mock_network() {
718 let bind_addr: SocketAddr = "127.0.0.1:7000".parse().unwrap();
719 let url: SafeUrl = "ws://127.0.0.1:7000".parse().unwrap();
720 let peer_a = PeerId::from(1);
721 let peer_b = PeerId::from(2);
722
723 let net = MockNetwork::new();
724 let conn_a = net.connector(peer_a, StreamReliability::FullyReliable);
725 let conn_b = net.connector(peer_b, StreamReliability::FullyReliable);
726
727 let mut listener = Connector::<u64>::listen(&conn_a, bind_addr).await.unwrap();
728 let conn_a_fut = spawn("listener next await", async move {
729 listener.next().await.unwrap().unwrap()
730 });
731
732 let (auth_peer_b, mut conn_b) = Connector::<u64>::connect_framed(&conn_b, url, peer_a)
733 .await
734 .unwrap();
735 let (auth_peer_a, mut conn_a) = conn_a_fut.await.unwrap();
736
737 assert_eq!(auth_peer_a, peer_b);
738 assert_eq!(auth_peer_b, peer_a);
739
740 conn_a.send(42).await.unwrap();
741 conn_b.send(21).await.unwrap();
742
743 assert_eq!(conn_a.next().await.unwrap().unwrap(), 21);
744 assert_eq!(conn_b.next().await.unwrap().unwrap(), 42);
745 }
746
747 #[tokio::test]
748 async fn test_unreliable_components() {
749 assert!(!FailureRate::new(0f64).random_fail());
750 assert!(FailureRate::new(1f64).random_fail());
751
752 let good_interval = (0..=3).contains(
753 &LatencyInterval::new(Duration::from_millis(0), Duration::from_millis(3))
754 .random()
755 .as_millis(),
756 );
757 assert!(good_interval);
758
759 let (a, b) = tokio::io::duplex(43_689);
760 let mut a_stream = UnreliableDuplexStream::new(a, StreamReliability::FullyReliable);
761 let mut b_stream = UnreliableDuplexStream::new(b, StreamReliability::FullyReliable);
762 assert!(a_stream.write(&[1, 2, 3]).await.is_ok());
763 assert!(a_stream.flush().await.is_ok());
764 assert_eq!(b_stream.read_u8().await.unwrap(), 1);
765 assert_eq!(b_stream.read_u8().await.unwrap(), 2);
766 assert_eq!(b_stream.read_u8().await.unwrap(), 3);
767
768 let (a, b) = tokio::io::duplex(43_689);
769 let mut a_stream = UnreliableDuplexStream::new(a, StreamReliability::FullyReliable);
770 let mut b_stream = UnreliableDuplexStream::new(b, StreamReliability::BROKEN);
771 assert!(a_stream.write(&[1, 2, 3]).await.is_ok());
772 assert!(a_stream.flush().await.is_ok());
773 assert!(b_stream.read_u8().await.is_err());
774
775 let (a, b) = tokio::io::duplex(43_689);
776 let mut a_stream = UnreliableDuplexStream::new(a, StreamReliability::BROKEN);
777 let mut _b_stream = UnreliableDuplexStream::new(b, StreamReliability::FullyReliable);
778 assert!(a_stream.write(&[1, 2, 3]).await.is_err());
779 }
781
782 #[allow(dead_code)]
783 async fn timeout<F, T>(f: F) -> Option<T>
784 where
785 F: Future<Output = T>,
786 {
787 tokio::time::timeout(Duration::from_secs(1), f).await.ok()
788 }
789
790 #[tokio::test]
791 async fn test_large_messages() {
792 let bind_addr: SocketAddr = "127.0.0.1:7000".parse().unwrap();
793 let url: SafeUrl = "ws://127.0.0.1:7000".parse().unwrap();
794 let peer_a = PeerId::from(1);
795 let peer_b = PeerId::from(2);
796
797 let net = MockNetwork::new();
798 let conn_a = net.connector(peer_a, StreamReliability::FullyReliable);
799 let conn_b = net.connector(peer_b, StreamReliability::FullyReliable);
800
801 let mut listener = Connector::<Vec<u8>>::listen(&conn_a, bind_addr)
802 .await
803 .unwrap();
804 let conn_a_fut = spawn("listener next await", async move {
805 listener.next().await.unwrap().unwrap()
806 });
807
808 let (auth_peer_b, mut conn_b) = Connector::<Vec<u8>>::connect_framed(&conn_b, url, peer_a)
809 .await
810 .unwrap();
811 let (auth_peer_a, mut conn_a) = conn_a_fut.await.unwrap();
812
813 assert_eq!(auth_peer_a, peer_b);
814 assert_eq!(auth_peer_b, peer_a);
815
816 let send_future = async {
817 conn_a.send(vec![42; 16000]).await.unwrap();
818 }
819 .boxed();
820 let receive_future = async {
821 assert_eq!(
822 timeout(conn_b.next()).await.unwrap().unwrap().unwrap(),
823 vec![42; 16000]
824 );
825 }
826 .boxed();
827
828 tokio::join!(send_future, receive_future);
829 }
830}
831
832#[cfg(test)]
833mod tests {
834 use std::net::SocketAddr;
835
836 use fedimint_core::runtime::spawn;
837 use fedimint_core::util::SafeUrl;
838 use fedimint_core::PeerId;
839 use futures::{SinkExt, StreamExt};
840
841 use crate::config::gen_cert_and_key;
842 use crate::net::connect::{ConnectionListener, Connector, TlsConfig};
843 use crate::net::framed::AnyFramedTransport;
844 use crate::TlsTcpConnector;
845
846 fn gen_connector_config(count: usize) -> Vec<TlsConfig> {
847 let peer_keys = (0..count)
848 .map(|id| {
849 let peer = PeerId::from(id as u16);
850 gen_cert_and_key(&format!("peer-{}", peer.to_usize())).unwrap()
851 })
852 .collect::<Vec<_>>();
853
854 peer_keys
855 .iter()
856 .map(|(_cert, key)| TlsConfig {
857 our_private_key: key.clone(),
858 peer_certs: peer_keys
859 .iter()
860 .enumerate()
861 .map(|(peer, (cert, _))| (PeerId::from(peer as u16), cert.clone()))
862 .collect(),
863 peer_names: peer_keys
864 .iter()
865 .enumerate()
866 .map(|(peer, (_, _))| (PeerId::from(peer as u16), format!("peer-{peer}")))
867 .collect(),
868 })
869 .collect()
870 }
871
872 #[tokio::test]
873 async fn connect_success() {
874 let bind_addr: SocketAddr = "127.0.0.1:7000".parse().unwrap();
877 let url: SafeUrl = "ws://127.0.0.1:7000".parse().unwrap();
878 let connectors = gen_connector_config(5)
879 .into_iter()
880 .enumerate()
881 .map(|(id, cfg)| TlsTcpConnector::new(cfg, PeerId::from(id as u16)))
882 .collect::<Vec<_>>();
883
884 let mut server: ConnectionListener<u64> = connectors[0].listen(bind_addr).await.unwrap();
885
886 let server_task = spawn("server next await", async move {
887 let (peer, mut conn) = server.next().await.unwrap().unwrap();
888 assert_eq!(peer.to_usize(), 2);
889 let received = conn.next().await.unwrap().unwrap();
890 assert_eq!(received, 42);
891 conn.send(21).await.unwrap();
892 assert!(conn.next().await.unwrap().is_err());
893 });
894
895 let (peer_of_a, mut client_a): (_, AnyFramedTransport<u64>) = connectors[2]
896 .connect_framed(url.clone(), PeerId::from(0))
897 .await
898 .unwrap();
899 assert_eq!(peer_of_a.to_usize(), 0);
900 client_a.send(42).await.unwrap();
901 let received = client_a.next().await.unwrap().unwrap();
902 assert_eq!(received, 21);
903 drop(client_a);
904
905 server_task.await.unwrap();
906 }
907
908 #[tokio::test]
909 async fn connect_reject() {
910 let bind_addr: SocketAddr = "127.0.0.1:7001".parse().unwrap();
911 let url: SafeUrl = "wss://127.0.0.1:7001".parse().unwrap();
912 let cfg = gen_connector_config(5);
913
914 let honest = TlsTcpConnector::new(cfg[0].clone(), PeerId::from(0));
915
916 let mut malicious_wrong_key_cfg = cfg[1].clone();
917 malicious_wrong_key_cfg.our_private_key = cfg[2].our_private_key.clone();
918 let malicious_wrong_key = TlsTcpConnector::new(malicious_wrong_key_cfg, PeerId::from(1));
919
920 {
922 let mut server: ConnectionListener<u64> = honest.listen(bind_addr).await.unwrap();
923
924 let server_task = spawn("server next await", async move {
925 let conn_res = server.next().await.unwrap();
926 assert_eq!(
927 conn_res.err().unwrap().to_string().as_str(),
928 "invalid peer certificate: BadSignature"
929 );
930 });
931
932 let err_anytime = async {
933 let (_peer, mut conn): (_, AnyFramedTransport<u64>) = malicious_wrong_key
934 .connect_framed(url.clone(), PeerId::from(0))
935 .await?;
936
937 conn.send(42).await?;
938 conn.flush().await?;
939 conn.next().await.unwrap()?;
940
941 Result::<_, anyhow::Error>::Ok(())
942 };
943
944 let conn_res = err_anytime.await;
945 assert_eq!(
946 conn_res.err().unwrap().to_string().as_str(),
947 "received fatal alert: DecryptError"
948 );
949
950 server_task.await.unwrap();
951 }
952
953 {
955 let mut server: ConnectionListener<u64> =
956 malicious_wrong_key.listen(bind_addr).await.unwrap();
957
958 let server_task = spawn("server next await", async move {
959 let conn_res = server.next().await.unwrap();
960 assert_eq!(
961 conn_res.err().unwrap().to_string().as_str(),
962 "received fatal alert: DecryptError"
963 );
964 });
965
966 let err_anytime = async {
967 let (_peer, mut conn): (_, AnyFramedTransport<u64>) =
968 honest.connect_framed(url.clone(), PeerId::from(1)).await?;
969
970 conn.send(42).await?;
971 conn.flush().await?;
972 conn.next().await.unwrap()?;
973
974 Result::<_, anyhow::Error>::Ok(())
975 };
976
977 let conn_res = err_anytime.await;
978 assert_eq!(
979 conn_res.err().unwrap().to_string().as_str(),
980 "invalid peer certificate: BadSignature"
981 );
982
983 server_task.await.unwrap();
984 }
985
986 {
988 let mut server: ConnectionListener<u64> =
989 TlsTcpConnector::new(cfg[2].clone(), PeerId::from(2))
990 .listen(bind_addr)
991 .await
992 .unwrap();
993
994 let server_task = spawn("server next await", async move {
995 let conn_res = server.next().await.unwrap();
996 assert_eq!(
997 conn_res.err().unwrap().to_string().as_str(),
998 "received fatal alert: BadCertificate"
999 );
1000 });
1001
1002 let err_anytime = async {
1003 let (_peer, mut conn): (_, AnyFramedTransport<u64>) =
1004 honest.connect_framed(url.clone(), PeerId::from(0)).await?;
1005
1006 conn.send(42).await?;
1007 conn.flush().await?;
1008 conn.next().await.unwrap()?;
1009
1010 Result::<_, anyhow::Error>::Ok(())
1011 };
1012
1013 let conn_res = err_anytime.await;
1014 assert_eq!(
1015 conn_res.err().unwrap().to_string().as_str(),
1016 "invalid peer certificate: NotValidForName"
1017 );
1018
1019 server_task.await.unwrap();
1020 }
1021 }
1022}