fedimint_server/net/
queue.rs

1use std::collections::VecDeque;
2
3use serde::{Deserialize, Serialize};
4use tracing::{debug, trace};
5
6#[derive(Debug, Clone, Eq, PartialEq)]
7pub struct MessageQueue<M> {
8    pub(super) queue: VecDeque<UniqueMessage<M>>,
9    pub(super) next_id: MessageId,
10}
11
12#[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, Ord, PartialOrd)]
13pub struct MessageId(pub u64);
14
15#[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, Ord, PartialOrd)]
16pub struct UniqueMessage<M> {
17    pub id: MessageId,
18    pub msg: M,
19}
20
21impl MessageId {
22    pub fn increment(self) -> MessageId {
23        MessageId(self.0 + 1)
24    }
25}
26
27impl<M> Default for MessageQueue<M> {
28    fn default() -> Self {
29        MessageQueue {
30            queue: VecDeque::default(),
31            next_id: MessageId(1),
32        }
33    }
34}
35
36impl<M> MessageQueue<M>
37where
38    M: Clone,
39{
40    pub fn push(&mut self, msg: M) -> UniqueMessage<M> {
41        let id_msg = UniqueMessage {
42            id: self.next_id,
43            msg,
44        };
45
46        self.queue.push_back(id_msg.clone());
47        self.next_id = self.next_id.increment();
48
49        id_msg
50    }
51
52    pub fn ack(&mut self, msg_id: MessageId) {
53        debug!("Received ACK for {:?}", msg_id);
54        while self.queue.front().is_some_and(|msg| msg.id <= msg_id) {
55            let msg = self.queue.pop_front().expect("Checked in while head");
56            trace!("Removing message {:?} from resend buffer", msg.id);
57        }
58    }
59
60    pub fn iter(&self) -> impl Iterator<Item = &UniqueMessage<M>> {
61        self.queue.iter()
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use crate::net::queue::{MessageId, MessageQueue};
68
69    #[test]
70    fn test_queue() {
71        fn assert_contains(queue: &MessageQueue<u64>, iter: impl Iterator<Item = u64>) {
72            let mut queue_iter = queue.iter();
73
74            for i in iter {
75                let umsg = queue_iter.next().unwrap();
76                assert_eq!(umsg.msg, 42 * i);
77                assert_eq!(umsg.id.0, i + 1);
78            }
79
80            assert_eq!(queue_iter.next(), None);
81        }
82
83        let mut queue = MessageQueue::default();
84
85        for i in 0u64..10 {
86            let umsg = queue.push(42 * i);
87            assert_eq!(umsg.msg, 42 * i);
88            assert_eq!(umsg.id.0, i + 1);
89        }
90
91        assert_eq!(queue.iter().count(), 10);
92        assert_contains(&queue, 0..10);
93
94        queue.ack(MessageId(1));
95        assert_contains(&queue, 1..10);
96
97        queue.ack(MessageId(4));
98        assert_contains(&queue, 4..10);
99
100        queue.ack(MessageId(2)); // TODO: should that throw an error?
101        assert_contains(&queue, 4..10);
102    }
103}