1use {
2 crate::{
3 client_connection::ClientConnection as BlockingClientConnection,
4 connection_cache_stats::{ConnectionCacheStats, CONNECTION_STAT_SUBMISSION_INTERVAL},
5 nonblocking::client_connection::ClientConnection as NonblockingClientConnection,
6 },
7 crossbeam_channel::{Receiver, RecvError, Sender},
8 indexmap::map::IndexMap,
9 log::*,
10 rand::{thread_rng, Rng},
11 solana_keypair::Keypair,
12 solana_measure::measure::Measure,
13 solana_time_utils::AtomicInterval,
14 std::{
15 net::SocketAddr,
16 sync::{atomic::Ordering, Arc, RwLock},
17 thread::{Builder, JoinHandle},
18 },
19 thiserror::Error,
20};
21
22const MAX_CONNECTIONS: usize = 1024;
24
25pub const DEFAULT_CONNECTION_POOL_SIZE: usize = 2;
27
28#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
29pub enum Protocol {
30 UDP,
31 QUIC,
32}
33
34pub trait ConnectionManager: Send + Sync + 'static {
35 type ConnectionPool: ConnectionPool;
36 type NewConnectionConfig: NewConnectionConfig;
37
38 const PROTOCOL: Protocol;
39
40 fn new_connection_pool(&self) -> Self::ConnectionPool;
41 fn new_connection_config(&self) -> Self::NewConnectionConfig;
42 fn update_key(&self, _key: &Keypair) -> Result<(), Box<dyn std::error::Error>>;
43}
44
45pub struct ConnectionCache<
46 R, S, T, > {
50 name: &'static str,
51 map: Arc<RwLock<IndexMap<SocketAddr, R>>>,
52 connection_manager: Arc<S>,
53 stats: Arc<ConnectionCacheStats>,
54 last_stats: AtomicInterval,
55 connection_pool_size: usize,
56 connection_config: Arc<T>,
57 sender: Sender<(usize, SocketAddr)>,
58}
59
60impl<P, M, C> ConnectionCache<P, M, C>
61where
62 P: ConnectionPool<NewConnectionConfig = C>,
63 M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
64 C: NewConnectionConfig,
65{
66 pub fn new(
67 name: &'static str,
68 connection_manager: M,
69 connection_pool_size: usize,
70 ) -> Result<Self, ClientError> {
71 let config = connection_manager.new_connection_config();
72 Ok(Self::new_with_config(
73 name,
74 connection_pool_size,
75 config,
76 connection_manager,
77 ))
78 }
79
80 pub fn new_with_config(
81 name: &'static str,
82 connection_pool_size: usize,
83 connection_config: C,
84 connection_manager: M,
85 ) -> Self {
86 info!("Creating ConnectionCache {name}, pool size: {connection_pool_size}");
87 let (sender, receiver) = crossbeam_channel::unbounded();
88
89 let map = Arc::new(RwLock::new(IndexMap::with_capacity(MAX_CONNECTIONS)));
90 let config = Arc::new(connection_config);
91 let connection_manager = Arc::new(connection_manager);
92 let connection_pool_size = 1.max(connection_pool_size); let stats = Arc::new(ConnectionCacheStats::default());
95
96 let _async_connection_thread =
97 Self::create_connection_async_thread(map.clone(), receiver, stats.clone());
98 Self {
99 name,
100 map,
101 stats,
102 connection_manager,
103 last_stats: AtomicInterval::default(),
104 connection_pool_size,
105 connection_config: config,
106 sender,
107 }
108 }
109
110 fn create_connection_async_thread(
112 map: Arc<RwLock<IndexMap<SocketAddr, P>>>,
113 receiver: Receiver<(usize, SocketAddr)>,
114 stats: Arc<ConnectionCacheStats>,
115 ) -> JoinHandle<()> {
116 Builder::new()
117 .name("solQAsynCon".to_string())
118 .spawn(move || loop {
119 let recv_result = receiver.recv();
120 match recv_result {
121 Err(RecvError) => {
122 break;
123 }
124 Ok((idx, addr)) => {
125 let map = map.read().unwrap();
126 let pool = map.get(&addr);
127 if let Some(pool) = pool {
128 let conn = pool.get(idx);
129 if let Ok(conn) = conn {
130 drop(map);
131 let conn = conn.new_blocking_connection(addr, stats.clone());
132 let result = conn.send_data(&[]);
133 debug!("Create async connection result {result:?} for {addr}");
134 }
135 }
136 }
137 }
138 })
139 .unwrap()
140 }
141
142 pub fn update_key(&self, key: &Keypair) -> Result<(), Box<dyn std::error::Error>> {
143 let mut map = self.map.write().unwrap();
144 map.clear();
145 self.connection_manager.update_key(key)
146 }
147 fn create_connection(
151 &self,
152 lock_timing_ms: &mut u64,
153 addr: &SocketAddr,
154 ) -> CreateConnectionResult<<P as ConnectionPool>::BaseClientConnection> {
155 let mut get_connection_map_lock_measure = Measure::start("get_connection_map_lock_measure");
156 let mut map = self.map.write().unwrap();
157 get_connection_map_lock_measure.stop();
158 *lock_timing_ms = lock_timing_ms.saturating_add(get_connection_map_lock_measure.as_ms());
159 let pool_status = map
163 .get(addr)
164 .map(|pool| pool.check_pool_status(self.connection_pool_size))
165 .unwrap_or(PoolStatus::Empty);
166
167 let (cache_hit, num_evictions, eviction_timing_ms) =
168 if matches!(pool_status, PoolStatus::Empty) {
169 Self::create_connection_internal(
170 &self.connection_config,
171 &self.connection_manager,
172 &mut map,
173 addr,
174 self.connection_pool_size,
175 None,
176 )
177 } else {
178 (true, 0, 0)
179 };
180
181 if matches!(pool_status, PoolStatus::PartiallyFull) {
182 debug!("Triggering async connection for {addr:?}");
184 Self::create_connection_internal(
185 &self.connection_config,
186 &self.connection_manager,
187 &mut map,
188 addr,
189 self.connection_pool_size,
190 Some(&self.sender),
191 );
192 }
193
194 let pool = map.get(addr).unwrap();
195 let connection = pool.borrow_connection();
196
197 CreateConnectionResult {
198 connection,
199 cache_hit,
200 connection_cache_stats: self.stats.clone(),
201 num_evictions,
202 eviction_timing_ms,
203 }
204 }
205
206 fn create_connection_internal(
207 config: &C,
208 connection_manager: &M,
209 map: &mut std::sync::RwLockWriteGuard<'_, IndexMap<SocketAddr, P>>,
210 addr: &SocketAddr,
211 connection_pool_size: usize,
212 async_connection_sender: Option<&Sender<(usize, SocketAddr)>>,
213 ) -> (bool, u64, u64) {
214 let mut num_evictions = 0;
216 let mut get_connection_cache_eviction_measure =
217 Measure::start("get_connection_cache_eviction_measure");
218 let existing_index = map.get_index_of(addr);
219 while map.len() >= MAX_CONNECTIONS {
220 let mut rng = thread_rng();
221 let n = rng.gen_range(0..MAX_CONNECTIONS);
222 if let Some(index) = existing_index {
223 if n == index {
224 continue;
225 }
226 }
227 map.swap_remove_index(n);
228 num_evictions += 1;
229 }
230 get_connection_cache_eviction_measure.stop();
231
232 let mut hit_cache = false;
233 map.entry(*addr)
234 .and_modify(|pool| {
235 if matches!(
236 pool.check_pool_status(connection_pool_size),
237 PoolStatus::PartiallyFull
238 ) {
239 let idx = pool.add_connection(config, addr);
240 if let Some(sender) = async_connection_sender {
241 debug!(
242 "Sending async connection creation {} for {addr}",
243 pool.num_connections() - 1
244 );
245 sender.send((idx, *addr)).unwrap();
246 };
247 } else {
248 hit_cache = true;
249 }
250 })
251 .or_insert_with(|| {
252 let mut pool = connection_manager.new_connection_pool();
253 pool.add_connection(config, addr);
254 pool
255 });
256 (
257 hit_cache,
258 num_evictions,
259 get_connection_cache_eviction_measure.as_ms(),
260 )
261 }
262
263 fn get_or_add_connection(
264 &self,
265 addr: &SocketAddr,
266 ) -> GetConnectionResult<<P as ConnectionPool>::BaseClientConnection> {
267 let mut get_connection_map_lock_measure = Measure::start("get_connection_map_lock_measure");
268 let map = self.map.read().unwrap();
269 get_connection_map_lock_measure.stop();
270
271 let mut lock_timing_ms = get_connection_map_lock_measure.as_ms();
272
273 let report_stats = self
274 .last_stats
275 .should_update(CONNECTION_STAT_SUBMISSION_INTERVAL);
276
277 let mut get_connection_map_measure = Measure::start("get_connection_hit_measure");
278 let CreateConnectionResult {
279 connection,
280 cache_hit,
281 connection_cache_stats,
282 num_evictions,
283 eviction_timing_ms,
284 } = match map.get(addr) {
285 Some(pool) => {
286 let pool_status = pool.check_pool_status(self.connection_pool_size);
287 match pool_status {
288 PoolStatus::Empty => {
289 drop(map);
291 self.create_connection(&mut lock_timing_ms, addr)
292 }
293 PoolStatus::PartiallyFull | PoolStatus::Full => {
294 let connection = pool.borrow_connection();
295 if matches!(pool_status, PoolStatus::PartiallyFull) {
296 debug!("Creating connection async for {addr}");
297 drop(map);
298 let mut map = self.map.write().unwrap();
299 Self::create_connection_internal(
300 &self.connection_config,
301 &self.connection_manager,
302 &mut map,
303 addr,
304 self.connection_pool_size,
305 Some(&self.sender),
306 );
307 }
308 CreateConnectionResult {
309 connection,
310 cache_hit: true,
311 connection_cache_stats: self.stats.clone(),
312 num_evictions: 0,
313 eviction_timing_ms: 0,
314 }
315 }
316 }
317 }
318 None => {
319 drop(map);
321 self.create_connection(&mut lock_timing_ms, addr)
322 }
323 };
324 get_connection_map_measure.stop();
325
326 GetConnectionResult {
327 connection,
328 cache_hit,
329 report_stats,
330 map_timing_ms: get_connection_map_measure.as_ms(),
331 lock_timing_ms,
332 connection_cache_stats,
333 num_evictions,
334 eviction_timing_ms,
335 }
336 }
337
338 fn get_connection_and_log_stats(
339 &self,
340 addr: &SocketAddr,
341 ) -> (
342 Arc<<P as ConnectionPool>::BaseClientConnection>,
343 Arc<ConnectionCacheStats>,
344 ) {
345 let mut get_connection_measure = Measure::start("get_connection_measure");
346 let GetConnectionResult {
347 connection,
348 cache_hit,
349 report_stats,
350 map_timing_ms,
351 lock_timing_ms,
352 connection_cache_stats,
353 num_evictions,
354 eviction_timing_ms,
355 } = self.get_or_add_connection(addr);
356
357 if report_stats {
358 connection_cache_stats.report(self.name);
359 }
360
361 if cache_hit {
362 connection_cache_stats
363 .cache_hits
364 .fetch_add(1, Ordering::Relaxed);
365 connection_cache_stats
366 .get_connection_hit_ms
367 .fetch_add(map_timing_ms, Ordering::Relaxed);
368 } else {
369 connection_cache_stats
370 .cache_misses
371 .fetch_add(1, Ordering::Relaxed);
372 connection_cache_stats
373 .get_connection_miss_ms
374 .fetch_add(map_timing_ms, Ordering::Relaxed);
375 connection_cache_stats
376 .cache_evictions
377 .fetch_add(num_evictions, Ordering::Relaxed);
378 connection_cache_stats
379 .eviction_time_ms
380 .fetch_add(eviction_timing_ms, Ordering::Relaxed);
381 }
382
383 get_connection_measure.stop();
384 connection_cache_stats
385 .get_connection_lock_ms
386 .fetch_add(lock_timing_ms, Ordering::Relaxed);
387 connection_cache_stats
388 .get_connection_ms
389 .fetch_add(get_connection_measure.as_ms(), Ordering::Relaxed);
390
391 (connection, connection_cache_stats)
392 }
393
394 pub fn get_connection(&self, addr: &SocketAddr) -> Arc<<<P as ConnectionPool>::BaseClientConnection as BaseClientConnection>::BlockingClientConnection>{
395 let (connection, connection_cache_stats) = self.get_connection_and_log_stats(addr);
396 connection.new_blocking_connection(*addr, connection_cache_stats)
397 }
398
399 pub fn get_nonblocking_connection(
400 &self,
401 addr: &SocketAddr,
402 ) -> Arc<<<P as ConnectionPool>::BaseClientConnection as BaseClientConnection>::NonblockingClientConnection>{
403 let (connection, connection_cache_stats) = self.get_connection_and_log_stats(addr);
404 connection.new_nonblocking_connection(*addr, connection_cache_stats)
405 }
406}
407
408#[derive(Error, Debug)]
409pub enum ConnectionPoolError {
410 #[error("connection index is out of range of the pool")]
411 IndexOutOfRange,
412}
413
414#[derive(Error, Debug)]
415pub enum ClientError {
416 #[error("IO error: {0:?}")]
417 IoError(#[from] std::io::Error),
418}
419
420pub trait NewConnectionConfig: Sized + Send + Sync + 'static {
421 fn new() -> Result<Self, ClientError>;
422}
423
424pub enum PoolStatus {
425 Empty,
426 PartiallyFull,
427 Full,
428}
429
430pub trait ConnectionPool: Send + Sync + 'static {
431 type NewConnectionConfig: NewConnectionConfig;
432 type BaseClientConnection: BaseClientConnection;
433
434 fn add_connection(&mut self, config: &Self::NewConnectionConfig, addr: &SocketAddr) -> usize;
436
437 fn num_connections(&self) -> usize;
439
440 fn get(&self, index: usize) -> Result<Arc<Self::BaseClientConnection>, ConnectionPoolError>;
442
443 fn borrow_connection(&self) -> Arc<Self::BaseClientConnection> {
446 let mut rng = thread_rng();
447 let n = rng.gen_range(0..self.num_connections());
448 self.get(n).expect("index is within num_connections")
449 }
450
451 fn check_pool_status(&self, required_pool_size: usize) -> PoolStatus {
454 if self.num_connections() == 0 {
455 PoolStatus::Empty
456 } else if self.num_connections() < required_pool_size {
457 PoolStatus::PartiallyFull
458 } else {
459 PoolStatus::Full
460 }
461 }
462
463 fn create_pool_entry(
464 &self,
465 config: &Self::NewConnectionConfig,
466 addr: &SocketAddr,
467 ) -> Arc<Self::BaseClientConnection>;
468}
469
470pub trait BaseClientConnection {
471 type BlockingClientConnection: BlockingClientConnection;
472 type NonblockingClientConnection: NonblockingClientConnection;
473
474 fn new_blocking_connection(
475 &self,
476 addr: SocketAddr,
477 stats: Arc<ConnectionCacheStats>,
478 ) -> Arc<Self::BlockingClientConnection>;
479
480 fn new_nonblocking_connection(
481 &self,
482 addr: SocketAddr,
483 stats: Arc<ConnectionCacheStats>,
484 ) -> Arc<Self::NonblockingClientConnection>;
485}
486
487struct GetConnectionResult<T> {
488 connection: Arc<T>,
489 cache_hit: bool,
490 report_stats: bool,
491 map_timing_ms: u64,
492 lock_timing_ms: u64,
493 connection_cache_stats: Arc<ConnectionCacheStats>,
494 num_evictions: u64,
495 eviction_timing_ms: u64,
496}
497
498struct CreateConnectionResult<T> {
499 connection: Arc<T>,
500 cache_hit: bool,
501 connection_cache_stats: Arc<ConnectionCacheStats>,
502 num_evictions: u64,
503 eviction_timing_ms: u64,
504}
505
506#[cfg(test)]
507mod tests {
508 use {
509 super::*,
510 crate::{
511 client_connection::ClientConnection as BlockingClientConnection,
512 nonblocking::client_connection::ClientConnection as NonblockingClientConnection,
513 },
514 async_trait::async_trait,
515 rand::{Rng, SeedableRng},
516 rand_chacha::ChaChaRng,
517 solana_net_utils::SocketConfig,
518 solana_transaction_error::TransportResult,
519 std::{
520 net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
521 sync::Arc,
522 },
523 };
524
525 struct MockUdpPool {
526 connections: Vec<Arc<MockUdp>>,
527 }
528 impl ConnectionPool for MockUdpPool {
529 type NewConnectionConfig = MockUdpConfig;
530 type BaseClientConnection = MockUdp;
531
532 fn add_connection(
534 &mut self,
535 config: &Self::NewConnectionConfig,
536 addr: &SocketAddr,
537 ) -> usize {
538 let connection = self.create_pool_entry(config, addr);
539 let idx = self.connections.len();
540 self.connections.push(connection);
541 idx
542 }
543
544 fn num_connections(&self) -> usize {
545 self.connections.len()
546 }
547
548 fn get(
549 &self,
550 index: usize,
551 ) -> Result<Arc<Self::BaseClientConnection>, ConnectionPoolError> {
552 self.connections
553 .get(index)
554 .cloned()
555 .ok_or(ConnectionPoolError::IndexOutOfRange)
556 }
557
558 fn create_pool_entry(
559 &self,
560 config: &Self::NewConnectionConfig,
561 _addr: &SocketAddr,
562 ) -> Arc<Self::BaseClientConnection> {
563 Arc::new(MockUdp(config.udp_socket.clone()))
564 }
565 }
566
567 struct MockUdpConfig {
568 udp_socket: Arc<UdpSocket>,
569 }
570
571 impl Default for MockUdpConfig {
572 fn default() -> Self {
573 Self {
574 udp_socket: Arc::new(
575 solana_net_utils::bind_with_any_port_with_config(
576 IpAddr::V4(Ipv4Addr::UNSPECIFIED),
577 SocketConfig::default(),
578 )
579 .expect("Unable to bind to UDP socket"),
580 ),
581 }
582 }
583 }
584
585 impl NewConnectionConfig for MockUdpConfig {
586 fn new() -> Result<Self, ClientError> {
587 Ok(Self {
588 udp_socket: Arc::new(
589 solana_net_utils::bind_with_any_port_with_config(
590 IpAddr::V4(Ipv4Addr::UNSPECIFIED),
591 SocketConfig::default(),
592 )
593 .map_err(Into::<ClientError>::into)?,
594 ),
595 })
596 }
597 }
598
599 struct MockUdp(Arc<UdpSocket>);
600 impl BaseClientConnection for MockUdp {
601 type BlockingClientConnection = MockUdpConnection;
602 type NonblockingClientConnection = MockUdpConnection;
603
604 fn new_blocking_connection(
605 &self,
606 addr: SocketAddr,
607 _stats: Arc<ConnectionCacheStats>,
608 ) -> Arc<Self::BlockingClientConnection> {
609 Arc::new(MockUdpConnection {
610 _socket: self.0.clone(),
611 addr,
612 })
613 }
614
615 fn new_nonblocking_connection(
616 &self,
617 addr: SocketAddr,
618 _stats: Arc<ConnectionCacheStats>,
619 ) -> Arc<Self::NonblockingClientConnection> {
620 Arc::new(MockUdpConnection {
621 _socket: self.0.clone(),
622 addr,
623 })
624 }
625 }
626
627 struct MockUdpConnection {
628 _socket: Arc<UdpSocket>,
629 addr: SocketAddr,
630 }
631
632 #[derive(Default)]
633 struct MockConnectionManager {}
634
635 impl ConnectionManager for MockConnectionManager {
636 type ConnectionPool = MockUdpPool;
637 type NewConnectionConfig = MockUdpConfig;
638
639 const PROTOCOL: Protocol = Protocol::QUIC;
640
641 fn new_connection_pool(&self) -> Self::ConnectionPool {
642 MockUdpPool {
643 connections: Vec::default(),
644 }
645 }
646
647 fn new_connection_config(&self) -> Self::NewConnectionConfig {
648 MockUdpConfig::new().unwrap()
649 }
650
651 fn update_key(&self, _key: &Keypair) -> Result<(), Box<dyn std::error::Error>> {
652 Ok(())
653 }
654 }
655
656 impl BlockingClientConnection for MockUdpConnection {
657 fn server_addr(&self) -> &SocketAddr {
658 &self.addr
659 }
660 fn send_data(&self, _buffer: &[u8]) -> TransportResult<()> {
661 unimplemented!()
662 }
663 fn send_data_async(&self, _data: Vec<u8>) -> TransportResult<()> {
664 unimplemented!()
665 }
666 fn send_data_batch(&self, _buffers: &[Vec<u8>]) -> TransportResult<()> {
667 unimplemented!()
668 }
669 fn send_data_batch_async(&self, _buffers: Vec<Vec<u8>>) -> TransportResult<()> {
670 unimplemented!()
671 }
672 }
673
674 #[async_trait]
675 impl NonblockingClientConnection for MockUdpConnection {
676 fn server_addr(&self) -> &SocketAddr {
677 &self.addr
678 }
679 async fn send_data(&self, _data: &[u8]) -> TransportResult<()> {
680 unimplemented!()
681 }
682 async fn send_data_batch(&self, _buffers: &[Vec<u8>]) -> TransportResult<()> {
683 unimplemented!()
684 }
685 }
686
687 fn get_addr(rng: &mut ChaChaRng) -> SocketAddr {
688 let a = rng.gen_range(1..255);
689 let b = rng.gen_range(1..255);
690 let c = rng.gen_range(1..255);
691 let d = rng.gen_range(1..255);
692
693 let addr_str = format!("{a}.{b}.{c}.{d}:80");
694
695 addr_str.parse().expect("Invalid address")
696 }
697
698 #[test]
699 fn test_connection_cache() {
700 solana_logger::setup();
701 let mut rng = ChaChaRng::seed_from_u64(42);
707
708 let connection_manager = MockConnectionManager::default();
714 let connection_cache = ConnectionCache::new(
715 "connection_cache_test",
716 connection_manager,
717 DEFAULT_CONNECTION_POOL_SIZE,
718 )
719 .unwrap();
720 let addrs = (0..MAX_CONNECTIONS)
721 .map(|_| {
722 let addr = get_addr(&mut rng);
723 connection_cache.get_connection(&addr);
724 addr
725 })
726 .collect::<Vec<_>>();
727 {
728 let map = connection_cache.map.read().unwrap();
729 assert!(map.len() == MAX_CONNECTIONS);
730 addrs.iter().for_each(|addr| {
731 let conn = &map.get(addr).expect("Address not found").get(0).unwrap();
732 let conn = conn.new_blocking_connection(*addr, connection_cache.stats.clone());
733 assert_eq!(
734 BlockingClientConnection::server_addr(&*conn).ip(),
735 addr.ip(),
736 );
737 assert_eq!(
738 NonblockingClientConnection::server_addr(&*conn).ip(),
739 addr.ip(),
740 );
741 });
742 }
743
744 let addr = &get_addr(&mut rng);
745 connection_cache.get_connection(addr);
746
747 let port = addr.port();
748 let addr_with_quic_port = SocketAddr::new(addr.ip(), port);
749 let map = connection_cache.map.read().unwrap();
750 assert!(map.len() == MAX_CONNECTIONS);
751 let _conn = map.get(&addr_with_quic_port).expect("Address not found");
752 }
753
754 #[test]
758 fn test_overflow_address() {
759 let port = u16::MAX;
760 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
761 let connection_manager = MockConnectionManager::default();
762 let connection_cache =
763 ConnectionCache::new("connection_cache_test", connection_manager, 1).unwrap();
764
765 let conn = connection_cache.get_connection(&addr);
766 assert_ne!(port, 0u16);
770 assert_eq!(BlockingClientConnection::server_addr(&*conn).port(), port);
771 assert_eq!(
772 NonblockingClientConnection::server_addr(&*conn).port(),
773 port
774 );
775 }
776}