use {
crate::nonblocking::{keyed_rate_limiter::KeyedRateLimiter, rate_limiter::RateLimiter},
std::{net::IpAddr, time::Duration},
};
pub struct ConnectionRateLimiter {
limiter: KeyedRateLimiter<IpAddr>,
}
impl ConnectionRateLimiter {
pub fn new(limit_per_minute: u64) -> Self {
Self {
limiter: KeyedRateLimiter::new(limit_per_minute, Duration::from_secs(60)),
}
}
pub fn is_allowed(&self, ip: &IpAddr) -> bool {
if self.limiter.check_and_update(*ip) {
debug!("Request from IP {:?} allowed", ip);
true } else {
debug!("Request from IP {:?} blocked", ip);
false }
}
pub fn retain_recent(&self) {
self.limiter.retain_recent()
}
pub fn len(&self) -> usize {
self.limiter.len()
}
pub fn is_empty(&self) -> bool {
self.limiter.is_empty()
}
}
pub struct TotalConnectionRateLimiter {
limiter: RateLimiter,
}
impl TotalConnectionRateLimiter {
pub fn new(limit_per_second: u64) -> Self {
Self {
limiter: RateLimiter::new(limit_per_second, Duration::from_secs(1)),
}
}
pub fn is_allowed(&mut self) -> bool {
self.limiter.check_and_update()
}
}
#[cfg(test)]
pub mod test {
use {super::*, std::net::Ipv4Addr};
#[tokio::test]
async fn test_total_connection_rate_limiter() {
let mut limiter = TotalConnectionRateLimiter::new(2);
assert!(limiter.is_allowed());
assert!(limiter.is_allowed());
assert!(!limiter.is_allowed());
}
#[tokio::test]
async fn test_connection_rate_limiter() {
let limiter = ConnectionRateLimiter::new(4);
let ip1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
assert!(limiter.is_allowed(&ip1));
assert!(limiter.is_allowed(&ip1));
assert!(limiter.is_allowed(&ip1));
assert!(limiter.is_allowed(&ip1));
assert!(!limiter.is_allowed(&ip1));
assert!(limiter.len() == 1);
let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
assert!(limiter.is_allowed(&ip2));
assert!(limiter.len() == 2);
assert!(limiter.is_allowed(&ip2));
assert!(limiter.is_allowed(&ip2));
assert!(limiter.is_allowed(&ip2));
assert!(!limiter.is_allowed(&ip2));
}
}