solana_streamer/nonblocking/
connection_rate_limiter.rs

1use {
2    governor::{DefaultDirectRateLimiter, DefaultKeyedRateLimiter, Quota, RateLimiter},
3    std::{net::IpAddr, num::NonZeroU32},
4};
5
6pub struct ConnectionRateLimiter {
7    limiter: DefaultKeyedRateLimiter<IpAddr>,
8}
9
10impl ConnectionRateLimiter {
11    /// Create a new rate limiter per IpAddr. The rate is specified as the count per minute to allow for
12    /// less frequent connections.
13    pub fn new(limit_per_minute: u64) -> Self {
14        let quota =
15            Quota::per_minute(NonZeroU32::new(u32::try_from(limit_per_minute).unwrap()).unwrap());
16        Self {
17            limiter: DefaultKeyedRateLimiter::keyed(quota),
18        }
19    }
20
21    /// Check if the connection from the said `ip` is allowed.
22    pub fn is_allowed(&self, ip: &IpAddr) -> bool {
23        // Acquire a permit from the rate limiter for the given IP address
24        if self.limiter.check_key(ip).is_ok() {
25            debug!("Request from IP {:?} allowed", ip);
26            true // Request allowed
27        } else {
28            debug!("Request from IP {:?} blocked", ip);
29            false // Request blocked
30        }
31    }
32
33    /// retain only keys whose rate-limiting start date is within the rate-limiting interval.
34    /// Otherwise drop them as inactive
35    pub fn retain_recent(&self) {
36        self.limiter.retain_recent()
37    }
38
39    /// Returns the number of "live" keys in the rate limiter.
40    pub fn len(&self) -> usize {
41        self.limiter.len()
42    }
43
44    /// Returns `true` if the rate limiter has no keys in it.
45    pub fn is_empty(&self) -> bool {
46        self.limiter.is_empty()
47    }
48}
49
50/// Connection rate limiter for enforcing connection rates from
51/// all clients.
52pub struct TotalConnectionRateLimiter {
53    limiter: DefaultDirectRateLimiter,
54}
55
56impl TotalConnectionRateLimiter {
57    /// Create a new rate limiter. The rate is specified as the count per second.
58    pub fn new(limit_per_second: u64) -> Self {
59        let quota =
60            Quota::per_second(NonZeroU32::new(u32::try_from(limit_per_second).unwrap()).unwrap());
61        Self {
62            limiter: RateLimiter::direct(quota),
63        }
64    }
65
66    /// Check if a connection is allowed.
67    pub fn is_allowed(&self) -> bool {
68        if self.limiter.check().is_ok() {
69            true // Request allowed
70        } else {
71            false // Request blocked
72        }
73    }
74}
75
76#[cfg(test)]
77pub mod test {
78    use {super::*, std::net::Ipv4Addr};
79
80    #[tokio::test]
81    async fn test_total_connection_rate_limiter() {
82        let limiter = TotalConnectionRateLimiter::new(2);
83        assert!(limiter.is_allowed());
84        assert!(limiter.is_allowed());
85        assert!(!limiter.is_allowed());
86    }
87
88    #[tokio::test]
89    async fn test_connection_rate_limiter() {
90        let limiter = ConnectionRateLimiter::new(4);
91        let ip1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
92        assert!(limiter.is_allowed(&ip1));
93        assert!(limiter.is_allowed(&ip1));
94        assert!(limiter.is_allowed(&ip1));
95        assert!(limiter.is_allowed(&ip1));
96        assert!(!limiter.is_allowed(&ip1));
97
98        assert!(limiter.len() == 1);
99        let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
100        assert!(limiter.is_allowed(&ip2));
101        assert!(limiter.len() == 2);
102        assert!(limiter.is_allowed(&ip2));
103        assert!(limiter.is_allowed(&ip2));
104        assert!(limiter.is_allowed(&ip2));
105        assert!(!limiter.is_allowed(&ip2));
106    }
107}