1use {
2 crate::{
3 client_connection::ClientConnection as BlockingClientConnection,
4 connection_cache_stats::{ConnectionCacheStats, CONNECTION_STAT_SUBMISSION_INTERVAL},
5 nonblocking::client_connection::ClientConnection as NonblockingClientConnection,
6 },
7 crossbeam_channel::{Receiver, RecvError, Sender},
8 indexmap::map::IndexMap,
9 log::*,
10 rand::{thread_rng, Rng},
11 solana_measure::measure::Measure,
12 solana_sdk::{signature::Keypair, timing::AtomicInterval},
13 std::{
14 net::SocketAddr,
15 sync::{atomic::Ordering, Arc, RwLock},
16 thread::{Builder, JoinHandle},
17 },
18 thiserror::Error,
19};
20
21const MAX_CONNECTIONS: usize = 1024;
23
24pub const DEFAULT_CONNECTION_POOL_SIZE: usize = 2;
26
27#[derive(Clone, Copy, Eq, Hash, PartialEq)]
28pub enum Protocol {
29 UDP,
30 QUIC,
31}
32
33pub trait ConnectionManager: Send + Sync + 'static {
34 type ConnectionPool: ConnectionPool;
35 type NewConnectionConfig: NewConnectionConfig;
36
37 const PROTOCOL: Protocol;
38
39 fn new_connection_pool(&self) -> Self::ConnectionPool;
40 fn new_connection_config(&self) -> Self::NewConnectionConfig;
41 fn update_key(&self, _key: &Keypair) -> Result<(), Box<dyn std::error::Error>>;
42}
43
44pub struct ConnectionCache<
45 R, S, T, > {
49 name: &'static str,
50 map: Arc<RwLock<IndexMap<SocketAddr, R>>>,
51 connection_manager: Arc<S>,
52 stats: Arc<ConnectionCacheStats>,
53 last_stats: AtomicInterval,
54 connection_pool_size: usize,
55 connection_config: Arc<T>,
56 sender: Sender<(usize, SocketAddr)>,
57}
58
59impl<P, M, C> ConnectionCache<P, M, C>
60where
61 P: ConnectionPool<NewConnectionConfig = C>,
62 M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
63 C: NewConnectionConfig,
64{
65 pub fn new(
66 name: &'static str,
67 connection_manager: M,
68 connection_pool_size: usize,
69 ) -> Result<Self, ClientError> {
70 let config = connection_manager.new_connection_config();
71 Ok(Self::new_with_config(
72 name,
73 connection_pool_size,
74 config,
75 connection_manager,
76 ))
77 }
78
79 pub fn new_with_config(
80 name: &'static str,
81 connection_pool_size: usize,
82 connection_config: C,
83 connection_manager: M,
84 ) -> Self {
85 info!("Creating ConnectionCache {name}, pool size: {connection_pool_size}");
86 let (sender, receiver) = crossbeam_channel::unbounded();
87
88 let map = Arc::new(RwLock::new(IndexMap::with_capacity(MAX_CONNECTIONS)));
89 let config = Arc::new(connection_config);
90 let connection_manager = Arc::new(connection_manager);
91 let connection_pool_size = 1.max(connection_pool_size); let stats = Arc::new(ConnectionCacheStats::default());
94
95 let _async_connection_thread =
96 Self::create_connection_async_thread(map.clone(), receiver, stats.clone());
97 Self {
98 name,
99 map,
100 stats,
101 connection_manager,
102 last_stats: AtomicInterval::default(),
103 connection_pool_size,
104 connection_config: config,
105 sender,
106 }
107 }
108
109 fn create_connection_async_thread(
111 map: Arc<RwLock<IndexMap<SocketAddr, P>>>,
112 receiver: Receiver<(usize, SocketAddr)>,
113 stats: Arc<ConnectionCacheStats>,
114 ) -> JoinHandle<()> {
115 Builder::new()
116 .name("solQAsynCon".to_string())
117 .spawn(move || loop {
118 let recv_result = receiver.recv();
119 match recv_result {
120 Err(RecvError) => {
121 break;
122 }
123 Ok((idx, addr)) => {
124 let map = map.read().unwrap();
125 let pool = map.get(&addr);
126 if let Some(pool) = pool {
127 let conn = pool.get(idx);
128 if let Ok(conn) = conn {
129 drop(map);
130 let conn = conn.new_blocking_connection(addr, stats.clone());
131 let result = conn.send_data(&[]);
132 debug!("Create async connection result {result:?} for {addr}");
133 }
134 }
135 }
136 }
137 })
138 .unwrap()
139 }
140
141 pub fn update_key(&self, key: &Keypair) -> Result<(), Box<dyn std::error::Error>> {
142 let mut map = self.map.write().unwrap();
143 map.clear();
144 self.connection_manager.update_key(key)
145 }
146 fn create_connection(
150 &self,
151 lock_timing_ms: &mut u64,
152 addr: &SocketAddr,
153 ) -> CreateConnectionResult<<P as ConnectionPool>::BaseClientConnection> {
154 let mut get_connection_map_lock_measure = Measure::start("get_connection_map_lock_measure");
155 let mut map = self.map.write().unwrap();
156 get_connection_map_lock_measure.stop();
157 *lock_timing_ms = lock_timing_ms.saturating_add(get_connection_map_lock_measure.as_ms());
158 let pool_status = map
162 .get(addr)
163 .map(|pool| pool.check_pool_status(self.connection_pool_size))
164 .unwrap_or(PoolStatus::Empty);
165
166 let (cache_hit, num_evictions, eviction_timing_ms) =
167 if matches!(pool_status, PoolStatus::Empty) {
168 Self::create_connection_internal(
169 &self.connection_config,
170 &self.connection_manager,
171 &mut map,
172 addr,
173 self.connection_pool_size,
174 None,
175 )
176 } else {
177 (true, 0, 0)
178 };
179
180 if matches!(pool_status, PoolStatus::PartiallyFull) {
181 debug!("Triggering async connection for {addr:?}");
183 Self::create_connection_internal(
184 &self.connection_config,
185 &self.connection_manager,
186 &mut map,
187 addr,
188 self.connection_pool_size,
189 Some(&self.sender),
190 );
191 }
192
193 let pool = map.get(addr).unwrap();
194 let connection = pool.borrow_connection();
195
196 CreateConnectionResult {
197 connection,
198 cache_hit,
199 connection_cache_stats: self.stats.clone(),
200 num_evictions,
201 eviction_timing_ms,
202 }
203 }
204
205 fn create_connection_internal(
206 config: &C,
207 connection_manager: &M,
208 map: &mut std::sync::RwLockWriteGuard<'_, IndexMap<SocketAddr, P>>,
209 addr: &SocketAddr,
210 connection_pool_size: usize,
211 async_connection_sender: Option<&Sender<(usize, SocketAddr)>>,
212 ) -> (bool, u64, u64) {
213 let mut num_evictions = 0;
215 let mut get_connection_cache_eviction_measure =
216 Measure::start("get_connection_cache_eviction_measure");
217 let existing_index = map.get_index_of(addr);
218 while map.len() >= MAX_CONNECTIONS {
219 let mut rng = thread_rng();
220 let n = rng.gen_range(0..MAX_CONNECTIONS);
221 if let Some(index) = existing_index {
222 if n == index {
223 continue;
224 }
225 }
226 map.swap_remove_index(n);
227 num_evictions += 1;
228 }
229 get_connection_cache_eviction_measure.stop();
230
231 let mut hit_cache = false;
232 map.entry(*addr)
233 .and_modify(|pool| {
234 if matches!(
235 pool.check_pool_status(connection_pool_size),
236 PoolStatus::PartiallyFull
237 ) {
238 let idx = pool.add_connection(config, addr);
239 if let Some(sender) = async_connection_sender {
240 debug!(
241 "Sending async connection creation {} for {addr}",
242 pool.num_connections() - 1
243 );
244 sender.send((idx, *addr)).unwrap();
245 };
246 } else {
247 hit_cache = true;
248 }
249 })
250 .or_insert_with(|| {
251 let mut pool = connection_manager.new_connection_pool();
252 pool.add_connection(config, addr);
253 pool
254 });
255 (
256 hit_cache,
257 num_evictions,
258 get_connection_cache_eviction_measure.as_ms(),
259 )
260 }
261
262 fn get_or_add_connection(
263 &self,
264 addr: &SocketAddr,
265 ) -> GetConnectionResult<<P as ConnectionPool>::BaseClientConnection> {
266 let mut get_connection_map_lock_measure = Measure::start("get_connection_map_lock_measure");
267 let map = self.map.read().unwrap();
268 get_connection_map_lock_measure.stop();
269
270 let mut lock_timing_ms = get_connection_map_lock_measure.as_ms();
271
272 let report_stats = self
273 .last_stats
274 .should_update(CONNECTION_STAT_SUBMISSION_INTERVAL);
275
276 let mut get_connection_map_measure = Measure::start("get_connection_hit_measure");
277 let CreateConnectionResult {
278 connection,
279 cache_hit,
280 connection_cache_stats,
281 num_evictions,
282 eviction_timing_ms,
283 } = match map.get(addr) {
284 Some(pool) => {
285 let pool_status = pool.check_pool_status(self.connection_pool_size);
286 match pool_status {
287 PoolStatus::Empty => {
288 drop(map);
290 self.create_connection(&mut lock_timing_ms, addr)
291 }
292 PoolStatus::PartiallyFull | PoolStatus::Full => {
293 let connection = pool.borrow_connection();
294 if matches!(pool_status, PoolStatus::PartiallyFull) {
295 debug!("Creating connection async for {addr}");
296 drop(map);
297 let mut map = self.map.write().unwrap();
298 Self::create_connection_internal(
299 &self.connection_config,
300 &self.connection_manager,
301 &mut map,
302 addr,
303 self.connection_pool_size,
304 Some(&self.sender),
305 );
306 }
307 CreateConnectionResult {
308 connection,
309 cache_hit: true,
310 connection_cache_stats: self.stats.clone(),
311 num_evictions: 0,
312 eviction_timing_ms: 0,
313 }
314 }
315 }
316 }
317 None => {
318 drop(map);
320 self.create_connection(&mut lock_timing_ms, addr)
321 }
322 };
323 get_connection_map_measure.stop();
324
325 GetConnectionResult {
326 connection,
327 cache_hit,
328 report_stats,
329 map_timing_ms: get_connection_map_measure.as_ms(),
330 lock_timing_ms,
331 connection_cache_stats,
332 num_evictions,
333 eviction_timing_ms,
334 }
335 }
336
337 fn get_connection_and_log_stats(
338 &self,
339 addr: &SocketAddr,
340 ) -> (
341 Arc<<P as ConnectionPool>::BaseClientConnection>,
342 Arc<ConnectionCacheStats>,
343 ) {
344 let mut get_connection_measure = Measure::start("get_connection_measure");
345 let GetConnectionResult {
346 connection,
347 cache_hit,
348 report_stats,
349 map_timing_ms,
350 lock_timing_ms,
351 connection_cache_stats,
352 num_evictions,
353 eviction_timing_ms,
354 } = self.get_or_add_connection(addr);
355
356 if report_stats {
357 connection_cache_stats.report(self.name);
358 }
359
360 if cache_hit {
361 connection_cache_stats
362 .cache_hits
363 .fetch_add(1, Ordering::Relaxed);
364 connection_cache_stats
365 .get_connection_hit_ms
366 .fetch_add(map_timing_ms, Ordering::Relaxed);
367 } else {
368 connection_cache_stats
369 .cache_misses
370 .fetch_add(1, Ordering::Relaxed);
371 connection_cache_stats
372 .get_connection_miss_ms
373 .fetch_add(map_timing_ms, Ordering::Relaxed);
374 connection_cache_stats
375 .cache_evictions
376 .fetch_add(num_evictions, Ordering::Relaxed);
377 connection_cache_stats
378 .eviction_time_ms
379 .fetch_add(eviction_timing_ms, Ordering::Relaxed);
380 }
381
382 get_connection_measure.stop();
383 connection_cache_stats
384 .get_connection_lock_ms
385 .fetch_add(lock_timing_ms, Ordering::Relaxed);
386 connection_cache_stats
387 .get_connection_ms
388 .fetch_add(get_connection_measure.as_ms(), Ordering::Relaxed);
389
390 (connection, connection_cache_stats)
391 }
392
393 pub fn get_connection(&self, addr: &SocketAddr) -> Arc<<<P as ConnectionPool>::BaseClientConnection as BaseClientConnection>::BlockingClientConnection>{
394 let (connection, connection_cache_stats) = self.get_connection_and_log_stats(addr);
395 connection.new_blocking_connection(*addr, connection_cache_stats)
396 }
397
398 pub fn get_nonblocking_connection(
399 &self,
400 addr: &SocketAddr,
401 ) -> Arc<<<P as ConnectionPool>::BaseClientConnection as BaseClientConnection>::NonblockingClientConnection>{
402 let (connection, connection_cache_stats) = self.get_connection_and_log_stats(addr);
403 connection.new_nonblocking_connection(*addr, connection_cache_stats)
404 }
405}
406
407#[derive(Error, Debug)]
408pub enum ConnectionPoolError {
409 #[error("connection index is out of range of the pool")]
410 IndexOutOfRange,
411}
412
413#[derive(Error, Debug)]
414pub enum ClientError {
415 #[error("IO error: {0:?}")]
416 IoError(#[from] std::io::Error),
417}
418
419pub trait NewConnectionConfig: Sized + Send + Sync + 'static {
420 fn new() -> Result<Self, ClientError>;
421}
422
423pub enum PoolStatus {
424 Empty,
425 PartiallyFull,
426 Full,
427}
428
429pub trait ConnectionPool: Send + Sync + 'static {
430 type NewConnectionConfig: NewConnectionConfig;
431 type BaseClientConnection: BaseClientConnection;
432
433 fn add_connection(&mut self, config: &Self::NewConnectionConfig, addr: &SocketAddr) -> usize;
435
436 fn num_connections(&self) -> usize;
438
439 fn get(&self, index: usize) -> Result<Arc<Self::BaseClientConnection>, ConnectionPoolError>;
441
442 fn borrow_connection(&self) -> Arc<Self::BaseClientConnection> {
445 let mut rng = thread_rng();
446 let n = rng.gen_range(0..self.num_connections());
447 self.get(n).expect("index is within num_connections")
448 }
449
450 fn check_pool_status(&self, required_pool_size: usize) -> PoolStatus {
453 if self.num_connections() == 0 {
454 PoolStatus::Empty
455 } else if self.num_connections() < required_pool_size {
456 PoolStatus::PartiallyFull
457 } else {
458 PoolStatus::Full
459 }
460 }
461
462 fn create_pool_entry(
463 &self,
464 config: &Self::NewConnectionConfig,
465 addr: &SocketAddr,
466 ) -> Arc<Self::BaseClientConnection>;
467}
468
469pub trait BaseClientConnection {
470 type BlockingClientConnection: BlockingClientConnection;
471 type NonblockingClientConnection: NonblockingClientConnection;
472
473 fn new_blocking_connection(
474 &self,
475 addr: SocketAddr,
476 stats: Arc<ConnectionCacheStats>,
477 ) -> Arc<Self::BlockingClientConnection>;
478
479 fn new_nonblocking_connection(
480 &self,
481 addr: SocketAddr,
482 stats: Arc<ConnectionCacheStats>,
483 ) -> Arc<Self::NonblockingClientConnection>;
484}
485
486struct GetConnectionResult<T> {
487 connection: Arc<T>,
488 cache_hit: bool,
489 report_stats: bool,
490 map_timing_ms: u64,
491 lock_timing_ms: u64,
492 connection_cache_stats: Arc<ConnectionCacheStats>,
493 num_evictions: u64,
494 eviction_timing_ms: u64,
495}
496
497struct CreateConnectionResult<T> {
498 connection: Arc<T>,
499 cache_hit: bool,
500 connection_cache_stats: Arc<ConnectionCacheStats>,
501 num_evictions: u64,
502 eviction_timing_ms: u64,
503}
504
505#[cfg(test)]
506mod tests {
507 use {
508 super::*,
509 crate::{
510 client_connection::ClientConnection as BlockingClientConnection,
511 nonblocking::client_connection::ClientConnection as NonblockingClientConnection,
512 },
513 async_trait::async_trait,
514 rand::{Rng, SeedableRng},
515 rand_chacha::ChaChaRng,
516 solana_sdk::transport::Result as TransportResult,
517 std::{
518 net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
519 sync::Arc,
520 },
521 };
522
523 struct MockUdpPool {
524 connections: Vec<Arc<MockUdp>>,
525 }
526 impl ConnectionPool for MockUdpPool {
527 type NewConnectionConfig = MockUdpConfig;
528 type BaseClientConnection = MockUdp;
529
530 fn add_connection(
532 &mut self,
533 config: &Self::NewConnectionConfig,
534 addr: &SocketAddr,
535 ) -> usize {
536 let connection = self.create_pool_entry(config, addr);
537 let idx = self.connections.len();
538 self.connections.push(connection);
539 idx
540 }
541
542 fn num_connections(&self) -> usize {
543 self.connections.len()
544 }
545
546 fn get(
547 &self,
548 index: usize,
549 ) -> Result<Arc<Self::BaseClientConnection>, ConnectionPoolError> {
550 self.connections
551 .get(index)
552 .cloned()
553 .ok_or(ConnectionPoolError::IndexOutOfRange)
554 }
555
556 fn create_pool_entry(
557 &self,
558 config: &Self::NewConnectionConfig,
559 _addr: &SocketAddr,
560 ) -> Arc<Self::BaseClientConnection> {
561 Arc::new(MockUdp(config.udp_socket.clone()))
562 }
563 }
564
565 struct MockUdpConfig {
566 udp_socket: Arc<UdpSocket>,
567 }
568
569 impl Default for MockUdpConfig {
570 fn default() -> Self {
571 Self {
572 udp_socket: Arc::new(
573 solana_net_utils::bind_with_any_port(IpAddr::V4(Ipv4Addr::UNSPECIFIED))
574 .expect("Unable to bind to UDP socket"),
575 ),
576 }
577 }
578 }
579
580 impl NewConnectionConfig for MockUdpConfig {
581 fn new() -> Result<Self, ClientError> {
582 Ok(Self {
583 udp_socket: Arc::new(
584 solana_net_utils::bind_with_any_port(IpAddr::V4(Ipv4Addr::UNSPECIFIED))
585 .map_err(Into::<ClientError>::into)?,
586 ),
587 })
588 }
589 }
590
591 struct MockUdp(Arc<UdpSocket>);
592 impl BaseClientConnection for MockUdp {
593 type BlockingClientConnection = MockUdpConnection;
594 type NonblockingClientConnection = MockUdpConnection;
595
596 fn new_blocking_connection(
597 &self,
598 addr: SocketAddr,
599 _stats: Arc<ConnectionCacheStats>,
600 ) -> Arc<Self::BlockingClientConnection> {
601 Arc::new(MockUdpConnection {
602 _socket: self.0.clone(),
603 addr,
604 })
605 }
606
607 fn new_nonblocking_connection(
608 &self,
609 addr: SocketAddr,
610 _stats: Arc<ConnectionCacheStats>,
611 ) -> Arc<Self::NonblockingClientConnection> {
612 Arc::new(MockUdpConnection {
613 _socket: self.0.clone(),
614 addr,
615 })
616 }
617 }
618
619 struct MockUdpConnection {
620 _socket: Arc<UdpSocket>,
621 addr: SocketAddr,
622 }
623
624 #[derive(Default)]
625 struct MockConnectionManager {}
626
627 impl ConnectionManager for MockConnectionManager {
628 type ConnectionPool = MockUdpPool;
629 type NewConnectionConfig = MockUdpConfig;
630
631 const PROTOCOL: Protocol = Protocol::QUIC;
632
633 fn new_connection_pool(&self) -> Self::ConnectionPool {
634 MockUdpPool {
635 connections: Vec::default(),
636 }
637 }
638
639 fn new_connection_config(&self) -> Self::NewConnectionConfig {
640 MockUdpConfig::new().unwrap()
641 }
642
643 fn update_key(&self, _key: &Keypair) -> Result<(), Box<dyn std::error::Error>> {
644 Ok(())
645 }
646 }
647
648 impl BlockingClientConnection for MockUdpConnection {
649 fn server_addr(&self) -> &SocketAddr {
650 &self.addr
651 }
652 fn send_data(&self, _buffer: &[u8]) -> TransportResult<()> {
653 unimplemented!()
654 }
655 fn send_data_async(&self, _data: Vec<u8>) -> TransportResult<()> {
656 unimplemented!()
657 }
658 fn send_data_batch(&self, _buffers: &[Vec<u8>]) -> TransportResult<()> {
659 unimplemented!()
660 }
661 fn send_data_batch_async(&self, _buffers: Vec<Vec<u8>>) -> TransportResult<()> {
662 unimplemented!()
663 }
664 }
665
666 #[async_trait]
667 impl NonblockingClientConnection for MockUdpConnection {
668 fn server_addr(&self) -> &SocketAddr {
669 &self.addr
670 }
671 async fn send_data(&self, _data: &[u8]) -> TransportResult<()> {
672 unimplemented!()
673 }
674 async fn send_data_batch(&self, _buffers: &[Vec<u8>]) -> TransportResult<()> {
675 unimplemented!()
676 }
677 }
678
679 fn get_addr(rng: &mut ChaChaRng) -> SocketAddr {
680 let a = rng.gen_range(1..255);
681 let b = rng.gen_range(1..255);
682 let c = rng.gen_range(1..255);
683 let d = rng.gen_range(1..255);
684
685 let addr_str = format!("{a}.{b}.{c}.{d}:80");
686
687 addr_str.parse().expect("Invalid address")
688 }
689
690 #[test]
691 fn test_connection_cache() {
692 solana_logger::setup();
693 let mut rng = ChaChaRng::seed_from_u64(42);
699
700 let connection_manager = MockConnectionManager::default();
706 let connection_cache = ConnectionCache::new(
707 "connection_cache_test",
708 connection_manager,
709 DEFAULT_CONNECTION_POOL_SIZE,
710 )
711 .unwrap();
712 let addrs = (0..MAX_CONNECTIONS)
713 .map(|_| {
714 let addr = get_addr(&mut rng);
715 connection_cache.get_connection(&addr);
716 addr
717 })
718 .collect::<Vec<_>>();
719 {
720 let map = connection_cache.map.read().unwrap();
721 assert!(map.len() == MAX_CONNECTIONS);
722 addrs.iter().for_each(|addr| {
723 let conn = &map.get(addr).expect("Address not found").get(0).unwrap();
724 let conn = conn.new_blocking_connection(*addr, connection_cache.stats.clone());
725 assert_eq!(
726 BlockingClientConnection::server_addr(&*conn).ip(),
727 addr.ip(),
728 );
729 assert_eq!(
730 NonblockingClientConnection::server_addr(&*conn).ip(),
731 addr.ip(),
732 );
733 });
734 }
735
736 let addr = &get_addr(&mut rng);
737 connection_cache.get_connection(addr);
738
739 let port = addr.port();
740 let addr_with_quic_port = SocketAddr::new(addr.ip(), port);
741 let map = connection_cache.map.read().unwrap();
742 assert!(map.len() == MAX_CONNECTIONS);
743 let _conn = map.get(&addr_with_quic_port).expect("Address not found");
744 }
745
746 #[test]
750 fn test_overflow_address() {
751 let port = u16::MAX;
752 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
753 let connection_manager = MockConnectionManager::default();
754 let connection_cache =
755 ConnectionCache::new("connection_cache_test", connection_manager, 1).unwrap();
756
757 let conn = connection_cache.get_connection(&addr);
758 assert_ne!(port, 0u16);
762 assert_eq!(BlockingClientConnection::server_addr(&*conn).port(), port);
763 assert_eq!(
764 NonblockingClientConnection::server_addr(&*conn).port(),
765 port
766 );
767 }
768}