kraken_async_rs/rate_limiting/
ttl_cache.rs1use std::cmp::Ordering;
2use std::collections::{BTreeMap, BTreeSet};
3use time::OffsetDateTime;
4
5#[inline]
6fn now() -> i128 {
7 OffsetDateTime::now_utc().unix_timestamp_nanos() / 1000
8}
9
10#[derive(Debug, Clone, Copy)]
15pub struct TtlEntry<K, T>
16where
17 K: Ord + Clone,
18 T: Clone,
19{
20 pub id: K,
21 ttl: i128,
22 pub data: T,
23}
24
25impl<K, T> TtlEntry<K, T>
26where
27 K: Ord + Clone,
28 T: Clone,
29{
30 pub fn new(id: K, ttl_us: i128, data: T) -> TtlEntry<K, T> {
31 TtlEntry {
32 id,
33 ttl: now() + ttl_us,
34 data,
35 }
36 }
37}
38
39impl<K, T> Eq for TtlEntry<K, T>
40where
41 K: Ord + Clone,
42 T: Clone,
43{
44}
45
46impl<K, T> PartialEq<Self> for TtlEntry<K, T>
47where
48 K: Ord + Clone,
49 T: Clone,
50{
51 fn eq(&self, other: &Self) -> bool {
52 self.id == other.id && self.ttl == other.ttl
53 }
54}
55
56impl<K, T> PartialOrd<Self> for TtlEntry<K, T>
57where
58 K: Ord + Clone,
59 T: Clone,
60{
61 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
62 Some(self.ttl.cmp(&other.ttl))
63 }
64}
65
66impl<K, T> Ord for TtlEntry<K, T>
67where
68 K: Ord + Clone,
69 T: Clone,
70{
71 fn cmp(&self, other: &Self) -> Ordering {
72 self.ttl.cmp(&other.ttl)
73 }
74}
75
76#[derive(Debug)]
79pub struct TtlCache<K, T>
80where
81 K: Ord + Clone,
82 T: Clone,
83{
84 ids: BTreeMap<K, TtlEntry<K, T>>,
85 ttls: BTreeSet<TtlEntry<K, T>>,
86}
87
88impl<K, T> Default for TtlCache<K, T>
89where
90 K: Ord + Clone,
91 T: Clone,
92{
93 fn default() -> Self {
94 TtlCache::new()
95 }
96}
97
98impl<K, T> TtlCache<K, T>
99where
100 K: Ord + Clone,
101 T: Clone,
102{
103 pub fn new() -> TtlCache<K, T> {
105 TtlCache {
106 ids: Default::default(),
107 ttls: Default::default(),
108 }
109 }
110
111 pub fn insert(&mut self, ttl_entry: TtlEntry<K, T>) -> Option<TtlEntry<K, T>> {
114 self.ttls.insert(ttl_entry.clone());
115 self.ids.insert(ttl_entry.id.clone(), ttl_entry)
116 }
117
118 pub fn remove(&mut self, ttl_entry: &TtlEntry<K, T>) -> bool {
124 self.ids.remove(&ttl_entry.id);
125 let removed = self.ttls.remove(ttl_entry);
126 self.remove_expired_values();
127
128 removed
129 }
130
131 pub fn contains(&mut self, id: &K) -> bool {
133 self.remove_expired_values();
134 self.ids.contains_key(id)
135 }
136
137 pub fn get(&mut self, id: &K) -> Option<&TtlEntry<K, T>> {
139 self.remove_expired_values();
140 self.ids.get(id)
141 }
142
143 fn remove_expired_values(&mut self) {
144 let now = now();
145 let mut to_remove = Vec::new();
146
147 for entry in &self.ttls {
148 if entry.ttl < now {
149 to_remove.push(entry.clone());
150 }
151 }
152
153 for entry in to_remove {
154 self.ids.remove(&entry.id);
155 self.ttls.remove(&entry);
156 }
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use crate::rate_limiting::ttl_cache::{TtlCache, TtlEntry};
163 use std::cmp::Ordering::{Equal, Greater, Less};
164 use std::thread::sleep;
165 use std::time::Duration as StdDuration;
166 use time::Duration;
167
168 #[test]
169 fn test_ttl_entry_eq_partial_cmp() {
170 let entry_0 = TtlEntry {
171 id: "0x1",
172 ttl: 0,
173 data: 0,
174 };
175 let entry_1 = TtlEntry {
176 id: "0x1",
177 ttl: 1,
178 data: 0,
179 };
180 let entry_2 = TtlEntry {
181 id: "0x1",
182 ttl: 1,
183 data: 0,
184 };
185
186 assert_ne!(entry_0, entry_1);
187 assert_ne!(entry_0, entry_2);
188 assert_eq!(entry_1, entry_2);
189
190 assert_eq!(Less, entry_0.partial_cmp(&entry_1).unwrap());
191 assert_eq!(Less, entry_0.partial_cmp(&entry_2).unwrap());
192 assert_eq!(Equal, entry_1.partial_cmp(&entry_2).unwrap());
193 }
194
195 #[test]
196 fn test_ttl_entry_ord() {
197 let entry_0 = TtlEntry {
198 id: "0x1",
199 ttl: 0,
200 data: 0,
201 };
202 let entry_1 = TtlEntry {
203 id: "0x1",
204 ttl: 1,
205 data: 0,
206 };
207 let entry_2 = TtlEntry {
208 id: "0x2",
209 ttl: 2,
210 data: 0,
211 };
212
213 assert_eq!(Less, entry_0.cmp(&entry_1));
214 assert_eq!(Less, entry_0.cmp(&entry_2));
215 assert_eq!(Less, entry_1.cmp(&entry_2));
216
217 assert_eq!(Greater, entry_1.cmp(&entry_0));
218 assert_eq!(Greater, entry_2.cmp(&entry_1));
219 assert_eq!(Greater, entry_2.cmp(&entry_1));
220
221 assert_eq!(Equal, entry_0.cmp(&entry_0));
222 assert_eq!(Equal, entry_1.cmp(&entry_1));
223 assert_eq!(Equal, entry_2.cmp(&entry_2));
224 }
225
226 #[test]
227 fn test_ttl_cache_insert_remove() {
228 let ttl = Duration::seconds(1).whole_microseconds();
229 let entry_1 = TtlEntry::new("0x1".to_string(), ttl, 0);
230 let entry_2 = TtlEntry::new("0x2".to_string(), ttl, 0);
231
232 let mut cache = TtlCache::new();
233
234 cache.insert(entry_1.clone());
235
236 assert!(cache.contains(&entry_1.id));
237 assert!(!cache.contains(&entry_2.id));
238
239 assert!(cache.remove(&entry_1));
240
241 assert!(!cache.contains(&entry_1.id));
242 assert!(!cache.contains(&entry_2.id));
243 }
244
245 #[test]
246 fn test_ttl_cache_insert_get() {
247 let ttl = Duration::seconds(1).whole_microseconds();
248 let entry_1 = TtlEntry::new("0x1".to_string(), ttl, 0);
249
250 let mut cache = TtlCache::new();
251
252 cache.insert(entry_1.clone());
253
254 assert!(cache.contains(&entry_1.id));
255
256 let result = cache.get(&entry_1.id);
257 assert!(result.is_some());
258 assert_eq!(entry_1, *result.unwrap())
259 }
260
261 #[test]
262 fn test_ttl_cache_expiry() {
263 let entry_1 = TtlEntry::new(
264 "0x1".to_string(),
265 Duration::milliseconds(250).whole_microseconds(),
266 "",
267 );
268 let entry_2 = TtlEntry::new(
269 "0x2".to_string(),
270 Duration::milliseconds(500).whole_microseconds(),
271 "",
272 );
273
274 let mut cache = TtlCache::new();
275
276 cache.insert(entry_1.clone());
277 cache.insert(entry_2.clone());
278
279 assert!(cache.contains(&entry_1.id));
280 assert!(cache.contains(&entry_2.id));
281
282 sleep(StdDuration::from_millis(300));
284 assert!(!cache.contains(&entry_1.id));
285 assert!(cache.contains(&entry_2.id));
286
287 sleep(StdDuration::from_millis(300));
289 assert!(!cache.contains(&entry_1.id));
290 assert!(!cache.contains(&entry_2.id));
291 }
292}