use {
crate::nonblocking::rate_limiter::RateLimiter,
dashmap::DashMap,
std::{hash::Hash, time::Duration},
};
pub struct KeyedRateLimiter<K> {
limiters: DashMap<K, RateLimiter>,
interval: Duration,
limit: u64,
}
impl<K> KeyedRateLimiter<K>
where
K: Eq + Hash,
{
pub fn new(limit: u64, interval: Duration) -> Self {
Self {
limiters: DashMap::default(),
interval,
limit,
}
}
pub fn check_and_update(&self, key: K) -> bool {
let allowed = match self.limiters.entry(key) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
let limiter = entry.get_mut();
limiter.check_and_update()
}
dashmap::mapref::entry::Entry::Vacant(entry) => entry
.insert(RateLimiter::new(self.limit, self.interval))
.value_mut()
.check_and_update(),
};
allowed
}
pub fn retain_recent(&self) {
let now = tokio::time::Instant::now();
self.limiters.retain(|_key, limiter| {
now.duration_since(*limiter.throttle_start_instant()) <= self.interval
});
}
pub fn len(&self) -> usize {
self.limiters.len()
}
pub fn is_empty(&self) -> bool {
self.limiters.is_empty()
}
}
#[cfg(test)]
pub mod test {
use {super::*, tokio::time::sleep};
#[allow(clippy::len_zero)]
#[tokio::test]
async fn test_rate_limiter() {
let limiter = KeyedRateLimiter::<u64>::new(2, Duration::from_millis(100));
assert!(limiter.len() == 0);
assert!(limiter.is_empty());
assert!(limiter.check_and_update(1));
assert!(limiter.check_and_update(1));
assert!(!limiter.check_and_update(1));
assert!(limiter.len() == 1);
assert!(limiter.check_and_update(2));
assert!(limiter.check_and_update(2));
assert!(!limiter.check_and_update(2));
assert!(limiter.len() == 2);
sleep(Duration::from_millis(150)).await;
assert!(limiter.len() == 2);
assert!(limiter.check_and_update(1));
assert!(limiter.check_and_update(1));
assert!(!limiter.check_and_update(1));
assert!(limiter.check_and_update(2));
assert!(limiter.check_and_update(2));
assert!(!limiter.check_and_update(2));
assert!(limiter.len() == 2);
sleep(Duration::from_millis(150)).await;
assert!(limiter.check_and_update(1));
assert!(limiter.check_and_update(1));
assert!(!limiter.check_and_update(1));
limiter.retain_recent();
assert!(limiter.len() == 1);
}
}