fedimint_server/net/
queue.rs1use std::collections::VecDeque;
2
3use serde::{Deserialize, Serialize};
4use tracing::{debug, trace};
5
6#[derive(Debug, Clone, Eq, PartialEq)]
7pub struct MessageQueue<M> {
8 pub(super) queue: VecDeque<UniqueMessage<M>>,
9 pub(super) next_id: MessageId,
10}
11
12#[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, Ord, PartialOrd)]
13pub struct MessageId(pub u64);
14
15#[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, Ord, PartialOrd)]
16pub struct UniqueMessage<M> {
17 pub id: MessageId,
18 pub msg: M,
19}
20
21impl MessageId {
22 pub fn increment(self) -> MessageId {
23 MessageId(self.0 + 1)
24 }
25}
26
27impl<M> Default for MessageQueue<M> {
28 fn default() -> Self {
29 MessageQueue {
30 queue: VecDeque::default(),
31 next_id: MessageId(1),
32 }
33 }
34}
35
36impl<M> MessageQueue<M>
37where
38 M: Clone,
39{
40 pub fn push(&mut self, msg: M) -> UniqueMessage<M> {
41 let id_msg = UniqueMessage {
42 id: self.next_id,
43 msg,
44 };
45
46 self.queue.push_back(id_msg.clone());
47 self.next_id = self.next_id.increment();
48
49 id_msg
50 }
51
52 pub fn ack(&mut self, msg_id: MessageId) {
53 debug!("Received ACK for {:?}", msg_id);
54 while self.queue.front().is_some_and(|msg| msg.id <= msg_id) {
55 let msg = self.queue.pop_front().expect("Checked in while head");
56 trace!("Removing message {:?} from resend buffer", msg.id);
57 }
58 }
59
60 pub fn iter(&self) -> impl Iterator<Item = &UniqueMessage<M>> {
61 self.queue.iter()
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use crate::net::queue::{MessageId, MessageQueue};
68
69 #[test]
70 fn test_queue() {
71 fn assert_contains(queue: &MessageQueue<u64>, iter: impl Iterator<Item = u64>) {
72 let mut queue_iter = queue.iter();
73
74 for i in iter {
75 let umsg = queue_iter.next().unwrap();
76 assert_eq!(umsg.msg, 42 * i);
77 assert_eq!(umsg.id.0, i + 1);
78 }
79
80 assert_eq!(queue_iter.next(), None);
81 }
82
83 let mut queue = MessageQueue::default();
84
85 for i in 0u64..10 {
86 let umsg = queue.push(42 * i);
87 assert_eq!(umsg.msg, 42 * i);
88 assert_eq!(umsg.id.0, i + 1);
89 }
90
91 assert_eq!(queue.iter().count(), 10);
92 assert_contains(&queue, 0..10);
93
94 queue.ack(MessageId(1));
95 assert_contains(&queue, 1..10);
96
97 queue.ack(MessageId(4));
98 assert_contains(&queue, 4..10);
99
100 queue.ack(MessageId(2)); assert_contains(&queue, 4..10);
102 }
103}