kraken_async_rs/rate_limiting/
keyed_rate_limits.rs1use async_rate_limit::limiters::{RateLimiter, VariableCostRateLimiter};
2use async_rate_limit::sliding_window::SlidingWindowRateLimiter;
3use std::collections::BTreeMap;
4use std::time::Duration;
5
6pub fn new_public_rate_limiter() -> SlidingWindowRateLimiter {
10 SlidingWindowRateLimiter::new(Duration::from_secs(1), 1)
11}
12
13#[derive(Debug, Clone)]
19pub struct KeyedRateLimiter<K>
20where
21 K: Ord,
22{
23 rate_limiters: BTreeMap<K, SlidingWindowRateLimiter>,
24 default: fn() -> SlidingWindowRateLimiter,
25}
26
27impl<K> Default for KeyedRateLimiter<K>
28where
29 K: Ord,
30{
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl<K> KeyedRateLimiter<K>
37where
38 K: Ord,
39{
40 pub fn new() -> Self {
42 KeyedRateLimiter {
43 rate_limiters: Default::default(),
44 default: new_public_rate_limiter,
45 }
46 }
47
48 pub fn add_rate_limiter(
54 &mut self,
55 key: K,
56 rate_limiter: SlidingWindowRateLimiter,
57 ) -> Option<SlidingWindowRateLimiter> {
58 self.rate_limiters.insert(key, rate_limiter)
59 }
60
61 pub fn remove_rate_limiter(&mut self, key: &K) -> Option<SlidingWindowRateLimiter> {
66 self.rate_limiters.remove(key)
67 }
68
69 pub async fn wait_until_ready(&mut self, key: K) {
72 self.rate_limiters
73 .entry(key)
74 .or_insert((self.default)())
75 .wait_until_ready()
76 .await
77 }
78
79 pub async fn wait_with_cost(&mut self, cost: usize, key: K) {
82 self.rate_limiters
83 .entry(key)
84 .or_insert((self.default)())
85 .wait_with_cost(cost)
86 .await
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use crate::rate_limiting::keyed_rate_limits::KeyedRateLimiter;
93 use async_rate_limit::sliding_window::SlidingWindowRateLimiter;
94 use std::time::Duration;
95 use tokio::time::{pause, Instant};
96
97 #[test]
98 fn test_add_remove() {
99 let mut limiter = KeyedRateLimiter::new();
100
101 let sub_limiter_1 = SlidingWindowRateLimiter::new(Duration::from_secs(1), 1);
102 let sub_limiter_2 = SlidingWindowRateLimiter::new(Duration::from_secs(1), 2);
103
104 let added = limiter.add_rate_limiter("k1", sub_limiter_1.clone());
105 assert!(added.is_none());
106
107 let added = limiter.add_rate_limiter("k2", sub_limiter_2.clone());
108 assert!(added.is_none());
109
110 assert_eq!(2, limiter.rate_limiters.len());
111
112 let removed = limiter.remove_rate_limiter(&"k1");
113 assert!(removed.is_some());
114
115 let removed = limiter.remove_rate_limiter(&"k2");
116 assert!(removed.is_some());
117
118 assert_eq!(0, limiter.rate_limiters.len());
119 }
120
121 #[tokio::test]
122 async fn test_waiting_separately() {
123 pause();
124
125 let mut limiter = KeyedRateLimiter::new();
126
127 let sub_limiter_1 = SlidingWindowRateLimiter::new(Duration::from_secs(1), 1);
128 let sub_limiter_2 = SlidingWindowRateLimiter::new(Duration::from_secs(1), 2);
129
130 limiter.add_rate_limiter("k1", sub_limiter_1.clone());
131 limiter.add_rate_limiter("k2", sub_limiter_2.clone());
132
133 let start = Instant::now();
134
135 for _ in 0..3 {
136 limiter.wait_until_ready("k1").await;
137 }
138
139 let mid = Instant::now();
140
141 for _ in 0..2 {
142 limiter.wait_with_cost(2, "k2").await;
143 }
144
145 let end = Instant::now();
146
147 let elapsed_start_mid = mid - start;
149
150 let elapsed_mid_end = end - mid;
152
153 assert!(elapsed_start_mid > Duration::from_secs(2));
154 assert!(elapsed_start_mid < Duration::from_millis(3300));
155
156 assert!(elapsed_mid_end > Duration::from_secs(1));
157 assert!(elapsed_mid_end < Duration::from_millis(2200));
158 }
159
160 #[tokio::test]
161 async fn test_waiting_separately_default() {
162 pause();
163
164 let mut limiter = KeyedRateLimiter::new();
165
166 let sub_limiter_1 = SlidingWindowRateLimiter::new(Duration::from_secs(2), 1);
167
168 limiter.add_rate_limiter("k1", sub_limiter_1.clone());
169
170 let start = Instant::now();
171
172 for _ in 0..3 {
173 limiter.wait_until_ready("k1").await;
174 }
175
176 let mid = Instant::now();
177
178 for _ in 0..3 {
179 limiter.wait_with_cost(1, "k2").await;
180 }
181
182 let end = Instant::now();
183
184 let elapsed_start_mid = mid - start;
186
187 let elapsed_mid_end = end - mid;
189
190 assert!(elapsed_start_mid > Duration::from_secs(4));
191 assert!(elapsed_start_mid < Duration::from_millis(4300));
192
193 println!("{:?}", elapsed_mid_end);
194 assert!(elapsed_mid_end > Duration::from_secs(2));
195 assert!(elapsed_mid_end < Duration::from_millis(2200));
196 }
197}