solana_streamer/nonblocking/
connection_rate_limiter.rs1use {
2 governor::{DefaultDirectRateLimiter, DefaultKeyedRateLimiter, Quota, RateLimiter},
3 std::{net::IpAddr, num::NonZeroU32},
4};
5
6pub struct ConnectionRateLimiter {
7 limiter: DefaultKeyedRateLimiter<IpAddr>,
8}
9
10impl ConnectionRateLimiter {
11 pub fn new(limit_per_minute: u64) -> Self {
14 let quota =
15 Quota::per_minute(NonZeroU32::new(u32::try_from(limit_per_minute).unwrap()).unwrap());
16 Self {
17 limiter: DefaultKeyedRateLimiter::keyed(quota),
18 }
19 }
20
21 pub fn is_allowed(&self, ip: &IpAddr) -> bool {
23 if self.limiter.check_key(ip).is_ok() {
25 debug!("Request from IP {:?} allowed", ip);
26 true } else {
28 debug!("Request from IP {:?} blocked", ip);
29 false }
31 }
32
33 pub fn retain_recent(&self) {
36 self.limiter.retain_recent()
37 }
38
39 pub fn len(&self) -> usize {
41 self.limiter.len()
42 }
43
44 pub fn is_empty(&self) -> bool {
46 self.limiter.is_empty()
47 }
48}
49
50pub struct TotalConnectionRateLimiter {
53 limiter: DefaultDirectRateLimiter,
54}
55
56impl TotalConnectionRateLimiter {
57 pub fn new(limit_per_second: u64) -> Self {
59 let quota =
60 Quota::per_second(NonZeroU32::new(u32::try_from(limit_per_second).unwrap()).unwrap());
61 Self {
62 limiter: RateLimiter::direct(quota),
63 }
64 }
65
66 pub fn is_allowed(&self) -> bool {
68 if self.limiter.check().is_ok() {
69 true } else {
71 false }
73 }
74}
75
76#[cfg(test)]
77pub mod test {
78 use {super::*, std::net::Ipv4Addr};
79
80 #[tokio::test]
81 async fn test_total_connection_rate_limiter() {
82 let limiter = TotalConnectionRateLimiter::new(2);
83 assert!(limiter.is_allowed());
84 assert!(limiter.is_allowed());
85 assert!(!limiter.is_allowed());
86 }
87
88 #[tokio::test]
89 async fn test_connection_rate_limiter() {
90 let limiter = ConnectionRateLimiter::new(4);
91 let ip1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
92 assert!(limiter.is_allowed(&ip1));
93 assert!(limiter.is_allowed(&ip1));
94 assert!(limiter.is_allowed(&ip1));
95 assert!(limiter.is_allowed(&ip1));
96 assert!(!limiter.is_allowed(&ip1));
97
98 assert!(limiter.len() == 1);
99 let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
100 assert!(limiter.is_allowed(&ip2));
101 assert!(limiter.len() == 2);
102 assert!(limiter.is_allowed(&ip2));
103 assert!(limiter.is_allowed(&ip2));
104 assert!(limiter.is_allowed(&ip2));
105 assert!(!limiter.is_allowed(&ip2));
106 }
107}