fedimint_core/net/peers/
fake.rs

1/// Fake (channel-based) implementation of [`super::PeerConnections`].
2use std::time::Duration;
3
4use async_trait::async_trait;
5use fedimint_core::net::peers::{IPeerConnections, PeerConnections};
6use fedimint_core::runtime::sleep;
7use fedimint_core::task::{Cancellable, Cancelled, TaskHandle};
8use fedimint_core::PeerId;
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use tokio::sync::mpsc::{self, Receiver, Sender};
12
13struct FakePeerConnections<Msg> {
14    tx: Sender<Msg>,
15    rx: Receiver<Msg>,
16    peer_id: PeerId,
17    task_handle: TaskHandle,
18}
19
20#[async_trait]
21impl<Msg> IPeerConnections<Msg> for FakePeerConnections<Msg>
22where
23    Msg: Serialize + DeserializeOwned + Unpin + Send,
24{
25    async fn send(&mut self, peers: &[PeerId], msg: Msg) -> Cancellable<()> {
26        assert_eq!(peers, &[self.peer_id]);
27
28        // If the peer is gone, just pretend we are going to resend
29        // the msg eventually, even if it will never happen.
30        let _ = self.tx.send(msg).await;
31        Ok(())
32    }
33
34    async fn receive(&mut self) -> Cancellable<(PeerId, Msg)> {
35        // Just like a real implementation, do not return
36        // if the peer is gone.
37        while !self.task_handle.is_shutting_down() {
38            if let Some(msg) = self.rx.recv().await {
39                return Ok((self.peer_id, msg));
40            }
41
42            sleep(Duration::from_secs(10)).await;
43        }
44        Err(Cancelled)
45    }
46
47    /// Removes a peer connection in case of misbehavior
48    async fn ban_peer(&mut self, _peer: PeerId) {
49        unimplemented!();
50    }
51}
52
53/// Create a fake link between `peer1` and `peer2` for test purposes
54///
55/// `buf_size` controls the size of the `tokio::mpsc::channel` used
56/// under the hood (both ways).
57pub fn make_fake_peer_connection<Msg>(
58    peer1: PeerId,
59    peer2: PeerId,
60    buf_size: usize,
61    task_handle: TaskHandle,
62) -> (PeerConnections<Msg>, PeerConnections<Msg>)
63where
64    Msg: Serialize + DeserializeOwned + Unpin + Send + 'static,
65{
66    let (tx1, rx1) = mpsc::channel(buf_size);
67    let (tx2, rx2) = mpsc::channel(buf_size);
68
69    (
70        FakePeerConnections {
71            tx: tx1,
72            rx: rx2,
73            peer_id: peer2,
74            task_handle: task_handle.clone(),
75        }
76        .into_dyn(),
77        FakePeerConnections {
78            tx: tx2,
79            rx: rx1,
80            peer_id: peer1,
81            task_handle,
82        }
83        .into_dyn(),
84    )
85}