kraken_async_rs/rate_limiting/
keyed_rate_limits.rs

1use async_rate_limit::limiters::{RateLimiter, VariableCostRateLimiter};
2use async_rate_limit::sliding_window::SlidingWindowRateLimiter;
3use std::collections::BTreeMap;
4use std::time::Duration;
5
6/// Create a new public rate limiter.
7///
8/// All tiers for Kraken are limited to on the order of one request per second.
9pub fn new_public_rate_limiter() -> SlidingWindowRateLimiter {
10    SlidingWindowRateLimiter::new(Duration::from_secs(1), 1)
11}
12
13/// A rate limiter that utilizes a [BTreeMap] to map K -> [SlidingWindowRateLimiter], allowing for
14/// a per-argument rate limiter.
15///
16/// This is used for several endpoints that are rate limited by IP and trading pair, so each pair
17/// is given a unique rate limiter.
18#[derive(Debug, Clone)]
19pub struct KeyedRateLimiter<K>
20where
21    K: Ord,
22{
23    rate_limiters: BTreeMap<K, SlidingWindowRateLimiter>,
24    default: fn() -> SlidingWindowRateLimiter,
25}
26
27impl<K> Default for KeyedRateLimiter<K>
28where
29    K: Ord,
30{
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl<K> KeyedRateLimiter<K>
37where
38    K: Ord,
39{
40    /// Create an empty instance with no rate limiters.
41    pub fn new() -> Self {
42        KeyedRateLimiter {
43            rate_limiters: Default::default(),
44            default: new_public_rate_limiter,
45        }
46    }
47
48    /// Add a rate limiter implementation for a given key, such that `self.wait_until_ready(key)`
49    /// will use this rate limiter.
50    ///
51    /// This can overwrite previous rate limiters if the key already exists and returns/follows the
52    /// same semantics as [BTreeMap]'s insert method.
53    pub fn add_rate_limiter(
54        &mut self,
55        key: K,
56        rate_limiter: SlidingWindowRateLimiter,
57    ) -> Option<SlidingWindowRateLimiter> {
58        self.rate_limiters.insert(key, rate_limiter)
59    }
60
61    /// Remove a rate limiter from the internal map. This may result in subsequent usages of this
62    /// key using a default rate limiter.
63    ///
64    /// This follows the same return semantics as [BTreeMap]'s remove method.
65    pub fn remove_rate_limiter(&mut self, key: &K) -> Option<SlidingWindowRateLimiter> {
66        self.rate_limiters.remove(key)
67    }
68
69    /// Follows the same semantics as [SlidingWindowRateLimiter], except it looks up a rate limiter
70    /// by key, and creates a rate limiter if none is found.
71    pub async fn wait_until_ready(&mut self, key: K) {
72        self.rate_limiters
73            .entry(key)
74            .or_insert((self.default)())
75            .wait_until_ready()
76            .await
77    }
78
79    /// Follows the same semantics as [SlidingWindowRateLimiter], except it looks up a rate limiter
80    /// by key, and creates a rate limiter if none is found.
81    pub async fn wait_with_cost(&mut self, cost: usize, key: K) {
82        self.rate_limiters
83            .entry(key)
84            .or_insert((self.default)())
85            .wait_with_cost(cost)
86            .await
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use crate::rate_limiting::keyed_rate_limits::KeyedRateLimiter;
93    use async_rate_limit::sliding_window::SlidingWindowRateLimiter;
94    use std::time::Duration;
95    use tokio::time::{pause, Instant};
96
97    #[test]
98    fn test_add_remove() {
99        let mut limiter = KeyedRateLimiter::new();
100
101        let sub_limiter_1 = SlidingWindowRateLimiter::new(Duration::from_secs(1), 1);
102        let sub_limiter_2 = SlidingWindowRateLimiter::new(Duration::from_secs(1), 2);
103
104        let added = limiter.add_rate_limiter("k1", sub_limiter_1.clone());
105        assert!(added.is_none());
106
107        let added = limiter.add_rate_limiter("k2", sub_limiter_2.clone());
108        assert!(added.is_none());
109
110        assert_eq!(2, limiter.rate_limiters.len());
111
112        let removed = limiter.remove_rate_limiter(&"k1");
113        assert!(removed.is_some());
114
115        let removed = limiter.remove_rate_limiter(&"k2");
116        assert!(removed.is_some());
117
118        assert_eq!(0, limiter.rate_limiters.len());
119    }
120
121    #[tokio::test]
122    async fn test_waiting_separately() {
123        pause();
124
125        let mut limiter = KeyedRateLimiter::new();
126
127        let sub_limiter_1 = SlidingWindowRateLimiter::new(Duration::from_secs(1), 1);
128        let sub_limiter_2 = SlidingWindowRateLimiter::new(Duration::from_secs(1), 2);
129
130        limiter.add_rate_limiter("k1", sub_limiter_1.clone());
131        limiter.add_rate_limiter("k2", sub_limiter_2.clone());
132
133        let start = Instant::now();
134
135        for _ in 0..3 {
136            limiter.wait_until_ready("k1").await;
137        }
138
139        let mid = Instant::now();
140
141        for _ in 0..2 {
142            limiter.wait_with_cost(2, "k2").await;
143        }
144
145        let end = Instant::now();
146
147        // three calls to the first rate limiter should wait twice, taking 2s
148        let elapsed_start_mid = mid - start;
149
150        // 2 calls to the second rate limiter should wait once for 1s
151        let elapsed_mid_end = end - mid;
152
153        assert!(elapsed_start_mid > Duration::from_secs(2));
154        assert!(elapsed_start_mid < Duration::from_millis(3300));
155
156        assert!(elapsed_mid_end > Duration::from_secs(1));
157        assert!(elapsed_mid_end < Duration::from_millis(2200));
158    }
159
160    #[tokio::test]
161    async fn test_waiting_separately_default() {
162        pause();
163
164        let mut limiter = KeyedRateLimiter::new();
165
166        let sub_limiter_1 = SlidingWindowRateLimiter::new(Duration::from_secs(2), 1);
167
168        limiter.add_rate_limiter("k1", sub_limiter_1.clone());
169
170        let start = Instant::now();
171
172        for _ in 0..3 {
173            limiter.wait_until_ready("k1").await;
174        }
175
176        let mid = Instant::now();
177
178        for _ in 0..3 {
179            limiter.wait_with_cost(1, "k2").await;
180        }
181
182        let end = Instant::now();
183
184        // three calls to the first rate limiter should wait twice, taking 4s
185        let elapsed_start_mid = mid - start;
186
187        // 3 calls to the second (default-inserted) rate limiter should wait twice for 2s total
188        let elapsed_mid_end = end - mid;
189
190        assert!(elapsed_start_mid > Duration::from_secs(4));
191        assert!(elapsed_start_mid < Duration::from_millis(4300));
192
193        println!("{:?}", elapsed_mid_end);
194        assert!(elapsed_mid_end > Duration::from_secs(2));
195        assert!(elapsed_mid_end < Duration::from_millis(2200));
196    }
197}