fedimint_server/
multiplexed.rs

1use std::collections::{HashMap, VecDeque};
2use std::fmt::Debug;
3use std::hash::Hash;
4
5use async_trait::async_trait;
6use fedimint_core::net::peers::{IMuxPeerConnections, PeerConnections};
7use fedimint_core::runtime::spawn;
8use fedimint_core::task::{Cancellable, Cancelled};
9use fedimint_core::PeerId;
10use fedimint_logging::LOG_NET_PEER;
11use serde::de::DeserializeOwned;
12use serde::{Deserialize, Serialize};
13use tokio::sync::mpsc::{channel, Receiver, Sender};
14use tokio::sync::oneshot;
15use tracing::{debug, warn};
16
17/// TODO: Use proper `ModuleId` after modularization is complete
18pub type ModuleId = String;
19pub type ModuleIdRef<'a> = &'a str;
20
21/// Amount of per-peer messages after which we will stop throwing them away.
22///
23/// It's hard to predict how many messages is too many, but we have
24/// to draw the line somewhere.
25pub const MAX_PEER_OUT_OF_ORDER_MESSAGES: u64 = 10000;
26
27/// A `Msg` that can target a specific destination module
28#[derive(Serialize, Deserialize, Debug, Clone)]
29pub struct ModuleMultiplexed<MuxKey, Msg> {
30    pub key: MuxKey,
31    pub msg: Msg,
32}
33
34struct ModuleMultiplexerOutOfOrder<MuxKey, Msg> {
35    /// Cached messages per `ModuleId` waiting for callback
36    msgs: HashMap<MuxKey, VecDeque<(PeerId, Msg)>>,
37    /// Callback queue from tasks that want to receive
38    callbacks: HashMap<MuxKey, VecDeque<oneshot::Sender<(PeerId, Msg)>>>,
39    /// Track pending messages per peer to avoid a potential DoS
40    peer_counts: HashMap<PeerId, u64>,
41}
42
43impl<MuxKey, Msg> Default for ModuleMultiplexerOutOfOrder<MuxKey, Msg> {
44    fn default() -> Self {
45        Self {
46            msgs: HashMap::new(),
47            callbacks: HashMap::new(),
48            peer_counts: HashMap::new(),
49        }
50    }
51}
52
53/// A wrapper around `AnyPeerConnections` multiplexing communication between
54/// multiple modules over it
55///
56/// This works by addressing each module when sending, and handling buffering
57/// messages received out of order until they are requested.
58///
59/// This type is thread-safe and can be cheaply cloned.
60#[derive(Clone)]
61pub struct PeerConnectionMultiplexer<MuxKey, Msg> {
62    /// Sender of send requests
63    send_requests_tx: Sender<(Vec<PeerId>, MuxKey, Msg)>,
64    /// Sender of receive callbacks
65    receive_callbacks_tx: Sender<Callback<MuxKey, Msg>>,
66    /// Sender of peer bans
67    peer_bans_tx: Sender<PeerId>,
68}
69
70type Callback<MuxKey, Msg> = (MuxKey, oneshot::Sender<(PeerId, Msg)>);
71
72impl<MuxKey, Msg> PeerConnectionMultiplexer<MuxKey, Msg>
73where
74    Msg: Serialize + DeserializeOwned + Unpin + Send + Debug + 'static,
75    MuxKey: Serialize + DeserializeOwned + Unpin + Send + Debug + Eq + Hash + Clone + 'static,
76{
77    pub fn new(connections: PeerConnections<ModuleMultiplexed<MuxKey, Msg>>) -> Self {
78        let (send_requests_tx, send_requests_rx) = channel(1000);
79        let (receive_callbacks_tx, receive_callbacks_rx) = channel(1000);
80        let (peer_bans_tx, peer_bans_rx) = channel(1000);
81
82        spawn(
83            "peer connection multiplexer",
84            Self::run(
85                connections,
86                ModuleMultiplexerOutOfOrder::default(),
87                send_requests_rx,
88                receive_callbacks_rx,
89                peer_bans_rx,
90            ),
91        );
92
93        Self {
94            send_requests_tx,
95            receive_callbacks_tx,
96            peer_bans_tx,
97        }
98    }
99
100    async fn run(
101        mut connections: PeerConnections<ModuleMultiplexed<MuxKey, Msg>>,
102        mut out_of_order: ModuleMultiplexerOutOfOrder<MuxKey, Msg>,
103        mut send_requests_rx: Receiver<(Vec<PeerId>, MuxKey, Msg)>,
104        mut receive_callbacks_rx: Receiver<Callback<MuxKey, Msg>>,
105        mut peer_bans_rx: Receiver<PeerId>,
106    ) -> Cancellable<()> {
107        loop {
108            let mut key_inserted: Option<MuxKey> = None;
109            tokio::select! {
110                 // Send requests are forwarded to underlying connections
111                 send_request = send_requests_rx.recv() => {
112                    let (peers, key, msg) = send_request.ok_or(Cancelled)?;
113                    connections.send(&peers, ModuleMultiplexed { key, msg }).await?;
114                }
115                // Ban requests are forwarded to underlying connections
116                peer_ban = peer_bans_rx.recv() => {
117                    let peer = peer_ban.ok_or(Cancelled)?;
118                    connections.ban_peer(peer).await;
119                }
120                // Receive callbacks are added to callback queue by key
121                receive_callback = receive_callbacks_rx.recv() => {
122                    let (key, callback) = receive_callback.ok_or(Cancelled)?;
123                    out_of_order.callbacks.entry(key.clone()).or_default().push_back(callback);
124                    key_inserted = Some(key);
125                }
126                // Actual received messages are added message queue by key
127                receive = connections.receive() => {
128                    let (peer, ModuleMultiplexed { key, msg }) = receive?;
129                    let peer_pending = out_of_order.peer_counts.entry(peer).or_default();
130                    // We limit our messages from any given peer to avoid OOM
131                    // In practice this would halt DKG
132                    if *peer_pending > MAX_PEER_OUT_OF_ORDER_MESSAGES {
133                        warn!(
134                            target: LOG_NET_PEER,
135                            "Peer {peer} has {peer_pending} pending messages. Dropping new message."
136                        );
137                    } else {
138                        *peer_pending += 1;
139                        out_of_order.msgs.entry(key.clone()).or_default().push_back((peer, msg));
140                        key_inserted = Some(key);
141                    }
142                }
143            }
144
145            // If a key was inserted, check to see if we can fulfill a callback
146            if let Some(key) = key_inserted {
147                let callbacks = out_of_order.callbacks.entry(key.clone()).or_default();
148                let msgs = out_of_order.msgs.entry(key.clone()).or_default();
149
150                if !callbacks.is_empty() && !msgs.is_empty() {
151                    let callback = callbacks.pop_front().expect("checked");
152                    let (peer, msg) = msgs.pop_front().expect("checked");
153                    let peer_pending = out_of_order.peer_counts.entry(peer).or_default();
154                    *peer_pending -= 1;
155                    callback.send((peer, msg)).map_err(|_| Cancelled)?;
156                }
157            }
158        }
159    }
160}
161
162#[async_trait]
163impl<MuxKey, Msg> IMuxPeerConnections<MuxKey, Msg> for PeerConnectionMultiplexer<MuxKey, Msg>
164where
165    Msg: Serialize + DeserializeOwned + Unpin + Send + Debug,
166    MuxKey: Serialize + DeserializeOwned + Unpin + Send + Debug + Eq + Hash + Clone,
167{
168    async fn send(&self, peers: &[PeerId], key: MuxKey, msg: Msg) -> Cancellable<()> {
169        debug!("Sending to {peers:?}/{key:?}, {msg:?}");
170        self.send_requests_tx
171            .send((peers.to_vec(), key, msg))
172            .await
173            .map_err(|_e| Cancelled)
174    }
175
176    /// Await receipt of a message from any connected peer.
177    async fn receive(&self, key: MuxKey) -> Cancellable<(PeerId, Msg)> {
178        let (callback_tx, callback_rx) = oneshot::channel();
179        self.receive_callbacks_tx
180            .send((key, callback_tx))
181            .await
182            .map_err(|_e| Cancelled)?;
183        callback_rx.await.map_err(|_e| Cancelled)
184    }
185
186    async fn ban_peer(&self, peer: PeerId) {
187        // We don't return a `Cancellable` for bans
188        let _ = self.peer_bans_tx.send(peer).await;
189    }
190}
191
192#[cfg(test)]
193pub mod test {
194    use std::time::Duration;
195
196    use fedimint_core::net::peers::fake::make_fake_peer_connection;
197    use fedimint_core::net::peers::IMuxPeerConnections;
198    use fedimint_core::task::{self, TaskGroup};
199    use fedimint_core::PeerId;
200    use rand::rngs::OsRng;
201    use rand::seq::SliceRandom;
202    use rand::{thread_rng, Rng};
203
204    use crate::multiplexed::PeerConnectionMultiplexer;
205
206    /// Send over many messages a multiplexed fake link
207    ///
208    /// Some things this is checking for:
209    ///
210    /// * no message were missed
211    /// * messages arrived in order (from PoW of each module)
212    /// * nothing deadlocked somewhere.
213    #[test_log::test(tokio::test)]
214    async fn test_multiplexer() {
215        const NUM_MODULES: usize = 128;
216        const NUM_MSGS_PER_MODULE: usize = 128;
217        const NUM_REPEAT_TEST: usize = 10;
218
219        for _ in 0..NUM_REPEAT_TEST {
220            let task_group = TaskGroup::new();
221            let task_handle = task_group.make_handle();
222
223            let peer1 = PeerId::from(0);
224            let peer2 = PeerId::from(1);
225
226            let (conn1, conn2) = make_fake_peer_connection(peer1, peer2, 1000, task_handle.clone());
227            let (conn1, conn2) = (
228                PeerConnectionMultiplexer::new(conn1).into_dyn(),
229                PeerConnectionMultiplexer::new(conn2).into_dyn(),
230            );
231
232            let mut modules: Vec<_> = (0..NUM_MODULES).collect();
233            modules.shuffle(&mut thread_rng());
234
235            for mux_key in modules.clone() {
236                let conn1 = conn1.clone();
237                let task_handle = task_handle.clone();
238                task_group.spawn(format!("sender-{mux_key}"), move |_| async move {
239                    for msg_i in 0..NUM_MSGS_PER_MODULE {
240                        // add some random jitter
241                        if OsRng.gen() {
242                            // Note that randomized sleep in sender is larger than
243                            // in receiver, to avoid just running with always full
244                            // queues.
245                            task::sleep(Duration::from_millis(2)).await;
246                        }
247                        if task_handle.is_shutting_down() {
248                            break;
249                        }
250                        conn1.send(&[peer2], mux_key, msg_i).await.unwrap();
251                    }
252                });
253            }
254
255            modules.shuffle(&mut thread_rng());
256            for mux_key in modules.clone() {
257                let conn2 = conn2.clone();
258                task_group.spawn(format!("receiver-{mux_key}"), move |_| async move {
259                    for msg_i in 0..NUM_MSGS_PER_MODULE {
260                        // add some random jitter
261                        if OsRng.gen() {
262                            task::sleep(Duration::from_millis(1)).await;
263                        }
264                        assert_eq!(conn2.receive(mux_key).await.unwrap(), (peer1, msg_i));
265                    }
266                });
267            }
268
269            task_group.join_all(None).await.expect("no failures");
270        }
271    }
272}