kraken_async_rs/rate_limiting/
ttl_cache.rs

1use std::cmp::Ordering;
2use std::collections::{BTreeMap, BTreeSet};
3use time::OffsetDateTime;
4
5#[inline]
6fn now() -> i128 {
7    OffsetDateTime::now_utc().unix_timestamp_nanos() / 1000
8}
9
10/// A time-to-live entry that should remain available until the provided `ttl` value.
11///
12/// These are used to store order ids, `id`, and the creation time, `data`, of orders, but was left
13/// generic for potential later use.
14#[derive(Debug, Clone, Copy)]
15pub struct TtlEntry<K, T>
16where
17    K: Ord + Clone,
18    T: Clone,
19{
20    pub id: K,
21    ttl: i128,
22    pub data: T,
23}
24
25impl<K, T> TtlEntry<K, T>
26where
27    K: Ord + Clone,
28    T: Clone,
29{
30    pub fn new(id: K, ttl_us: i128, data: T) -> TtlEntry<K, T> {
31        TtlEntry {
32            id,
33            ttl: now() + ttl_us,
34            data,
35        }
36    }
37}
38
39impl<K, T> Eq for TtlEntry<K, T>
40where
41    K: Ord + Clone,
42    T: Clone,
43{
44}
45
46impl<K, T> PartialEq<Self> for TtlEntry<K, T>
47where
48    K: Ord + Clone,
49    T: Clone,
50{
51    fn eq(&self, other: &Self) -> bool {
52        self.id == other.id && self.ttl == other.ttl
53    }
54}
55
56impl<K, T> PartialOrd<Self> for TtlEntry<K, T>
57where
58    K: Ord + Clone,
59    T: Clone,
60{
61    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
62        Some(self.ttl.cmp(&other.ttl))
63    }
64}
65
66impl<K, T> Ord for TtlEntry<K, T>
67where
68    K: Ord + Clone,
69    T: Clone,
70{
71    fn cmp(&self, other: &Self) -> Ordering {
72        self.ttl.cmp(&other.ttl)
73    }
74}
75
76/// A time-to-live cache that removes values when they expire. This is used to store and look up
77/// orders, to determine how old they are for rate limiting penalties when editing or cancelling.
78#[derive(Debug)]
79pub struct TtlCache<K, T>
80where
81    K: Ord + Clone,
82    T: Clone,
83{
84    ids: BTreeMap<K, TtlEntry<K, T>>,
85    ttls: BTreeSet<TtlEntry<K, T>>,
86}
87
88impl<K, T> Default for TtlCache<K, T>
89where
90    K: Ord + Clone,
91    T: Clone,
92{
93    fn default() -> Self {
94        TtlCache::new()
95    }
96}
97
98impl<K, T> TtlCache<K, T>
99where
100    K: Ord + Clone,
101    T: Clone,
102{
103    /// Create a new, empty cache.
104    pub fn new() -> TtlCache<K, T> {
105        TtlCache {
106            ids: Default::default(),
107            ttls: Default::default(),
108        }
109    }
110
111    /// Insert the provided [TtlEntry] by it's id for future lookup. Entries beyond their ttl are
112    /// removed automatically any time the `remove`, `get`, or `contains` methods are called.
113    pub fn insert(&mut self, ttl_entry: TtlEntry<K, T>) -> Option<TtlEntry<K, T>> {
114        self.ttls.insert(ttl_entry.clone());
115        self.ids.insert(ttl_entry.id.clone(), ttl_entry)
116    }
117
118    /// Removes an entry manually, returning if the entry was removed.
119    ///
120    /// The cache is cleaned of any expired values after checking if this value was removed.
121    ///
122    /// This follows the same semantics as [BTreeSet]'s `remove` method.
123    pub fn remove(&mut self, ttl_entry: &TtlEntry<K, T>) -> bool {
124        self.ids.remove(&ttl_entry.id);
125        let removed = self.ttls.remove(ttl_entry);
126        self.remove_expired_values();
127
128        removed
129    }
130
131    /// Returns if the provided key is in the cache, after removing any expired values.
132    pub fn contains(&mut self, id: &K) -> bool {
133        self.remove_expired_values();
134        self.ids.contains_key(id)
135    }
136
137    /// Gets a [TtlEntry] by id after removing any expired values.
138    pub fn get(&mut self, id: &K) -> Option<&TtlEntry<K, T>> {
139        self.remove_expired_values();
140        self.ids.get(id)
141    }
142
143    fn remove_expired_values(&mut self) {
144        let now = now();
145        let mut to_remove = Vec::new();
146
147        for entry in &self.ttls {
148            if entry.ttl < now {
149                to_remove.push(entry.clone());
150            }
151        }
152
153        for entry in to_remove {
154            self.ids.remove(&entry.id);
155            self.ttls.remove(&entry);
156        }
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use crate::rate_limiting::ttl_cache::{TtlCache, TtlEntry};
163    use std::cmp::Ordering::{Equal, Greater, Less};
164    use std::thread::sleep;
165    use std::time::Duration as StdDuration;
166    use time::Duration;
167
168    #[test]
169    fn test_ttl_entry_eq_partial_cmp() {
170        let entry_0 = TtlEntry {
171            id: "0x1",
172            ttl: 0,
173            data: 0,
174        };
175        let entry_1 = TtlEntry {
176            id: "0x1",
177            ttl: 1,
178            data: 0,
179        };
180        let entry_2 = TtlEntry {
181            id: "0x1",
182            ttl: 1,
183            data: 0,
184        };
185
186        assert_ne!(entry_0, entry_1);
187        assert_ne!(entry_0, entry_2);
188        assert_eq!(entry_1, entry_2);
189
190        assert_eq!(Less, entry_0.partial_cmp(&entry_1).unwrap());
191        assert_eq!(Less, entry_0.partial_cmp(&entry_2).unwrap());
192        assert_eq!(Equal, entry_1.partial_cmp(&entry_2).unwrap());
193    }
194
195    #[test]
196    fn test_ttl_entry_ord() {
197        let entry_0 = TtlEntry {
198            id: "0x1",
199            ttl: 0,
200            data: 0,
201        };
202        let entry_1 = TtlEntry {
203            id: "0x1",
204            ttl: 1,
205            data: 0,
206        };
207        let entry_2 = TtlEntry {
208            id: "0x2",
209            ttl: 2,
210            data: 0,
211        };
212
213        assert_eq!(Less, entry_0.cmp(&entry_1));
214        assert_eq!(Less, entry_0.cmp(&entry_2));
215        assert_eq!(Less, entry_1.cmp(&entry_2));
216
217        assert_eq!(Greater, entry_1.cmp(&entry_0));
218        assert_eq!(Greater, entry_2.cmp(&entry_1));
219        assert_eq!(Greater, entry_2.cmp(&entry_1));
220
221        assert_eq!(Equal, entry_0.cmp(&entry_0));
222        assert_eq!(Equal, entry_1.cmp(&entry_1));
223        assert_eq!(Equal, entry_2.cmp(&entry_2));
224    }
225
226    #[test]
227    fn test_ttl_cache_insert_remove() {
228        let ttl = Duration::seconds(1).whole_microseconds();
229        let entry_1 = TtlEntry::new("0x1".to_string(), ttl, 0);
230        let entry_2 = TtlEntry::new("0x2".to_string(), ttl, 0);
231
232        let mut cache = TtlCache::new();
233
234        cache.insert(entry_1.clone());
235
236        assert!(cache.contains(&entry_1.id));
237        assert!(!cache.contains(&entry_2.id));
238
239        assert!(cache.remove(&entry_1));
240
241        assert!(!cache.contains(&entry_1.id));
242        assert!(!cache.contains(&entry_2.id));
243    }
244
245    #[test]
246    fn test_ttl_cache_insert_get() {
247        let ttl = Duration::seconds(1).whole_microseconds();
248        let entry_1 = TtlEntry::new("0x1".to_string(), ttl, 0);
249
250        let mut cache = TtlCache::new();
251
252        cache.insert(entry_1.clone());
253
254        assert!(cache.contains(&entry_1.id));
255
256        let result = cache.get(&entry_1.id);
257        assert!(result.is_some());
258        assert_eq!(entry_1, *result.unwrap())
259    }
260
261    #[test]
262    fn test_ttl_cache_expiry() {
263        let entry_1 = TtlEntry::new(
264            "0x1".to_string(),
265            Duration::milliseconds(250).whole_microseconds(),
266            "",
267        );
268        let entry_2 = TtlEntry::new(
269            "0x2".to_string(),
270            Duration::milliseconds(500).whole_microseconds(),
271            "",
272        );
273
274        let mut cache = TtlCache::new();
275
276        cache.insert(entry_1.clone());
277        cache.insert(entry_2.clone());
278
279        assert!(cache.contains(&entry_1.id));
280        assert!(cache.contains(&entry_2.id));
281
282        // let first entry expire
283        sleep(StdDuration::from_millis(300));
284        assert!(!cache.contains(&entry_1.id));
285        assert!(cache.contains(&entry_2.id));
286
287        // let second entry expire
288        sleep(StdDuration::from_millis(300));
289        assert!(!cache.contains(&entry_1.id));
290        assert!(!cache.contains(&entry_2.id));
291    }
292}