1use std::collections::{HashMap, VecDeque};
2use std::fmt::Debug;
3use std::hash::Hash;
4
5use async_trait::async_trait;
6use fedimint_core::net::peers::{IMuxPeerConnections, PeerConnections};
7use fedimint_core::runtime::spawn;
8use fedimint_core::task::{Cancellable, Cancelled};
9use fedimint_core::PeerId;
10use fedimint_logging::LOG_NET_PEER;
11use serde::de::DeserializeOwned;
12use serde::{Deserialize, Serialize};
13use tokio::sync::mpsc::{channel, Receiver, Sender};
14use tokio::sync::oneshot;
15use tracing::{debug, warn};
16
17pub type ModuleId = String;
19pub type ModuleIdRef<'a> = &'a str;
20
21pub const MAX_PEER_OUT_OF_ORDER_MESSAGES: u64 = 10000;
26
27#[derive(Serialize, Deserialize, Debug, Clone)]
29pub struct ModuleMultiplexed<MuxKey, Msg> {
30 pub key: MuxKey,
31 pub msg: Msg,
32}
33
34struct ModuleMultiplexerOutOfOrder<MuxKey, Msg> {
35 msgs: HashMap<MuxKey, VecDeque<(PeerId, Msg)>>,
37 callbacks: HashMap<MuxKey, VecDeque<oneshot::Sender<(PeerId, Msg)>>>,
39 peer_counts: HashMap<PeerId, u64>,
41}
42
43impl<MuxKey, Msg> Default for ModuleMultiplexerOutOfOrder<MuxKey, Msg> {
44 fn default() -> Self {
45 Self {
46 msgs: HashMap::new(),
47 callbacks: HashMap::new(),
48 peer_counts: HashMap::new(),
49 }
50 }
51}
52
53#[derive(Clone)]
61pub struct PeerConnectionMultiplexer<MuxKey, Msg> {
62 send_requests_tx: Sender<(Vec<PeerId>, MuxKey, Msg)>,
64 receive_callbacks_tx: Sender<Callback<MuxKey, Msg>>,
66 peer_bans_tx: Sender<PeerId>,
68}
69
70type Callback<MuxKey, Msg> = (MuxKey, oneshot::Sender<(PeerId, Msg)>);
71
72impl<MuxKey, Msg> PeerConnectionMultiplexer<MuxKey, Msg>
73where
74 Msg: Serialize + DeserializeOwned + Unpin + Send + Debug + 'static,
75 MuxKey: Serialize + DeserializeOwned + Unpin + Send + Debug + Eq + Hash + Clone + 'static,
76{
77 pub fn new(connections: PeerConnections<ModuleMultiplexed<MuxKey, Msg>>) -> Self {
78 let (send_requests_tx, send_requests_rx) = channel(1000);
79 let (receive_callbacks_tx, receive_callbacks_rx) = channel(1000);
80 let (peer_bans_tx, peer_bans_rx) = channel(1000);
81
82 spawn(
83 "peer connection multiplexer",
84 Self::run(
85 connections,
86 ModuleMultiplexerOutOfOrder::default(),
87 send_requests_rx,
88 receive_callbacks_rx,
89 peer_bans_rx,
90 ),
91 );
92
93 Self {
94 send_requests_tx,
95 receive_callbacks_tx,
96 peer_bans_tx,
97 }
98 }
99
100 async fn run(
101 mut connections: PeerConnections<ModuleMultiplexed<MuxKey, Msg>>,
102 mut out_of_order: ModuleMultiplexerOutOfOrder<MuxKey, Msg>,
103 mut send_requests_rx: Receiver<(Vec<PeerId>, MuxKey, Msg)>,
104 mut receive_callbacks_rx: Receiver<Callback<MuxKey, Msg>>,
105 mut peer_bans_rx: Receiver<PeerId>,
106 ) -> Cancellable<()> {
107 loop {
108 let mut key_inserted: Option<MuxKey> = None;
109 tokio::select! {
110 send_request = send_requests_rx.recv() => {
112 let (peers, key, msg) = send_request.ok_or(Cancelled)?;
113 connections.send(&peers, ModuleMultiplexed { key, msg }).await?;
114 }
115 peer_ban = peer_bans_rx.recv() => {
117 let peer = peer_ban.ok_or(Cancelled)?;
118 connections.ban_peer(peer).await;
119 }
120 receive_callback = receive_callbacks_rx.recv() => {
122 let (key, callback) = receive_callback.ok_or(Cancelled)?;
123 out_of_order.callbacks.entry(key.clone()).or_default().push_back(callback);
124 key_inserted = Some(key);
125 }
126 receive = connections.receive() => {
128 let (peer, ModuleMultiplexed { key, msg }) = receive?;
129 let peer_pending = out_of_order.peer_counts.entry(peer).or_default();
130 if *peer_pending > MAX_PEER_OUT_OF_ORDER_MESSAGES {
133 warn!(
134 target: LOG_NET_PEER,
135 "Peer {peer} has {peer_pending} pending messages. Dropping new message."
136 );
137 } else {
138 *peer_pending += 1;
139 out_of_order.msgs.entry(key.clone()).or_default().push_back((peer, msg));
140 key_inserted = Some(key);
141 }
142 }
143 }
144
145 if let Some(key) = key_inserted {
147 let callbacks = out_of_order.callbacks.entry(key.clone()).or_default();
148 let msgs = out_of_order.msgs.entry(key.clone()).or_default();
149
150 if !callbacks.is_empty() && !msgs.is_empty() {
151 let callback = callbacks.pop_front().expect("checked");
152 let (peer, msg) = msgs.pop_front().expect("checked");
153 let peer_pending = out_of_order.peer_counts.entry(peer).or_default();
154 *peer_pending -= 1;
155 callback.send((peer, msg)).map_err(|_| Cancelled)?;
156 }
157 }
158 }
159 }
160}
161
162#[async_trait]
163impl<MuxKey, Msg> IMuxPeerConnections<MuxKey, Msg> for PeerConnectionMultiplexer<MuxKey, Msg>
164where
165 Msg: Serialize + DeserializeOwned + Unpin + Send + Debug,
166 MuxKey: Serialize + DeserializeOwned + Unpin + Send + Debug + Eq + Hash + Clone,
167{
168 async fn send(&self, peers: &[PeerId], key: MuxKey, msg: Msg) -> Cancellable<()> {
169 debug!("Sending to {peers:?}/{key:?}, {msg:?}");
170 self.send_requests_tx
171 .send((peers.to_vec(), key, msg))
172 .await
173 .map_err(|_e| Cancelled)
174 }
175
176 async fn receive(&self, key: MuxKey) -> Cancellable<(PeerId, Msg)> {
178 let (callback_tx, callback_rx) = oneshot::channel();
179 self.receive_callbacks_tx
180 .send((key, callback_tx))
181 .await
182 .map_err(|_e| Cancelled)?;
183 callback_rx.await.map_err(|_e| Cancelled)
184 }
185
186 async fn ban_peer(&self, peer: PeerId) {
187 let _ = self.peer_bans_tx.send(peer).await;
189 }
190}
191
192#[cfg(test)]
193pub mod test {
194 use std::time::Duration;
195
196 use fedimint_core::net::peers::fake::make_fake_peer_connection;
197 use fedimint_core::net::peers::IMuxPeerConnections;
198 use fedimint_core::task::{self, TaskGroup};
199 use fedimint_core::PeerId;
200 use rand::rngs::OsRng;
201 use rand::seq::SliceRandom;
202 use rand::{thread_rng, Rng};
203
204 use crate::multiplexed::PeerConnectionMultiplexer;
205
206 #[test_log::test(tokio::test)]
214 async fn test_multiplexer() {
215 const NUM_MODULES: usize = 128;
216 const NUM_MSGS_PER_MODULE: usize = 128;
217 const NUM_REPEAT_TEST: usize = 10;
218
219 for _ in 0..NUM_REPEAT_TEST {
220 let task_group = TaskGroup::new();
221 let task_handle = task_group.make_handle();
222
223 let peer1 = PeerId::from(0);
224 let peer2 = PeerId::from(1);
225
226 let (conn1, conn2) = make_fake_peer_connection(peer1, peer2, 1000, task_handle.clone());
227 let (conn1, conn2) = (
228 PeerConnectionMultiplexer::new(conn1).into_dyn(),
229 PeerConnectionMultiplexer::new(conn2).into_dyn(),
230 );
231
232 let mut modules: Vec<_> = (0..NUM_MODULES).collect();
233 modules.shuffle(&mut thread_rng());
234
235 for mux_key in modules.clone() {
236 let conn1 = conn1.clone();
237 let task_handle = task_handle.clone();
238 task_group.spawn(format!("sender-{mux_key}"), move |_| async move {
239 for msg_i in 0..NUM_MSGS_PER_MODULE {
240 if OsRng.gen() {
242 task::sleep(Duration::from_millis(2)).await;
246 }
247 if task_handle.is_shutting_down() {
248 break;
249 }
250 conn1.send(&[peer2], mux_key, msg_i).await.unwrap();
251 }
252 });
253 }
254
255 modules.shuffle(&mut thread_rng());
256 for mux_key in modules.clone() {
257 let conn2 = conn2.clone();
258 task_group.spawn(format!("receiver-{mux_key}"), move |_| async move {
259 for msg_i in 0..NUM_MSGS_PER_MODULE {
260 if OsRng.gen() {
262 task::sleep(Duration::from_millis(1)).await;
263 }
264 assert_eq!(conn2.receive(mux_key).await.unwrap(), (peer1, msg_i));
265 }
266 });
267 }
268
269 task_group.join_all(None).await.expect("no failures");
270 }
271 }
272}